Impl ThreadPool
This commit is contained in:
parent
381e069ec2
commit
2e9e18a964
4 changed files with 171 additions and 3 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -4,4 +4,4 @@ build*/
|
|||
doxygen_html/
|
||||
compile_commands.json
|
||||
tags
|
||||
.clangd/
|
||||
.cache/
|
||||
|
|
|
@ -14,7 +14,9 @@ set(EntityComponentSystem_HEADERS
|
|||
EC/Meta/Meta.hpp
|
||||
EC/Bitset.hpp
|
||||
EC/Manager.hpp
|
||||
EC/EC.hpp)
|
||||
EC/EC.hpp
|
||||
EC/ThreadPool.hpp
|
||||
)
|
||||
|
||||
set(WillFailCompile_SOURCES
|
||||
test/WillFailCompileTest.cpp)
|
||||
|
@ -48,7 +50,9 @@ if(GTEST_FOUND)
|
|||
set(UnitTests_SOURCES
|
||||
test/MetaTest.cpp
|
||||
test/ECTest.cpp
|
||||
test/Main.cpp)
|
||||
test/ThreadPoolTest.cpp
|
||||
test/Main.cpp
|
||||
)
|
||||
|
||||
add_executable(UnitTests ${UnitTests_SOURCES})
|
||||
target_link_libraries(UnitTests EntityComponentSystem ${GTEST_LIBRARIES})
|
||||
|
|
123
src/EC/ThreadPool.hpp
Normal file
123
src/EC/ThreadPool.hpp
Normal file
|
@ -0,0 +1,123 @@
|
|||
#ifndef EC_META_SYSTEM_THREADPOOL_HPP
|
||||
#define EC_META_SYSTEM_THREADPOOL_HPP
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <chrono>
|
||||
|
||||
namespace EC {
|
||||
|
||||
namespace Internal {
|
||||
using TPFnType = std::function<void(void*)>;
|
||||
using TPTupleType = std::tuple<TPFnType, void*>;
|
||||
using TPQueueType = std::queue<TPTupleType>;
|
||||
} // namespace Internal
|
||||
|
||||
template <unsigned int SIZE, typename = void>
|
||||
class ThreadPool;
|
||||
|
||||
template <unsigned int SIZE>
|
||||
class ThreadPool<SIZE, typename std::enable_if<(SIZE >= 2)>::type> {
|
||||
public:
|
||||
using THREADCOUNT = std::integral_constant<int, SIZE>;
|
||||
|
||||
ThreadPool() : waitCount(0) {
|
||||
isAlive.store(true);
|
||||
for(unsigned int i = 0; i < SIZE; ++i) {
|
||||
threads.emplace_back([] (std::atomic_bool *isAlive,
|
||||
std::condition_variable *cv,
|
||||
std::mutex *cvMutex,
|
||||
Internal::TPQueueType *fnQueue,
|
||||
std::mutex *queueMutex,
|
||||
int *waitCount,
|
||||
std::mutex *waitCountMutex) {
|
||||
bool hasFn = false;
|
||||
Internal::TPTupleType fnTuple;
|
||||
while(isAlive->load()) {
|
||||
hasFn = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*queueMutex);
|
||||
if(!fnQueue->empty()) {
|
||||
fnTuple = fnQueue->front();
|
||||
fnQueue->pop();
|
||||
hasFn = true;
|
||||
}
|
||||
}
|
||||
if(hasFn) {
|
||||
std::get<0>(fnTuple)(std::get<1>(fnTuple));
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*waitCountMutex);
|
||||
*waitCount += 1;
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(*cvMutex);
|
||||
cv->wait(lock);
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*waitCountMutex);
|
||||
*waitCount -= 1;
|
||||
}
|
||||
}
|
||||
}, &isAlive, &cv, &cvMutex, &fnQueue, &queueMutex, &waitCount, &waitCountMutex);
|
||||
}
|
||||
}
|
||||
|
||||
~ThreadPool() {
|
||||
isAlive.store(false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
cv.notify_all();
|
||||
for(auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
void queueFn(std::function<void(void*)>&& fn, void *ud = nullptr) {
|
||||
std::lock_guard<std::mutex> lock(queueMutex);
|
||||
fnQueue.emplace(std::make_tuple(fn, ud));
|
||||
}
|
||||
|
||||
void wakeThreads(bool wakeAll = true) {
|
||||
unsigned int counter = 0;
|
||||
if(wakeAll) {
|
||||
cv.notify_all();
|
||||
} else {
|
||||
cv.notify_one();
|
||||
}
|
||||
while(isAllThreadsWaiting() && counter++ < 10000) {}
|
||||
}
|
||||
|
||||
int getWaitCount() {
|
||||
std::lock_guard<std::mutex> lock(waitCountMutex);
|
||||
return waitCount;
|
||||
}
|
||||
|
||||
bool isAllThreadsWaiting() {
|
||||
std::lock_guard<std::mutex> lock(waitCountMutex);
|
||||
return waitCount == THREADCOUNT::value;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::thread> threads;
|
||||
std::atomic_bool isAlive;
|
||||
std::condition_variable cv;
|
||||
std::mutex cvMutex;
|
||||
Internal::TPQueueType fnQueue;
|
||||
std::mutex queueMutex;
|
||||
int waitCount;
|
||||
std::mutex waitCountMutex;
|
||||
|
||||
};
|
||||
|
||||
} // namespace EC
|
||||
|
||||
#endif
|
41
src/test/ThreadPoolTest.cpp
Normal file
41
src/test/ThreadPoolTest.cpp
Normal file
|
@ -0,0 +1,41 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <EC/ThreadPool.hpp>
|
||||
|
||||
//using OneThreadPool = EC::ThreadPool<1>;
|
||||
using ThreeThreadPool = EC::ThreadPool<3>;
|
||||
|
||||
//TEST(ECThreadPool, CannotCompile) {
|
||||
// OneThreadPool tp;
|
||||
//}
|
||||
|
||||
TEST(ECThreadPool, Simple) {
|
||||
ThreeThreadPool p{};
|
||||
std::atomic_int data;
|
||||
data.store(0);
|
||||
const auto fn = [](void *ud) {
|
||||
auto *data = static_cast<std::atomic_int*>(ud);
|
||||
data->fetch_add(1);
|
||||
};
|
||||
|
||||
p.queueFn(fn, &data);
|
||||
|
||||
p.wakeThreads();
|
||||
|
||||
while(!p.isAllThreadsWaiting()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
ASSERT_EQ(data.load(), 1);
|
||||
|
||||
for(unsigned int i = 0; i < 10; ++i) {
|
||||
p.queueFn(fn, &data);
|
||||
}
|
||||
p.wakeThreads();
|
||||
|
||||
while(!p.isAllThreadsWaiting()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
ASSERT_EQ(data.load(), 11);
|
||||
}
|
Loading…
Reference in a new issue