From 6a8902ad5106ccda803acd9139ae4039fb7572fd Mon Sep 17 00:00:00 2001 From: Stephen Seo Date: Tue, 7 Sep 2021 11:46:38 +0900 Subject: [PATCH] Allow ThreadPool to be created with < 2 ThreadCount --- src/EC/Manager.hpp | 85 +++++++++++------------ src/EC/ThreadPool.hpp | 130 ++++++++++++++++++++++-------------- src/test/ECTest.cpp | 27 ++++++++ src/test/ThreadPoolTest.cpp | 41 ++++++++++-- 4 files changed, 184 insertions(+), 99 deletions(-) diff --git a/src/EC/Manager.hpp b/src/EC/Manager.hpp index 4e5b57a..a4eb821 100644 --- a/src/EC/Manager.hpp +++ b/src/EC/Manager.hpp @@ -93,7 +93,7 @@ namespace EC std::size_t currentSize = 0; std::unordered_set deletedSet; - ThreadPool threadPool; + std::unique_ptr > threadPool; public: /*! @@ -105,6 +105,9 @@ namespace EC Manager() { resize(EC_INIT_ENTITIES_SIZE); + if(ThreadCount >= 2) { + threadPool = std::make_unique >(); + } } private: @@ -634,7 +637,7 @@ namespace EC BitsetType signatureBitset = BitsetType::template generateBitset(); - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -673,7 +676,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = &signatureBitset; std::get<3>(fnDataAr.at(i)) = {begin, end}; std::get<4>(fnDataAr.at(i)) = userData; - threadPool.queueFn([&function] (void *ud) { + threadPool->queueFn([&function] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<3>(*data).at(0); i < std::get<3>(*data).at(1); @@ -691,10 +694,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } @@ -753,7 +756,7 @@ namespace EC BitsetType signatureBitset = BitsetType::template generateBitset(); - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -792,7 +795,7 @@ namespace EC std::get<3>(fnDataAr.at(i)) = {begin, end}; std::get<4>(fnDataAr.at(i)) = userData; std::get<5>(fnDataAr.at(i)) = function; - threadPool.queueFn([] (void *ud) { + threadPool->queueFn([] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<3>(*data).at(0); i < std::get<3>(*data).at(1); @@ -810,10 +813,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } @@ -906,7 +909,7 @@ namespace EC std::vector matching, void* userData) { - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(auto eid : matching) { @@ -939,7 +942,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = {begin, end}; std::get<3>(fnDataAr.at(i)) = userData; std::get<4>(fnDataAr.at(i)) = &matching; - threadPool.queueFn([function, helper] (void* ud) { + threadPool->queueFn([function, helper] (void* ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<2>(*data).at(0); i < std::get<2>(*data).at(1); @@ -954,10 +957,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } }))); @@ -970,7 +973,7 @@ namespace EC { std::vector > matchingV(bitsets.size()); - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -1012,7 +1015,7 @@ namespace EC std::get<3>(fnDataAr.at(i)) = &bitsets; std::get<4>(fnDataAr.at(i)) = &entities; std::get<5>(fnDataAr.at(i)) = &mutex; - threadPool.queueFn([] (void *ud) { + threadPool->queueFn([] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<1>(*data).at(0); i < std::get<1>(*data).at(1); @@ -1033,10 +1036,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } return matchingV; @@ -1351,7 +1354,7 @@ namespace EC }); // find and store entities matching signatures - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(std::size_t eid = 0; eid < currentSize; ++eid) { @@ -1394,7 +1397,7 @@ namespace EC std::get<3>(fnDataAr.at(i)) = signatureBitsets; std::get<4>(fnDataAr.at(i)) = &mutex; - threadPool.queueFn([] (void *ud) { + threadPool->queueFn([] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<1>(*data).at(0); i < std::get<1>(*data).at(1); @@ -1412,10 +1415,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } // call functions on matching entities @@ -1432,7 +1435,7 @@ namespace EC EC::Meta::Morph< SignatureComponents, ForMatchingSignatureHelper<> >; - if(!useThreadPool) { + if(!useThreadPool || !threadPool) { for(const auto& id : multiMatchingEntities[index]) { if(isAlive(id)) { Helper::call(id, *this, func, userData); @@ -1459,7 +1462,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = {begin, end}; std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities; std::get<4>(fnDataAr.at(i)) = index; - threadPool.queueFn([&func] (void *ud) { + threadPool->queueFn([&func] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<2>(*data).at(0); i < std::get<2>(*data).at(1); @@ -1474,10 +1477,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } ); @@ -1544,7 +1547,7 @@ namespace EC }); // find and store entities matching signatures - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(std::size_t eid = 0; eid < currentSize; ++eid) { @@ -1587,7 +1590,7 @@ namespace EC std::get<3>(fnDataAr.at(i)) = signatureBitsets; std::get<4>(fnDataAr.at(i)) = &mutex; - threadPool.queueFn([] (void *ud) { + threadPool->queueFn([] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<1>(*data).at(0); i < std::get<1>(*data).at(1); @@ -1605,10 +1608,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } // call functions on matching entities @@ -1625,7 +1628,7 @@ namespace EC EC::Meta::Morph< SignatureComponents, ForMatchingSignatureHelper<> >; - if(!useThreadPool) + if(!useThreadPool || !threadPool) { for(const auto& id : multiMatchingEntities[index]) { @@ -1657,7 +1660,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = {begin, end}; std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities; std::get<4>(fnDataAr.at(i)) = index; - threadPool.queueFn([&func] (void *ud) { + threadPool->queueFn([&func] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<2>(*data).at(0); i < std::get<2>(*data).at(1); @@ -1672,10 +1675,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } ); @@ -1694,7 +1697,7 @@ namespace EC template void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, const bool useThreadPool = false) { const BitsetType signatureBitset = BitsetType::template generateBitset(); - if(!useThreadPool) { + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { if(!std::get(entities[i])) { continue; @@ -1723,7 +1726,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = &signatureBitset; std::get<3>(fnDataAr.at(i)) = {begin, end}; std::get<4>(fnDataAr.at(i)) = userData; - threadPool.queueFn([&fn] (void *ud) { + threadPool->queueFn([&fn] (void *ud) { auto *data = static_cast(ud); for(std::size_t i = std::get<3>(*data).at(0); i < std::get<3>(*data).at(1); @@ -1736,10 +1739,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } @@ -1754,7 +1757,7 @@ namespace EC */ template void forMatchingIterable(Iterable iterable, ForMatchingFn fn, void* userData = nullptr, const bool useThreadPool = false) { - if(!useThreadPool) { + if(!useThreadPool || !threadPool) { bool isValid; for(std::size_t i = 0; i < currentSize; ++i) { if(!std::get(entities[i])) { @@ -1793,7 +1796,7 @@ namespace EC std::get<2>(fnDataAr.at(i)) = &iterable; std::get<3>(fnDataAr.at(i)) = {begin, end}; std::get<4>(fnDataAr.at(i)) = userData; - threadPool.queueFn([&fn] (void *ud) { + threadPool->queueFn([&fn] (void *ud) { auto *data = static_cast(ud); bool isValid; for(std::size_t i = std::get<3>(*data).at(0); @@ -1816,10 +1819,10 @@ namespace EC } }, &fnDataAr.at(i)); } - threadPool.wakeThreads(); + threadPool->wakeThreads(); do { std::this_thread::sleep_for(std::chrono::microseconds(200)); - } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting()); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } }; diff --git a/src/EC/ThreadPool.hpp b/src/EC/ThreadPool.hpp index aee681d..c7d078e 100644 --- a/src/EC/ThreadPool.hpp +++ b/src/EC/ThreadPool.hpp @@ -20,64 +20,66 @@ namespace Internal { using TPQueueType = std::queue; } // namespace Internal -template -class ThreadPool; - template -class ThreadPool= 2)>::type> { +class ThreadPool { public: using THREADCOUNT = std::integral_constant; ThreadPool() : waitCount(0) { isAlive.store(true); - for(unsigned int i = 0; i < SIZE; ++i) { - threads.emplace_back([] (std::atomic_bool *isAlive, - std::condition_variable *cv, - std::mutex *cvMutex, - Internal::TPQueueType *fnQueue, - std::mutex *queueMutex, - int *waitCount, - std::mutex *waitCountMutex) { - bool hasFn = false; - Internal::TPTupleType fnTuple; - while(isAlive->load()) { - hasFn = false; - { - std::lock_guard lock(*queueMutex); - if(!fnQueue->empty()) { - fnTuple = fnQueue->front(); - fnQueue->pop(); - hasFn = true; + if(SIZE >= 2) { + for(unsigned int i = 0; i < SIZE; ++i) { + threads.emplace_back([] (std::atomic_bool *isAlive, + std::condition_variable *cv, + std::mutex *cvMutex, + Internal::TPQueueType *fnQueue, + std::mutex *queueMutex, + int *waitCount, + std::mutex *waitCountMutex) { + bool hasFn = false; + Internal::TPTupleType fnTuple; + while(isAlive->load()) { + hasFn = false; + { + std::lock_guard lock(*queueMutex); + if(!fnQueue->empty()) { + fnTuple = fnQueue->front(); + fnQueue->pop(); + hasFn = true; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + continue; + } + + { + std::lock_guard lock(*waitCountMutex); + *waitCount += 1; + } + { + std::unique_lock lock(*cvMutex); + cv->wait(lock); + } + { + std::lock_guard lock(*waitCountMutex); + *waitCount -= 1; } } - if(hasFn) { - std::get<0>(fnTuple)(std::get<1>(fnTuple)); - continue; - } - - { - std::lock_guard lock(*waitCountMutex); - *waitCount += 1; - } - { - std::unique_lock lock(*cvMutex); - cv->wait(lock); - } - { - std::lock_guard lock(*waitCountMutex); - *waitCount -= 1; - } - } - }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex); + }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, + &waitCountMutex); + } } } ~ThreadPool() { - isAlive.store(false); - std::this_thread::sleep_for(std::chrono::milliseconds(200)); - cv.notify_all(); - for(auto &thread : threads) { - thread.join(); + if(SIZE >= 2) { + isAlive.store(false); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + cv.notify_all(); + for(auto &thread : threads) { + thread.join(); + } } } @@ -87,10 +89,32 @@ public: } void wakeThreads(bool wakeAll = true) { - if(wakeAll) { - cv.notify_all(); + if(SIZE >= 2) { + // wake threads to pull functions from queue and run them + if(wakeAll) { + cv.notify_all(); + } else { + cv.notify_one(); + } } else { - cv.notify_one(); + // pull functions from queue and run them on main thread + Internal::TPTupleType fnTuple; + bool hasFn; + do { + { + std::lock_guard lock(queueMutex); + if(!fnQueue.empty()) { + hasFn = true; + fnTuple = fnQueue.front(); + fnQueue.pop(); + } else { + hasFn = false; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + } + } while(hasFn); } } @@ -100,8 +124,12 @@ public: } bool isAllThreadsWaiting() { - std::lock_guard lock(waitCountMutex); - return waitCount == THREADCOUNT::value; + if(SIZE >= 2) { + std::lock_guard lock(waitCountMutex); + return waitCount == THREADCOUNT::value; + } else { + return true; + } } bool isQueueEmpty() { diff --git a/src/test/ECTest.cpp b/src/test/ECTest.cpp index 3105903..75e65ae 100644 --- a/src/test/ECTest.cpp +++ b/src/test/ECTest.cpp @@ -1370,3 +1370,30 @@ TEST(EC, MultiThreadedForMatching) { EXPECT_TRUE(manager.isAlive(first)); EXPECT_FALSE(manager.isAlive(second)); } + +TEST(EC, ManagerWithLowThreadCount) { + EC::Manager manager; + + std::array entities; + for(auto &id : entities) { + id = manager.addEntity(); + manager.addComponent(id); + } + + for(const auto &id : entities) { + auto *component = manager.getEntityComponent(id); + EXPECT_EQ(component->x, 0); + EXPECT_EQ(component->y, 0); + } + + manager.forMatchingSignature >([] (std::size_t /*id*/, void* /*ud*/, C0 *c) { + c->x += 1; + c->y += 1; + }, nullptr, true); + + for(const auto &id : entities) { + auto *component = manager.getEntityComponent(id); + EXPECT_EQ(component->x, 1); + EXPECT_EQ(component->y, 1); + } +} diff --git a/src/test/ThreadPoolTest.cpp b/src/test/ThreadPoolTest.cpp index 2cec50c..31c8529 100644 --- a/src/test/ThreadPoolTest.cpp +++ b/src/test/ThreadPoolTest.cpp @@ -2,15 +2,42 @@ #include -//using OneThreadPool = EC::ThreadPool<1>; +using OneThreadPool = EC::ThreadPool<1>; using ThreeThreadPool = EC::ThreadPool<3>; -//TEST(ECThreadPool, CannotCompile) { -// OneThreadPool tp; -//} - -TEST(ECThreadPool, Simple) { - ThreeThreadPool p{}; +TEST(ECThreadPool, CannotCompile) { + OneThreadPool p; + std::atomic_int data; + data.store(0); + const auto fn = [](void *ud) { + auto *data = static_cast(ud); + data->fetch_add(1); + }; + + p.queueFn(fn, &data); + + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 1); + + for(unsigned int i = 0; i < 10; ++i) { + p.queueFn(fn, &data); + } + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 11); +} + +TEST(ECThreadPool, Simple) { + ThreeThreadPool p; std::atomic_int data; data.store(0); const auto fn = [](void *ud) {