Allow ThreadPool to be created with < 2 ThreadCount

This commit is contained in:
Stephen Seo 2021-09-07 11:46:38 +09:00
parent 16f410c8ef
commit 6a8902ad51
4 changed files with 184 additions and 99 deletions

View file

@ -93,7 +93,7 @@ namespace EC
std::size_t currentSize = 0;
std::unordered_set<std::size_t> deletedSet;
ThreadPool<ThreadCount> threadPool;
std::unique_ptr<ThreadPool<ThreadCount> > threadPool;
public:
/*!
@ -105,6 +105,9 @@ namespace EC
Manager()
{
resize(EC_INIT_ENTITIES_SIZE);
if(ThreadCount >= 2) {
threadPool = std::make_unique<ThreadPool<ThreadCount> >();
}
}
private:
@ -634,7 +637,7 @@ namespace EC
BitsetType signatureBitset =
BitsetType::template generateBitset<Signature>();
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(std::size_t i = 0; i < currentSize; ++i)
{
@ -673,7 +676,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = &signatureBitset;
std::get<3>(fnDataAr.at(i)) = {begin, end};
std::get<4>(fnDataAr.at(i)) = userData;
threadPool.queueFn([&function] (void *ud) {
threadPool->queueFn([&function] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<3>(*data).at(0);
i < std::get<3>(*data).at(1);
@ -691,10 +694,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
@ -753,7 +756,7 @@ namespace EC
BitsetType signatureBitset =
BitsetType::template generateBitset<Signature>();
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(std::size_t i = 0; i < currentSize; ++i)
{
@ -792,7 +795,7 @@ namespace EC
std::get<3>(fnDataAr.at(i)) = {begin, end};
std::get<4>(fnDataAr.at(i)) = userData;
std::get<5>(fnDataAr.at(i)) = function;
threadPool.queueFn([] (void *ud) {
threadPool->queueFn([] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<3>(*data).at(0);
i < std::get<3>(*data).at(1);
@ -810,10 +813,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
@ -906,7 +909,7 @@ namespace EC
std::vector<std::size_t> matching,
void* userData)
{
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(auto eid : matching)
{
@ -939,7 +942,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = {begin, end};
std::get<3>(fnDataAr.at(i)) = userData;
std::get<4>(fnDataAr.at(i)) = &matching;
threadPool.queueFn([function, helper] (void* ud) {
threadPool->queueFn([function, helper] (void* ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<2>(*data).at(0);
i < std::get<2>(*data).at(1);
@ -954,10 +957,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
})));
@ -970,7 +973,7 @@ namespace EC
{
std::vector<std::vector<std::size_t> > matchingV(bitsets.size());
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(std::size_t i = 0; i < currentSize; ++i)
{
@ -1012,7 +1015,7 @@ namespace EC
std::get<3>(fnDataAr.at(i)) = &bitsets;
std::get<4>(fnDataAr.at(i)) = &entities;
std::get<5>(fnDataAr.at(i)) = &mutex;
threadPool.queueFn([] (void *ud) {
threadPool->queueFn([] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<1>(*data).at(0);
i < std::get<1>(*data).at(1);
@ -1033,10 +1036,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
return matchingV;
@ -1351,7 +1354,7 @@ namespace EC
});
// find and store entities matching signatures
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(std::size_t eid = 0; eid < currentSize; ++eid)
{
@ -1394,7 +1397,7 @@ namespace EC
std::get<3>(fnDataAr.at(i)) = signatureBitsets;
std::get<4>(fnDataAr.at(i)) = &mutex;
threadPool.queueFn([] (void *ud) {
threadPool->queueFn([] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<1>(*data).at(0);
i < std::get<1>(*data).at(1);
@ -1412,10 +1415,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
// call functions on matching entities
@ -1432,7 +1435,7 @@ namespace EC
EC::Meta::Morph<
SignatureComponents,
ForMatchingSignatureHelper<> >;
if(!useThreadPool) {
if(!useThreadPool || !threadPool) {
for(const auto& id : multiMatchingEntities[index]) {
if(isAlive(id)) {
Helper::call(id, *this, func, userData);
@ -1459,7 +1462,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = {begin, end};
std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities;
std::get<4>(fnDataAr.at(i)) = index;
threadPool.queueFn([&func] (void *ud) {
threadPool->queueFn([&func] (void *ud) {
auto *data = static_cast<TPFnType*>(ud);
for(std::size_t i = std::get<2>(*data).at(0);
i < std::get<2>(*data).at(1);
@ -1474,10 +1477,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
);
@ -1544,7 +1547,7 @@ namespace EC
});
// find and store entities matching signatures
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(std::size_t eid = 0; eid < currentSize; ++eid)
{
@ -1587,7 +1590,7 @@ namespace EC
std::get<3>(fnDataAr.at(i)) = signatureBitsets;
std::get<4>(fnDataAr.at(i)) = &mutex;
threadPool.queueFn([] (void *ud) {
threadPool->queueFn([] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<1>(*data).at(0);
i < std::get<1>(*data).at(1);
@ -1605,10 +1608,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
// call functions on matching entities
@ -1625,7 +1628,7 @@ namespace EC
EC::Meta::Morph<
SignatureComponents,
ForMatchingSignatureHelper<> >;
if(!useThreadPool)
if(!useThreadPool || !threadPool)
{
for(const auto& id : multiMatchingEntities[index])
{
@ -1657,7 +1660,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = {begin, end};
std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities;
std::get<4>(fnDataAr.at(i)) = index;
threadPool.queueFn([&func] (void *ud) {
threadPool->queueFn([&func] (void *ud) {
auto *data = static_cast<TPFnType*>(ud);
for(std::size_t i = std::get<2>(*data).at(0);
i < std::get<2>(*data).at(1);
@ -1672,10 +1675,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
);
@ -1694,7 +1697,7 @@ namespace EC
template <typename Signature>
void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, const bool useThreadPool = false) {
const BitsetType signatureBitset = BitsetType::template generateBitset<Signature>();
if(!useThreadPool) {
if(!useThreadPool || !threadPool) {
for(std::size_t i = 0; i < currentSize; ++i) {
if(!std::get<bool>(entities[i])) {
continue;
@ -1723,7 +1726,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = &signatureBitset;
std::get<3>(fnDataAr.at(i)) = {begin, end};
std::get<4>(fnDataAr.at(i)) = userData;
threadPool.queueFn([&fn] (void *ud) {
threadPool->queueFn([&fn] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
for(std::size_t i = std::get<3>(*data).at(0);
i < std::get<3>(*data).at(1);
@ -1736,10 +1739,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
@ -1754,7 +1757,7 @@ namespace EC
*/
template <typename Iterable>
void forMatchingIterable(Iterable iterable, ForMatchingFn fn, void* userData = nullptr, const bool useThreadPool = false) {
if(!useThreadPool) {
if(!useThreadPool || !threadPool) {
bool isValid;
for(std::size_t i = 0; i < currentSize; ++i) {
if(!std::get<bool>(entities[i])) {
@ -1793,7 +1796,7 @@ namespace EC
std::get<2>(fnDataAr.at(i)) = &iterable;
std::get<3>(fnDataAr.at(i)) = {begin, end};
std::get<4>(fnDataAr.at(i)) = userData;
threadPool.queueFn([&fn] (void *ud) {
threadPool->queueFn([&fn] (void *ud) {
auto *data = static_cast<TPFnDataType*>(ud);
bool isValid;
for(std::size_t i = std::get<3>(*data).at(0);
@ -1816,10 +1819,10 @@ namespace EC
}
}, &fnDataAr.at(i));
}
threadPool.wakeThreads();
threadPool->wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
} while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
}
}
};

View file

@ -20,16 +20,14 @@ namespace Internal {
using TPQueueType = std::queue<TPTupleType>;
} // namespace Internal
template <unsigned int SIZE, typename = void>
class ThreadPool;
template <unsigned int SIZE>
class ThreadPool<SIZE, typename std::enable_if<(SIZE >= 2)>::type> {
class ThreadPool {
public:
using THREADCOUNT = std::integral_constant<int, SIZE>;
ThreadPool() : waitCount(0) {
isAlive.store(true);
if(SIZE >= 2) {
for(unsigned int i = 0; i < SIZE; ++i) {
threads.emplace_back([] (std::atomic_bool *isAlive,
std::condition_variable *cv,
@ -68,18 +66,22 @@ public:
*waitCount -= 1;
}
}
}, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex);
}, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount,
&waitCountMutex);
}
}
}
~ThreadPool() {
if(SIZE >= 2) {
isAlive.store(false);
std::this_thread::sleep_for(std::chrono::milliseconds(200));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
cv.notify_all();
for(auto &thread : threads) {
thread.join();
}
}
}
void queueFn(std::function<void(void*)>&& fn, void *ud = nullptr) {
std::lock_guard<std::mutex> lock(queueMutex);
@ -87,11 +89,33 @@ public:
}
void wakeThreads(bool wakeAll = true) {
if(SIZE >= 2) {
// wake threads to pull functions from queue and run them
if(wakeAll) {
cv.notify_all();
} else {
cv.notify_one();
}
} else {
// pull functions from queue and run them on main thread
Internal::TPTupleType fnTuple;
bool hasFn;
do {
{
std::lock_guard<std::mutex> lock(queueMutex);
if(!fnQueue.empty()) {
hasFn = true;
fnTuple = fnQueue.front();
fnQueue.pop();
} else {
hasFn = false;
}
}
if(hasFn) {
std::get<0>(fnTuple)(std::get<1>(fnTuple));
}
} while(hasFn);
}
}
int getWaitCount() {
@ -100,8 +124,12 @@ public:
}
bool isAllThreadsWaiting() {
if(SIZE >= 2) {
std::lock_guard<std::mutex> lock(waitCountMutex);
return waitCount == THREADCOUNT::value;
} else {
return true;
}
}
bool isQueueEmpty() {

View file

@ -1370,3 +1370,30 @@ TEST(EC, MultiThreadedForMatching) {
EXPECT_TRUE(manager.isAlive(first));
EXPECT_FALSE(manager.isAlive(second));
}
TEST(EC, ManagerWithLowThreadCount) {
EC::Manager<ListComponentsAll, ListTagsAll, 1> manager;
std::array<std::size_t, 10> entities;
for(auto &id : entities) {
id = manager.addEntity();
manager.addComponent<C0>(id);
}
for(const auto &id : entities) {
auto *component = manager.getEntityComponent<C0>(id);
EXPECT_EQ(component->x, 0);
EXPECT_EQ(component->y, 0);
}
manager.forMatchingSignature<EC::Meta::TypeList<C0> >([] (std::size_t /*id*/, void* /*ud*/, C0 *c) {
c->x += 1;
c->y += 1;
}, nullptr, true);
for(const auto &id : entities) {
auto *component = manager.getEntityComponent<C0>(id);
EXPECT_EQ(component->x, 1);
EXPECT_EQ(component->y, 1);
}
}

View file

@ -2,15 +2,42 @@
#include <EC/ThreadPool.hpp>
//using OneThreadPool = EC::ThreadPool<1>;
using OneThreadPool = EC::ThreadPool<1>;
using ThreeThreadPool = EC::ThreadPool<3>;
//TEST(ECThreadPool, CannotCompile) {
// OneThreadPool tp;
//}
TEST(ECThreadPool, Simple) {
ThreeThreadPool p{};
TEST(ECThreadPool, CannotCompile) {
OneThreadPool p;
std::atomic_int data;
data.store(0);
const auto fn = [](void *ud) {
auto *data = static_cast<std::atomic_int*>(ud);
data->fetch_add(1);
};
p.queueFn(fn, &data);
p.wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} while(!p.isQueueEmpty() && !p.isAllThreadsWaiting());
ASSERT_EQ(data.load(), 1);
for(unsigned int i = 0; i < 10; ++i) {
p.queueFn(fn, &data);
}
p.wakeThreads();
do {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} while(!p.isQueueEmpty() && !p.isAllThreadsWaiting());
ASSERT_EQ(data.load(), 11);
}
TEST(ECThreadPool, Simple) {
ThreeThreadPool p;
std::atomic_int data;
data.store(0);
const auto fn = [](void *ud) {