]> git.seodisparate.com - UDPConnection/commitdiff
Add iterator to TSLQueue
authorStephen Seo <seo.disparate@gmail.com>
Sun, 3 Nov 2019 09:46:25 +0000 (18:46 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Sun, 3 Nov 2019 09:46:25 +0000 (18:46 +0900)
Removed shared_ptr "locks", replaced by iterator owning a lock_guard.

cpp_impl/src/TSLQueue.hpp
cpp_impl/src/test/TestTSLQueue.cpp

index eaa73c2cae4afd600f1193b15d2563f62cb15ec5..5b7410feb1103758e7a5188b59e6d4597f31af4b 100644 (file)
@@ -7,7 +7,6 @@
 #include <chrono>
 #include <optional>
 #include <cassert>
-
 #include <list>
 #include <type_traits>
 
@@ -38,7 +37,7 @@ class TSLQueue {
 
   private:
     struct TSLQNode {
-        TSLQNode() = default;
+        TSLQNode();
         // disable copy
         TSLQNode(TSLQNode& other) = delete;
         TSLQNode& operator=(TSLQNode& other) = delete;
@@ -49,10 +48,37 @@ class TSLQueue {
         std::shared_ptr<TSLQNode> next;
         std::weak_ptr<TSLQNode> prev;
         std::unique_ptr<T> data;
+
+        enum TSLQN_Type {
+            TSLQN_NORMAL,
+            TSLQN_HEAD,
+            TSLQN_TAIL
+        };
+
+        TSLQN_Type type;
+        bool isNormal() const;
+    };
+
+    class TSLQIter {
+    public:
+        TSLQIter(std::mutex &mutex,
+                 std::weak_ptr<TSLQNode> currentNode);
+        ~TSLQIter();
+
+        std::optional<T> current();
+        bool next();
+        bool prev();
+        bool remove();
+
+    private:
+        std::lock_guard<std::mutex> lock;
+        std::weak_ptr<TSLQNode> currentNode;
     };
 
-    std::shared_ptr<char> iterValid;
-    std::shared_ptr<char> iterWrapperCount;
+  public:
+    TSLQIter begin();
+
+  private:
     std::mutex mutex;
     std::shared_ptr<TSLQNode> head;
     std::shared_ptr<TSLQNode> tail;
@@ -61,14 +87,14 @@ class TSLQueue {
 
 template <typename T>
 TSLQueue<T>::TSLQueue() :
-    iterValid(std::make_shared<char>()),
-    iterWrapperCount(std::make_shared<char>()),
     head(std::make_shared<TSLQNode>()),
     tail(std::make_shared<TSLQNode>()),
     msize(0)
 {
     head->next = tail;
     tail->prev = head;
+    head->type = TSLQNode::TSLQN_Type::TSLQN_HEAD;
+    tail->type = TSLQNode::TSLQN_Type::TSLQN_TAIL;
 }
 
 template <typename T>
@@ -76,9 +102,7 @@ TSLQueue<T>::~TSLQueue() {
 }
 
 template <typename T>
-TSLQueue<T>::TSLQueue(TSLQueue &&other) :
-    iterValid(std::make_shared<char>()),
-    iterWrapperCount(std::make_shared<char>())
+TSLQueue<T>::TSLQueue(TSLQueue &&other)
 {
     std::lock_guard lock(other.mutex);
     head = std::move(other.head);
@@ -88,8 +112,6 @@ TSLQueue<T>::TSLQueue(TSLQueue &&other) :
 
 template <typename T>
 TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
-    iterValid = std::make_shared<char>();
-    iterWrapperCount = std::make_shared<char>();
     std::scoped_lock lock(mutex, other.mutex);
     head = std::move(other.head);
     tail = std::move(other.tail);
@@ -98,9 +120,6 @@ TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
 
 template <typename T>
 void TSLQueue<T>::push(const T &data) {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     auto newNode = std::make_shared<TSLQNode>();
     newNode->data = std::make_unique<T>(data);
@@ -117,9 +136,7 @@ void TSLQueue<T>::push(const T &data) {
 
 template <typename T>
 bool TSLQueue<T>::push_nb(const T &data) {
-    if(iterWrapperCount.use_count() > 1) {
-        return false;
-    } else if(mutex.try_lock()) {
+    if(mutex.try_lock()) {
         auto newNode = std::make_shared<TSLQNode>();
         newNode->data = std::make_unique<T>(data);
 
@@ -141,9 +158,6 @@ bool TSLQueue<T>::push_nb(const T &data) {
 
 template <typename T>
 std::optional<T> TSLQueue<T>::top() {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     if(head->next != tail) {
         assert(head->next->data);
@@ -155,9 +169,7 @@ std::optional<T> TSLQueue<T>::top() {
 
 template <typename T>
 std::optional<T> TSLQueue<T>::top_nb() {
-    if(iterWrapperCount.use_count() > 1) {
-        return std::nullopt;
-    } else if(mutex.try_lock()) {
+    if(mutex.try_lock()) {
         std::optional<T> ret = std::nullopt;
         if(head->next != tail) {
             assert(head->next->data);
@@ -172,9 +184,6 @@ std::optional<T> TSLQueue<T>::top_nb() {
 
 template <typename T>
 bool TSLQueue<T>::pop() {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     if(head->next == tail) {
         return false;
@@ -185,8 +194,6 @@ bool TSLQueue<T>::pop() {
         assert(msize > 0);
         --msize;
 
-        iterValid = std::make_shared<char>();
-        iterWrapperCount = std::make_shared<char>();
         return true;
     }
 }
@@ -194,9 +201,6 @@ bool TSLQueue<T>::pop() {
 template <typename T>
 std::optional<T> TSLQueue<T>::top_and_pop() {
     std::optional<T> ret = std::nullopt;
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     if(head->next != tail) {
         assert(head->next->data);
@@ -207,9 +211,6 @@ std::optional<T> TSLQueue<T>::top_and_pop() {
         head->next = newNext;
         assert(msize > 0);
         --msize;
-
-        iterValid = std::make_shared<char>();
-        iterWrapperCount = std::make_shared<char>();
     }
     return ret;
 }
@@ -217,9 +218,6 @@ std::optional<T> TSLQueue<T>::top_and_pop() {
 template <typename T>
 std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
     std::optional<T> ret = std::nullopt;
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     if(head->next == tail) {
         if(isEmpty) {
@@ -235,8 +233,6 @@ std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
         assert(msize > 0);
         --msize;
 
-        iterValid = std::make_shared<char>();
-        iterWrapperCount = std::make_shared<char>();
         if(isEmpty) {
             *isEmpty = head->next == tail;
         }
@@ -246,35 +242,105 @@ std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
 
 template <typename T>
 void TSLQueue<T>::clear() {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
 
     head->next = tail;
     tail->prev = head;
     msize = 0;
-
-    iterValid = std::make_shared<char>();
-    iterWrapperCount = std::make_shared<char>();
 }
 
 template <typename T>
 bool TSLQueue<T>::empty() {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     return head->next == tail;
 }
 
 template <typename T>
 unsigned long long TSLQueue<T>::size() {
-    while(iterWrapperCount.use_count() > 1) {
-        std::this_thread::sleep_for(std::chrono::milliseconds(10));
-    }
     std::lock_guard lock(mutex);
     return msize;
 }
 
+template <typename T>
+TSLQueue<T>::TSLQNode::TSLQNode() :
+type(TSLQN_Type::TSLQN_NORMAL)
+{}
+
+template <typename T>
+bool TSLQueue<T>::TSLQNode::isNormal() const {
+    return type == TSLQN_Type::TSLQN_NORMAL;
+}
+
+template <typename T>
+TSLQueue<T>::TSLQIter::TSLQIter(std::mutex &mutex,
+                                std::weak_ptr<TSLQNode> currentNode) :
+lock(mutex),
+currentNode(currentNode)
+{
+}
+
+template <typename T>
+TSLQueue<T>::TSLQIter::~TSLQIter() {
+}
+
+template <typename T>
+std::optional<T> TSLQueue<T>::TSLQIter::current() {
+    std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
+    assert(currentNode);
+    if(currentNode->isNormal()) {
+        return *currentNode->data.get();
+    } else {
+        return std::nullopt;
+    }
+}
+
+template <typename T>
+bool TSLQueue<T>::TSLQIter::next() {
+    std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
+    assert(currentNode);
+    if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_TAIL) {
+        return false;
+    }
+
+    this->currentNode = currentNode->next;
+    return currentNode->next->type != TSLQNode::TSLQN_Type::TSLQN_TAIL;
+}
+
+template <typename T>
+bool TSLQueue<T>::TSLQIter::prev() {
+    std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
+    assert(currentNode);
+    if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_HEAD) {
+        return false;
+    }
+
+    auto parent = currentNode->prev.lock();
+    assert(parent);
+    this->currentNode = currentNode->prev;
+    return parent->type != TSLQNode::TSLQN_Type::TSLQN_HEAD;
+}
+
+template <typename T>
+bool TSLQueue<T>::TSLQIter::remove() {
+    std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
+    assert(currentNode);
+    if(!currentNode->isNormal()) {
+        return false;
+    }
+
+    this->currentNode = currentNode->next;
+    auto parent = currentNode->prev.lock();
+    assert(parent);
+
+    currentNode->next->prev = parent;
+    parent->next = currentNode->next;
+
+    return parent->next->isNormal();
+}
+
+template <typename T>
+typename TSLQueue<T>::TSLQIter TSLQueue<T>::begin() {
+    return TSLQIter(mutex, head->next);
+}
+
 #endif
index cd961047f4e718ea7021747eb0721f54525ae8d2..9fab4a40e812db63dfc8c1f44874db8d527ccf2e 100644 (file)
@@ -116,3 +116,76 @@ TEST(TSLQueue, Concurrent) {
     }
     EXPECT_EQ(q.size(), 0);
 }
+
+TEST(TSLQueue, Iterator) {
+    TSLQueue<int> q;
+
+    for(int i = 0; i < 10; ++i) {
+        q.push(i);
+    }
+
+    {
+        // iteration
+        auto iter = q.begin();
+        int i = 0;
+        auto op = iter.current();
+        while(op.has_value()) {
+            EXPECT_EQ(op.value(), i++);
+            if(i < 10) {
+                EXPECT_TRUE(iter.next());
+            } else {
+                EXPECT_FALSE(iter.next());
+            }
+            op = iter.current();
+        }
+
+        // test that lock is held by iterator
+        EXPECT_FALSE(q.push_nb(10));
+        op = q.top_nb();
+        EXPECT_FALSE(op.has_value());
+
+        // backwards iteration
+        EXPECT_TRUE(iter.prev());
+        op = iter.current();
+        while(op.has_value()) {
+            EXPECT_EQ(op.value(), --i);
+            if(i > 0) {
+                EXPECT_TRUE(iter.prev());
+            } else {
+                EXPECT_FALSE(iter.prev());
+            }
+            op = iter.current();
+        }
+    }
+
+    {
+        // iter remove
+        auto iter = q.begin();
+        EXPECT_TRUE(iter.next());
+        EXPECT_TRUE(iter.next());
+        EXPECT_TRUE(iter.next());
+        EXPECT_TRUE(iter.remove());
+
+        auto op = iter.current();
+        EXPECT_TRUE(op.has_value());
+        EXPECT_EQ(op.value(), 4);
+
+        EXPECT_TRUE(iter.prev());
+        op = iter.current();
+        EXPECT_TRUE(op.has_value());
+        EXPECT_EQ(op.value(), 2);
+    }
+
+    // check that "3" was removed from queue
+    int i = 0;
+    std::optional<int> op;
+    while(!q.empty()) {
+        op = q.top();
+        EXPECT_TRUE(op.has_value());
+        EXPECT_EQ(i++, op.value());
+        if(i == 3) {
+            ++i;
+        }
+        EXPECT_TRUE(q.pop());
+    }
+}