diff --git a/src/EC/Manager.hpp b/src/EC/Manager.hpp index 5557c97..51a981f 100644 --- a/src/EC/Manager.hpp +++ b/src/EC/Manager.hpp @@ -763,7 +763,11 @@ namespace EC private: - std::unordered_map > + std::unordered_map)> > > forMatchingFunctions; std::size_t functionIndex = 0; @@ -827,71 +831,131 @@ namespace EC forMatchingFunctions.emplace(std::make_pair( functionIndex, - [function, signatureBitset, helper, this] - (std::size_t threadCount) - { - if(threadCount <= 1) + std::make_tuple( + signatureBitset, + [function, helper, this] + (std::size_t threadCount, + std::vector matching) { - for(std::size_t i = 0; i < this->currentSize; ++i) + if(threadCount <= 1) { - if(!std::get(this->entities[i])) + for(auto eid : matching) { - continue; - } - if((signatureBitset - & std::get(this->entities[i])) - == signatureBitset) - { - helper.callInstancePtr(i, *this, &function); + helper.callInstancePtr(eid, *this, &function); } } - } - else - { - std::vector threads(threadCount); - std::size_t s = this->currentSize / threadCount; - for(std::size_t i = 0; i < threadCount; ++ i) + else { - std::size_t begin = s * i; - std::size_t end; - if(i == threadCount - 1) + std::vector threads(threadCount); + std::size_t s = matching.size() / threadCount; + for(std::size_t i = 0; i < threadCount; ++ i) { - end = this->currentSize; - } - else - { - end = s * (i + 1); - } - threads[i] = std::thread( - [this, &function, &signatureBitset, &helper] - (std::size_t begin, - std::size_t end) { - for(std::size_t i = begin; i < end; ++i) + std::size_t begin = s * i; + std::size_t end; + if(i == threadCount - 1) { - if(!std::get(this->entities[i])) - { - continue; - } - if((signatureBitset - & std::get(this->entities[i])) - == signatureBitset) + end = matching.size(); + } + else + { + end = s * (i + 1); + } + threads[i] = std::thread( + [this, &function, &helper] + (std::size_t begin, + std::size_t end) { + for(std::size_t i = begin; i < end; ++i) { helper.callInstancePtr(i, *this, &function); } - } - }, - begin, end); + }, + begin, end); + } + for(std::size_t i = 0; i < threadCount; ++i) + { + threads[i].join(); + } } - for(std::size_t i = 0; i < threadCount; ++i) - { - threads[i].join(); - } - } - })); + }))); return functionIndex++; } + private: + std::vector > getMatchingEntities( + std::vector bitsets, std::size_t threadCount = 1) + { + std::vector > matchingV(bitsets.size()); + + if(threadCount <= 1) + { + for(std::size_t i = 0; i < currentSize; ++i) + { + if(!isAlive(i)) + { + continue; + } + for(std::size_t j = 0; j < bitsets.size(); ++j) + { + if(((*bitsets[j]) & std::get(entities[i])) + == (*bitsets[j])) + { + matchingV[j].push_back(i); + } + } + } + } + else + { + std::vector threads(threadCount); + std::size_t s = currentSize / threadCount; + std::mutex mutex; + for(std::size_t i = 0; i < threadCount; ++i) + { + std::size_t begin = s * i; + std::size_t end; + if(i == threadCount - 1) + { + end = currentSize; + } + else + { + end = s * (i + 1); + } + threads[i] = std::thread( + [this, &matchingV, &bitsets, &mutex] + (std::size_t begin, std::size_t end) + { + for(std::size_t j = begin; j < end; ++j) + { + if(!isAlive(j)) + { + continue; + } + for(std::size_t k = 0; k < bitsets.size(); ++k) + { + if(((*bitsets[k]) & + std::get(entities[j])) + == (*bitsets[k])) + { + std::lock_guard guard(mutex); + matchingV[k].push_back(j); + } + } + } + }, begin, end); + } + for(std::size_t i = 0; i < threadCount; ++i) + { + threads[i].join(); + } + } + + return matchingV; + } + + public: + /*! \brief Call all stored functions. @@ -922,11 +986,23 @@ namespace EC */ void callForMatchingFunctions(std::size_t threadCount = 1) { - for(auto functionIter = forMatchingFunctions.begin(); - functionIter != forMatchingFunctions.end(); - ++functionIter) + std::vector bitsets; + for(auto iter = forMatchingFunctions.begin(); + iter != forMatchingFunctions.end(); + ++iter) { - functionIter->second(threadCount); + bitsets.push_back(&std::get(iter->second)); + } + + std::vector > matching = + getMatchingEntities(bitsets, threadCount); + + std::size_t i = 0; + for(auto iter = forMatchingFunctions.begin(); + iter != forMatchingFunctions.end(); + ++iter) + { + std::get<1>(iter->second)(threadCount, matching[i++]); } } @@ -966,7 +1042,10 @@ namespace EC { return false; } - iter->second(threadCount); + std::vector > matching = + getMatchingEntities(std::vector{ + &std::get(iter->second)}, threadCount); + std::get<1>(iter->second)(threadCount, matching[0]); return true; }