diff --git a/.gitignore b/.gitignore index a087dc2..46ea9c4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ build*/ doxygen_html/ compile_commands.json tags -.clangd/ +.cache/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0244e2b..186d0da 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,7 +14,9 @@ set(EntityComponentSystem_HEADERS EC/Meta/Meta.hpp EC/Bitset.hpp EC/Manager.hpp - EC/EC.hpp) + EC/EC.hpp + EC/ThreadPool.hpp +) set(WillFailCompile_SOURCES test/WillFailCompileTest.cpp) @@ -48,7 +50,9 @@ if(GTEST_FOUND) set(UnitTests_SOURCES test/MetaTest.cpp test/ECTest.cpp - test/Main.cpp) + test/ThreadPoolTest.cpp + test/Main.cpp + ) add_executable(UnitTests ${UnitTests_SOURCES}) target_link_libraries(UnitTests EntityComponentSystem ${GTEST_LIBRARIES}) diff --git a/src/EC/ThreadPool.hpp b/src/EC/ThreadPool.hpp new file mode 100644 index 0000000..30969e1 --- /dev/null +++ b/src/EC/ThreadPool.hpp @@ -0,0 +1,123 @@ +#ifndef EC_META_SYSTEM_THREADPOOL_HPP +#define EC_META_SYSTEM_THREADPOOL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace EC { + +namespace Internal { + using TPFnType = std::function; + using TPTupleType = std::tuple; + using TPQueueType = std::queue; +} // namespace Internal + +template +class ThreadPool; + +template +class ThreadPool= 2)>::type> { +public: + using THREADCOUNT = std::integral_constant; + + ThreadPool() : waitCount(0) { + isAlive.store(true); + for(unsigned int i = 0; i < SIZE; ++i) { + threads.emplace_back([] (std::atomic_bool *isAlive, + std::condition_variable *cv, + std::mutex *cvMutex, + Internal::TPQueueType *fnQueue, + std::mutex *queueMutex, + int *waitCount, + std::mutex *waitCountMutex) { + bool hasFn = false; + Internal::TPTupleType fnTuple; + while(isAlive->load()) { + hasFn = false; + { + std::lock_guard lock(*queueMutex); + if(!fnQueue->empty()) { + fnTuple = fnQueue->front(); + fnQueue->pop(); + hasFn = true; + } + } + if(hasFn) { + std::get<0>(fnTuple)(std::get<1>(fnTuple)); + continue; + } + + { + std::lock_guard lock(*waitCountMutex); + *waitCount += 1; + } + { + std::unique_lock lock(*cvMutex); + cv->wait(lock); + } + { + std::lock_guard lock(*waitCountMutex); + *waitCount -= 1; + } + } + }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex); + } + } + + ~ThreadPool() { + isAlive.store(false); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + cv.notify_all(); + for(auto &thread : threads) { + thread.join(); + } + } + + void queueFn(std::function&& fn, void *ud = nullptr) { + std::lock_guard lock(queueMutex); + fnQueue.emplace(std::make_tuple(fn, ud)); + } + + void wakeThreads(bool wakeAll = true) { + unsigned int counter = 0; + if(wakeAll) { + cv.notify_all(); + } else { + cv.notify_one(); + } + while(isAllThreadsWaiting() && counter++ < 10000) {} + } + + int getWaitCount() { + std::lock_guard lock(waitCountMutex); + return waitCount; + } + + bool isAllThreadsWaiting() { + std::lock_guard lock(waitCountMutex); + return waitCount == THREADCOUNT::value; + } + +private: + std::vector threads; + std::atomic_bool isAlive; + std::condition_variable cv; + std::mutex cvMutex; + Internal::TPQueueType fnQueue; + std::mutex queueMutex; + int waitCount; + std::mutex waitCountMutex; + +}; + +} // namespace EC + +#endif diff --git a/src/test/ThreadPoolTest.cpp b/src/test/ThreadPoolTest.cpp new file mode 100644 index 0000000..ef7a695 --- /dev/null +++ b/src/test/ThreadPoolTest.cpp @@ -0,0 +1,41 @@ +#include + +#include + +//using OneThreadPool = EC::ThreadPool<1>; +using ThreeThreadPool = EC::ThreadPool<3>; + +//TEST(ECThreadPool, CannotCompile) { +// OneThreadPool tp; +//} + +TEST(ECThreadPool, Simple) { + ThreeThreadPool p{}; + std::atomic_int data; + data.store(0); + const auto fn = [](void *ud) { + auto *data = static_cast(ud); + data->fetch_add(1); + }; + + p.queueFn(fn, &data); + + p.wakeThreads(); + + while(!p.isAllThreadsWaiting()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + ASSERT_EQ(data.load(), 1); + + for(unsigned int i = 0; i < 10; ++i) { + p.queueFn(fn, &data); + } + p.wakeThreads(); + + while(!p.isAllThreadsWaiting()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + ASSERT_EQ(data.load(), 11); +}