diff --git a/cpp_impl/src/TSLQueue.hpp b/cpp_impl/src/TSLQueue.hpp index eaa73c2..5b7410f 100644 --- a/cpp_impl/src/TSLQueue.hpp +++ b/cpp_impl/src/TSLQueue.hpp @@ -7,7 +7,6 @@ #include #include #include - #include #include @@ -38,7 +37,7 @@ class TSLQueue { private: struct TSLQNode { - TSLQNode() = default; + TSLQNode(); // disable copy TSLQNode(TSLQNode& other) = delete; TSLQNode& operator=(TSLQNode& other) = delete; @@ -49,10 +48,37 @@ class TSLQueue { std::shared_ptr next; std::weak_ptr prev; std::unique_ptr data; + + enum TSLQN_Type { + TSLQN_NORMAL, + TSLQN_HEAD, + TSLQN_TAIL + }; + + TSLQN_Type type; + bool isNormal() const; }; - std::shared_ptr iterValid; - std::shared_ptr iterWrapperCount; + class TSLQIter { + public: + TSLQIter(std::mutex &mutex, + std::weak_ptr currentNode); + ~TSLQIter(); + + std::optional current(); + bool next(); + bool prev(); + bool remove(); + + private: + std::lock_guard lock; + std::weak_ptr currentNode; + }; + + public: + TSLQIter begin(); + + private: std::mutex mutex; std::shared_ptr head; std::shared_ptr tail; @@ -61,14 +87,14 @@ class TSLQueue { template TSLQueue::TSLQueue() : - iterValid(std::make_shared()), - iterWrapperCount(std::make_shared()), head(std::make_shared()), tail(std::make_shared()), msize(0) { head->next = tail; tail->prev = head; + head->type = TSLQNode::TSLQN_Type::TSLQN_HEAD; + tail->type = TSLQNode::TSLQN_Type::TSLQN_TAIL; } template @@ -76,9 +102,7 @@ TSLQueue::~TSLQueue() { } template -TSLQueue::TSLQueue(TSLQueue &&other) : - iterValid(std::make_shared()), - iterWrapperCount(std::make_shared()) +TSLQueue::TSLQueue(TSLQueue &&other) { std::lock_guard lock(other.mutex); head = std::move(other.head); @@ -88,8 +112,6 @@ TSLQueue::TSLQueue(TSLQueue &&other) : template TSLQueue & TSLQueue::operator=(TSLQueue &&other) { - iterValid = std::make_shared(); - iterWrapperCount = std::make_shared(); std::scoped_lock lock(mutex, other.mutex); head = std::move(other.head); tail = std::move(other.tail); @@ -98,9 +120,6 @@ TSLQueue & TSLQueue::operator=(TSLQueue &&other) { template void TSLQueue::push(const T &data) { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); auto newNode = std::make_shared(); newNode->data = std::make_unique(data); @@ -117,9 +136,7 @@ void TSLQueue::push(const T &data) { template bool TSLQueue::push_nb(const T &data) { - if(iterWrapperCount.use_count() > 1) { - return false; - } else if(mutex.try_lock()) { + if(mutex.try_lock()) { auto newNode = std::make_shared(); newNode->data = std::make_unique(data); @@ -141,9 +158,6 @@ bool TSLQueue::push_nb(const T &data) { template std::optional TSLQueue::top() { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); if(head->next != tail) { assert(head->next->data); @@ -155,9 +169,7 @@ std::optional TSLQueue::top() { template std::optional TSLQueue::top_nb() { - if(iterWrapperCount.use_count() > 1) { - return std::nullopt; - } else if(mutex.try_lock()) { + if(mutex.try_lock()) { std::optional ret = std::nullopt; if(head->next != tail) { assert(head->next->data); @@ -172,9 +184,6 @@ std::optional TSLQueue::top_nb() { template bool TSLQueue::pop() { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); if(head->next == tail) { return false; @@ -185,8 +194,6 @@ bool TSLQueue::pop() { assert(msize > 0); --msize; - iterValid = std::make_shared(); - iterWrapperCount = std::make_shared(); return true; } } @@ -194,9 +201,6 @@ bool TSLQueue::pop() { template std::optional TSLQueue::top_and_pop() { std::optional ret = std::nullopt; - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); if(head->next != tail) { assert(head->next->data); @@ -207,9 +211,6 @@ std::optional TSLQueue::top_and_pop() { head->next = newNext; assert(msize > 0); --msize; - - iterValid = std::make_shared(); - iterWrapperCount = std::make_shared(); } return ret; } @@ -217,9 +218,6 @@ std::optional TSLQueue::top_and_pop() { template std::optional TSLQueue::top_and_pop_and_empty(bool *isEmpty) { std::optional ret = std::nullopt; - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); if(head->next == tail) { if(isEmpty) { @@ -235,8 +233,6 @@ std::optional TSLQueue::top_and_pop_and_empty(bool *isEmpty) { assert(msize > 0); --msize; - iterValid = std::make_shared(); - iterWrapperCount = std::make_shared(); if(isEmpty) { *isEmpty = head->next == tail; } @@ -246,35 +242,105 @@ std::optional TSLQueue::top_and_pop_and_empty(bool *isEmpty) { template void TSLQueue::clear() { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); head->next = tail; tail->prev = head; msize = 0; - - iterValid = std::make_shared(); - iterWrapperCount = std::make_shared(); } template bool TSLQueue::empty() { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); return head->next == tail; } template unsigned long long TSLQueue::size() { - while(iterWrapperCount.use_count() > 1) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } std::lock_guard lock(mutex); return msize; } +template +TSLQueue::TSLQNode::TSLQNode() : +type(TSLQN_Type::TSLQN_NORMAL) +{} + +template +bool TSLQueue::TSLQNode::isNormal() const { + return type == TSLQN_Type::TSLQN_NORMAL; +} + +template +TSLQueue::TSLQIter::TSLQIter(std::mutex &mutex, + std::weak_ptr currentNode) : +lock(mutex), +currentNode(currentNode) +{ +} + +template +TSLQueue::TSLQIter::~TSLQIter() { +} + +template +std::optional TSLQueue::TSLQIter::current() { + std::shared_ptr currentNode = this->currentNode.lock(); + assert(currentNode); + if(currentNode->isNormal()) { + return *currentNode->data.get(); + } else { + return std::nullopt; + } +} + +template +bool TSLQueue::TSLQIter::next() { + std::shared_ptr currentNode = this->currentNode.lock(); + assert(currentNode); + if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_TAIL) { + return false; + } + + this->currentNode = currentNode->next; + return currentNode->next->type != TSLQNode::TSLQN_Type::TSLQN_TAIL; +} + +template +bool TSLQueue::TSLQIter::prev() { + std::shared_ptr currentNode = this->currentNode.lock(); + assert(currentNode); + if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_HEAD) { + return false; + } + + auto parent = currentNode->prev.lock(); + assert(parent); + this->currentNode = currentNode->prev; + return parent->type != TSLQNode::TSLQN_Type::TSLQN_HEAD; +} + +template +bool TSLQueue::TSLQIter::remove() { + std::shared_ptr currentNode = this->currentNode.lock(); + assert(currentNode); + if(!currentNode->isNormal()) { + return false; + } + + this->currentNode = currentNode->next; + auto parent = currentNode->prev.lock(); + assert(parent); + + currentNode->next->prev = parent; + parent->next = currentNode->next; + + return parent->next->isNormal(); +} + +template +typename TSLQueue::TSLQIter TSLQueue::begin() { + return TSLQIter(mutex, head->next); +} + #endif diff --git a/cpp_impl/src/test/TestTSLQueue.cpp b/cpp_impl/src/test/TestTSLQueue.cpp index cd96104..9fab4a4 100644 --- a/cpp_impl/src/test/TestTSLQueue.cpp +++ b/cpp_impl/src/test/TestTSLQueue.cpp @@ -116,3 +116,76 @@ TEST(TSLQueue, Concurrent) { } EXPECT_EQ(q.size(), 0); } + +TEST(TSLQueue, Iterator) { + TSLQueue q; + + for(int i = 0; i < 10; ++i) { + q.push(i); + } + + { + // iteration + auto iter = q.begin(); + int i = 0; + auto op = iter.current(); + while(op.has_value()) { + EXPECT_EQ(op.value(), i++); + if(i < 10) { + EXPECT_TRUE(iter.next()); + } else { + EXPECT_FALSE(iter.next()); + } + op = iter.current(); + } + + // test that lock is held by iterator + EXPECT_FALSE(q.push_nb(10)); + op = q.top_nb(); + EXPECT_FALSE(op.has_value()); + + // backwards iteration + EXPECT_TRUE(iter.prev()); + op = iter.current(); + while(op.has_value()) { + EXPECT_EQ(op.value(), --i); + if(i > 0) { + EXPECT_TRUE(iter.prev()); + } else { + EXPECT_FALSE(iter.prev()); + } + op = iter.current(); + } + } + + { + // iter remove + auto iter = q.begin(); + EXPECT_TRUE(iter.next()); + EXPECT_TRUE(iter.next()); + EXPECT_TRUE(iter.next()); + EXPECT_TRUE(iter.remove()); + + auto op = iter.current(); + EXPECT_TRUE(op.has_value()); + EXPECT_EQ(op.value(), 4); + + EXPECT_TRUE(iter.prev()); + op = iter.current(); + EXPECT_TRUE(op.has_value()); + EXPECT_EQ(op.value(), 2); + } + + // check that "3" was removed from queue + int i = 0; + std::optional op; + while(!q.empty()) { + op = q.top(); + EXPECT_TRUE(op.has_value()); + EXPECT_EQ(i++, op.value()); + if(i == 3) { + ++i; + } + EXPECT_TRUE(q.pop()); + } +}