diff --git a/CMakeLists.txt b/CMakeLists.txt index 50a67cf..a504c7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ set(UDPC_VERSION 1.0) set(UDPC_SOURCES src/UDPConnection.cpp + src/CXX11_shared_spin_lock.cpp ) add_compile_options( @@ -62,6 +63,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") find_package(GTest QUIET) if(GTEST_FOUND) set(UDPC_UnitTest_SOURCES + src/CXX11_shared_spin_lock.cpp src/test/UDPC_UnitTest.cpp src/test/TestTSLQueue.cpp src/test/TestUDPC.cpp diff --git a/src/CXX11_shared_spin_lock.cpp b/src/CXX11_shared_spin_lock.cpp new file mode 100644 index 0000000..7778155 --- /dev/null +++ b/src/CXX11_shared_spin_lock.cpp @@ -0,0 +1,142 @@ +#include "CXX11_shared_spin_lock.hpp" + +UDPC::Badge UDPC::Badge::newInvalid() { + Badge badge; + badge.isValid = false; + return badge; +} + +UDPC::Badge::Badge() : +isValid(true) +{} + +UDPC::SharedSpinLock::Ptr UDPC::SharedSpinLock::newInstance() { + Ptr sharedSpinLock = Ptr(new SharedSpinLock()); + sharedSpinLock->selfWeakPtr = sharedSpinLock; + return sharedSpinLock; +} + +UDPC::SharedSpinLock::SharedSpinLock() : +selfWeakPtr(), +mutex(), +read(0), +write(false) +{} + +UDPC::LockObj UDPC::SharedSpinLock::spin_read_lock() { + while (true) { + std::lock_guard lock(mutex); + if (!write) { + ++read; + return LockObj(selfWeakPtr, Badge{}); + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_spin_read_lock() { + std::lock_guard lock(mutex); + if (!write) { + ++read; + return LockObj(selfWeakPtr, Badge{}); + } + return LockObj(Badge{}); +} + +void UDPC::SharedSpinLock::read_unlock(UDPC::Badge &&badge) { + if (badge.isValid) { + std::lock_guard lock(mutex); + if (read > 0) { + --read; + badge.isValid = false; + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::spin_write_lock() { + while (true) { + std::lock_guard lock(mutex); + if (!write && read == 0) { + write = true; + return LockObj(selfWeakPtr, Badge{}); + } + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_spin_write_lock() { + std::lock_guard lock(mutex); + if (!write && read == 0) { + write = true; + return LockObj(selfWeakPtr, Badge{}); + } + return LockObj(Badge{}); +} + +void UDPC::SharedSpinLock::write_unlock(UDPC::Badge &&badge) { + if (badge.isValid) { + std::lock_guard lock(mutex); + write = false; + badge.isValid = false; + } +} + +UDPC::LockObj UDPC::SharedSpinLock::trade_write_for_read_lock(UDPC::LockObj &lockObj) { + if (lockObj.isValid() && lockObj.badge.isValid) { + while (true) { + std::lock_guard lock(mutex); + if (write && read == 0) { + read = 1; + write = false; + lockObj.isLocked = false; + lockObj.badge.isValid = false; + return LockObj(selfWeakPtr, Badge{}); + } + } + } else { + return LockObj(Badge{}); + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_trade_write_for_read_lock(UDPC::LockObj &lockObj) { + if (lockObj.isValid() && lockObj.badge.isValid) { + std::lock_guard lock(mutex); + if (write && read == 0) { + read = 1; + write = false; + lockObj.isLocked = false; + lockObj.badge.isValid = false; + return LockObj(selfWeakPtr, Badge{}); + } + } + return LockObj(Badge{}); +} + +UDPC::LockObj UDPC::SharedSpinLock::trade_read_for_write_lock(UDPC::LockObj &lockObj) { + if (lockObj.isValid() && lockObj.badge.isValid) { + while (true) { + std::lock_guard lock(mutex); + if (!write && read == 1) { + read = 0; + write = true; + lockObj.isLocked = false; + lockObj.badge.isValid = false; + return LockObj(selfWeakPtr, Badge{}); + } + } + } else { + return LockObj(Badge{}); + } +} + +UDPC::LockObj UDPC::SharedSpinLock::try_trade_read_for_write_lock(UDPC::LockObj &lockObj) { + if (lockObj.isValid() && lockObj.badge.isValid) { + std::lock_guard lock(mutex); + if (!write && read == 1) { + read = 0; + write = true; + lockObj.isLocked = false; + lockObj.badge.isValid = false; + return LockObj(selfWeakPtr, Badge{}); + } + } + return LockObj(Badge{}); +} diff --git a/src/CXX11_shared_spin_lock.hpp b/src/CXX11_shared_spin_lock.hpp new file mode 100644 index 0000000..67c2f22 --- /dev/null +++ b/src/CXX11_shared_spin_lock.hpp @@ -0,0 +1,154 @@ +#ifndef UDPC_CXX11_SHARED_SPIN_LOCK_H_ +#define UDPC_CXX11_SHARED_SPIN_LOCK_H_ + +#include +#include +#include + +namespace UDPC { + +// Forward declaration for LockObj. +class SharedSpinLock; + +class Badge { +public: + static Badge newInvalid(); + + // Disallow copy. + Badge(const Badge&) = delete; + Badge& operator=(const Badge&) = delete; + + // Allow move. + Badge(Badge&&) = default; + Badge& operator=(Badge&&) = default; + +private: + friend class SharedSpinLock; + + // Can only be created by SharedSpinLock. + Badge(); + + bool isValid; +}; + +template +class LockObj { +public: + // Invalid instance constructor. + LockObj(); + + ~LockObj(); + + // Explicit invalid instance constructor. + static LockObj newInvalid(); + + // Disallow copy. + LockObj(const LockObj&) = delete; + LockObj& operator=(const LockObj&) = delete; + + // Allow move. + LockObj(LockObj&&) = default; + LockObj& operator=(LockObj&&) = default; + + bool isValid() const; + +private: + friend class SharedSpinLock; + + // Only can be created by SharedSpinLock. + LockObj(Badge &&badge); + LockObj(std::weak_ptr lockPtr, Badge &&badge); + + std::weak_ptr weakPtrLock; + bool isLocked; + Badge badge; +}; + +class SharedSpinLock { +public: + using Ptr = std::shared_ptr; + using Weak = std::weak_ptr; + + static Ptr newInstance(); + + // Disallow copy. + SharedSpinLock(const SharedSpinLock&) = delete; + SharedSpinLock& operator=(const SharedSpinLock&) = delete; + + // Allow move. + SharedSpinLock(SharedSpinLock&&) = default; + SharedSpinLock& operator=(SharedSpinLock&&) = default; + + LockObj spin_read_lock(); + LockObj try_spin_read_lock(); + void read_unlock(Badge&&); + + LockObj spin_write_lock(); + LockObj try_spin_write_lock(); + void write_unlock(Badge&&); + + LockObj trade_write_for_read_lock(LockObj&); + LockObj try_trade_write_for_read_lock(LockObj&); + + LockObj trade_read_for_write_lock(LockObj&); + LockObj try_trade_read_for_write_lock(LockObj&); + +private: + SharedSpinLock(); + + Weak selfWeakPtr; + std::mutex mutex; + unsigned int read; + bool write; + +}; + +template +LockObj::LockObj() : +weakPtrLock(), +isLocked(false), +badge(UDPC::Badge::newInvalid()) +{} + +template +LockObj::LockObj(Badge &&badge) : +weakPtrLock(), +isLocked(false), +badge(std::forward(badge)) +{} + +template +LockObj::LockObj(SharedSpinLock::Weak lockPtr, Badge &&badge) : +weakPtrLock(lockPtr), +isLocked(true), +badge(std::forward(badge)) +{} + +template +LockObj::~LockObj() { + if (!isLocked) { + return; + } + auto strongPtrLock = weakPtrLock.lock(); + if (strongPtrLock) { + if (IsWriteObj) { + strongPtrLock->write_unlock(std::move(badge)); + } else { + strongPtrLock->read_unlock(std::move(badge)); + } + } +} + +template +LockObj LockObj::newInvalid() { + return LockObj{}; +} + +template +bool LockObj::isValid() const { + return isLocked; +} + +} // namespace UDPC + +#endif diff --git a/src/TSLQueue.hpp b/src/TSLQueue.hpp index d2c27f8..58b468f 100644 --- a/src/TSLQueue.hpp +++ b/src/TSLQueue.hpp @@ -10,6 +10,8 @@ #include #include +#include "CXX11_shared_spin_lock.hpp" + template class TSLQueue { public: @@ -62,7 +64,7 @@ class TSLQueue { class TSLQIter { public: - TSLQIter(std::mutex *mutex, + TSLQIter(UDPC::SharedSpinLock::Weak sharedSpinLockWeak, std::weak_ptr currentNode, unsigned long *msize); ~TSLQIter(); @@ -75,19 +77,24 @@ class TSLQueue { bool next(); bool prev(); bool remove(); + bool try_remove(); private: - std::mutex *mutex; + UDPC::SharedSpinLock::Weak sharedSpinLockWeak; + std::unique_ptr> readLock; + std::unique_ptr> writeLock; std::weak_ptr currentNode; unsigned long *const msize; + bool remove_impl(); + }; public: TSLQIter begin(); private: - std::mutex mutex; + UDPC::SharedSpinLock::Ptr sharedSpinLock; std::shared_ptr head; std::shared_ptr tail; unsigned long msize; @@ -95,7 +102,7 @@ class TSLQueue { template TSLQueue::TSLQueue() : - mutex(), + sharedSpinLock(UDPC::SharedSpinLock::newInstance()), head(std::shared_ptr(new TSLQNode())), tail(std::shared_ptr(new TSLQNode())), msize(0) @@ -121,8 +128,8 @@ TSLQueue::TSLQueue(TSLQueue &&other) template TSLQueue & TSLQueue::operator=(TSLQueue &&other) { - std::lock_guard lock(mutex); - std::lock_guard otherLock(other.mutex); + auto selfWriteLock = sharedSpinLock->spin_write_lock(); + auto otherWriteLock = other.sharedSpinLock->spin_write_lock(); head = std::move(other.head); tail = std::move(other.tail); msize = std::move(other.msize); @@ -130,7 +137,7 @@ TSLQueue & TSLQueue::operator=(TSLQueue &&other) { template void TSLQueue::push(const T &data) { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); auto newNode = std::shared_ptr(new TSLQNode()); newNode->data = std::unique_ptr(new T(data)); @@ -146,7 +153,8 @@ void TSLQueue::push(const T &data) { template bool TSLQueue::push_nb(const T &data) { - if(mutex.try_lock()) { + auto writeLock = sharedSpinLock->try_spin_write_lock(); + if(writeLock.isValid()) { auto newNode = std::shared_ptr(new TSLQNode()); newNode->data = std::unique_ptr(new T(data)); @@ -159,7 +167,6 @@ bool TSLQueue::push_nb(const T &data) { tail->prev = newNode; ++msize; - mutex.unlock(); return true; } else { return false; @@ -168,7 +175,7 @@ bool TSLQueue::push_nb(const T &data) { template std::unique_ptr TSLQueue::top() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); std::unique_ptr result; if(head->next != tail) { assert(head->next->data); @@ -181,20 +188,20 @@ std::unique_ptr TSLQueue::top() { template std::unique_ptr TSLQueue::top_nb() { std::unique_ptr result; - if(mutex.try_lock()) { + auto readLock = sharedSpinLock->try_spin_read_lock(); + if(readLock.isValid()) { if(head->next != tail) { assert(head->next->data); result = std::unique_ptr(new T); *result = *head->next->data; } - mutex.unlock(); } return result; } template bool TSLQueue::pop() { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { return false; } else { @@ -211,7 +218,7 @@ bool TSLQueue::pop() { template std::unique_ptr TSLQueue::top_and_pop() { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next != tail) { assert(head->next->data); result = std::unique_ptr(new T); @@ -229,7 +236,7 @@ std::unique_ptr TSLQueue::top_and_pop() { template std::unique_ptr TSLQueue::top_and_pop_and_empty(bool *isEmpty) { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { if(isEmpty) { *isEmpty = true; @@ -255,7 +262,7 @@ std::unique_ptr TSLQueue::top_and_pop_and_empty(bool *isEmpty) { template std::unique_ptr TSLQueue::top_and_pop_and_rsize(unsigned long *rsize) { std::unique_ptr result; - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); if(head->next == tail) { if(rsize) { *rsize = 0; @@ -280,7 +287,7 @@ std::unique_ptr TSLQueue::top_and_pop_and_rsize(unsigned long *rsize) { template void TSLQueue::clear() { - std::lock_guard lock(mutex); + auto writeLock = sharedSpinLock->spin_write_lock(); head->next = tail; tail->prev = head; @@ -289,13 +296,13 @@ void TSLQueue::clear() { template bool TSLQueue::empty() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); return head->next == tail; } template unsigned long TSLQueue::size() { - std::lock_guard lock(mutex); + auto readLock = sharedSpinLock->spin_read_lock(); return msize; } @@ -313,20 +320,20 @@ bool TSLQueue::TSLQNode::isNormal() const { } template -TSLQueue::TSLQIter::TSLQIter(std::mutex *mutex, +TSLQueue::TSLQIter::TSLQIter(UDPC::SharedSpinLock::Weak lockWeak, std::weak_ptr currentNode, unsigned long *msize) : -mutex(mutex), +sharedSpinLockWeak(lockWeak), +readLock(std::unique_ptr>(new UDPC::LockObj{})), +writeLock(), currentNode(currentNode), msize(msize) { - mutex->lock(); + *readLock = lockWeak.lock()->spin_read_lock(); } template -TSLQueue::TSLQIter::~TSLQIter() { - mutex->unlock(); -} +TSLQueue::TSLQIter::~TSLQIter() {} template std::unique_ptr TSLQueue::TSLQIter::current() { @@ -368,9 +375,61 @@ bool TSLQueue::TSLQIter::prev() { template bool TSLQueue::TSLQIter::remove() { + if (readLock && !writeLock && readLock->isValid()) { + auto sharedSpinLockStrong = sharedSpinLockWeak.lock(); + if (!sharedSpinLockStrong) { + return false; + } + + writeLock = std::unique_ptr>(new UDPC::LockObj{}); + *writeLock = sharedSpinLockStrong->trade_read_for_write_lock(*readLock); + readLock.reset(nullptr); + + return remove_impl(); + } else { + return false; + } +} + +template +bool TSLQueue::TSLQIter::try_remove() { + if (readLock && !writeLock && readLock->isValid()) { + auto sharedSpinLockStrong = sharedSpinLockWeak.lock(); + if (!sharedSpinLockStrong) { + return false; + } + + writeLock = std::unique_ptr>(new UDPC::LockObj{}); + *writeLock = sharedSpinLockStrong->try_trade_read_for_write_lock(*readLock); + if (writeLock->isValid()) { + readLock.reset(nullptr); + return remove_impl(); + } else { + writeLock.reset(nullptr); + return false; + } + } else { + return false; + } +} + +template +bool TSLQueue::TSLQIter::remove_impl() { + const auto cleanupWriteLock = [this] () { + UDPC::SharedSpinLock::Ptr sharedSpinLockStrong = this->sharedSpinLockWeak.lock(); + if (!sharedSpinLockStrong) { + writeLock.reset(nullptr); + return; + } + this->readLock = std::unique_ptr>(new UDPC::LockObj{}); + (*this->readLock) = sharedSpinLockStrong->trade_write_for_read_lock(*(this->writeLock)); + this->writeLock.reset(nullptr); + }; + std::shared_ptr currentNode = this->currentNode.lock(); assert(currentNode); if(!currentNode->isNormal()) { + cleanupWriteLock(); return false; } @@ -384,12 +443,13 @@ bool TSLQueue::TSLQIter::remove() { assert(*msize > 0); --(*msize); + cleanupWriteLock(); return parent->next->isNormal(); } template typename TSLQueue::TSLQIter TSLQueue::begin() { - return TSLQIter(&mutex, head->next, &msize); + return TSLQIter(sharedSpinLock, head->next, &msize); } #endif diff --git a/src/test/TestTSLQueue.cpp b/src/test/TestTSLQueue.cpp index f579712..b46de73 100644 --- a/src/test/TestTSLQueue.cpp +++ b/src/test/TestTSLQueue.cpp @@ -143,7 +143,8 @@ TEST(TSLQueue, Iterator) { // test that lock is held by iterator EXPECT_FALSE(q.push_nb(10)); op = q.top_nb(); - EXPECT_FALSE(op); + // Getting top and iterator both hold read locks so this should be true. + EXPECT_TRUE(op); // backwards iteration EXPECT_TRUE(iter.prev()); @@ -175,6 +176,21 @@ TEST(TSLQueue, Iterator) { op = iter.current(); EXPECT_TRUE(op); EXPECT_EQ(*op, 2); + + // second iterator + auto iter2 = q.begin(); + + // Still should be able to get top. + EXPECT_TRUE(iter2.current()); + + // Shouldn't be able to remove if 2 iterators exist. + EXPECT_FALSE(iter2.try_remove()); + + // This will never return since the first iterator has a "read" lock. + //EXPECT_FALSE(iter2.remove()); + + // Still should be able to get top. + EXPECT_TRUE(iter2.current()); } EXPECT_EQ(q.size(), 9);