]> git.seodisparate.com - EntityComponentMetaSystem/commitdiff
Impl ThreadPool
authorStephen Seo <seo.disparate@gmail.com>
Mon, 6 Sep 2021 06:54:24 +0000 (15:54 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Mon, 6 Sep 2021 06:54:24 +0000 (15:54 +0900)
.gitignore
src/CMakeLists.txt
src/EC/ThreadPool.hpp [new file with mode: 0644]
src/test/ThreadPoolTest.cpp [new file with mode: 0644]

index a087dc25db5d487b2c80d8d3b7d2988d70fbc2f5..46ea9c480d395be861f28b92dd36de62178705e3 100644 (file)
@@ -4,4 +4,4 @@ build*/
 doxygen_html/
 compile_commands.json
 tags
-.clangd/
+.cache/
index 0244e2b32f5426ac971e06b4926f563a32444b26..186d0daf0aefb5a9f86927f671d227543be5ed36 100644 (file)
@@ -14,7 +14,9 @@ set(EntityComponentSystem_HEADERS
     EC/Meta/Meta.hpp
     EC/Bitset.hpp
     EC/Manager.hpp
-    EC/EC.hpp)
+    EC/EC.hpp
+    EC/ThreadPool.hpp
+)
 
 set(WillFailCompile_SOURCES
     test/WillFailCompileTest.cpp)
@@ -48,7 +50,9 @@ if(GTEST_FOUND)
     set(UnitTests_SOURCES
         test/MetaTest.cpp
         test/ECTest.cpp
-        test/Main.cpp)
+        test/ThreadPoolTest.cpp
+        test/Main.cpp
+    )
 
     add_executable(UnitTests ${UnitTests_SOURCES})
     target_link_libraries(UnitTests EntityComponentSystem ${GTEST_LIBRARIES})
diff --git a/src/EC/ThreadPool.hpp b/src/EC/ThreadPool.hpp
new file mode 100644 (file)
index 0000000..30969e1
--- /dev/null
@@ -0,0 +1,123 @@
+#ifndef EC_META_SYSTEM_THREADPOOL_HPP
+#define EC_META_SYSTEM_THREADPOOL_HPP
+
+#include <type_traits>
+#include <vector>
+#include <thread>
+#include <atomic>
+#include <mutex>
+#include <condition_variable>
+#include <queue>
+#include <functional>
+#include <tuple>
+#include <chrono>
+
+namespace EC {
+
+namespace Internal {
+    using TPFnType = std::function<void(void*)>;
+    using TPTupleType = std::tuple<TPFnType, void*>;
+    using TPQueueType = std::queue<TPTupleType>;
+} // namespace Internal
+
+template <unsigned int SIZE, typename = void>
+class ThreadPool;
+
+template <unsigned int SIZE>
+class ThreadPool<SIZE, typename std::enable_if<(SIZE >= 2)>::type> {
+public:
+    using THREADCOUNT = std::integral_constant<int, SIZE>;
+
+    ThreadPool() : waitCount(0) {
+        isAlive.store(true);
+        for(unsigned int i = 0; i < SIZE; ++i) {
+            threads.emplace_back([] (std::atomic_bool *isAlive,
+                                     std::condition_variable *cv,
+                                     std::mutex *cvMutex,
+                                     Internal::TPQueueType *fnQueue,
+                                     std::mutex *queueMutex,
+                                     int *waitCount,
+                                     std::mutex *waitCountMutex) {
+                bool hasFn = false;
+                Internal::TPTupleType fnTuple;
+                while(isAlive->load()) {
+                    hasFn = false;
+                    {
+                        std::lock_guard<std::mutex> lock(*queueMutex);
+                        if(!fnQueue->empty()) {
+                            fnTuple = fnQueue->front();
+                            fnQueue->pop();
+                            hasFn = true;
+                        }
+                    }
+                    if(hasFn) {
+                        std::get<0>(fnTuple)(std::get<1>(fnTuple));
+                        continue;
+                    }
+
+                    {
+                        std::lock_guard<std::mutex> lock(*waitCountMutex);
+                        *waitCount += 1;
+                    }
+                    {
+                        std::unique_lock<std::mutex> lock(*cvMutex);
+                        cv->wait(lock);
+                    }
+                    {
+                        std::lock_guard<std::mutex> lock(*waitCountMutex);
+                        *waitCount -= 1;
+                    }
+                }
+            }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex);
+        }
+    }
+
+    ~ThreadPool() {
+        isAlive.store(false);
+        std::this_thread::sleep_for(std::chrono::milliseconds(200));
+        cv.notify_all();
+        for(auto &thread : threads) {
+            thread.join();
+        }
+    }
+
+    void queueFn(std::function<void(void*)>&& fn, void *ud = nullptr) {
+        std::lock_guard<std::mutex> lock(queueMutex);
+        fnQueue.emplace(std::make_tuple(fn, ud));
+    }
+
+    void wakeThreads(bool wakeAll = true) {
+        unsigned int counter = 0;
+        if(wakeAll) {
+            cv.notify_all();
+        } else {
+            cv.notify_one();
+        }
+        while(isAllThreadsWaiting() && counter++ < 10000) {}
+    }
+
+    int getWaitCount() {
+        std::lock_guard<std::mutex> lock(waitCountMutex);
+        return waitCount;
+    }
+
+    bool isAllThreadsWaiting() {
+        std::lock_guard<std::mutex> lock(waitCountMutex);
+        return waitCount == THREADCOUNT::value;
+    }
+
+private:
+    std::vector<std::thread> threads;
+    std::atomic_bool isAlive;
+    std::condition_variable cv;
+    std::mutex cvMutex;
+    Internal::TPQueueType fnQueue;
+    std::mutex queueMutex;
+    int waitCount;
+    std::mutex waitCountMutex;
+
+};
+
+} // namespace EC
+
+#endif
diff --git a/src/test/ThreadPoolTest.cpp b/src/test/ThreadPoolTest.cpp
new file mode 100644 (file)
index 0000000..ef7a695
--- /dev/null
@@ -0,0 +1,41 @@
+#include <gtest/gtest.h>
+
+#include <EC/ThreadPool.hpp>
+
+//using OneThreadPool = EC::ThreadPool<1>;
+using ThreeThreadPool = EC::ThreadPool<3>;
+
+//TEST(ECThreadPool, CannotCompile) {
+//    OneThreadPool tp;
+//}
+
+TEST(ECThreadPool, Simple) {
+    ThreeThreadPool p{};
+    std::atomic_int data;
+    data.store(0);
+    const auto fn = [](void *ud) {
+        auto *data = static_cast<std::atomic_int*>(ud);
+        data->fetch_add(1);
+    };
+
+    p.queueFn(fn, &data);
+
+    p.wakeThreads();
+
+    while(!p.isAllThreadsWaiting()) {
+        std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    }
+
+    ASSERT_EQ(data.load(), 1);
+
+    for(unsigned int i = 0; i < 10; ++i) {
+        p.queueFn(fn, &data);
+    }
+    p.wakeThreads();
+
+    while(!p.isAllThreadsWaiting()) {
+        std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    }
+
+    ASSERT_EQ(data.load(), 11);
+}