From cf27a6bb76f4dae2344d6061eb199e2b35eb754a Mon Sep 17 00:00:00 2001 From: Stephen Seo Date: Sat, 22 Jul 2023 16:29:19 +0900 Subject: [PATCH] Impl "RWLock" for use in TSLQueue This project supports C++11, and std::shared_lock was made available in C++17, thus a "shared_spin_lock" was created with similar functionality. This "shared_spin_lock" is used in TSLQueue. --- CMakeLists.txt | 2 + src/CXX11_shared_spin_lock.cpp | 74 +++++++++++++++++++ src/CXX11_shared_spin_lock.hpp | 128 +++++++++++++++++++++++++++++++++ src/TSLQueue.hpp | 56 +++++++-------- src/test/TestTSLQueue.cpp | 3 +- 5 files changed, 234 insertions(+), 29 deletions(-) create mode 100644 src/CXX11_shared_spin_lock.cpp create mode 100644 src/CXX11_shared_spin_lock.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ebdd6f..87ac8a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ set(UDPC_VERSION 1.0) set(UDPC_SOURCES src/UDPConnection.cpp + src/CXX11_shared_spin_lock.cpp ) add_compile_options( @@ -62,6 +63,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") find_package(GTest QUIET) if(GTEST_FOUND) set(UDPC_UnitTest_SOURCES + src/CXX11_shared_spin_lock.cpp src/test/UDPC_UnitTest.cpp src/test/TestTSLQueue.cpp src/test/TestUDPC.cpp diff --git a/src/CXX11_shared_spin_lock.cpp b/src/CXX11_shared_spin_lock.cpp new file mode 100644 index 0000000..1c0ab84 --- /dev/null +++ b/src/CXX11_shared_spin_lock.cpp @@ -0,0 +1,74 @@ +#include "CXX11_shared_spin_lock.hpp" + +UDPC::Badge::Badge() : +isValid(true) +{} + +UDPC::SharedSpinLock::Ptr UDPC::SharedSpinLock::newInstance() { + Ptr sharedSpinLock = Ptr(new SharedSpinLock()); + sharedSpinLock->selfWeakPtr = sharedSpinLock; + return sharedSpinLock; +} + +UDPC::SharedSpinLock::SharedSpinLock() : +selfWeakPtr(), +mutex(), +read(0), +write(false) +{} + +UDPC::LockObj UDPC::SharedSpinLock::spin_read_lock() { + while (true) { + std::lock_guard lock(mutex); + if (!write.load()) { + ++read; + return LockObj(selfWeakPtr, Badge{}); + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_spin_read_lock() { + std::lock_guard lock(mutex); + if (!write.load()) { + ++read; + return LockObj(selfWeakPtr, Badge{}); + } + return LockObj(Badge{}); +} + +void UDPC::SharedSpinLock::read_unlock(UDPC::Badge &&badge) { + if (badge.isValid) { + std::lock_guard lock(mutex); + if (read.load() > 0) { + --read; + badge.isValid = false; + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::spin_write_lock() { + while (true) { + std::lock_guard lock(mutex); + if (!write.load() && read.load() == 0) { + write.store(true); + return LockObj(selfWeakPtr, Badge{}); + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_spin_write_lock() { + std::lock_guard lock(mutex); + if (!write.load() && read.load() == 0) { + write.store(true); + return LockObj(selfWeakPtr, Badge{}); + } + return LockObj(Badge{}); +} + +void UDPC::SharedSpinLock::write_unlock(UDPC::Badge &&badge) { + if (badge.isValid) { + std::lock_guard lock(mutex); + write.store(false); + badge.isValid = false; + } +} diff --git a/src/CXX11_shared_spin_lock.hpp b/src/CXX11_shared_spin_lock.hpp new file mode 100644 index 0000000..8b524ed --- /dev/null +++ b/src/CXX11_shared_spin_lock.hpp @@ -0,0 +1,128 @@ +#ifndef UDPC_CXX11_SHARED_SPIN_LOCK_H_ +#define UDPC_CXX11_SHARED_SPIN_LOCK_H_ + +#include +#include +#include + +namespace UDPC { + +// Forward declaration for LockObj. +class SharedSpinLock; + +class Badge { +public: + // Disallow copy. + Badge(const Badge&) = delete; + Badge& operator=(const Badge&) = delete; + + // Allow move. + Badge(Badge&&) = default; + Badge& operator=(Badge&&) = default; + +private: + friend class SharedSpinLock; + + // Can only be created by SharedSpinLock. + Badge(); + + bool isValid; +}; + +template +class LockObj { +public: + ~LockObj(); + + // Disallow copy. + LockObj(const LockObj&) = delete; + LockObj& operator=(const LockObj&) = delete; + + // Allow move. + LockObj(LockObj&&) = default; + LockObj& operator=(LockObj&&) = default; + + bool isValid() const; + +private: + friend class SharedSpinLock; + + // Only can be created by SharedSpinLock. + LockObj(Badge &&badge); + LockObj(std::weak_ptr lockPtr, Badge &&badge); + + std::weak_ptr weakPtrLock; + bool isLocked; + Badge badge; +}; + +class SharedSpinLock { +public: + using Ptr = std::shared_ptr; + using Weak = std::weak_ptr; + + static Ptr newInstance(); + + // Disallow copy. + SharedSpinLock(const SharedSpinLock&) = delete; + SharedSpinLock& operator=(const SharedSpinLock&) = delete; + + // Allow move. + SharedSpinLock(SharedSpinLock&&) = default; + SharedSpinLock& operator=(SharedSpinLock&&) = default; + + LockObj spin_read_lock(); + LockObj try_spin_read_lock(); + void read_unlock(Badge&&); + + LockObj spin_write_lock(); + LockObj try_spin_write_lock(); + void write_unlock(Badge&&); + +private: + SharedSpinLock(); + + Weak selfWeakPtr; + std::mutex mutex; + std::atomic_uint read; + std::atomic_bool write; + +}; + +template +LockObj::LockObj(Badge &&badge) : +weakPtrLock(), +isLocked(false), +badge(std::forward(badge)) +{} + +template +LockObj::LockObj(SharedSpinLock::Weak lockPtr, Badge &&badge) : +weakPtrLock(lockPtr), +isLocked(true), +badge(std::forward(badge)) +{} + +template +LockObj::~LockObj() { + if (!isLocked) { + return; + } + auto strongPtrLock = weakPtrLock.lock(); + if (strongPtrLock) { + if (IsWriteObj) { + strongPtrLock->write_unlock(std::move(badge)); + } else { + strongPtrLock->read_unlock(std::move(badge)); + } + } +} + +template +bool LockObj::isValid() const { + return isLocked; +} + +} // namespace UDPC + +#endif diff --git a/src/TSLQueue.hpp b/src/TSLQueue.hpp index d2c27f8..7a52c2c 100644 --- a/src/TSLQueue.hpp +++ b/src/TSLQueue.hpp @@ -10,6 +10,8 @@ #include #include +#include "CXX11_shared_spin_lock.hpp" + template class TSLQueue { public: @@ -62,7 +64,7 @@ class TSLQueue { class TSLQIter { public: - TSLQIter(std::mutex *mutex, + TSLQIter(UDPC::SharedSpinLock::Weak sharedSpinLockWeak, std::weak_ptr currentNode, unsigned long *msize); ~TSLQIter(); @@ -77,7 +79,8 @@ class TSLQueue { bool remove(); private: - std::mutex *mutex; + UDPC::SharedSpinLock::Weak sharedSpinLockWeak; + UDPC::LockObj readLock; std::weak_ptr currentNode; unsigned long *const msize; @@ -87,7 +90,7 @@ class TSLQueue { TSLQIter begin(); private: - std::mutex mutex; + UDPC::SharedSpinLock::Ptr sharedSpinLock; std::shared_ptr head; std::shared_ptr tail; unsigned long msize; @@ -95,7 +98,7 @@ class TSLQueue { template TSLQueue::TSLQueue() : - mutex(), + sharedSpinLock(UDPC::SharedSpinLock::newInstance()), head(std::shared_ptr(new TSLQNode())), tail(std::shared_ptr(new TSLQNode())), msize(0) @@ -121,8 +124,8 @@ TSLQueue::TSLQueue(TSLQueue &&other) template TSLQueue & TSLQueue::operator=(TSLQueue &&other) { - std::lock_guard lock(mutex); - std::lock_guard otherLock(other.mutex); + auto selfWriteLock = sharedSpinLock->spin_write_lock(); + auto otherWriteLock = other.sharedSpinLock->spin_write_lock(); head = std::move(other.head); tail = std::move(other.tail); msize = std::move(other.msize); @@ -130,7 +133,7 @@ TSLQueue & TSLQueue::operator=(TSLQueue &&other) { template void TSLQueue::push(const T &data) { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); auto newNode = std::shared_ptr(new TSLQNode()); newNode->data = std::unique_ptr(new T(data)); @@ -146,7 +149,8 @@ void TSLQueue::push(const T &data) { template bool TSLQueue::push_nb(const T &data) { - if(mutex.try_lock()) { + auto writeLock = sharedSpinLock->try_spin_write_lock(); + if(writeLock.isValid()) { auto newNode = std::shared_ptr(new TSLQNode()); newNode->data = std::unique_ptr(new T(data)); @@ -159,7 +163,6 @@ bool TSLQueue::push_nb(const T &data) { tail->prev = newNode; ++msize; - mutex.unlock(); return true; } else { return false; @@ -168,7 +171,7 @@ bool TSLQueue::push_nb(const T &data) { template std::unique_ptr TSLQueue::top() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); std::unique_ptr result; if(head->next != tail) { assert(head->next->data); @@ -181,20 +184,20 @@ std::unique_ptr TSLQueue::top() { template std::unique_ptr TSLQueue::top_nb() { std::unique_ptr result; - if(mutex.try_lock()) { + auto readLock = sharedSpinLock->try_spin_read_lock(); + if(readLock.isValid()) { if(head->next != tail) { assert(head->next->data); result = std::unique_ptr(new T); *result = *head->next->data; } - mutex.unlock(); } return result; } template bool TSLQueue::pop() { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { return false; } else { @@ -211,7 +214,7 @@ bool TSLQueue::pop() { template std::unique_ptr TSLQueue::top_and_pop() { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next != tail) { assert(head->next->data); result = std::unique_ptr(new T); @@ -229,7 +232,7 @@ std::unique_ptr TSLQueue::top_and_pop() { template std::unique_ptr TSLQueue::top_and_pop_and_empty(bool *isEmpty) { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { if(isEmpty) { *isEmpty = true; @@ -255,7 +258,7 @@ std::unique_ptr TSLQueue::top_and_pop_and_empty(bool *isEmpty) { template std::unique_ptr TSLQueue::top_and_pop_and_rsize(unsigned long *rsize) { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { if(rsize) { *rsize = 0; @@ -280,7 +283,7 @@ std::unique_ptr TSLQueue::top_and_pop_and_rsize(unsigned long *rsize) { template void TSLQueue::clear() { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); head->next = tail; tail->prev = head; @@ -289,13 +292,13 @@ void TSLQueue::clear() { template bool TSLQueue::empty() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); return head->next == tail; } template unsigned long TSLQueue::size() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); return msize; } @@ -313,20 +316,17 @@ bool TSLQueue::TSLQNode::isNormal() const { } template -TSLQueue::TSLQIter::TSLQIter(std::mutex *mutex, +TSLQueue::TSLQIter::TSLQIter(UDPC::SharedSpinLock::Weak lockWeak, std::weak_ptr currentNode, unsigned long *msize) : -mutex(mutex), +sharedSpinLockWeak(lockWeak), +readLock(lockWeak.lock()->spin_read_lock()), currentNode(currentNode), msize(msize) -{ - mutex->lock(); -} +{} template -TSLQueue::TSLQIter::~TSLQIter() { - mutex->unlock(); -} +TSLQueue::TSLQIter::~TSLQIter() {} template std::unique_ptr TSLQueue::TSLQIter::current() { @@ -389,7 +389,7 @@ bool TSLQueue::TSLQIter::remove() { template typename TSLQueue::TSLQIter TSLQueue::begin() { - return TSLQIter(&mutex, head->next, &msize); + return TSLQIter(sharedSpinLock, head->next, &msize); } #endif diff --git a/src/test/TestTSLQueue.cpp b/src/test/TestTSLQueue.cpp index f579712..10a32f7 100644 --- a/src/test/TestTSLQueue.cpp +++ b/src/test/TestTSLQueue.cpp @@ -143,7 +143,8 @@ TEST(TSLQueue, Iterator) { // test that lock is held by iterator EXPECT_FALSE(q.push_nb(10)); op = q.top_nb(); - EXPECT_FALSE(op); + // Getting top and iterator use the read lock, so this should be true. + EXPECT_TRUE(op); // backwards iteration EXPECT_TRUE(iter.prev());