diff --git a/src/EC/Manager.hpp b/src/EC/Manager.hpp index 4c1c8e1..d82031d 100644 --- a/src/EC/Manager.hpp +++ b/src/EC/Manager.hpp @@ -114,7 +114,7 @@ namespace EC std::array range; Manager *manager; EntitiesType *entities; - const BitsetType *signature; + BitsetType signature; void *userData; std::unordered_set dead; }; @@ -779,7 +779,7 @@ namespace EC } else { - std::array fnDataAr; + std::array fnDataAr; std::size_t s = currentSize / (ThreadCount * 2); for(std::size_t i = 0; i < ThreadCount * 2; ++i) { @@ -793,14 +793,15 @@ namespace EC if(begin == end) { continue; } - fnDataAr[i].range = {begin, end}; - fnDataAr[i].manager = this; - fnDataAr[i].entities = &entities; - fnDataAr[i].signature = &signatureBitset; - fnDataAr[i].userData = userData; + fnDataAr[i] = new TPFnDataStructZero{}; + fnDataAr[i]->range = {begin, end}; + fnDataAr[i]->manager = this; + fnDataAr[i]->entities = &entities; + fnDataAr[i]->signature = signatureBitset; + fnDataAr[i]->userData = userData; for(std::size_t j = begin; j < end; ++j) { if(!isAlive(j)) { - fnDataAr[i].dead.insert(j); + fnDataAr[i]->dead.insert(j); } } @@ -812,17 +813,18 @@ namespace EC continue; } - if(((*data->signature) + if(((data->signature) & std::get( data->entities->at(i))) - == *data->signature) { + == data->signature) { Helper::call(i, *data->manager, std::forward(function), data->userData); } } - }, &fnDataAr[i]); + delete data; + }, fnDataAr[i]); } threadPool->easyWakeAndWait(); } @@ -1874,7 +1876,7 @@ namespace EC may not have as great of a speed-up. */ template - void forMatchingSimple(ForMatchingFn fn, + void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, const bool useThreadPool = false) { deferringDeletions.fetch_add(1); @@ -1891,7 +1893,7 @@ namespace EC } } } else { - std::array fnDataAr; + std::array fnDataAr; std::size_t s = currentSize / (ThreadCount * 2); for(std::size_t i = 0; i < ThreadCount * 2; ++i) { @@ -1905,14 +1907,15 @@ namespace EC if(begin == end) { continue; } - fnDataAr[i].range = {begin, end}; - fnDataAr[i].manager = this; - fnDataAr[i].entities = &entities; - fnDataAr[i].signature = &signatureBitset; - fnDataAr[i].userData = userData; + fnDataAr[i] = new TPFnDataStructZero{}; + fnDataAr[i]->range = {begin, end}; + fnDataAr[i]->manager = this; + fnDataAr[i]->entities = &entities; + fnDataAr[i]->signature = signatureBitset; + fnDataAr[i]->userData = userData; for(std::size_t j = begin; j < end; ++j) { if(!isAlive(j)) { - fnDataAr[i].dead.insert(j); + fnDataAr[i]->dead.insert(j); } } threadPool->queueFn([&fn] (void *ud) { @@ -1921,14 +1924,15 @@ namespace EC ++i) { if(data->dead.find(i) != data->dead.end()) { continue; - } else if((*data->signature + } else if((data->signature & std::get( data->entities->at(i))) - == *data->signature) { + == data->signature) { fn(i, data->manager, data->userData); } } - }, &fnDataAr[i]); + delete data; + }, fnDataAr[i]); } threadPool->easyWakeAndWait(); } diff --git a/src/EC/ThreadPool.hpp b/src/EC/ThreadPool.hpp index e0dd21f..3611954 100644 --- a/src/EC/ThreadPool.hpp +++ b/src/EC/ThreadPool.hpp @@ -11,6 +11,11 @@ #include #include #include +#include + +#ifndef NDEBUG +# include +#endif namespace EC { @@ -18,6 +23,39 @@ namespace Internal { using TPFnType = std::function; using TPTupleType = std::tuple; using TPQueueType = std::queue; + + template + void thread_fn(std::atomic_bool *isAlive, + std::condition_variable *cv, + std::mutex *cvMutex, + Internal::TPQueueType *fnQueue, + std::mutex *queueMutex, + std::atomic_int *waitCount) { + bool hasFn = false; + Internal::TPTupleType fnTuple; + while(isAlive->load()) { + hasFn = false; + { + std::lock_guard lock(*queueMutex); + if(!fnQueue->empty()) { + fnTuple = fnQueue->front(); + fnQueue->pop(); + hasFn = true; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + continue; + } + + waitCount->fetch_add(1); + { + std::unique_lock lock(*cvMutex); + cv->wait(lock); + } + waitCount->fetch_sub(1); + } + } } // namespace Internal /*! @@ -29,49 +67,20 @@ namespace Internal { template class ThreadPool { public: - ThreadPool() : waitCount(0) { + ThreadPool() { + waitCount.store(0); + extraThreadCount.store(0); isAlive.store(true); if(SIZE >= 2) { for(unsigned int i = 0; i < SIZE; ++i) { - threads.emplace_back([] (std::atomic_bool *isAlive, - std::condition_variable *cv, - std::mutex *cvMutex, - Internal::TPQueueType *fnQueue, - std::mutex *queueMutex, - int *waitCount, - std::mutex *waitCountMutex) { - bool hasFn = false; - Internal::TPTupleType fnTuple; - while(isAlive->load()) { - hasFn = false; - { - std::lock_guard lock(*queueMutex); - if(!fnQueue->empty()) { - fnTuple = fnQueue->front(); - fnQueue->pop(); - hasFn = true; - } - } - if(hasFn) { - std::get<0>(fnTuple)(std::get<1>(fnTuple)); - continue; - } - - { - std::lock_guard lock(*waitCountMutex); - *waitCount += 1; - } - { - std::unique_lock lock(*cvMutex); - cv->wait(lock); - } - { - std::lock_guard lock(*waitCountMutex); - *waitCount -= 1; - } - } - }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, - &waitCountMutex); + threads.emplace_back(Internal::thread_fn, + &isAlive, + &cv, + &cvMutex, + &fnQueue, + &queueMutex, + &waitCount); + threadsIDs.insert(threads.back().get_id()); } } } @@ -84,6 +93,7 @@ public: for(auto &thread : threads) { thread.join(); } + std::this_thread::sleep_for(std::chrono::milliseconds(20)); } } @@ -108,7 +118,7 @@ public: If SIZE is 2 or greater, then this function will return immediately after waking one or all threads, depending on the given boolean parameter. */ - void wakeThreads(bool wakeAll = true) { + void wakeThreads(const bool wakeAll = true) { if(SIZE >= 2) { // wake threads to pull functions from queue and run them if(wakeAll) { @@ -116,6 +126,36 @@ public: } else { cv.notify_one(); } + + // check if all threads are running a task, and spawn a new thread + // if this is the case + Internal::TPTupleType fnTuple; + bool hasFn = false; + if (waitCount.load(std::memory_order_relaxed) == 0) { + std::lock_guard queueLock(queueMutex); + if (!fnQueue.empty()) { + fnTuple = fnQueue.front(); + fnQueue.pop(); + hasFn = true; + } + } + + if (hasFn) { +#ifndef NDEBUG + std::cout << "Spawning extra thread...\n"; +#endif + extraThreadCount.fetch_add(1); + std::thread newThread = std::thread( + [] (Internal::TPTupleType &&tuple, std::atomic_int *count) { + std::get<0>(tuple)(std::get<1>(tuple)); +#ifndef NDEBUG + std::cout << "Stopping extra thread...\n"; +#endif + count->fetch_sub(1); + }, + std::move(fnTuple), &extraThreadCount); + newThread.detach(); + } } else { sequentiallyRunTasks(); } @@ -129,8 +169,7 @@ public: If SIZE is less than 2, then this will always return 0. */ int getWaitCount() { - std::lock_guard lock(waitCountMutex); - return waitCount; + return waitCount.load(std::memory_order_relaxed); } /*! @@ -140,8 +179,7 @@ public: */ bool isAllThreadsWaiting() { if(SIZE >= 2) { - std::lock_guard lock(waitCountMutex); - return waitCount == SIZE; + return waitCount.load(std::memory_order_relaxed) == SIZE; } else { return true; } @@ -173,10 +211,13 @@ public: */ void easyWakeAndWait() { if(SIZE >= 2) { - wakeThreads(); do { + wakeThreads(); std::this_thread::sleep_for(std::chrono::microseconds(150)); - } while(!isQueueEmpty() || !isAllThreadsWaiting()); + } while(!isQueueEmpty() + || (threadsIDs.find(std::this_thread::get_id()) != threadsIDs.end() + && extraThreadCount.load(std::memory_order_relaxed) != 0)); +// } while(!isQueueEmpty() || !isAllThreadsWaiting()); } else { sequentiallyRunTasks(); } @@ -184,13 +225,14 @@ public: private: std::vector threads; + std::unordered_set threadsIDs; std::atomic_bool isAlive; std::condition_variable cv; std::mutex cvMutex; Internal::TPQueueType fnQueue; std::mutex queueMutex; - int waitCount; - std::mutex waitCountMutex; + std::atomic_int waitCount; + std::atomic_int extraThreadCount; void sequentiallyRunTasks() { // pull functions from queue and run them on current thread diff --git a/src/test/ECTest.cpp b/src/test/ECTest.cpp index e0ff4bb..7852154 100644 --- a/src/test/ECTest.cpp +++ b/src/test/ECTest.cpp @@ -1,7 +1,9 @@ #include +#include #include +#include #include #include #include @@ -1431,3 +1433,33 @@ TEST(EC, ManagerDeferredDeletions) { } } } + +TEST(EC, NestedThreadPoolTasks) { + using ManagerType = EC::Manager; + ManagerType manager; + + std::array entities; + for (auto &entity : entities) { + entity = manager.addEntity(); + manager.addComponent(entity, entity, entity); + } + + manager.forMatchingSignature>([] (std::size_t id, void *data, C0 *c) { + ManagerType *manager = (ManagerType*)data; + + manager->forMatchingSignature>([id] (std::size_t inner_id, void* data, C0 *inner_c) { + const C0 *const outer_c = (C0*)data; + EXPECT_EQ(id, outer_c->x); + EXPECT_EQ(inner_id, inner_c->x); + if (id == inner_id) { + EXPECT_EQ(outer_c->x, inner_c->x); + EXPECT_EQ(outer_c->y, inner_c->y); + } else { + EXPECT_NE(outer_c->x, inner_c->x); + EXPECT_NE(outer_c->y, inner_c->y); + } + }, c, false); + }, &manager, true); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); +}