]> git.seodisparate.com - EntityComponentMetaSystem/commitdiff
WIP nestable threads via ThreadPool
authorStephen Seo <seo.disparate@gmail.com>
Wed, 15 Jun 2022 07:38:36 +0000 (16:38 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Wed, 15 Jun 2022 07:38:36 +0000 (16:38 +0900)
Currently has a race-condition-like memory corruption bug where function
contexts are deleted before child-functions complete (probably).

src/EC/Manager.hpp
src/EC/ThreadPool.hpp
src/test/ECTest.cpp

index 4c1c8e1af9dedfecf78408700bd060cf83acf668..d82031d8e85e37ddc322946aecbe1e951519c668 100644 (file)
@@ -114,7 +114,7 @@ namespace EC
             std::array<std::size_t, 2> range;
             Manager *manager;
             EntitiesType *entities;
-            const BitsetType *signature;
+            BitsetType signature;
             void *userData;
             std::unordered_set<std::size_t> dead;
         };
@@ -779,7 +779,7 @@ namespace EC
             }
             else
             {
-                std::array<TPFnDataStructZero, ThreadCount * 2> fnDataAr;
+                std::array<TPFnDataStructZero*, ThreadCount * 2> fnDataAr;
 
                 std::size_t s = currentSize / (ThreadCount * 2);
                 for(std::size_t i = 0; i < ThreadCount * 2; ++i) {
@@ -793,14 +793,15 @@ namespace EC
                     if(begin == end) {
                         continue;
                     }
-                    fnDataAr[i].range = {begin, end};
-                    fnDataAr[i].manager = this;
-                    fnDataAr[i].entities = &entities;
-                    fnDataAr[i].signature = &signatureBitset;
-                    fnDataAr[i].userData = userData;
+                    fnDataAr[i] = new TPFnDataStructZero{};
+                    fnDataAr[i]->range = {begin, end};
+                    fnDataAr[i]->manager = this;
+                    fnDataAr[i]->entities = &entities;
+                    fnDataAr[i]->signature = signatureBitset;
+                    fnDataAr[i]->userData = userData;
                     for(std::size_t j = begin; j < end; ++j) {
                         if(!isAlive(j)) {
-                            fnDataAr[i].dead.insert(j);
+                            fnDataAr[i]->dead.insert(j);
                         }
                     }
 
@@ -812,17 +813,18 @@ namespace EC
                                 continue;
                             }
 
-                            if(((*data->signature)
+                            if(((data->signature)
                                         & std::get<BitsetType>(
                                             data->entities->at(i)))
-                                    == *data->signature) {
+                                    == data->signature) {
                                 Helper::call(i,
                                              *data->manager,
                                              std::forward<Function>(function),
                                              data->userData);
                             }
                         }
-                    }, &fnDataAr[i]);
+                        delete data;
+                    }, fnDataAr[i]);
                 }
                 threadPool->easyWakeAndWait();
             }
@@ -1874,7 +1876,7 @@ namespace EC
             may not have as great of a speed-up.
          */
         template <typename Signature>
-        void forMatchingSimple(ForMatchingFn fn, 
+        void forMatchingSimple(ForMatchingFn fn,
                                void *userData = nullptr,
                                const bool useThreadPool = false) {
             deferringDeletions.fetch_add(1);
@@ -1891,7 +1893,7 @@ namespace EC
                     }
                 }
             } else {
-                std::array<TPFnDataStructZero, ThreadCount * 2> fnDataAr;
+                std::array<TPFnDataStructZero*, ThreadCount * 2> fnDataAr;
 
                 std::size_t s = currentSize / (ThreadCount * 2);
                 for(std::size_t i = 0; i < ThreadCount * 2; ++i) {
@@ -1905,14 +1907,15 @@ namespace EC
                     if(begin == end) {
                         continue;
                     }
-                    fnDataAr[i].range = {begin, end};
-                    fnDataAr[i].manager = this;
-                    fnDataAr[i].entities = &entities;
-                    fnDataAr[i].signature = &signatureBitset;
-                    fnDataAr[i].userData = userData;
+                    fnDataAr[i] = new TPFnDataStructZero{};
+                    fnDataAr[i]->range = {begin, end};
+                    fnDataAr[i]->manager = this;
+                    fnDataAr[i]->entities = &entities;
+                    fnDataAr[i]->signature = signatureBitset;
+                    fnDataAr[i]->userData = userData;
                     for(std::size_t j = begin; j < end; ++j) {
                         if(!isAlive(j)) {
-                            fnDataAr[i].dead.insert(j);
+                            fnDataAr[i]->dead.insert(j);
                         }
                     }
                     threadPool->queueFn([&fn] (void *ud) {
@@ -1921,14 +1924,15 @@ namespace EC
                                 ++i) {
                             if(data->dead.find(i) != data->dead.end()) {
                                 continue;
-                            } else if((*data->signature
+                            } else if((data->signature
                                         & std::get<BitsetType>(
                                             data->entities->at(i)))
-                                    == *data->signature) {
+                                    == data->signature) {
                                 fn(i, data->manager, data->userData);
                             }
                         }
-                    }, &fnDataAr[i]);
+                        delete data;
+                    }, fnDataAr[i]);
                 }
                 threadPool->easyWakeAndWait();
             }
index e0dd21f7b5c758cf6539372186d777ca1a97fa88..3611954d4291088718ad487cec5e54df5b89fa90 100644 (file)
 #include <functional>
 #include <tuple>
 #include <chrono>
+#include <unordered_set>
+
+#ifndef NDEBUG
+# include <iostream>
+#endif
 
 namespace EC {
 
@@ -18,6 +23,39 @@ namespace Internal {
     using TPFnType = std::function<void(void*)>;
     using TPTupleType = std::tuple<TPFnType, void*>;
     using TPQueueType = std::queue<TPTupleType>;
+
+    template <unsigned int SIZE>
+    void thread_fn(std::atomic_bool *isAlive,
+                   std::condition_variable *cv,
+                   std::mutex *cvMutex,
+                   Internal::TPQueueType *fnQueue,
+                   std::mutex *queueMutex,
+                   std::atomic_int *waitCount) {
+        bool hasFn = false;
+        Internal::TPTupleType fnTuple;
+        while(isAlive->load()) {
+            hasFn = false;
+            {
+                std::lock_guard<std::mutex> lock(*queueMutex);
+                if(!fnQueue->empty()) {
+                    fnTuple = fnQueue->front();
+                    fnQueue->pop();
+                    hasFn = true;
+                }
+            }
+            if(hasFn) {
+                std::get<0>(fnTuple)(std::get<1>(fnTuple));
+                continue;
+            }
+
+            waitCount->fetch_add(1);
+            {
+                std::unique_lock<std::mutex> lock(*cvMutex);
+                cv->wait(lock);
+            }
+            waitCount->fetch_sub(1);
+        }
+    }
 } // namespace Internal
 
 /*!
@@ -29,49 +67,20 @@ namespace Internal {
 template <unsigned int SIZE>
 class ThreadPool {
 public:
-    ThreadPool() : waitCount(0) {
+    ThreadPool() {
+        waitCount.store(0);
+        extraThreadCount.store(0);
         isAlive.store(true);
         if(SIZE >= 2) {
             for(unsigned int i = 0; i < SIZE; ++i) {
-                threads.emplace_back([] (std::atomic_bool *isAlive,
-                                         std::condition_variable *cv,
-                                         std::mutex *cvMutex,
-                                         Internal::TPQueueType *fnQueue,
-                                         std::mutex *queueMutex,
-                                         int *waitCount,
-                                         std::mutex *waitCountMutex) {
-                    bool hasFn = false;
-                    Internal::TPTupleType fnTuple;
-                    while(isAlive->load()) {
-                        hasFn = false;
-                        {
-                            std::lock_guard<std::mutex> lock(*queueMutex);
-                            if(!fnQueue->empty()) {
-                                fnTuple = fnQueue->front();
-                                fnQueue->pop();
-                                hasFn = true;
-                            }
-                        }
-                        if(hasFn) {
-                            std::get<0>(fnTuple)(std::get<1>(fnTuple));
-                            continue;
-                        }
-
-                        {
-                            std::lock_guard<std::mutex> lock(*waitCountMutex);
-                            *waitCount += 1;
-                        }
-                        {
-                            std::unique_lock<std::mutex> lock(*cvMutex);
-                            cv->wait(lock);
-                        }
-                        {
-                            std::lock_guard<std::mutex> lock(*waitCountMutex);
-                            *waitCount -= 1;
-                        }
-                    }
-                }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount,
-                   &waitCountMutex);
+                threads.emplace_back(Internal::thread_fn<SIZE>,
+                                     &isAlive,
+                                     &cv,
+                                     &cvMutex,
+                                     &fnQueue,
+                                     &queueMutex,
+                                     &waitCount);
+                threadsIDs.insert(threads.back().get_id());
             }
         }
     }
@@ -84,6 +93,7 @@ public:
             for(auto &thread : threads) {
                 thread.join();
             }
+            std::this_thread::sleep_for(std::chrono::milliseconds(20));
         }
     }
 
@@ -108,7 +118,7 @@ public:
         If SIZE is 2 or greater, then this function will return immediately after
         waking one or all threads, depending on the given boolean parameter.
     */
-    void wakeThreads(bool wakeAll = true) {
+    void wakeThreads(const bool wakeAll = true) {
         if(SIZE >= 2) {
             // wake threads to pull functions from queue and run them
             if(wakeAll) {
@@ -116,6 +126,36 @@ public:
             } else {
                 cv.notify_one();
             }
+
+            // check if all threads are running a task, and spawn a new thread
+            // if this is the case
+            Internal::TPTupleType fnTuple;
+            bool hasFn = false;
+            if (waitCount.load(std::memory_order_relaxed) == 0) {
+                std::lock_guard<std::mutex> queueLock(queueMutex);
+                if (!fnQueue.empty()) {
+                    fnTuple = fnQueue.front();
+                    fnQueue.pop();
+                    hasFn = true;
+                }
+            }
+
+            if (hasFn) {
+#ifndef NDEBUG
+                std::cout << "Spawning extra thread...\n";
+#endif
+                extraThreadCount.fetch_add(1);
+                std::thread newThread = std::thread(
+                        [] (Internal::TPTupleType &&tuple, std::atomic_int *count) {
+                            std::get<0>(tuple)(std::get<1>(tuple));
+#ifndef NDEBUG
+                            std::cout << "Stopping extra thread...\n";
+#endif
+                            count->fetch_sub(1);
+                        },
+                        std::move(fnTuple), &extraThreadCount);
+                newThread.detach();
+            }
         } else {
             sequentiallyRunTasks();
         }
@@ -129,8 +169,7 @@ public:
         If SIZE is less than 2, then this will always return 0.
     */
     int getWaitCount() {
-        std::lock_guard<std::mutex> lock(waitCountMutex);
-        return waitCount;
+        return waitCount.load(std::memory_order_relaxed);
     }
 
     /*!
@@ -140,8 +179,7 @@ public:
     */
     bool isAllThreadsWaiting() {
         if(SIZE >= 2) {
-            std::lock_guard<std::mutex> lock(waitCountMutex);
-            return waitCount == SIZE;
+            return waitCount.load(std::memory_order_relaxed) == SIZE;
         } else {
             return true;
         }
@@ -173,10 +211,13 @@ public:
      */
     void easyWakeAndWait() {
         if(SIZE >= 2) {
-            wakeThreads();
             do {
+                wakeThreads();
                 std::this_thread::sleep_for(std::chrono::microseconds(150));
-            } while(!isQueueEmpty() || !isAllThreadsWaiting());
+            } while(!isQueueEmpty()
+                    || (threadsIDs.find(std::this_thread::get_id()) != threadsIDs.end()
+                         && extraThreadCount.load(std::memory_order_relaxed) != 0));
+//            } while(!isQueueEmpty() || !isAllThreadsWaiting());
         } else {
             sequentiallyRunTasks();
         }
@@ -184,13 +225,14 @@ public:
 
 private:
     std::vector<std::thread> threads;
+    std::unordered_set<std::thread::id> threadsIDs;
     std::atomic_bool isAlive;
     std::condition_variable cv;
     std::mutex cvMutex;
     Internal::TPQueueType fnQueue;
     std::mutex queueMutex;
-    int waitCount;
-    std::mutex waitCountMutex;
+    std::atomic_int waitCount;
+    std::atomic_int extraThreadCount;
 
     void sequentiallyRunTasks() {
         // pull functions from queue and run them on current thread
index e0ff4bb9686866338beada67efb72e605c9c442a..7852154bfc0b7b0daba1b28b6da4a4c2312e1c4c 100644 (file)
@@ -1,7 +1,9 @@
 
 #include <gtest/gtest.h>
 
+#include <chrono>
 #include <iostream>
+#include <thread>
 #include <tuple>
 #include <memory>
 #include <unordered_map>
@@ -1431,3 +1433,33 @@ TEST(EC, ManagerDeferredDeletions) {
         }
     }
 }
+
+TEST(EC, NestedThreadPoolTasks) {
+    using ManagerType = EC::Manager<ListComponentsAll, ListTagsAll, 2>;
+    ManagerType manager;
+
+    std::array<std::size_t, 64> entities;
+    for (auto &entity : entities) {
+        entity = manager.addEntity();
+        manager.addComponent<C0>(entity, entity, entity);
+    }
+
+    manager.forMatchingSignature<EC::Meta::TypeList<C0>>([] (std::size_t id, void *data, C0 *c) {
+        ManagerType *manager = (ManagerType*)data;
+
+        manager->forMatchingSignature<EC::Meta::TypeList<C0>>([id] (std::size_t inner_id, void* data, C0 *inner_c) {
+            const C0 *const outer_c = (C0*)data;
+            EXPECT_EQ(id, outer_c->x);
+            EXPECT_EQ(inner_id, inner_c->x);
+            if (id == inner_id) {
+                EXPECT_EQ(outer_c->x, inner_c->x);
+                EXPECT_EQ(outer_c->y, inner_c->y);
+            } else {
+                EXPECT_NE(outer_c->x, inner_c->x);
+                EXPECT_NE(outer_c->y, inner_c->y);
+            }
+        }, c, false);
+    }, &manager, true);
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+}