diff --git a/.gitignore b/.gitignore index a087dc2..46ea9c4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ build*/ doxygen_html/ compile_commands.json tags -.clangd/ +.cache/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 84ecbda..18d02e1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,13 +14,17 @@ set(EntityComponentSystem_HEADERS EC/Meta/Meta.hpp EC/Bitset.hpp EC/Manager.hpp - EC/EC.hpp) + EC/EC.hpp + EC/ThreadPool.hpp +) set(WillFailCompile_SOURCES test/WillFailCompileTest.cpp) +find_package(Threads REQUIRED) + add_library(EntityComponentSystem INTERFACE) -target_link_libraries(EntityComponentSystem INTERFACE pthread) +target_link_libraries(EntityComponentSystem INTERFACE ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(EntityComponentSystem INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -48,7 +52,9 @@ if(GTEST_FOUND) set(UnitTests_SOURCES test/MetaTest.cpp test/ECTest.cpp - test/Main.cpp) + test/ThreadPoolTest.cpp + test/Main.cpp + ) add_executable(UnitTests ${UnitTests_SOURCES}) target_link_libraries(UnitTests EntityComponentSystem ${GTEST_LIBRARIES}) diff --git a/src/EC/Manager.hpp b/src/EC/Manager.hpp index 3ae090c..22d757f 100644 --- a/src/EC/Manager.hpp +++ b/src/EC/Manager.hpp @@ -35,6 +35,8 @@ #include "Meta/ForEachDoubleTuple.hpp" #include "Bitset.hpp" +#include "ThreadPool.hpp" + namespace EC { /*! @@ -45,12 +47,18 @@ namespace EC Note that all components must have a default constructor. + An optional third template parameter may be given, which is the size of + the number of threads in the internal ThreadPool, and should be at + least 2. If ThreadCount is 1 or less, then the ThreadPool will not be + created and it will never be used, even if the "true" parameter is given + for functions that enable its usage. + Example: \code{.cpp} EC::Manager, TypeList> manager; \endcode */ - template + template struct Manager { public: @@ -82,6 +90,8 @@ namespace EC std::size_t currentSize = 0; std::unordered_set deletedSet; + std::unique_ptr > threadPool; + public: /*! \brief Initializes the manager with a default capacity. @@ -92,6 +102,9 @@ namespace EC Manager() { resize(EC_INIT_ENTITIES_SIZE); + if(ThreadCount >= 2) { + threadPool = std::make_unique >(); + } } private: @@ -463,16 +476,16 @@ namespace EC } } - /*! - \brief Removes the given Tag from the given Entity. + /*! + \brief Removes the given Tag from the given Entity. - If the Entity does not have the Tag given, nothing will change. + If the Entity does not have the Tag given, nothing will change. - Example: - \code{.cpp} - manager.removeTag(entityID); - \endcode - */ + Example: + \code{.cpp} + manager.removeTag(entityID); + \endcode + */ template void removeTag(const std::size_t& entityID) { @@ -518,11 +531,11 @@ namespace EC const std::size_t& entityID, CType& ctype, Function&& function, - void* context = nullptr) + void* userData = nullptr) { function( entityID, - context, + userData, ctype.template getEntityData(entityID)... ); } @@ -532,11 +545,11 @@ namespace EC const std::size_t& entityID, CType& ctype, Function* function, - void* context = nullptr) + void* userData = nullptr) { (*function)( entityID, - context, + userData, ctype.template getEntityData(entityID)... ); } @@ -546,13 +559,13 @@ namespace EC const std::size_t& entityID, CType& ctype, Function&& function, - void* context = nullptr) const + void* userData = nullptr) const { ForMatchingSignatureHelper::call( entityID, ctype, std::forward(function), - context); + userData); } template @@ -560,13 +573,13 @@ namespace EC const std::size_t& entityID, CType& ctype, Function* function, - void* context = nullptr) const + void* userData = nullptr) const { ForMatchingSignatureHelper::callPtr( entityID, ctype, function, - context); + userData); } }; @@ -583,14 +596,15 @@ namespace EC The second parameter is default nullptr and will be passed to the function call as the second parameter as a means of providing - context (useful when the function is not a lambda function). The - third parameter is default 1 (not multi-threaded). If the third - parameter threadCount is set to a value greater than 1, then - threadCount threads will be used. Note that multi-threading is - based on splitting the task of calling the function across sections - of entities. Thus if there are only a small amount of entities in - the manager, then using multiple threads may not have as great of a - speed-up. + context (useful when the function is not a lambda function). + + The third parameter is default false (not multi-threaded). + Otherwise, if true, then the thread pool will be used to call the + given function in parallel across all entities. Note that + multi-threading is based on splitting the task of calling the + function across sections of entities. Thus if there are only a small + amount of entities in the manager, then using multiple threads may + not have as great of a speed-up. Example: \code{.cpp} @@ -611,8 +625,8 @@ namespace EC */ template void forMatchingSignature(Function&& function, - void* context = nullptr, - std::size_t threadCount = 1) + void* userData = nullptr, + const bool useThreadPool = false) { using SignatureComponents = typename EC::Meta::Matching::type; @@ -623,7 +637,7 @@ namespace EC BitsetType signatureBitset = BitsetType::template generateBitset(); - if(threadCount <= 1) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -636,53 +650,54 @@ namespace EC == signatureBitset) { Helper::call(i, *this, - std::forward(function), context); + std::forward(function), userData); } } } else { - std::vector threads(threadCount); - std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + using TPFnDataType = std::tuple, void*>; + std::array fnDataAr; + + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = currentSize; - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &function, &signatureBitset, &context] - (std::size_t begin, - std::size_t end) { - for(std::size_t i = begin; i < end; ++i) - { - if(!std::get(this->entities[i])) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = &entities; + std::get<2>(fnDataAr.at(i)) = &signatureBitset; + std::get<3>(fnDataAr.at(i)) = {begin, end}; + std::get<4>(fnDataAr.at(i)) = userData; + threadPool->queueFn([&function] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<3>(*data).at(0); + i < std::get<3>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { continue; } - if((signatureBitset - & std::get(entities[i])) - == signatureBitset) - { - Helper::call(i, *this, - std::forward(function), context); + if((*std::get<2>(*data) + & std::get( + std::get<1>(*data)->at(i))) + == *std::get<2>(*data)) { + Helper::call(i, *std::get<0>(*data), std::forward(function), std::get<4>(*data)); } } - }, - begin, - end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } @@ -698,14 +713,15 @@ namespace EC The second parameter is default nullptr and will be passed to the function call as the second parameter as a means of providing - context (useful when the function is not a lambda function). The - third parameter is default 1 (not multi-threaded). If the third - parameter threadCount is set to a value greater than 1, then - threadCount threads will be used. Note that multi-threading is based - on splitting the task of calling the function across sections of - entities. Thus if there are only a small amount of entities in the - manager, then using multiple threads may not have as great of a - speed-up. + context (useful when the function is not a lambda function). + + The third parameter is default false (not multi-threaded). + Otherwise, if true, then the thread pool will be used to call the + given function in parallel across all entities. Note that + multi-threading is based on splitting the task of calling the + function across sections of entities. Thus if there are only a small + amount of entities in the manager, then using multiple threads may + not have as great of a speed-up. Example: \code{.cpp} @@ -728,8 +744,8 @@ namespace EC */ template void forMatchingSignaturePtr(Function* function, - void* context = nullptr, - std::size_t threadCount = 1) + void* userData = nullptr, + const bool useThreadPool = false) { using SignatureComponents = typename EC::Meta::Matching::type; @@ -740,7 +756,7 @@ namespace EC BitsetType signatureBitset = BitsetType::template generateBitset(); - if(threadCount <= 1) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -752,52 +768,55 @@ namespace EC if((signatureBitset & std::get(entities[i])) == signatureBitset) { - Helper::callPtr(i, *this, function, context); + Helper::callPtr(i, *this, function, userData); } } } else { - std::vector threads(threadCount); - std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + using TPFnDataType = std::tuple, void*, Function*>; + std::array fnDataAr; + + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = currentSize; - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &function, &signatureBitset, &context] - (std::size_t begin, - std::size_t end) { - for(std::size_t i = begin; i < end; ++i) - { - if(!std::get(this->entities[i])) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = &entities; + std::get<2>(fnDataAr.at(i)) = &signatureBitset; + std::get<3>(fnDataAr.at(i)) = {begin, end}; + std::get<4>(fnDataAr.at(i)) = userData; + std::get<5>(fnDataAr.at(i)) = function; + threadPool->queueFn([] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<3>(*data).at(0); + i < std::get<3>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { continue; } - if((signatureBitset - & std::get(entities[i])) - == signatureBitset) - { - Helper::callPtr(i, *this, function, context); + if((*std::get<2>(*data) + & std::get( + std::get<1>(*data)->at(i))) + == *std::get<2>(*data)) { + Helper::callPtr(i, *std::get<0>(*data), std::get<5>(*data), std::get<4>(*data)); } } - }, - begin, - end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } @@ -861,7 +880,7 @@ namespace EC template std::size_t addForMatchingFunction( Function&& function, - void* context = nullptr) + void* userData = nullptr) { while(forMatchingFunctions.find(functionIndex) != forMatchingFunctions.end()) @@ -884,55 +903,64 @@ namespace EC functionIndex, std::make_tuple( signatureBitset, - context, + userData, [function, helper, this] - (std::size_t threadCount, + (const bool useThreadPool, std::vector matching, - void* context) + void* userData) { - if(threadCount <= 1 || matching.size() < threadCount) + if(!useThreadPool || !threadPool) { for(auto eid : matching) { if(isAlive(eid)) { helper.callInstancePtr( - eid, *this, &function, context); + eid, *this, &function, userData); } } } else { - std::vector threads(threadCount); - std::size_t s = matching.size() / threadCount; - for(std::size_t i = 0; i < threadCount; ++ i) - { + using TPFnDataType = std::tuple, void*, const std::vector*>; + std::array fnDataAr; + + std::size_t s = matching.size() / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) { + if(i == ThreadCount - 1) { end = matching.size(); } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &function, &helper, &context, &matching] - (std::size_t begin, - std::size_t end) { - for(std::size_t j = begin; j < end; ++j) - { - if(isAlive(matching[j])) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = &entities; + std::get<2>(fnDataAr.at(i)) = {begin, end}; + std::get<3>(fnDataAr.at(i)) = userData; + std::get<4>(fnDataAr.at(i)) = &matching; + threadPool->queueFn([function, helper] (void* ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<2>(*data).at(0); + i < std::get<2>(*data).at(1); + ++i) { + if(std::get<0>(*data)->isAlive(std::get<4>(*data)->at(i))) { helper.callInstancePtr( - matching[j], *this, &function, context); + std::get<4>(*data)->at(i), + *std::get<0>(*data), + &function, + std::get<3>(*data)); } } - }, - begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } }))); @@ -941,11 +969,11 @@ namespace EC private: std::vector > getMatchingEntities( - std::vector bitsets, std::size_t threadCount = 1) + std::vector bitsets, const bool useThreadPool = false) { std::vector > matchingV(bitsets.size()); - if(threadCount <= 1 || currentSize <= threadCount) + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { @@ -965,64 +993,53 @@ namespace EC } else { - std::vector threads(threadCount); - std::size_t s = currentSize / threadCount; - std::mutex mutex; + using TPFnDataType = std::tuple, std::vector >*, const std::vector*, EntitiesType*, std::mutex*>; + std::array fnDataAr; - if(s == 0) { - for(std::size_t i = 0; i < currentSize; ++i) { - threads[i] = std::thread( - [this, &matchingV, &bitsets, &mutex] (std::size_t idx) { - if(!isAlive(idx)) { - return; + std::size_t s = currentSize / ThreadCount; + std::mutex mutex; + for(std::size_t i = 0; i < ThreadCount; ++i) { + std::size_t begin = s * i; + std::size_t end; + if(i == ThreadCount - 1) { + end = currentSize; + } else { + end = s * (i + 1); + } + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = {begin, end}; + std::get<2>(fnDataAr.at(i)) = &matchingV; + std::get<3>(fnDataAr.at(i)) = &bitsets; + std::get<4>(fnDataAr.at(i)) = &entities; + std::get<5>(fnDataAr.at(i)) = &mutex; + threadPool->queueFn([] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<1>(*data).at(0); + i < std::get<1>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { + continue; } - for(std::size_t k = 0; k < bitsets.size(); ++k) - { - if(((*bitsets[k]) & - std::get(entities[idx])) - == (*bitsets[k])) - { - std::lock_guard guard(mutex); - matchingV[k].push_back(idx); + for(std::size_t j = 0; + j < std::get<3>(*data)->size(); + ++j) { + if(((*std::get<3>(*data)->at(j)) + & std::get(std::get<4>(*data)->at(i))) + == (*std::get<3>(*data)->at(j))) { + std::lock_guard lock(*std::get<5>(*data)); + std::get<2>(*data)->at(j).push_back(i); } } - }, i); - } - for(std::size_t i = 0; i < currentSize; ++i) { - threads[i].join(); - } - } else { - for (std::size_t i = 0; i < threadCount; ++i) { - std::size_t begin = s * i; - std::size_t end; - if (i == threadCount - 1) { - end = currentSize; - } else { - end = s * (i + 1); } - - threads[i] = std::thread( - [this, &matchingV, &bitsets, &mutex] - (std::size_t begin, std::size_t end) { - for (std::size_t j = begin; j < end; ++j) { - if (!isAlive(j)) { - continue; - } - for (std::size_t k = 0; k < bitsets.size(); ++k) { - if (((*bitsets[k]) & - std::get(entities[j])) - == (*bitsets[k])) { - std::lock_guard guard(mutex); - matchingV[k].push_back(j); - } - } - } - }, begin, end); - } - for (std::size_t i = 0; i < threadCount; ++i) { - threads[i].join(); - } + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } return matchingV; @@ -1033,12 +1050,13 @@ namespace EC /*! \brief Call all stored functions. - The first (and only) parameter can be optionally used to specify the - number of threads to use when calling the functions. Otherwise, this - function is by default not multi-threaded. - Note that multi-threading is based on splitting the task of calling - the functions across sections of entities. Thus if there are only - a small amount of entities in the manager, then using multiple + The first (and only) parameter can be optionally used to enable the + use of the internal ThreadPool to call all stored functions in + parallel. Using the value false (which is the default) will not use + the ThreadPool and run all stored functions sequentially on the main + thread. Note that multi-threading is based on splitting the task of + calling the functions across sections of entities. Thus if there are + only a small amount of entities in the manager, then using multiple threads may not have as great of a speed-up. Example: @@ -1060,7 +1078,7 @@ namespace EC manager.clearForMatchingFunctions(); \endcode */ - void callForMatchingFunctions(std::size_t threadCount = 1) + void callForMatchingFunctions(const bool useThreadPool = false) { std::vector bitsets; for(auto iter = forMatchingFunctions.begin(); @@ -1071,7 +1089,7 @@ namespace EC } std::vector > matching = - getMatchingEntities(bitsets, threadCount); + getMatchingEntities(bitsets, useThreadPool); std::size_t i = 0; for(auto iter = forMatchingFunctions.begin(); @@ -1079,20 +1097,21 @@ namespace EC ++iter) { std::get<2>(iter->second)( - threadCount, matching[i++], std::get<1>(iter->second)); + useThreadPool, matching[i++], std::get<1>(iter->second)); } } /*! \brief Call a specific stored function. - A second parameter can be optionally used to specify the number - of threads to use when calling the function. Otherwise, this - function is by default not multi-threaded. - Note that multi-threading is based on splitting the task of calling - the function across sections of entities. Thus if there are only - a small amount of entities in the manager, then using multiple - threads may not have as great of a speed-up. + The second parameter can be optionally used to enable the use of the + internal ThreadPool to call the stored function in parallel. Using + the value false (which is the default) will not use the ThreadPool + and run the stored function sequentially on the main thread. Note + that multi-threading is based on splitting the task of calling the + functions across sections of entities. Thus if there are only a + small amount of entities in the manager, then using multiple threads + may not have as great of a speed-up. Example: \code{.cpp} @@ -1112,7 +1131,7 @@ namespace EC \return False if a function with the given id does not exist. */ bool callForMatchingFunction(std::size_t id, - std::size_t threadCount = 1) + const bool useThreadPool = false) { auto iter = forMatchingFunctions.find(id); if(iter == forMatchingFunctions.end()) @@ -1121,9 +1140,9 @@ namespace EC } std::vector > matching = getMatchingEntities(std::vector{ - &std::get(iter->second)}, threadCount); + &std::get(iter->second)}, useThreadPool); std::get<2>(iter->second)( - threadCount, matching[0], std::get<1>(iter->second)); + useThreadPool, matching[0], std::get<1>(iter->second)); return true; } @@ -1265,12 +1284,12 @@ namespace EC \return True if id is valid and context was updated */ - bool changeForMatchingFunctionContext(std::size_t id, void* context) + bool changeForMatchingFunctionContext(std::size_t id, void* userData) { auto f = forMatchingFunctions.find(id); if(f != forMatchingFunctions.end()) { - std::get<1>(f->second) = context; + std::get<1>(f->second) = userData; return true; } return false; @@ -1300,6 +1319,14 @@ namespace EC The second parameter (default nullptr) will be provided to every function call as a void* (context). + The third parameter is default false (not multi-threaded). + Otherwise, if true, then the thread pool will be used to call the + given function in parallel across all entities. Note that + multi-threading is based on splitting the task of calling the + function across sections of entities. Thus if there are only a small + amount of entities in the manager, then using multiple threads may + not have as great of a speed-up. + This function was created for the use case where there are many entities in the system which can cause multiple calls to forMatchingSignature to be slow due to the overhead of iterating @@ -1319,11 +1346,11 @@ namespace EC template void forMatchingSignatures( FTuple fTuple, - void* context = nullptr, - const std::size_t threadCount = 1) + void* userData = nullptr, + const bool useThreadPool = false) { - std::vector > multiMatchingEntities( - SigList::size); + std::vector > + multiMatchingEntities(SigList::size); BitsetType signatureBitsets[SigList::size]; // generate bitsets for each signature @@ -1335,7 +1362,7 @@ namespace EC }); // find and store entities matching signatures - if(threadCount <= 1) + if(!useThreadPool || !threadPool) { for(std::size_t eid = 0; eid < currentSize; ++eid) { @@ -1356,56 +1383,57 @@ namespace EC } else { - std::vector threads(threadCount); - std::mutex mutexes[SigList::size]; - std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + using TPFnDataType = std::tuple, std::vector >*, BitsetType*, std::mutex*>; + std::array fnDataAr; + + std::mutex mutex; + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = currentSize; - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &mutexes, &multiMatchingEntities, &signatureBitsets] - (std::size_t begin, std::size_t end) - { - for(std::size_t j = begin; j < end; ++j) - { - if(!isAlive(j)) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = {begin, end}; + std::get<2>(fnDataAr.at(i)) = &multiMatchingEntities; + std::get<3>(fnDataAr.at(i)) = signatureBitsets; + std::get<4>(fnDataAr.at(i)) = &mutex; + + threadPool->queueFn([] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<1>(*data).at(0); + i < std::get<1>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { continue; } - for(std::size_t k = 0; k < SigList::size; ++k) - { - if((signatureBitsets[k] - & std::get(entities[j])) - == signatureBitsets[k]) - { - std::lock_guard guard( - mutexes[k]); - multiMatchingEntities[k].push_back(j); + for(std::size_t j = 0; j < SigList::size; ++j) { + if((std::get<3>(*data)[j] & std::get(std::get<0>(*data)->entities[i])) + == std::get<3>(*data)[j]) { + std::lock_guard lock(*std::get<4>(*data)); + std::get<2>(*data)->at(j).push_back(i); } } } - }, begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } // call functions on matching entities EC::Meta::forEachDoubleTuple( EC::Meta::Morph >{}, fTuple, - [this, &multiMatchingEntities, &threadCount, &context] + [this, &multiMatchingEntities, useThreadPool, &userData] (auto sig, auto func, auto index) { using SignatureComponents = @@ -1415,56 +1443,52 @@ namespace EC EC::Meta::Morph< SignatureComponents, ForMatchingSignatureHelper<> >; - if(threadCount <= 1) - { - for(const auto& id : multiMatchingEntities[index]) - { - if(isAlive(id)) - { - Helper::call(id, *this, func, context); + if(!useThreadPool || !threadPool) { + for(const auto& id : multiMatchingEntities[index]) { + if(isAlive(id)) { + Helper::call(id, *this, func, userData); } } - } - else - { - std::vector threads(threadCount); + } else { + using TPFnType = std::tuple, std::vector > *, std::size_t>; + std::array fnDataAr; std::size_t s = multiMatchingEntities[index].size() - / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + / ThreadCount; + for(unsigned int i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = multiMatchingEntities[index].size(); - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &multiMatchingEntities, &index, &func, - &context] - (std::size_t begin, std::size_t end) - { - for(std::size_t j = begin; j < end; - ++j) - { - if(isAlive(multiMatchingEntities[index][j])) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = userData; + std::get<2>(fnDataAr.at(i)) = {begin, end}; + std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities; + std::get<4>(fnDataAr.at(i)) = index; + threadPool->queueFn([&func] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<2>(*data).at(0); + i < std::get<2>(*data).at(1); + ++i) { + if(std::get<0>(*data)->isAlive(std::get<3>(*data)->at(std::get<4>(*data)).at(i))) { Helper::call( - multiMatchingEntities[index][j], - *this, + std::get<3>(*data)->at(std::get<4>(*data)).at(i), + *std::get<0>(*data), func, - context); + std::get<1>(*data)); } } - }, begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } ); @@ -1497,6 +1521,14 @@ namespace EC The second parameter (default nullptr) will be provided to every function call as a void* (context). + The third parameter is default false (not multi-threaded). + Otherwise, if true, then the thread pool will be used to call the + given function in parallel across all entities. Note that + multi-threading is based on splitting the task of calling the + function across sections of entities. Thus if there are only a small + amount of entities in the manager, then using multiple threads may + not have as great of a speed-up. + This function was created for the use case where there are many entities in the system which can cause multiple calls to forMatchingSignature to be slow due to the overhead of iterating @@ -1515,8 +1547,8 @@ namespace EC */ template void forMatchingSignaturesPtr(FTuple fTuple, - void* context = nullptr, - std::size_t threadCount = 1) + void* userData = nullptr, + const bool useThreadPool = false) { std::vector > multiMatchingEntities( SigList::size); @@ -1531,7 +1563,7 @@ namespace EC }); // find and store entities matching signatures - if(threadCount <= 1) + if(!useThreadPool || !threadPool) { for(std::size_t eid = 0; eid < currentSize; ++eid) { @@ -1552,56 +1584,57 @@ namespace EC } else { - std::vector threads(threadCount); - std::mutex mutexes[SigList::size]; - std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + using TPFnDataType = std::tuple, std::vector >*, BitsetType*, std::mutex*>; + std::array fnDataAr; + + std::mutex mutex; + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = currentSize; - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &mutexes, &multiMatchingEntities, &signatureBitsets] - (std::size_t begin, std::size_t end) - { - for(std::size_t j = begin; j < end; ++j) - { - if(!isAlive(j)) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = {begin, end}; + std::get<2>(fnDataAr.at(i)) = &multiMatchingEntities; + std::get<3>(fnDataAr.at(i)) = signatureBitsets; + std::get<4>(fnDataAr.at(i)) = &mutex; + + threadPool->queueFn([] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<1>(*data).at(0); + i < std::get<1>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { continue; } - for(std::size_t k = 0; k < SigList::size; ++k) - { - if((signatureBitsets[k] - & std::get(entities[j])) - == signatureBitsets[k]) - { - std::lock_guard guard( - mutexes[k]); - multiMatchingEntities[k].push_back(j); + for(std::size_t j = 0; j < SigList::size; ++j) { + if((std::get<3>(*data)[j] & std::get(std::get<0>(*data)->entities[i])) + == std::get<3>(*data)[j]) { + std::lock_guard lock(*std::get<4>(*data)); + std::get<2>(*data)->at(j).push_back(i); } } } - }, begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } // call functions on matching entities EC::Meta::forEachDoubleTuple( EC::Meta::Morph >{}, fTuple, - [this, &multiMatchingEntities, &threadCount, &context] + [this, &multiMatchingEntities, useThreadPool, &userData] (auto sig, auto func, auto index) { using SignatureComponents = @@ -1611,56 +1644,57 @@ namespace EC EC::Meta::Morph< SignatureComponents, ForMatchingSignatureHelper<> >; - if(threadCount <= 1) + if(!useThreadPool || !threadPool) { for(const auto& id : multiMatchingEntities[index]) { if(isAlive(id)) { - Helper::callPtr(id, *this, func, context); + Helper::callPtr(id, *this, func, userData); } } } else { - std::vector threads(threadCount); + using TPFnType = std::tuple, std::vector > *, std::size_t>; + std::array fnDataAr; std::size_t s = multiMatchingEntities[index].size() - / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) - { + / ThreadCount; + for(unsigned int i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; std::size_t end; - if(i == threadCount - 1) - { + if(i == ThreadCount - 1) { end = multiMatchingEntities[index].size(); - } - else - { + } else { end = s * (i + 1); } - threads[i] = std::thread( - [this, &multiMatchingEntities, &index, &func, - &context] - (std::size_t begin, std::size_t end) - { - for(std::size_t j = begin; j < end; - ++j) - { - if(isAlive(multiMatchingEntities[index][j])) - { + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = userData; + std::get<2>(fnDataAr.at(i)) = {begin, end}; + std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities; + std::get<4>(fnDataAr.at(i)) = index; + threadPool->queueFn([&func] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<2>(*data).at(0); + i < std::get<2>(*data).at(1); + ++i) { + if(std::get<0>(*data)->isAlive(std::get<3>(*data)->at(std::get<4>(*data)).at(i))) { Helper::callPtr( - multiMatchingEntities[index][j], - *this, + std::get<3>(*data)->at(std::get<4>(*data)).at(i), + *std::get<0>(*data), func, - context); + std::get<1>(*data)); } } - }, begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } ); @@ -1669,17 +1703,26 @@ namespace EC typedef void ForMatchingFn(std::size_t, Manager*, void*); /*! - * \brief A simple version of forMatchingSignature() - * - * This function behaves like forMatchingSignature(), but instead of - * providing a function with each requested component as a parameter, - * the function receives a pointer to the manager itself, with which to - * query component/tag data. + \brief A simple version of forMatchingSignature() + + This function behaves like forMatchingSignature(), but instead of + providing a function with each requested component as a parameter, + the function receives a pointer to the manager itself, with which to + query component/tag data. + + The third parameter can be optionally used to enable the use of the + internal ThreadPool to call the function in parallel. Using the + value false (which is the default) will not use the ThreadPool and + run the function sequentially on all entities on the main thread. + Note that multi-threading is based on splitting the task of calling + the functions across sections of entities. Thus if there are only a + small amount of entities in the manager, then using multiple threads + may not have as great of a speed-up. */ template - void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, std::size_t threadCount = 1) { + void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, const bool useThreadPool = false) { const BitsetType signatureBitset = BitsetType::template generateBitset(); - if(threadCount <= 1) { + if(!useThreadPool || !threadPool) { for(std::size_t i = 0; i < currentSize; ++i) { if(!std::get(entities[i])) { continue; @@ -1688,52 +1731,67 @@ namespace EC } } } else { - std::vector threads(threadCount); - const std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) { - const std::size_t begin = s * i; - const std::size_t end = - i == threadCount - 1 ? - currentSize : - s * (i + 1); - threads[i] = std::thread( - [this] (const std::size_t begin, - const std::size_t end, - const BitsetType signatureBitset, - ForMatchingFn fn, - void *userData) { - for(std::size_t i = begin; i < end; ++i) { - if(!std::get(entities[i])) { - continue; - } else if((signatureBitset & std::get(entities[i])) == signatureBitset) { - fn(i, this, userData); - } + using TPFnDataType = std::tuple, void*>; + std::array fnDataAr; + + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { + std::size_t begin = s * i; + std::size_t end; + if(i == ThreadCount - 1) { + end = currentSize; + } else { + end = s * (i + 1); + } + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = &entities; + std::get<2>(fnDataAr.at(i)) = &signatureBitset; + std::get<3>(fnDataAr.at(i)) = {begin, end}; + std::get<4>(fnDataAr.at(i)) = userData; + threadPool->queueFn([&fn] (void *ud) { + auto *data = static_cast(ud); + for(std::size_t i = std::get<3>(*data).at(0); + i < std::get<3>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { + continue; + } else if((*std::get<2>(*data) & std::get(std::get<1>(*data)->at(i))) == *std::get<2>(*data)) { + fn(i, std::get<0>(*data), std::get<4>(*data)); } - }, - begin, - end, - signatureBitset, - fn, - userData); - } - for(std::size_t i = 0; i < threadCount; ++i) { - threads[i].join(); + } + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } /*! - * \brief Similar to forMatchingSimple(), but with a collection of Component/Tag indices - * - * This function works like forMatchingSimple(), but instead of - * providing template types that filter out non-matching entities, an - * iterable of indices must be provided which correlate to matching - * Component/Tag indices. The function given must match the previously - * defined typedef of type ForMatchingFn. + \brief Similar to forMatchingSimple(), but with a collection of Component/Tag indices + + This function works like forMatchingSimple(), but instead of + providing template types that filter out non-matching entities, an + iterable of indices must be provided which correlate to matching + Component/Tag indices. The function given must match the previously + defined typedef of type ForMatchingFn. + + The fourth parameter can be optionally used to enable the use of the + internal ThreadPool to call the function in parallel. Using the + value false (which is the default) will not use the ThreadPool and + run the function sequentially on all entities on the main thread. + Note that multi-threading is based on splitting the task of calling + the functions across sections of entities. Thus if there are only a + small amount of entities in the manager, then using multiple threads + may not have as great of a speed-up. */ template - void forMatchingIterable(Iterable iterable, ForMatchingFn fn, void* userPtr = nullptr, std::size_t threadCount = 1) { - if(threadCount <= 1) { + void forMatchingIterable(Iterable iterable, ForMatchingFn fn, void* userData = nullptr, const bool useThreadPool = false) { + if(!useThreadPool || !threadPool) { bool isValid; for(std::size_t i = 0; i < currentSize; ++i) { if(!std::get(entities[i])) { @@ -1749,42 +1807,56 @@ namespace EC } if(!isValid) { continue; } - fn(i, this, userPtr); + fn(i, this, userData); } } else { - std::vector threads(threadCount); - std::size_t s = currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++i) { + using TPFnDataType = std::tuple, void*>; + std::array fnDataAr; + + std::size_t s = currentSize / ThreadCount; + for(std::size_t i = 0; i < ThreadCount; ++i) { std::size_t begin = s * i; - std::size_t end = - i == threadCount - 1 ? - currentSize : - s * (i + 1); - threads[i] = std::thread( - [this, &fn, &iterable, userPtr] (std::size_t begin, std::size_t end) { - bool isValid; - for(std::size_t i = begin; i < end; ++i) { - if(!std::get(this->entities[i])) { - continue; - } - - isValid = true; - for(const auto& integralValue : iterable) { - if(!std::get(entities[i]).getCombinedBit(integralValue)) { - isValid = false; - break; - } - } - if(!isValid) { continue; } - - fn(i, this, userPtr); + std::size_t end; + if(i == ThreadCount - 1) { + end = currentSize; + } else { + end = s * (i + 1); + } + if(begin == end) { + continue; + } + std::get<0>(fnDataAr.at(i)) = this; + std::get<1>(fnDataAr.at(i)) = &entities; + std::get<2>(fnDataAr.at(i)) = &iterable; + std::get<3>(fnDataAr.at(i)) = {begin, end}; + std::get<4>(fnDataAr.at(i)) = userData; + threadPool->queueFn([&fn] (void *ud) { + auto *data = static_cast(ud); + bool isValid; + for(std::size_t i = std::get<3>(*data).at(0); + i < std::get<3>(*data).at(1); + ++i) { + if(!std::get<0>(*data)->isAlive(i)) { + continue; } - }, - begin, end); - } - for(std::size_t i = 0; i < threadCount; ++i) { - threads[i].join(); + isValid = true; + for(const auto& integralValue : *std::get<2>(*data)) { + if(!std::get(std::get<1>(*data)->at(i)).getCombinedBit(integralValue)) { + isValid = false; + break; + } + } + if(!isValid) { continue; } + + fn(i, std::get<0>(*data), std::get<4>(*data)); + + } + }, &fnDataAr.at(i)); } + threadPool->wakeThreads(); + do { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting()); } } }; diff --git a/src/EC/ThreadPool.hpp b/src/EC/ThreadPool.hpp new file mode 100644 index 0000000..3580b5e --- /dev/null +++ b/src/EC/ThreadPool.hpp @@ -0,0 +1,191 @@ +#ifndef EC_META_SYSTEM_THREADPOOL_HPP +#define EC_META_SYSTEM_THREADPOOL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace EC { + +namespace Internal { + using TPFnType = std::function; + using TPTupleType = std::tuple; + using TPQueueType = std::queue; +} // namespace Internal + +/*! + \brief Implementation of a Thread Pool. + + Note that if SIZE is less than 2, then ThreadPool will not create threads and + run queued functions on the calling thread. +*/ +template +class ThreadPool { +public: + using THREADCOUNT = std::integral_constant; + + ThreadPool() : waitCount(0) { + isAlive.store(true); + if(SIZE >= 2) { + for(unsigned int i = 0; i < SIZE; ++i) { + threads.emplace_back([] (std::atomic_bool *isAlive, + std::condition_variable *cv, + std::mutex *cvMutex, + Internal::TPQueueType *fnQueue, + std::mutex *queueMutex, + int *waitCount, + std::mutex *waitCountMutex) { + bool hasFn = false; + Internal::TPTupleType fnTuple; + while(isAlive->load()) { + hasFn = false; + { + std::lock_guard lock(*queueMutex); + if(!fnQueue->empty()) { + fnTuple = fnQueue->front(); + fnQueue->pop(); + hasFn = true; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + continue; + } + + { + std::lock_guard lock(*waitCountMutex); + *waitCount += 1; + } + { + std::unique_lock lock(*cvMutex); + cv->wait(lock); + } + { + std::lock_guard lock(*waitCountMutex); + *waitCount -= 1; + } + } + }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, + &waitCountMutex); + } + } + } + + ~ThreadPool() { + if(SIZE >= 2) { + isAlive.store(false); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + cv.notify_all(); + for(auto &thread : threads) { + thread.join(); + } + } + } + + /*! + \brief Queues a function to be called (doesn't start calling yet). + + To run the queued functions, wakeThreads() must be called to wake the + waiting threads which will start pulling functions from the queue to be + called. + */ + void queueFn(std::function&& fn, void *ud = nullptr) { + std::lock_guard lock(queueMutex); + fnQueue.emplace(std::make_tuple(fn, ud)); + } + + /*! + \brief Wakes waiting threads to start running queued functions. + + If SIZE is less than 2, then this function call will block until all the + queued functions have been executed on the calling thread. + + If SIZE is 2 or greater, then this function will return immediately after + waking one or all threads, depending on the given boolean parameter. + */ + void wakeThreads(bool wakeAll = true) { + if(SIZE >= 2) { + // wake threads to pull functions from queue and run them + if(wakeAll) { + cv.notify_all(); + } else { + cv.notify_one(); + } + } else { + // pull functions from queue and run them on main thread + Internal::TPTupleType fnTuple; + bool hasFn; + do { + { + std::lock_guard lock(queueMutex); + if(!fnQueue.empty()) { + hasFn = true; + fnTuple = fnQueue.front(); + fnQueue.pop(); + } else { + hasFn = false; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + } + } while(hasFn); + } + } + + /*! + \brief Gets the number of waiting threads. + + If all threads are waiting, this should equal ThreadCount. + + If SIZE is less than 2, then this will always return 0. + */ + int getWaitCount() { + std::lock_guard lock(waitCountMutex); + return waitCount; + } + + /*! + \brief Returns true if all threads are waiting. + + If SIZE is less than 2, then this will always return true. + */ + bool isAllThreadsWaiting() { + if(SIZE >= 2) { + std::lock_guard lock(waitCountMutex); + return waitCount == THREADCOUNT::value; + } else { + return true; + } + } + + /*! + \brief Returns true if the function queue is empty. + */ + bool isQueueEmpty() { + std::lock_guard lock(queueMutex); + return fnQueue.empty(); + } + +private: + std::vector threads; + std::atomic_bool isAlive; + std::condition_variable cv; + std::mutex cvMutex; + Internal::TPQueueType fnQueue; + std::mutex queueMutex; + int waitCount; + std::mutex waitCountMutex; + +}; + +} // namespace EC + +#endif diff --git a/src/test/ECTest.cpp b/src/test/ECTest.cpp index 24b8b11..75e65ae 100644 --- a/src/test/ECTest.cpp +++ b/src/test/ECTest.cpp @@ -464,7 +464,7 @@ TEST(EC, MultiThreaded) c->y = 2; }, nullptr, - 2 + true ); for(unsigned int i = 0; i < 17; ++i) @@ -490,7 +490,7 @@ TEST(EC, MultiThreaded) c->y = 4; }, nullptr, - 8 + true ); for(unsigned int i = 0; i < 3; ++i) @@ -516,7 +516,7 @@ TEST(EC, MultiThreaded) } ); - manager.callForMatchingFunctions(2); + manager.callForMatchingFunctions(true); for(unsigned int i = 0; i < 17; ++i) { @@ -531,7 +531,7 @@ TEST(EC, MultiThreaded) } ); - manager.callForMatchingFunction(f1, 4); + manager.callForMatchingFunction(f1, true); for(unsigned int i = 0; i < 17; ++i) { @@ -544,7 +544,7 @@ TEST(EC, MultiThreaded) manager.deleteEntity(i); } - manager.callForMatchingFunction(f0, 8); + manager.callForMatchingFunction(f0, true); for(unsigned int i = 0; i < 4; ++i) { @@ -710,7 +710,7 @@ TEST(EC, ForMatchingSignatures) c0->y = 2; }), nullptr, - 3 + true ); for(auto iter = cx.begin(); iter != cx.end(); ++iter) @@ -850,7 +850,7 @@ TEST(EC, forMatchingPtrs) &func0 ); manager.forMatchingSignaturePtr >( - &func1 + &func1, nullptr, true ); for(auto eid : e) @@ -1098,7 +1098,7 @@ TEST(EC, forMatchingSimple) { C0 *c0 = manager->getEntityData(id); c0->x += 10; c0->y += 10; - }, nullptr, 3); + }, nullptr, true); // verify { @@ -1296,7 +1296,7 @@ TEST(EC, forMatchingIterableFn) c->x += 100; c->y += 100; }; - manager.forMatchingIterable(iterable, fn, nullptr, 3); + manager.forMatchingIterable(iterable, fn, nullptr, true); } { @@ -1322,7 +1322,7 @@ TEST(EC, forMatchingIterableFn) c->x += 1000; c->y += 1000; }; - manager.forMatchingIterable(iterable, fn, nullptr, 3); + manager.forMatchingIterable(iterable, fn, nullptr, true); } { @@ -1365,8 +1365,35 @@ TEST(EC, MultiThreadedForMatching) { EXPECT_TRUE(manager.isAlive(first)); EXPECT_TRUE(manager.isAlive(second)); - manager.callForMatchingFunction(fnIdx, 2); + manager.callForMatchingFunction(fnIdx, true); EXPECT_TRUE(manager.isAlive(first)); EXPECT_FALSE(manager.isAlive(second)); } + +TEST(EC, ManagerWithLowThreadCount) { + EC::Manager manager; + + std::array entities; + for(auto &id : entities) { + id = manager.addEntity(); + manager.addComponent(id); + } + + for(const auto &id : entities) { + auto *component = manager.getEntityComponent(id); + EXPECT_EQ(component->x, 0); + EXPECT_EQ(component->y, 0); + } + + manager.forMatchingSignature >([] (std::size_t /*id*/, void* /*ud*/, C0 *c) { + c->x += 1; + c->y += 1; + }, nullptr, true); + + for(const auto &id : entities) { + auto *component = manager.getEntityComponent(id); + EXPECT_EQ(component->x, 1); + EXPECT_EQ(component->y, 1); + } +} diff --git a/src/test/ThreadPoolTest.cpp b/src/test/ThreadPoolTest.cpp new file mode 100644 index 0000000..31c8529 --- /dev/null +++ b/src/test/ThreadPoolTest.cpp @@ -0,0 +1,68 @@ +#include + +#include + +using OneThreadPool = EC::ThreadPool<1>; +using ThreeThreadPool = EC::ThreadPool<3>; + +TEST(ECThreadPool, CannotCompile) { + OneThreadPool p; + std::atomic_int data; + data.store(0); + const auto fn = [](void *ud) { + auto *data = static_cast(ud); + data->fetch_add(1); + }; + + p.queueFn(fn, &data); + + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 1); + + for(unsigned int i = 0; i < 10; ++i) { + p.queueFn(fn, &data); + } + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 11); +} + +TEST(ECThreadPool, Simple) { + ThreeThreadPool p; + std::atomic_int data; + data.store(0); + const auto fn = [](void *ud) { + auto *data = static_cast(ud); + data->fetch_add(1); + }; + + p.queueFn(fn, &data); + + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 1); + + for(unsigned int i = 0; i < 10; ++i) { + p.queueFn(fn, &data); + } + p.wakeThreads(); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting()); + + ASSERT_EQ(data.load(), 11); +}