Add iterator to TSLQueue

Removed shared_ptr "locks", replaced by iterator owning a lock_guard.
This commit is contained in:
Stephen Seo 2019-11-03 18:46:25 +09:00
parent 10899ffaab
commit 3830be6e2d
2 changed files with 190 additions and 51 deletions

View file

@ -7,7 +7,6 @@
#include <chrono> #include <chrono>
#include <optional> #include <optional>
#include <cassert> #include <cassert>
#include <list> #include <list>
#include <type_traits> #include <type_traits>
@ -38,7 +37,7 @@ class TSLQueue {
private: private:
struct TSLQNode { struct TSLQNode {
TSLQNode() = default; TSLQNode();
// disable copy // disable copy
TSLQNode(TSLQNode& other) = delete; TSLQNode(TSLQNode& other) = delete;
TSLQNode& operator=(TSLQNode& other) = delete; TSLQNode& operator=(TSLQNode& other) = delete;
@ -49,10 +48,37 @@ class TSLQueue {
std::shared_ptr<TSLQNode> next; std::shared_ptr<TSLQNode> next;
std::weak_ptr<TSLQNode> prev; std::weak_ptr<TSLQNode> prev;
std::unique_ptr<T> data; std::unique_ptr<T> data;
enum TSLQN_Type {
TSLQN_NORMAL,
TSLQN_HEAD,
TSLQN_TAIL
};
TSLQN_Type type;
bool isNormal() const;
}; };
std::shared_ptr<char> iterValid; class TSLQIter {
std::shared_ptr<char> iterWrapperCount; public:
TSLQIter(std::mutex &mutex,
std::weak_ptr<TSLQNode> currentNode);
~TSLQIter();
std::optional<T> current();
bool next();
bool prev();
bool remove();
private:
std::lock_guard<std::mutex> lock;
std::weak_ptr<TSLQNode> currentNode;
};
public:
TSLQIter begin();
private:
std::mutex mutex; std::mutex mutex;
std::shared_ptr<TSLQNode> head; std::shared_ptr<TSLQNode> head;
std::shared_ptr<TSLQNode> tail; std::shared_ptr<TSLQNode> tail;
@ -61,14 +87,14 @@ class TSLQueue {
template <typename T> template <typename T>
TSLQueue<T>::TSLQueue() : TSLQueue<T>::TSLQueue() :
iterValid(std::make_shared<char>()),
iterWrapperCount(std::make_shared<char>()),
head(std::make_shared<TSLQNode>()), head(std::make_shared<TSLQNode>()),
tail(std::make_shared<TSLQNode>()), tail(std::make_shared<TSLQNode>()),
msize(0) msize(0)
{ {
head->next = tail; head->next = tail;
tail->prev = head; tail->prev = head;
head->type = TSLQNode::TSLQN_Type::TSLQN_HEAD;
tail->type = TSLQNode::TSLQN_Type::TSLQN_TAIL;
} }
template <typename T> template <typename T>
@ -76,9 +102,7 @@ TSLQueue<T>::~TSLQueue() {
} }
template <typename T> template <typename T>
TSLQueue<T>::TSLQueue(TSLQueue &&other) : TSLQueue<T>::TSLQueue(TSLQueue &&other)
iterValid(std::make_shared<char>()),
iterWrapperCount(std::make_shared<char>())
{ {
std::lock_guard lock(other.mutex); std::lock_guard lock(other.mutex);
head = std::move(other.head); head = std::move(other.head);
@ -88,8 +112,6 @@ TSLQueue<T>::TSLQueue(TSLQueue &&other) :
template <typename T> template <typename T>
TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) { TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
iterValid = std::make_shared<char>();
iterWrapperCount = std::make_shared<char>();
std::scoped_lock lock(mutex, other.mutex); std::scoped_lock lock(mutex, other.mutex);
head = std::move(other.head); head = std::move(other.head);
tail = std::move(other.tail); tail = std::move(other.tail);
@ -98,9 +120,6 @@ TSLQueue<T> & TSLQueue<T>::operator=(TSLQueue &&other) {
template <typename T> template <typename T>
void TSLQueue<T>::push(const T &data) { void TSLQueue<T>::push(const T &data) {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
auto newNode = std::make_shared<TSLQNode>(); auto newNode = std::make_shared<TSLQNode>();
newNode->data = std::make_unique<T>(data); newNode->data = std::make_unique<T>(data);
@ -117,9 +136,7 @@ void TSLQueue<T>::push(const T &data) {
template <typename T> template <typename T>
bool TSLQueue<T>::push_nb(const T &data) { bool TSLQueue<T>::push_nb(const T &data) {
if(iterWrapperCount.use_count() > 1) { if(mutex.try_lock()) {
return false;
} else if(mutex.try_lock()) {
auto newNode = std::make_shared<TSLQNode>(); auto newNode = std::make_shared<TSLQNode>();
newNode->data = std::make_unique<T>(data); newNode->data = std::make_unique<T>(data);
@ -141,9 +158,6 @@ bool TSLQueue<T>::push_nb(const T &data) {
template <typename T> template <typename T>
std::optional<T> TSLQueue<T>::top() { std::optional<T> TSLQueue<T>::top() {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
@ -155,9 +169,7 @@ std::optional<T> TSLQueue<T>::top() {
template <typename T> template <typename T>
std::optional<T> TSLQueue<T>::top_nb() { std::optional<T> TSLQueue<T>::top_nb() {
if(iterWrapperCount.use_count() > 1) { if(mutex.try_lock()) {
return std::nullopt;
} else if(mutex.try_lock()) {
std::optional<T> ret = std::nullopt; std::optional<T> ret = std::nullopt;
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
@ -172,9 +184,6 @@ std::optional<T> TSLQueue<T>::top_nb() {
template <typename T> template <typename T>
bool TSLQueue<T>::pop() { bool TSLQueue<T>::pop() {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
if(head->next == tail) { if(head->next == tail) {
return false; return false;
@ -185,8 +194,6 @@ bool TSLQueue<T>::pop() {
assert(msize > 0); assert(msize > 0);
--msize; --msize;
iterValid = std::make_shared<char>();
iterWrapperCount = std::make_shared<char>();
return true; return true;
} }
} }
@ -194,9 +201,6 @@ bool TSLQueue<T>::pop() {
template <typename T> template <typename T>
std::optional<T> TSLQueue<T>::top_and_pop() { std::optional<T> TSLQueue<T>::top_and_pop() {
std::optional<T> ret = std::nullopt; std::optional<T> ret = std::nullopt;
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
if(head->next != tail) { if(head->next != tail) {
assert(head->next->data); assert(head->next->data);
@ -207,9 +211,6 @@ std::optional<T> TSLQueue<T>::top_and_pop() {
head->next = newNext; head->next = newNext;
assert(msize > 0); assert(msize > 0);
--msize; --msize;
iterValid = std::make_shared<char>();
iterWrapperCount = std::make_shared<char>();
} }
return ret; return ret;
} }
@ -217,9 +218,6 @@ std::optional<T> TSLQueue<T>::top_and_pop() {
template <typename T> template <typename T>
std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) { std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
std::optional<T> ret = std::nullopt; std::optional<T> ret = std::nullopt;
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
if(head->next == tail) { if(head->next == tail) {
if(isEmpty) { if(isEmpty) {
@ -235,8 +233,6 @@ std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
assert(msize > 0); assert(msize > 0);
--msize; --msize;
iterValid = std::make_shared<char>();
iterWrapperCount = std::make_shared<char>();
if(isEmpty) { if(isEmpty) {
*isEmpty = head->next == tail; *isEmpty = head->next == tail;
} }
@ -246,35 +242,105 @@ std::optional<T> TSLQueue<T>::top_and_pop_and_empty(bool *isEmpty) {
template <typename T> template <typename T>
void TSLQueue<T>::clear() { void TSLQueue<T>::clear() {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
head->next = tail; head->next = tail;
tail->prev = head; tail->prev = head;
msize = 0; msize = 0;
iterValid = std::make_shared<char>();
iterWrapperCount = std::make_shared<char>();
} }
template <typename T> template <typename T>
bool TSLQueue<T>::empty() { bool TSLQueue<T>::empty() {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
return head->next == tail; return head->next == tail;
} }
template <typename T> template <typename T>
unsigned long long TSLQueue<T>::size() { unsigned long long TSLQueue<T>::size() {
while(iterWrapperCount.use_count() > 1) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
return msize; return msize;
} }
template <typename T>
TSLQueue<T>::TSLQNode::TSLQNode() :
type(TSLQN_Type::TSLQN_NORMAL)
{}
template <typename T>
bool TSLQueue<T>::TSLQNode::isNormal() const {
return type == TSLQN_Type::TSLQN_NORMAL;
}
template <typename T>
TSLQueue<T>::TSLQIter::TSLQIter(std::mutex &mutex,
std::weak_ptr<TSLQNode> currentNode) :
lock(mutex),
currentNode(currentNode)
{
}
template <typename T>
TSLQueue<T>::TSLQIter::~TSLQIter() {
}
template <typename T>
std::optional<T> TSLQueue<T>::TSLQIter::current() {
std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
assert(currentNode);
if(currentNode->isNormal()) {
return *currentNode->data.get();
} else {
return std::nullopt;
}
}
template <typename T>
bool TSLQueue<T>::TSLQIter::next() {
std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
assert(currentNode);
if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_TAIL) {
return false;
}
this->currentNode = currentNode->next;
return currentNode->next->type != TSLQNode::TSLQN_Type::TSLQN_TAIL;
}
template <typename T>
bool TSLQueue<T>::TSLQIter::prev() {
std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
assert(currentNode);
if(currentNode->type == TSLQNode::TSLQN_Type::TSLQN_HEAD) {
return false;
}
auto parent = currentNode->prev.lock();
assert(parent);
this->currentNode = currentNode->prev;
return parent->type != TSLQNode::TSLQN_Type::TSLQN_HEAD;
}
template <typename T>
bool TSLQueue<T>::TSLQIter::remove() {
std::shared_ptr<TSLQNode> currentNode = this->currentNode.lock();
assert(currentNode);
if(!currentNode->isNormal()) {
return false;
}
this->currentNode = currentNode->next;
auto parent = currentNode->prev.lock();
assert(parent);
currentNode->next->prev = parent;
parent->next = currentNode->next;
return parent->next->isNormal();
}
template <typename T>
typename TSLQueue<T>::TSLQIter TSLQueue<T>::begin() {
return TSLQIter(mutex, head->next);
}
#endif #endif

View file

@ -116,3 +116,76 @@ TEST(TSLQueue, Concurrent) {
} }
EXPECT_EQ(q.size(), 0); EXPECT_EQ(q.size(), 0);
} }
TEST(TSLQueue, Iterator) {
TSLQueue<int> q;
for(int i = 0; i < 10; ++i) {
q.push(i);
}
{
// iteration
auto iter = q.begin();
int i = 0;
auto op = iter.current();
while(op.has_value()) {
EXPECT_EQ(op.value(), i++);
if(i < 10) {
EXPECT_TRUE(iter.next());
} else {
EXPECT_FALSE(iter.next());
}
op = iter.current();
}
// test that lock is held by iterator
EXPECT_FALSE(q.push_nb(10));
op = q.top_nb();
EXPECT_FALSE(op.has_value());
// backwards iteration
EXPECT_TRUE(iter.prev());
op = iter.current();
while(op.has_value()) {
EXPECT_EQ(op.value(), --i);
if(i > 0) {
EXPECT_TRUE(iter.prev());
} else {
EXPECT_FALSE(iter.prev());
}
op = iter.current();
}
}
{
// iter remove
auto iter = q.begin();
EXPECT_TRUE(iter.next());
EXPECT_TRUE(iter.next());
EXPECT_TRUE(iter.next());
EXPECT_TRUE(iter.remove());
auto op = iter.current();
EXPECT_TRUE(op.has_value());
EXPECT_EQ(op.value(), 4);
EXPECT_TRUE(iter.prev());
op = iter.current();
EXPECT_TRUE(op.has_value());
EXPECT_EQ(op.value(), 2);
}
// check that "3" was removed from queue
int i = 0;
std::optional<int> op;
while(!q.empty()) {
op = q.top();
EXPECT_TRUE(op.has_value());
EXPECT_EQ(i++, op.value());
if(i == 3) {
++i;
}
EXPECT_TRUE(q.pop());
}
}