]> git.seodisparate.com - UDPConnection/commitdiff
Reland C++11 "shared_lock" with iter remove fix
authorStephen Seo <seo.disparate@gmail.com>
Sat, 22 Jul 2023 09:58:36 +0000 (18:58 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Fri, 12 Jan 2024 04:32:05 +0000 (13:32 +0900)
On iterator remove, the iterator will trade the read lock for a write
lock, and trade back for a read lock once the remove has been completed.

CMakeLists.txt
src/CXX11_shared_spin_lock.cpp [new file with mode: 0644]
src/CXX11_shared_spin_lock.hpp [new file with mode: 0644]
src/TSLQueue.hpp
src/test/TestTSLQueue.cpp

index 50a67cf0ec49fe5f87731538ab332158134fedd8..a504c7b89bee989ec8777cdb91c1e6ba8e672e22 100644 (file)
@@ -5,6 +5,7 @@ set(UDPC_VERSION 1.0)
 
 set(UDPC_SOURCES
     src/UDPConnection.cpp
+    src/CXX11_shared_spin_lock.cpp
 )
 
 add_compile_options(
@@ -62,6 +63,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
     find_package(GTest QUIET)
     if(GTEST_FOUND)
         set(UDPC_UnitTest_SOURCES
+            src/CXX11_shared_spin_lock.cpp
             src/test/UDPC_UnitTest.cpp
             src/test/TestTSLQueue.cpp
             src/test/TestUDPC.cpp
diff --git a/src/CXX11_shared_spin_lock.cpp b/src/CXX11_shared_spin_lock.cpp
new file mode 100644 (file)
index 0000000..7778155
--- /dev/null
@@ -0,0 +1,142 @@
+#include "CXX11_shared_spin_lock.hpp"
+
+UDPC::Badge UDPC::Badge::newInvalid() {
+    Badge badge;
+    badge.isValid = false;
+    return badge;
+}
+
+UDPC::Badge::Badge() :
+isValid(true)
+{}
+
+UDPC::SharedSpinLock::Ptr UDPC::SharedSpinLock::newInstance() {
+    Ptr sharedSpinLock = Ptr(new SharedSpinLock());
+    sharedSpinLock->selfWeakPtr = sharedSpinLock;
+    return sharedSpinLock;
+}
+
+UDPC::SharedSpinLock::SharedSpinLock() :
+selfWeakPtr(),
+mutex(),
+read(0),
+write(false)
+{}
+
+UDPC::LockObj<false> UDPC::SharedSpinLock::spin_read_lock() {
+    while (true) {
+        std::lock_guard<std::mutex> lock(mutex);
+        if (!write) {
+            ++read;
+            return LockObj<false>(selfWeakPtr, Badge{});
+        }
+    }
+}
+
+UDPC::LockObj<false> UDPC::SharedSpinLock::try_spin_read_lock() {
+    std::lock_guard<std::mutex> lock(mutex);
+    if (!write) {
+        ++read;
+        return LockObj<false>(selfWeakPtr, Badge{});
+    }
+    return LockObj<false>(Badge{});
+}
+
+void UDPC::SharedSpinLock::read_unlock(UDPC::Badge &&badge) {
+    if (badge.isValid) {
+        std::lock_guard<std::mutex> lock(mutex);
+        if (read > 0) {
+            --read;
+            badge.isValid = false;
+        }
+    }
+}
+
+UDPC::LockObj<true> UDPC::SharedSpinLock::spin_write_lock() {
+    while (true) {
+        std::lock_guard<std::mutex> lock(mutex);
+        if (!write && read == 0) {
+            write = true;
+            return LockObj<true>(selfWeakPtr, Badge{});
+        }
+    }
+}
+
+UDPC::LockObj<true> UDPC::SharedSpinLock::try_spin_write_lock() {
+    std::lock_guard<std::mutex> lock(mutex);
+    if (!write && read == 0) {
+        write = true;
+        return LockObj<true>(selfWeakPtr, Badge{});
+    }
+    return LockObj<true>(Badge{});
+}
+
+void UDPC::SharedSpinLock::write_unlock(UDPC::Badge &&badge) {
+    if (badge.isValid) {
+        std::lock_guard<std::mutex> lock(mutex);
+        write = false;
+        badge.isValid = false;
+    }
+}
+
+UDPC::LockObj<false> UDPC::SharedSpinLock::trade_write_for_read_lock(UDPC::LockObj<true> &lockObj) {
+    if (lockObj.isValid() && lockObj.badge.isValid) {
+        while (true) {
+            std::lock_guard<std::mutex> lock(mutex);
+            if (write && read == 0) {
+                read = 1;
+                write = false;
+                lockObj.isLocked = false;
+                lockObj.badge.isValid = false;
+                return LockObj<false>(selfWeakPtr, Badge{});
+            }
+        }
+    } else {
+        return LockObj<false>(Badge{});
+    }
+}
+
+UDPC::LockObj<false> UDPC::SharedSpinLock::try_trade_write_for_read_lock(UDPC::LockObj<true> &lockObj) {
+    if (lockObj.isValid() && lockObj.badge.isValid) {
+        std::lock_guard<std::mutex> lock(mutex);
+        if (write && read == 0) {
+            read = 1;
+            write = false;
+            lockObj.isLocked = false;
+            lockObj.badge.isValid = false;
+            return LockObj<false>(selfWeakPtr, Badge{});
+        }
+    }
+    return LockObj<false>(Badge{});
+}
+
+UDPC::LockObj<true> UDPC::SharedSpinLock::trade_read_for_write_lock(UDPC::LockObj<false> &lockObj) {
+    if (lockObj.isValid() && lockObj.badge.isValid) {
+        while (true) {
+            std::lock_guard<std::mutex> lock(mutex);
+            if (!write && read == 1) {
+                read = 0;
+                write = true;
+                lockObj.isLocked = false;
+                lockObj.badge.isValid = false;
+                return LockObj<true>(selfWeakPtr, Badge{});
+            }
+        }
+    } else {
+        return LockObj<true>(Badge{});
+    }
+}
+
+UDPC::LockObj<true> UDPC::SharedSpinLock::try_trade_read_for_write_lock(UDPC::LockObj<false> &lockObj) {
+    if (lockObj.isValid() && lockObj.badge.isValid) {
+        std::lock_guard<std::mutex> lock(mutex);
+        if (!write && read == 1) {
+            read = 0;
+            write = true;
+            lockObj.isLocked = false;
+            lockObj.badge.isValid = false;
+            return LockObj<true>(selfWeakPtr, Badge{});
+        }
+    }
+    return LockObj<true>(Badge{});
+}
diff --git a/src/CXX11_shared_spin_lock.hpp b/src/CXX11_shared_spin_lock.hpp
new file mode 100644 (file)
index 0000000..67c2f22
--- /dev/null
@@ -0,0 +1,154 @@
+#ifndef UDPC_CXX11_SHARED_SPIN_LOCK_H_
+#define UDPC_CXX11_SHARED_SPIN_LOCK_H_
+
+#include <memory>
+#include <mutex>
+#include <atomic>
+
+namespace UDPC {
+
+// Forward declaration for LockObj.
+class SharedSpinLock;
+
+class Badge {
+public:
+    static Badge newInvalid();
+
+    // Disallow copy.
+    Badge(const Badge&) = delete;
+    Badge& operator=(const Badge&) = delete;
+
+    // Allow move.
+    Badge(Badge&&) = default;
+    Badge& operator=(Badge&&) = default;
+
+private:
+    friend class SharedSpinLock;
+
+    // Can only be created by SharedSpinLock.
+    Badge();
+
+    bool isValid;
+};
+
+template <bool IsWriteObj>
+class LockObj {
+public:
+    // Invalid instance constructor.
+    LockObj();
+
+    ~LockObj();
+
+    // Explicit invalid instance constructor.
+    static LockObj<IsWriteObj> newInvalid();
+
+    // Disallow copy.
+    LockObj(const LockObj&) = delete;
+    LockObj& operator=(const LockObj&) = delete;
+
+    // Allow move.
+    LockObj(LockObj&&) = default;
+    LockObj& operator=(LockObj&&) = default;
+
+    bool isValid() const;
+
+private:
+    friend class SharedSpinLock;
+
+    // Only can be created by SharedSpinLock.
+    LockObj(Badge &&badge);
+    LockObj(std::weak_ptr<SharedSpinLock> lockPtr, Badge &&badge);
+
+    std::weak_ptr<SharedSpinLock> weakPtrLock;
+    bool isLocked;
+    Badge badge;
+};
+
+class SharedSpinLock {
+public:
+    using Ptr = std::shared_ptr<SharedSpinLock>;
+    using Weak = std::weak_ptr<SharedSpinLock>;
+
+    static Ptr newInstance();
+
+    // Disallow copy.
+    SharedSpinLock(const SharedSpinLock&) = delete;
+    SharedSpinLock& operator=(const SharedSpinLock&) = delete;
+
+    // Allow move.
+    SharedSpinLock(SharedSpinLock&&) = default;
+    SharedSpinLock& operator=(SharedSpinLock&&) = default;
+
+    LockObj<false> spin_read_lock();
+    LockObj<false> try_spin_read_lock();
+    void read_unlock(Badge&&);
+
+    LockObj<true> spin_write_lock();
+    LockObj<true> try_spin_write_lock();
+    void write_unlock(Badge&&);
+
+    LockObj<false> trade_write_for_read_lock(LockObj<true>&);
+    LockObj<false> try_trade_write_for_read_lock(LockObj<true>&);
+
+    LockObj<true> trade_read_for_write_lock(LockObj<false>&);
+    LockObj<true> try_trade_read_for_write_lock(LockObj<false>&);
+
+private:
+    SharedSpinLock();
+
+    Weak selfWeakPtr;
+    std::mutex mutex;
+    unsigned int read;
+    bool write;
+
+};
+
+template <bool IsWriteObj>
+LockObj<IsWriteObj>::LockObj() :
+weakPtrLock(),
+isLocked(false),
+badge(UDPC::Badge::newInvalid())
+{}
+
+template <bool IsWriteObj>
+LockObj<IsWriteObj>::LockObj(Badge &&badge) :
+weakPtrLock(),
+isLocked(false),
+badge(std::forward<Badge>(badge))
+{}
+
+template <bool IsWriteObj>
+LockObj<IsWriteObj>::LockObj(SharedSpinLock::Weak lockPtr, Badge &&badge) :
+weakPtrLock(lockPtr),
+isLocked(true),
+badge(std::forward<Badge>(badge))
+{}
+
+template <bool IsWriteObj>
+LockObj<IsWriteObj>::~LockObj() {
+    if (!isLocked) {
+        return;
+    }
+    auto strongPtrLock = weakPtrLock.lock();
+    if (strongPtrLock) {
+        if (IsWriteObj) {
+            strongPtrLock->write_unlock(std::move(badge));
+        } else {
+            strongPtrLock->read_unlock(std::move(badge));
+        }
+    }
+}
+
+template <bool IsWriteObj>
+LockObj<IsWriteObj> LockObj<IsWriteObj>::newInvalid() {
+    return LockObj<IsWriteObj>{};
+}
+
+template <bool IsWriteObj>
+bool LockObj<IsWriteObj>::isValid() const {
+    return isLocked;
+}
+
+} // namespace UDPC
+
+#endif
index d2c27f8b13efae996f6862b6643450e9667d8f62..58b468f9d4b97eac3e012d95e18e3c5d2d56ac87 100644 (file)
@@ -10,6 +10,8 @@
 #include <list>
 #include <type_traits>
 
+#include "CXX11_shared_spin_lock.hpp"
+
 template <typename T>
 class TSLQueue {
   public:
@@ -62,7 +64,7 @@ class TSLQueue {
 
     class TSLQIter {
     public:
-        TSLQIter(std::mutex *mutex,
+        TSLQIter(UDPC::SharedSpinLock::Weak sharedSpinLockWeak,
                  std::weak_ptr<TSLQNode> currentNode,
                  unsigned long *msize);
         ~TSLQIter();
@@ -75,19 +77,24 @@ class TSLQueue {
         bool next();
         bool prev();
         bool remove();
+        bool try_remove();
 
     private:
-        std::mutex *mutex;
+        UDPC::SharedSpinLock::Weak sharedSpinLockWeak;
+        std::unique_ptr<UDPC::LockObj<false>> readLock;
+        std::unique_ptr<UDPC::LockObj<true>> writeLock;
         std::weak_ptr<TSLQNode> currentNode;
         unsigned long *const msize;
 
+        bool remove_impl();
+
     };
 
   public:
     TSLQIter begin();
 
   private:
-    std::mutex mutex;
+    UDPC::SharedSpinLock::Ptr sharedSpinLock;
     std::shared_ptr<TSLQNode> head;
     std::shared_ptr<TSLQNode> tail;
     unsigned long msize;
@@ -95,7 +102,7 @@ class TSLQueue {
 
 template <typename T>
 TSLQueue<T>::TSLQueue() :
-    mutex(),
+    sharedSpinLock(UDPC::SharedSpinLock::newInstance()),
     head(std::shared_ptr<TSLQNode>(new TSLQNode())),
     tail(std::shared_ptr<TSLQNode>(new TSLQNode())),
     msize(0)
@@ -121,8 +128,8 @@ TSLQueue<T>::TSLQueue(TSLQueue &&other)
 
 template <typename T>
 TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
-    std::lock_guard<std::mutex> lock(mutex);
-    std::lock_guard<std::mutex> otherLock(other.mutex);
+    auto selfWriteLock = sharedSpinLock->spin_write_lock();
+    auto otherWriteLock = other.sharedSpinLock->spin_write_lock();
     head = std::move(other.head);
     tail = std::move(other.tail);
     msize = std::move(other.msize);
@@ -130,7 +137,7 @@ TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
 
 template <typename T>
 void TSLQueue<T>::push(const T &data) {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
     auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode());
     newNode->data = std::unique_ptr<T>(new T(data));
 
@@ -146,7 +153,8 @@ void TSLQueue<T>::push(const T &data) {
 
 template <typename T>
 bool TSLQueue<T>::push_nb(const T &data) {
-    if(mutex.try_lock()) {
+    auto writeLock = sharedSpinLock->try_spin_write_lock();
+    if(writeLock.isValid()) {
         auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode());
         newNode->data = std::unique_ptr<T>(new T(data));
 
@@ -159,7 +167,6 @@ bool TSLQueue<T>::push_nb(const T &data) {
         tail->prev = newNode;
         ++msize;
 
-        mutex.unlock();
         return true;
     } else {
         return false;
@@ -168,7 +175,7 @@ bool TSLQueue<T>::push_nb(const T &data) {
 
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::top() {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto readLock = sharedSpinLock->spin_read_lock();
     std::unique_ptr<T> result;
     if(head->next != tail) {
         assert(head->next->data);
@@ -181,20 +188,20 @@ std::unique_ptr<T> TSLQueue<T>::top() {
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::top_nb() {
     std::unique_ptr<T> result;
-    if(mutex.try_lock()) {
+    auto readLock = sharedSpinLock->try_spin_read_lock();
+    if(readLock.isValid()) {
         if(head->next != tail) {
             assert(head->next->data);
             result = std::unique_ptr<T>(new T);
             *result = *head->next->data;
         }
-        mutex.unlock();
     }
     return result;
 }
 
 template <typename T>
 bool TSLQueue<T>::pop() {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
     if(head->next == tail) {
         return false;
     } else {
@@ -211,7 +218,7 @@ bool TSLQueue<T>::pop() {
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::top_and_pop() {
     std::unique_ptr<T> result;
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
     if(head->next != tail) {
         assert(head->next->data);
         result = std::unique_ptr<T>(new T);
@@ -229,7 +236,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop() {
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
     std::unique_ptr<T> result;
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
     if(head->next == tail) {
         if(isEmpty) {
             *isEmpty = true;
@@ -255,7 +262,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_rsize(unsigned long *rsize) {
     std::unique_ptr<T> result;
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
     if(head->next == tail) {
         if(rsize) {
             *rsize = 0;
@@ -280,7 +287,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_rsize(unsigned long *rsize) {
 
 template <typename T>
 void TSLQueue<T>::clear() {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto writeLock = sharedSpinLock->spin_write_lock();
 
     head->next = tail;
     tail->prev = head;
@@ -289,13 +296,13 @@ void TSLQueue<T>::clear() {
 
 template <typename T>
 bool TSLQueue<T>::empty() {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto readLock = sharedSpinLock->spin_read_lock();
     return head->next == tail;
 }
 
 template <typename T>
 unsigned long TSLQueue<T>::size() {
-    std::lock_guard<std::mutex> lock(mutex);
+    auto readLock = sharedSpinLock->spin_read_lock();
     return msize;
 }
 
@@ -313,20 +320,20 @@ bool TSLQueue<T>::TSLQNode::isNormal() const {
 }
 
 template <typename T>
-TSLQueue<T>::TSLQIter::TSLQIter(std::mutex *mutex,
+TSLQueue<T>::TSLQIter::TSLQIter(UDPC::SharedSpinLock::Weak lockWeak,
                                 std::weak_ptr<TSLQNode> currentNode,
                                 unsigned long *msize) :
-mutex(mutex),
+sharedSpinLockWeak(lockWeak),
+readLock(std::unique_ptr<UDPC::LockObj<false>>(new UDPC::LockObj<false>{})),
+writeLock(),
 currentNode(currentNode),
 msize(msize)
 {
-    mutex->lock();
+    *readLock = lockWeak.lock()->spin_read_lock();
 }
 
 template <typename T>
-TSLQueue<T>::TSLQIter::~TSLQIter() {
-    mutex->unlock();
-}
+TSLQueue<T>::TSLQIter::~TSLQIter() {}
 
 template <typename T>
 std::unique_ptr<T> TSLQueue<T>::TSLQIter::current() {
@@ -368,9 +375,61 @@ bool TSLQueue<T>::TSLQIter::prev() {
 
 template <typename T>
 bool TSLQueue<T>::TSLQIter::remove() {
+    if (readLock && !writeLock && readLock->isValid()) {
+        auto sharedSpinLockStrong = sharedSpinLockWeak.lock();
+        if (!sharedSpinLockStrong) {
+            return false;
+        }
+
+        writeLock = std::unique_ptr<UDPC::LockObj<true>>(new UDPC::LockObj<true>{});
+        *writeLock = sharedSpinLockStrong->trade_read_for_write_lock(*readLock);
+        readLock.reset(nullptr);
+
+        return remove_impl();
+    } else {
+        return false;
+    }
+}
+
+template <typename T>
+bool TSLQueue<T>::TSLQIter::try_remove() {
+    if (readLock && !writeLock && readLock->isValid()) {
+        auto sharedSpinLockStrong = sharedSpinLockWeak.lock();
+        if (!sharedSpinLockStrong) {
+            return false;
+        }
+
+        writeLock = std::unique_ptr<UDPC::LockObj<true>>(new UDPC::LockObj<true>{});
+        *writeLock = sharedSpinLockStrong->try_trade_read_for_write_lock(*readLock);
+        if (writeLock->isValid()) {
+            readLock.reset(nullptr);
+            return remove_impl();
+        } else {
+            writeLock.reset(nullptr);
+            return false;
+        }
+    } else {
+        return false;
+    }
+}
+
+template <typename T>
+bool TSLQueue<T>::TSLQIter::remove_impl() {
+    const auto cleanupWriteLock = [this] () {
+        UDPC::SharedSpinLock::Ptr sharedSpinLockStrong = this->sharedSpinLockWeak.lock();
+        if (!sharedSpinLockStrong) {
+            writeLock.reset(nullptr);
+            return;
+        }
+        this->readLock = std::unique_ptr<UDPC::LockObj<false>>(new UDPC::LockObj<false>{});
+        (*this->readLock) = sharedSpinLockStrong->trade_write_for_read_lock(*(this->writeLock));
+        this->writeLock.reset(nullptr);
+    };
+
     std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
     assert(currentNode);
     if(!currentNode->isNormal()) {
+        cleanupWriteLock();
         return false;
     }
 
@@ -384,12 +443,13 @@ bool TSLQueue<T>::TSLQIter::remove() {
     assert(*msize > 0);
     --(*msize);
 
+    cleanupWriteLock();
     return parent->next->isNormal();
 }
 
 template <typename T>
 typename TSLQueue<T>::TSLQIter TSLQueue<T>::begin() {
-    return TSLQIter(&mutex, head->next, &msize);
+    return TSLQIter(sharedSpinLock, head->next, &msize);
 }
 
 #endif
index f579712237fcf492ded1bc3dcea81e9530641a1d..b46de736ed8846798be165ae673b5289bc7ca4e0 100644 (file)
@@ -143,7 +143,8 @@ TEST(TSLQueue, Iterator) {
         // test that lock is held by iterator
         EXPECT_FALSE(q.push_nb(10));
         op = q.top_nb();
-        EXPECT_FALSE(op);
+        // Getting top and iterator both hold read locks so this should be true.
+        EXPECT_TRUE(op);
 
         // backwards iteration
         EXPECT_TRUE(iter.prev());
@@ -175,6 +176,21 @@ TEST(TSLQueue, Iterator) {
         op = iter.current();
         EXPECT_TRUE(op);
         EXPECT_EQ(*op, 2);
+
+        // second iterator
+        auto iter2 = q.begin();
+
+        // Still should be able to get top.
+        EXPECT_TRUE(iter2.current());
+
+        // Shouldn't be able to remove if 2 iterators exist.
+        EXPECT_FALSE(iter2.try_remove());
+
+        // This will never return since the first iterator has a "read" lock.
+        //EXPECT_FALSE(iter2.remove());
+
+        // Still should be able to get top.
+        EXPECT_TRUE(iter2.current());
     }
     EXPECT_EQ(q.size(), 9);