Reland C++11 "shared_lock" with iter remove fix

On iterator remove, the iterator will trade the read lock for a write
lock, and trade back for a read lock once the remove has been completed.
This commit is contained in:
Stephen Seo 2023-07-22 18:58:36 +09:00
parent 611287b377
commit 9b323eff55
5 changed files with 401 additions and 27 deletions

View file

@ -5,6 +5,7 @@ set(UDPC_VERSION 1.0)
set(UDPC_SOURCES set(UDPC_SOURCES
src/UDPConnection.cpp src/UDPConnection.cpp
src/CXX11_shared_spin_lock.cpp
) )
add_compile_options( add_compile_options(
@ -62,6 +63,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug")
find_package(GTest QUIET) find_package(GTest QUIET)
if(GTEST_FOUND) if(GTEST_FOUND)
set(UDPC_UnitTest_SOURCES set(UDPC_UnitTest_SOURCES
src/CXX11_shared_spin_lock.cpp
src/test/UDPC_UnitTest.cpp src/test/UDPC_UnitTest.cpp
src/test/TestTSLQueue.cpp src/test/TestTSLQueue.cpp
src/test/TestUDPC.cpp src/test/TestUDPC.cpp

View file

@ -0,0 +1,142 @@
#include "CXX11_shared_spin_lock.hpp"
UDPC::Badge UDPC::Badge::newInvalid() {
Badge badge;
badge.isValid = false;
return badge;
}
UDPC::Badge::Badge() :
isValid(true)
{}
UDPC::SharedSpinLock::Ptr UDPC::SharedSpinLock::newInstance() {
Ptr sharedSpinLock = Ptr(new SharedSpinLock());
sharedSpinLock->selfWeakPtr = sharedSpinLock;
return sharedSpinLock;
}
UDPC::SharedSpinLock::SharedSpinLock() :
selfWeakPtr(),
mutex(),
read(0),
write(false)
{}
UDPC::LockObj<false> UDPC::SharedSpinLock::spin_read_lock() {
while (true) {
std::lock_guard<std::mutex> lock(mutex);
if (!write) {
++read;
return LockObj<false>(selfWeakPtr, Badge{});
}
}
}
UDPC::LockObj<false> UDPC::SharedSpinLock::try_spin_read_lock() {
std::lock_guard<std::mutex> lock(mutex);
if (!write) {
++read;
return LockObj<false>(selfWeakPtr, Badge{});
}
return LockObj<false>(Badge{});
}
void UDPC::SharedSpinLock::read_unlock(UDPC::Badge &&badge) {
if (badge.isValid) {
std::lock_guard<std::mutex> lock(mutex);
if (read > 0) {
--read;
badge.isValid = false;
}
}
}
UDPC::LockObj<true> UDPC::SharedSpinLock::spin_write_lock() {
while (true) {
std::lock_guard<std::mutex> lock(mutex);
if (!write && read == 0) {
write = true;
return LockObj<true>(selfWeakPtr, Badge{});
}
}
}
UDPC::LockObj<true> UDPC::SharedSpinLock::try_spin_write_lock() {
std::lock_guard<std::mutex> lock(mutex);
if (!write && read == 0) {
write = true;
return LockObj<true>(selfWeakPtr, Badge{});
}
return LockObj<true>(Badge{});
}
void UDPC::SharedSpinLock::write_unlock(UDPC::Badge &&badge) {
if (badge.isValid) {
std::lock_guard<std::mutex> lock(mutex);
write = false;
badge.isValid = false;
}
}
UDPC::LockObj<false> UDPC::SharedSpinLock::trade_write_for_read_lock(UDPC::LockObj<true> &lockObj) {
if (lockObj.isValid() && lockObj.badge.isValid) {
while (true) {
std::lock_guard<std::mutex> lock(mutex);
if (write && read == 0) {
read = 1;
write = false;
lockObj.isLocked = false;
lockObj.badge.isValid = false;
return LockObj<false>(selfWeakPtr, Badge{});
}
}
} else {
return LockObj<false>(Badge{});
}
}
UDPC::LockObj<false> UDPC::SharedSpinLock::try_trade_write_for_read_lock(UDPC::LockObj<true> &lockObj) {
if (lockObj.isValid() && lockObj.badge.isValid) {
std::lock_guard<std::mutex> lock(mutex);
if (write && read == 0) {
read = 1;
write = false;
lockObj.isLocked = false;
lockObj.badge.isValid = false;
return LockObj<false>(selfWeakPtr, Badge{});
}
}
return LockObj<false>(Badge{});
}
UDPC::LockObj<true> UDPC::SharedSpinLock::trade_read_for_write_lock(UDPC::LockObj<false> &lockObj) {
if (lockObj.isValid() && lockObj.badge.isValid) {
while (true) {
std::lock_guard<std::mutex> lock(mutex);
if (!write && read == 1) {
read = 0;
write = true;
lockObj.isLocked = false;
lockObj.badge.isValid = false;
return LockObj<true>(selfWeakPtr, Badge{});
}
}
} else {
return LockObj<true>(Badge{});
}
}
UDPC::LockObj<true> UDPC::SharedSpinLock::try_trade_read_for_write_lock(UDPC::LockObj<false> &lockObj) {
if (lockObj.isValid() && lockObj.badge.isValid) {
std::lock_guard<std::mutex> lock(mutex);
if (!write && read == 1) {
read = 0;
write = true;
lockObj.isLocked = false;
lockObj.badge.isValid = false;
return LockObj<true>(selfWeakPtr, Badge{});
}
}
return LockObj<true>(Badge{});
}

View file

@ -0,0 +1,154 @@
#ifndef UDPC_CXX11_SHARED_SPIN_LOCK_H_
#define UDPC_CXX11_SHARED_SPIN_LOCK_H_
#include <memory>
#include <mutex>
#include <atomic>
namespace UDPC {
// Forward declaration for LockObj.
class SharedSpinLock;
class Badge {
public:
static Badge newInvalid();
// Disallow copy.
Badge(const Badge&) = delete;
Badge& operator=(const Badge&) = delete;
// Allow move.
Badge(Badge&&) = default;
Badge& operator=(Badge&&) = default;
private:
friend class SharedSpinLock;
// Can only be created by SharedSpinLock.
Badge();
bool isValid;
};
template <bool IsWriteObj>
class LockObj {
public:
// Invalid instance constructor.
LockObj();
~LockObj();
// Explicit invalid instance constructor.
static LockObj<IsWriteObj> newInvalid();
// Disallow copy.
LockObj(const LockObj&) = delete;
LockObj& operator=(const LockObj&) = delete;
// Allow move.
LockObj(LockObj&&) = default;
LockObj& operator=(LockObj&&) = default;
bool isValid() const;
private:
friend class SharedSpinLock;
// Only can be created by SharedSpinLock.
LockObj(Badge &&badge);
LockObj(std::weak_ptr<SharedSpinLock> lockPtr, Badge &&badge);
std::weak_ptr<SharedSpinLock> weakPtrLock;
bool isLocked;
Badge badge;
};
class SharedSpinLock {
public:
using Ptr = std::shared_ptr<SharedSpinLock>;
using Weak = std::weak_ptr<SharedSpinLock>;
static Ptr newInstance();
// Disallow copy.
SharedSpinLock(const SharedSpinLock&) = delete;
SharedSpinLock& operator=(const SharedSpinLock&) = delete;
// Allow move.
SharedSpinLock(SharedSpinLock&&) = default;
SharedSpinLock& operator=(SharedSpinLock&&) = default;
LockObj<false> spin_read_lock();
LockObj<false> try_spin_read_lock();
void read_unlock(Badge&&);
LockObj<true> spin_write_lock();
LockObj<true> try_spin_write_lock();
void write_unlock(Badge&&);
LockObj<false> trade_write_for_read_lock(LockObj<true>&);
LockObj<false> try_trade_write_for_read_lock(LockObj<true>&);
LockObj<true> trade_read_for_write_lock(LockObj<false>&);
LockObj<true> try_trade_read_for_write_lock(LockObj<false>&);
private:
SharedSpinLock();
Weak selfWeakPtr;
std::mutex mutex;
unsigned int read;
bool write;
};
template <bool IsWriteObj>
LockObj<IsWriteObj>::LockObj() :
weakPtrLock(),
isLocked(false),
badge(UDPC::Badge::newInvalid())
{}
template <bool IsWriteObj>
LockObj<IsWriteObj>::LockObj(Badge &&badge) :
weakPtrLock(),
isLocked(false),
badge(std::forward<Badge>(badge))
{}
template <bool IsWriteObj>
LockObj<IsWriteObj>::LockObj(SharedSpinLock::Weak lockPtr, Badge &&badge) :
weakPtrLock(lockPtr),
isLocked(true),
badge(std::forward<Badge>(badge))
{}
template <bool IsWriteObj>
LockObj<IsWriteObj>::~LockObj() {
if (!isLocked) {
return;
}
auto strongPtrLock = weakPtrLock.lock();
if (strongPtrLock) {
if (IsWriteObj) {
strongPtrLock->write_unlock(std::move(badge));
} else {
strongPtrLock->read_unlock(std::move(badge));
}
}
}
template <bool IsWriteObj>
LockObj<IsWriteObj> LockObj<IsWriteObj>::newInvalid() {
return LockObj<IsWriteObj>{};
}
template <bool IsWriteObj>
bool LockObj<IsWriteObj>::isValid() const {
return isLocked;
}
} // namespace UDPC
#endif

View file

@ -10,6 +10,8 @@
#include <list> #include <list>
#include <type_traits> #include <type_traits>
#include "CXX11_shared_spin_lock.hpp"
template <typename T> template <typename T>
class TSLQueue { class TSLQueue {
public: public:
@ -62,7 +64,7 @@ class TSLQueue {
class TSLQIter { class TSLQIter {
public: public:
TSLQIter(std::mutex *mutex, TSLQIter(UDPC::SharedSpinLock::Weak sharedSpinLockWeak,
std::weak_ptr<TSLQNode> currentNode, std::weak_ptr<TSLQNode> currentNode,
unsigned long *msize); unsigned long *msize);
~TSLQIter(); ~TSLQIter();
@ -75,19 +77,24 @@ class TSLQueue {
bool next(); bool next();
bool prev(); bool prev();
bool remove(); bool remove();
bool try_remove();
private: private:
std::mutex *mutex; UDPC::SharedSpinLock::Weak sharedSpinLockWeak;
std::unique_ptr<UDPC::LockObj<false>> readLock;
std::unique_ptr<UDPC::LockObj<true>> writeLock;
std::weak_ptr<TSLQNode> currentNode; std::weak_ptr<TSLQNode> currentNode;
unsigned long *const msize; unsigned long *const msize;
bool remove_impl();
}; };
public: public:
TSLQIter begin(); TSLQIter begin();
private: private:
std::mutex mutex; UDPC::SharedSpinLock::Ptr sharedSpinLock;
std::shared_ptr<TSLQNode> head; std::shared_ptr<TSLQNode> head;
std::shared_ptr<TSLQNode> tail; std::shared_ptr<TSLQNode> tail;
unsigned long msize; unsigned long msize;
@ -95,7 +102,7 @@ class TSLQueue {
template <typename T> template <typename T>
TSLQueue<T>::TSLQueue() : TSLQueue<T>::TSLQueue() :
mutex(), sharedSpinLock(UDPC::SharedSpinLock::newInstance()),
head(std::shared_ptr<TSLQNode>(new TSLQNode())), head(std::shared_ptr<TSLQNode>(new TSLQNode())),
tail(std::shared_ptr<TSLQNode>(new TSLQNode())), tail(std::shared_ptr<TSLQNode>(new TSLQNode())),
msize(0) msize(0)
@ -121,8 +128,8 @@ TSLQueue<T>::TSLQueue(TSLQueue &&other)
template <typename T> template <typename T>
TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) { TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
std::lock_guard<std::mutex> lock(mutex); auto selfWriteLock = sharedSpinLock->spin_write_lock();
std::lock_guard<std::mutex> otherLock(other.mutex); auto otherWriteLock = other.sharedSpinLock->spin_write_lock();
head = std::move(other.head); head = std::move(other.head);
tail = std::move(other.tail); tail = std::move(other.tail);
msize = std::move(other.msize); msize = std::move(other.msize);
@ -130,7 +137,7 @@ TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
template <typename T> template <typename T>
void TSLQueue<T>::push(const T &data) { void TSLQueue<T>::push(const T &data) {
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode()); auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode());
newNode->data = std::unique_ptr<T>(new T(data)); newNode->data = std::unique_ptr<T>(new T(data));
@ -146,7 +153,8 @@ void TSLQueue<T>::push(const T &data) {
template <typename T> template <typename T>
bool TSLQueue<T>::push_nb(const T &data) { bool TSLQueue<T>::push_nb(const T &data) {
if(mutex.try_lock()) { auto writeLock = sharedSpinLock->try_spin_write_lock();
if(writeLock.isValid()) {
auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode()); auto newNode = std::shared_ptr<TSLQNode>(new TSLQNode());
newNode->data = std::unique_ptr<T>(new T(data)); newNode->data = std::unique_ptr<T>(new T(data));
@ -159,7 +167,6 @@ bool TSLQueue<T>::push_nb(const T &data) {
tail->prev = newNode; tail->prev = newNode;
++msize; ++msize;
mutex.unlock();
return true; return true;
} else { } else {
return false; return false;
@ -168,7 +175,7 @@ bool TSLQueue<T>::push_nb(const T &data) {
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::top() { std::unique_ptr<T> TSLQueue<T>::top() {
std::lock_guard<std::mutex> lock(mutex); auto readLock = sharedSpinLock->spin_read_lock();
std::unique_ptr<T> result; std::unique_ptr<T> result;
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
@ -181,20 +188,20 @@ std::unique_ptr<T> TSLQueue<T>::top() {
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::top_nb() { std::unique_ptr<T> TSLQueue<T>::top_nb() {
std::unique_ptr<T> result; std::unique_ptr<T> result;
if(mutex.try_lock()) { auto readLock = sharedSpinLock->try_spin_read_lock();
if(readLock.isValid()) {
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
result = std::unique_ptr<T>(new T); result = std::unique_ptr<T>(new T);
*result = *head->next->data; *result = *head->next->data;
} }
mutex.unlock();
} }
return result; return result;
} }
template <typename T> template <typename T>
bool TSLQueue<T>::pop() { bool TSLQueue<T>::pop() {
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
if(head->next == tail) { if(head->next == tail) {
return false; return false;
} else { } else {
@ -211,7 +218,7 @@ bool TSLQueue<T>::pop() {
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::top_and_pop() { std::unique_ptr<T> TSLQueue<T>::top_and_pop() {
std::unique_ptr<T> result; std::unique_ptr<T> result;
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
result = std::unique_ptr<T>(new T); result = std::unique_ptr<T>(new T);
@ -229,7 +236,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop() {
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) { std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
std::unique_ptr<T> result; std::unique_ptr<T> result;
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
if(head->next == tail) { if(head->next == tail) {
if(isEmpty) { if(isEmpty) {
*isEmpty = true; *isEmpty = true;
@ -255,7 +262,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_rsize(unsigned long *rsize) { std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_rsize(unsigned long *rsize) {
std::unique_ptr<T> result; std::unique_ptr<T> result;
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
if(head->next == tail) { if(head->next == tail) {
if(rsize) { if(rsize) {
*rsize = 0; *rsize = 0;
@ -280,7 +287,7 @@ std::unique_ptr<T> TSLQueue<T>::top_and_pop_and_rsize(unsigned long *rsize) {
template <typename T> template <typename T>
void TSLQueue<T>::clear() { void TSLQueue<T>::clear() {
std::lock_guard<std::mutex> lock(mutex); auto writeLock = sharedSpinLock->spin_write_lock();
head->next = tail; head->next = tail;
tail->prev = head; tail->prev = head;
@ -289,13 +296,13 @@ void TSLQueue<T>::clear() {
template <typename T> template <typename T>
bool TSLQueue<T>::empty() { bool TSLQueue<T>::empty() {
std::lock_guard<std::mutex> lock(mutex); auto readLock = sharedSpinLock->spin_read_lock();
return head->next == tail; return head->next == tail;
} }
template <typename T> template <typename T>
unsigned long TSLQueue<T>::size() { unsigned long TSLQueue<T>::size() {
std::lock_guard<std::mutex> lock(mutex); auto readLock = sharedSpinLock->spin_read_lock();
return msize; return msize;
} }
@ -313,20 +320,20 @@ bool TSLQueue<T>::TSLQNode::isNormal() const {
} }
template <typename T> template <typename T>
TSLQueue<T>::TSLQIter::TSLQIter(std::mutex *mutex, TSLQueue<T>::TSLQIter::TSLQIter(UDPC::SharedSpinLock::Weak lockWeak,
std::weak_ptr<TSLQNode> currentNode, std::weak_ptr<TSLQNode> currentNode,
unsigned long *msize) : unsigned long *msize) :
mutex(mutex), sharedSpinLockWeak(lockWeak),
readLock(std::unique_ptr<UDPC::LockObj<false>>(new UDPC::LockObj<false>{})),
writeLock(),
currentNode(currentNode), currentNode(currentNode),
msize(msize) msize(msize)
{ {
mutex->lock(); *readLock = lockWeak.lock()->spin_read_lock();
} }
template <typename T> template <typename T>
TSLQueue<T>::TSLQIter::~TSLQIter() { TSLQueue<T>::TSLQIter::~TSLQIter() {}
mutex->unlock();
}
template <typename T> template <typename T>
std::unique_ptr<T> TSLQueue<T>::TSLQIter::current() { std::unique_ptr<T> TSLQueue<T>::TSLQIter::current() {
@ -368,9 +375,61 @@ bool TSLQueue<T>::TSLQIter::prev() {
template <typename T> template <typename T>
bool TSLQueue<T>::TSLQIter::remove() { bool TSLQueue<T>::TSLQIter::remove() {
if (readLock && !writeLock && readLock->isValid()) {
auto sharedSpinLockStrong = sharedSpinLockWeak.lock();
if (!sharedSpinLockStrong) {
return false;
}
writeLock = std::unique_ptr<UDPC::LockObj<true>>(new UDPC::LockObj<true>{});
*writeLock = sharedSpinLockStrong->trade_read_for_write_lock(*readLock);
readLock.reset(nullptr);
return remove_impl();
} else {
return false;
}
}
template <typename T>
bool TSLQueue<T>::TSLQIter::try_remove() {
if (readLock && !writeLock && readLock->isValid()) {
auto sharedSpinLockStrong = sharedSpinLockWeak.lock();
if (!sharedSpinLockStrong) {
return false;
}
writeLock = std::unique_ptr<UDPC::LockObj<true>>(new UDPC::LockObj<true>{});
*writeLock = sharedSpinLockStrong->try_trade_read_for_write_lock(*readLock);
if (writeLock->isValid()) {
readLock.reset(nullptr);
return remove_impl();
} else {
writeLock.reset(nullptr);
return false;
}
} else {
return false;
}
}
template <typename T>
bool TSLQueue<T>::TSLQIter::remove_impl() {
const auto cleanupWriteLock = [this] () {
UDPC::SharedSpinLock::Ptr sharedSpinLockStrong = this->sharedSpinLockWeak.lock();
if (!sharedSpinLockStrong) {
writeLock.reset(nullptr);
return;
}
this->readLock = std::unique_ptr<UDPC::LockObj<false>>(new UDPC::LockObj<false>{});
(*this->readLock) = sharedSpinLockStrong->trade_write_for_read_lock(*(this->writeLock));
this->writeLock.reset(nullptr);
};
std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock(); std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
assert(currentNode); assert(currentNode);
if(!currentNode->isNormal()) { if(!currentNode->isNormal()) {
cleanupWriteLock();
return false; return false;
} }
@ -384,12 +443,13 @@ bool TSLQueue<T>::TSLQIter::remove() {
assert(*msize > 0); assert(*msize > 0);
--(*msize); --(*msize);
cleanupWriteLock();
return parent->next->isNormal(); return parent->next->isNormal();
} }
template <typename T> template <typename T>
typename TSLQueue<T>::TSLQIter TSLQueue<T>::begin() { typename TSLQueue<T>::TSLQIter TSLQueue<T>::begin() {
return TSLQIter(&mutex, head->next, &msize); return TSLQIter(sharedSpinLock, head->next, &msize);
} }
#endif #endif

View file

@ -143,7 +143,8 @@ TEST(TSLQueue, Iterator) {
// test that lock is held by iterator // test that lock is held by iterator
EXPECT_FALSE(q.push_nb(10)); EXPECT_FALSE(q.push_nb(10));
op = q.top_nb(); op = q.top_nb();
EXPECT_FALSE(op); // Getting top and iterator both hold read locks so this should be true.
EXPECT_TRUE(op);
// backwards iteration // backwards iteration
EXPECT_TRUE(iter.prev()); EXPECT_TRUE(iter.prev());
@ -175,6 +176,21 @@ TEST(TSLQueue, Iterator) {
op = iter.current(); op = iter.current();
EXPECT_TRUE(op); EXPECT_TRUE(op);
EXPECT_EQ(*op, 2); EXPECT_EQ(*op, 2);
// second iterator
auto iter2 = q.begin();
// Still should be able to get top.
EXPECT_TRUE(iter2.current());
// Shouldn't be able to remove if 2 iterators exist.
EXPECT_FALSE(iter2.try_remove());
// This will never return since the first iterator has a "read" lock.
//EXPECT_FALSE(iter2.remove());
// Still should be able to get top.
EXPECT_TRUE(iter2.current());
} }
EXPECT_EQ(q.size(), 9); EXPECT_EQ(q.size(), 9);