]> git.seodisparate.com - EntityComponentMetaSystem/commitdiff
Allow ThreadPool to be created with < 2 ThreadCount
authorStephen Seo <seo.disparate@gmail.com>
Tue, 7 Sep 2021 02:46:38 +0000 (11:46 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Tue, 7 Sep 2021 02:46:38 +0000 (11:46 +0900)
src/EC/Manager.hpp
src/EC/ThreadPool.hpp
src/test/ECTest.cpp
src/test/ThreadPoolTest.cpp

index 4e5b57ae3337690d01bfb083a4b7e38825d3213e..a4eb821b53e01a73f7a217fb35475375b9dbb5ed 100644 (file)
@@ -93,7 +93,7 @@ namespace EC
         std::size_t currentSize = 0;
         std::unordered_set<std::size_t> deletedSet;
 
-        ThreadPool<ThreadCount> threadPool;
+        std::unique_ptr<ThreadPool<ThreadCount> > threadPool;
 
     public:
         /*!
@@ -105,6 +105,9 @@ namespace EC
         Manager()
         {
             resize(EC_INIT_ENTITIES_SIZE);
+            if(ThreadCount >= 2) {
+                threadPool = std::make_unique<ThreadPool<ThreadCount> >();
+            }
         }
 
     private:
@@ -634,7 +637,7 @@ namespace EC
 
             BitsetType signatureBitset =
                 BitsetType::template generateBitset<Signature>();
-            if(!useThreadPool)
+            if(!useThreadPool || !threadPool)
             {
                 for(std::size_t i = 0; i < currentSize; ++i)
                 {
@@ -673,7 +676,7 @@ namespace EC
                     std::get<2>(fnDataAr.at(i)) = &signatureBitset;
                     std::get<3>(fnDataAr.at(i)) = {begin, end};
                     std::get<4>(fnDataAr.at(i)) = userData;
-                    threadPool.queueFn([&function] (void *ud) {
+                    threadPool->queueFn([&function] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<3>(*data).at(0);
                                 i < std::get<3>(*data).at(1);
@@ -691,10 +694,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
         }
 
@@ -753,7 +756,7 @@ namespace EC
 
             BitsetType signatureBitset =
                 BitsetType::template generateBitset<Signature>();
-            if(!useThreadPool)
+            if(!useThreadPool || !threadPool)
             {
                 for(std::size_t i = 0; i < currentSize; ++i)
                 {
@@ -792,7 +795,7 @@ namespace EC
                     std::get<3>(fnDataAr.at(i)) = {begin, end};
                     std::get<4>(fnDataAr.at(i)) = userData;
                     std::get<5>(fnDataAr.at(i)) = function;
-                    threadPool.queueFn([] (void *ud) {
+                    threadPool->queueFn([] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<3>(*data).at(0);
                                 i < std::get<3>(*data).at(1);
@@ -810,10 +813,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
         }
 
@@ -906,7 +909,7 @@ namespace EC
                         std::vector<std::size_t> matching,
                         void* userData)
                 {
-                    if(!useThreadPool)
+                    if(!useThreadPool || !threadPool)
                     {
                         for(auto eid : matching)
                         {
@@ -939,7 +942,7 @@ namespace EC
                             std::get<2>(fnDataAr.at(i)) = {begin, end};
                             std::get<3>(fnDataAr.at(i)) = userData;
                             std::get<4>(fnDataAr.at(i)) = &matching;
-                            threadPool.queueFn([function, helper] (void* ud) {
+                            threadPool->queueFn([function, helper] (void* ud) {
                                 auto *data = static_cast<TPFnDataType*>(ud);
                                 for(std::size_t i = std::get<2>(*data).at(0);
                                         i < std::get<2>(*data).at(1);
@@ -954,10 +957,10 @@ namespace EC
                                 }
                             }, &fnDataAr.at(i));
                         }
-                        threadPool.wakeThreads();
+                        threadPool->wakeThreads();
                         do {
                             std::this_thread::sleep_for(std::chrono::microseconds(200));
-                        } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                        } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
                     }
                 })));
 
@@ -970,7 +973,7 @@ namespace EC
         {
             std::vector<std::vector<std::size_t> > matchingV(bitsets.size());
 
-            if(!useThreadPool)
+            if(!useThreadPool || !threadPool)
             {
                 for(std::size_t i = 0; i < currentSize; ++i)
                 {
@@ -1012,7 +1015,7 @@ namespace EC
                     std::get<3>(fnDataAr.at(i)) = &bitsets;
                     std::get<4>(fnDataAr.at(i)) = &entities;
                     std::get<5>(fnDataAr.at(i)) = &mutex;
-                    threadPool.queueFn([] (void *ud) {
+                    threadPool->queueFn([] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<1>(*data).at(0);
                                 i < std::get<1>(*data).at(1);
@@ -1033,10 +1036,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
 
             return matchingV;
@@ -1351,7 +1354,7 @@ namespace EC
             });
 
             // find and store entities matching signatures
-            if(!useThreadPool)
+            if(!useThreadPool || !threadPool)
             {
                 for(std::size_t eid = 0; eid < currentSize; ++eid)
                 {
@@ -1394,7 +1397,7 @@ namespace EC
                     std::get<3>(fnDataAr.at(i)) = signatureBitsets;
                     std::get<4>(fnDataAr.at(i)) = &mutex;
 
-                    threadPool.queueFn([] (void *ud) {
+                    threadPool->queueFn([] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<1>(*data).at(0);
                                 i < std::get<1>(*data).at(1);
@@ -1412,10 +1415,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
 
             // call functions on matching entities
@@ -1432,7 +1435,7 @@ namespace EC
                         EC::Meta::Morph<
                             SignatureComponents,
                             ForMatchingSignatureHelper<> >;
-                    if(!useThreadPool) {
+                    if(!useThreadPool || !threadPool) {
                         for(const auto& id : multiMatchingEntities[index]) {
                             if(isAlive(id)) {
                                 Helper::call(id, *this, func, userData);
@@ -1459,7 +1462,7 @@ namespace EC
                             std::get<2>(fnDataAr.at(i)) = {begin, end};
                             std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities;
                             std::get<4>(fnDataAr.at(i)) = index;
-                            threadPool.queueFn([&func] (void *ud) {
+                            threadPool->queueFn([&func] (void *ud) {
                                 auto *data = static_cast<TPFnType*>(ud);
                                 for(std::size_t i = std::get<2>(*data).at(0);
                                         i < std::get<2>(*data).at(1);
@@ -1474,10 +1477,10 @@ namespace EC
                                 }
                             }, &fnDataAr.at(i));
                         }
-                        threadPool.wakeThreads();
+                        threadPool->wakeThreads();
                         do {
                             std::this_thread::sleep_for(std::chrono::microseconds(200));
-                        } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                        } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
                     }
                 }
             );
@@ -1544,7 +1547,7 @@ namespace EC
             });
 
             // find and store entities matching signatures
-            if(!useThreadPool)
+            if(!useThreadPool || !threadPool)
             {
                 for(std::size_t eid = 0; eid < currentSize; ++eid)
                 {
@@ -1587,7 +1590,7 @@ namespace EC
                     std::get<3>(fnDataAr.at(i)) = signatureBitsets;
                     std::get<4>(fnDataAr.at(i)) = &mutex;
 
-                    threadPool.queueFn([] (void *ud) {
+                    threadPool->queueFn([] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<1>(*data).at(0);
                                 i < std::get<1>(*data).at(1);
@@ -1605,10 +1608,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
 
             // call functions on matching entities
@@ -1625,7 +1628,7 @@ namespace EC
                         EC::Meta::Morph<
                             SignatureComponents,
                             ForMatchingSignatureHelper<> >;
-                    if(!useThreadPool)
+                    if(!useThreadPool || !threadPool)
                     {
                         for(const auto& id : multiMatchingEntities[index])
                         {
@@ -1657,7 +1660,7 @@ namespace EC
                             std::get<2>(fnDataAr.at(i)) = {begin, end};
                             std::get<3>(fnDataAr.at(i)) = &multiMatchingEntities;
                             std::get<4>(fnDataAr.at(i)) = index;
-                            threadPool.queueFn([&func] (void *ud) {
+                            threadPool->queueFn([&func] (void *ud) {
                                 auto *data = static_cast<TPFnType*>(ud);
                                 for(std::size_t i = std::get<2>(*data).at(0);
                                         i < std::get<2>(*data).at(1);
@@ -1672,10 +1675,10 @@ namespace EC
                                 }
                             }, &fnDataAr.at(i));
                         }
-                        threadPool.wakeThreads();
+                        threadPool->wakeThreads();
                         do {
                             std::this_thread::sleep_for(std::chrono::microseconds(200));
-                        } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                        } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
                     }
                 }
             );
@@ -1694,7 +1697,7 @@ namespace EC
         template <typename Signature>
         void forMatchingSimple(ForMatchingFn fn, void *userData = nullptr, const bool useThreadPool = false) {
             const BitsetType signatureBitset = BitsetType::template generateBitset<Signature>();
-            if(!useThreadPool) {
+            if(!useThreadPool || !threadPool) {
                 for(std::size_t i = 0; i < currentSize; ++i) {
                     if(!std::get<bool>(entities[i])) {
                         continue;
@@ -1723,7 +1726,7 @@ namespace EC
                     std::get<2>(fnDataAr.at(i)) = &signatureBitset;
                     std::get<3>(fnDataAr.at(i)) = {begin, end};
                     std::get<4>(fnDataAr.at(i)) = userData;
-                    threadPool.queueFn([&fn] (void *ud) {
+                    threadPool->queueFn([&fn] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         for(std::size_t i = std::get<3>(*data).at(0);
                                 i < std::get<3>(*data).at(1);
@@ -1736,10 +1739,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
         }
 
@@ -1754,7 +1757,7 @@ namespace EC
          */
         template <typename Iterable>
         void forMatchingIterable(Iterable iterable, ForMatchingFn fn, void* userData = nullptr, const bool useThreadPool = false) {
-            if(!useThreadPool) {
+            if(!useThreadPool || !threadPool) {
                 bool isValid;
                 for(std::size_t i = 0; i < currentSize; ++i) {
                     if(!std::get<bool>(entities[i])) {
@@ -1793,7 +1796,7 @@ namespace EC
                     std::get<2>(fnDataAr.at(i)) = &iterable;
                     std::get<3>(fnDataAr.at(i)) = {begin, end};
                     std::get<4>(fnDataAr.at(i)) = userData;
-                    threadPool.queueFn([&fn] (void *ud) {
+                    threadPool->queueFn([&fn] (void *ud) {
                         auto *data = static_cast<TPFnDataType*>(ud);
                         bool isValid;
                         for(std::size_t i = std::get<3>(*data).at(0);
@@ -1816,10 +1819,10 @@ namespace EC
                         }
                     }, &fnDataAr.at(i));
                 }
-                threadPool.wakeThreads();
+                threadPool->wakeThreads();
                 do {
                     std::this_thread::sleep_for(std::chrono::microseconds(200));
-                } while(!threadPool.isQueueEmpty() && !threadPool.isAllThreadsWaiting());
+                } while(!threadPool->isQueueEmpty() && !threadPool->isAllThreadsWaiting());
             }
         }
     };
index aee681d0e66bd784656f2b787ebcab7f1def9d6f..c7d078e59b8484b37933a1d2129887f2c70d19a2 100644 (file)
@@ -20,64 +20,66 @@ namespace Internal {
     using TPQueueType = std::queue<TPTupleType>;
 } // namespace Internal
 
-template <unsigned int SIZE, typename = void>
-class ThreadPool;
-
 template <unsigned int SIZE>
-class ThreadPool<SIZE, typename std::enable_if<(SIZE >= 2)>::type> {
+class ThreadPool {
 public:
     using THREADCOUNT = std::integral_constant<int, SIZE>;
 
     ThreadPool() : waitCount(0) {
         isAlive.store(true);
-        for(unsigned int i = 0; i < SIZE; ++i) {
-            threads.emplace_back([] (std::atomic_bool *isAlive,
-                                     std::condition_variable *cv,
-                                     std::mutex *cvMutex,
-                                     Internal::TPQueueType *fnQueue,
-                                     std::mutex *queueMutex,
-                                     int *waitCount,
-                                     std::mutex *waitCountMutex) {
-                bool hasFn = false;
-                Internal::TPTupleType fnTuple;
-                while(isAlive->load()) {
-                    hasFn = false;
-                    {
-                        std::lock_guard<std::mutex> lock(*queueMutex);
-                        if(!fnQueue->empty()) {
-                            fnTuple = fnQueue->front();
-                            fnQueue->pop();
-                            hasFn = true;
+        if(SIZE >= 2) {
+            for(unsigned int i = 0; i < SIZE; ++i) {
+                threads.emplace_back([] (std::atomic_bool *isAlive,
+                                         std::condition_variable *cv,
+                                         std::mutex *cvMutex,
+                                         Internal::TPQueueType *fnQueue,
+                                         std::mutex *queueMutex,
+                                         int *waitCount,
+                                         std::mutex *waitCountMutex) {
+                    bool hasFn = false;
+                    Internal::TPTupleType fnTuple;
+                    while(isAlive->load()) {
+                        hasFn = false;
+                        {
+                            std::lock_guard<std::mutex> lock(*queueMutex);
+                            if(!fnQueue->empty()) {
+                                fnTuple = fnQueue->front();
+                                fnQueue->pop();
+                                hasFn = true;
+                            }
+                        }
+                        if(hasFn) {
+                            std::get<0>(fnTuple)(std::get<1>(fnTuple));
+                            continue;
                         }
-                    }
-                    if(hasFn) {
-                        std::get<0>(fnTuple)(std::get<1>(fnTuple));
-                        continue;
-                    }
 
-                    {
-                        std::lock_guard<std::mutex> lock(*waitCountMutex);
-                        *waitCount += 1;
-                    }
-                    {
-                        std::unique_lock<std::mutex> lock(*cvMutex);
-                        cv->wait(lock);
-                    }
-                    {
-                        std::lock_guard<std::mutex> lock(*waitCountMutex);
-                        *waitCount -= 1;
+                        {
+                            std::lock_guard<std::mutex> lock(*waitCountMutex);
+                            *waitCount += 1;
+                        }
+                        {
+                            std::unique_lock<std::mutex> lock(*cvMutex);
+                            cv->wait(lock);
+                        }
+                        {
+                            std::lock_guard<std::mutex> lock(*waitCountMutex);
+                            *waitCount -= 1;
+                        }
                     }
-                }
-            }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex);
+                }, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount,
+                   &waitCountMutex);
+            }
         }
     }
 
     ~ThreadPool() {
-        isAlive.store(false);
-        std::this_thread::sleep_for(std::chrono::milliseconds(200));
-        cv.notify_all();
-        for(auto &thread : threads) {
-            thread.join();
+        if(SIZE >= 2) {
+            isAlive.store(false);
+            std::this_thread::sleep_for(std::chrono::milliseconds(20));
+            cv.notify_all();
+            for(auto &thread : threads) {
+                thread.join();
+            }
         }
     }
 
@@ -87,10 +89,32 @@ public:
     }
 
     void wakeThreads(bool wakeAll = true) {
-        if(wakeAll) {
-            cv.notify_all();
+        if(SIZE >= 2) {
+            // wake threads to pull functions from queue and run them
+            if(wakeAll) {
+                cv.notify_all();
+            } else {
+                cv.notify_one();
+            }
         } else {
-            cv.notify_one();
+            // pull functions from queue and run them on main thread
+            Internal::TPTupleType fnTuple;
+            bool hasFn;
+            do {
+                {
+                    std::lock_guard<std::mutex> lock(queueMutex);
+                    if(!fnQueue.empty()) {
+                        hasFn = true;
+                        fnTuple = fnQueue.front();
+                        fnQueue.pop();
+                    } else {
+                        hasFn = false;
+                    }
+                }
+                if(hasFn) {
+                    std::get<0>(fnTuple)(std::get<1>(fnTuple));
+                }
+            } while(hasFn);
         }
     }
 
@@ -100,8 +124,12 @@ public:
     }
 
     bool isAllThreadsWaiting() {
-        std::lock_guard<std::mutex> lock(waitCountMutex);
-        return waitCount == THREADCOUNT::value;
+        if(SIZE >= 2) {
+            std::lock_guard<std::mutex> lock(waitCountMutex);
+            return waitCount == THREADCOUNT::value;
+        } else {
+            return true;
+        }
     }
 
     bool isQueueEmpty() {
index 31059038759b57205feb94eefe9678a605f9eb18..75e65ae0646752947d6faf27531fd209b740a765 100644 (file)
@@ -1370,3 +1370,30 @@ TEST(EC, MultiThreadedForMatching) {
     EXPECT_TRUE(manager.isAlive(first));
     EXPECT_FALSE(manager.isAlive(second));
 }
+
+TEST(EC, ManagerWithLowThreadCount) {
+    EC::Manager<ListComponentsAll, ListTagsAll, 1> manager;
+
+    std::array<std::size_t, 10> entities;
+    for(auto &id : entities) {
+        id = manager.addEntity();
+        manager.addComponent<C0>(id);
+    }
+
+    for(const auto &id : entities) {
+        auto *component = manager.getEntityComponent<C0>(id);
+        EXPECT_EQ(component->x, 0);
+        EXPECT_EQ(component->y, 0);
+    }
+
+    manager.forMatchingSignature<EC::Meta::TypeList<C0> >([] (std::size_t /*id*/, void* /*ud*/, C0 *c) {
+        c->x += 1;
+        c->y += 1;
+    }, nullptr, true);
+
+    for(const auto &id : entities) {
+        auto *component = manager.getEntityComponent<C0>(id);
+        EXPECT_EQ(component->x, 1);
+        EXPECT_EQ(component->y, 1);
+    }
+}
index 2cec50ccf7d5681c7bb9970c6976761dceefc2e3..31c85293bcdb1910aff8619eaffbd5d670d6a241 100644 (file)
@@ -2,15 +2,42 @@
 
 #include <EC/ThreadPool.hpp>
 
-//using OneThreadPool = EC::ThreadPool<1>;
+using OneThreadPool = EC::ThreadPool<1>;
 using ThreeThreadPool = EC::ThreadPool<3>;
 
-//TEST(ECThreadPool, CannotCompile) {
-//    OneThreadPool tp;
-//}
+TEST(ECThreadPool, CannotCompile) {
+    OneThreadPool p;
+    std::atomic_int data;
+    data.store(0);
+    const auto fn = [](void *ud) {
+        auto *data = static_cast<std::atomic_int*>(ud);
+        data->fetch_add(1);
+    };
+
+    p.queueFn(fn, &data);
+
+    p.wakeThreads();
+
+    do {
+        std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting());
+
+    ASSERT_EQ(data.load(), 1);
+
+    for(unsigned int i = 0; i < 10; ++i) {
+        p.queueFn(fn, &data);
+    }
+    p.wakeThreads();
+
+    do {
+        std::this_thread::sleep_for(std::chrono::milliseconds(10));
+    } while(!p.isQueueEmpty() && !p.isAllThreadsWaiting());
+
+    ASSERT_EQ(data.load(), 11);
+}
 
 TEST(ECThreadPool, Simple) {
-    ThreeThreadPool p{};
+    ThreeThreadPool p;
     std::atomic_int data;
     data.store(0);
     const auto fn = [](void *ud) {