diff --git a/cpp_impl/src/UDPC_Defines.hpp b/cpp_impl/src/UDPC_Defines.hpp index 40e53e6..180e8b0 100644 --- a/cpp_impl/src/UDPC_Defines.hpp +++ b/cpp_impl/src/UDPC_Defines.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -147,6 +148,8 @@ struct Context { std::chrono::steady_clock::time_point lastUpdated; // ipv4 address and port (as ConnectionIdentifier) to ConnectionData std::unordered_map conMap; + // ipv4 address to all connected ConnectionIdentifiers + std::unordered_map > addrConMap; // id to ipv4 address and port (as ConnectionIdentifier) std::unordered_map idMap; diff --git a/cpp_impl/src/UDPConnection.cpp b/cpp_impl/src/UDPConnection.cpp index 15104a8..a32879f 100644 --- a/cpp_impl/src/UDPConnection.cpp +++ b/cpp_impl/src/UDPConnection.cpp @@ -322,11 +322,25 @@ void UDPC_update(void *ctx) { } } for(auto iter = removed.begin(); iter != removed.end(); ++iter) { + auto addrConIter = c->addrConMap.find(iter->getAddr()); + assert(addrConIter != c->addrConMap.end() + && "addrConMap must have an entry for a current connection"); + auto addrConSetIter = addrConIter->second.find(*iter); + assert(addrConSetIter != addrConIter->second.end() + && "nested set in addrConMap must have an entry for a current connection"); + addrConIter->second.erase(addrConSetIter); + if(addrConIter->second.empty()) { + c->addrConMap.erase(addrConIter); + } + auto cIter = c->conMap.find(*iter); - assert(cIter != c->conMap.end()); + assert(cIter != c->conMap.end() + && "conMap must have the entry set to be removed"); + if(cIter->second.flags.test(4)) { c->idMap.erase(cIter->second.id); } + c->conMap.erase(cIter); } } @@ -584,6 +598,18 @@ void UDPC_update(void *ctx) { c->idMap.insert(std::make_pair(newConnection.id, identifier)); c->conMap.insert(std::make_pair(identifier, std::move(newConnection))); + auto addrConIter = c->addrConMap.find(identifier.getAddr()); + if(addrConIter == c->addrConMap.end()) { + auto insertResult = c->addrConMap.insert( + std::make_pair( + identifier.getAddr(), + std::unordered_set{} + )); + assert(insertResult.second + && "Must successfully insert into addrConMap"); + addrConIter = insertResult.first; + } + addrConIter->second.insert(identifier); // TODO trigger event server established connection with client } else if (c->flags.test(1)) { // is client @@ -809,6 +835,16 @@ int UDPC_drop_connection(void *ctx, uint32_t addr, uint16_t port) { auto iter = c->conMap.find(identifier); if(iter != c->conMap.end()) { + if(iter->second.flags.test(4)) { + c->idMap.erase(iter->second.id); + } + auto addrConIter = c->addrConMap.find(addr); + if(addrConIter != c->addrConMap.end()) { + addrConIter->second.erase(identifier); + if(addrConIter->second.empty()) { + c->addrConMap.erase(addrConIter); + } + } c->conMap.erase(iter); return 1; } @@ -816,6 +852,31 @@ int UDPC_drop_connection(void *ctx, uint32_t addr, uint16_t port) { return 0; } +int UDPC_drop_connection_addr(void *ctx, uint32_t addr) { + UDPC::Context *c = UDPC::verifyContext(ctx); + if(!c) { + return 0; + } + + auto addrConIter = c->addrConMap.find(addr); + if(addrConIter != c->addrConMap.end()) { + for(auto identIter = addrConIter->second.begin(); + identIter != addrConIter->second.end(); + ++identIter) { + auto conIter = c->conMap.find(*identIter); + assert(conIter != c->conMap.end()); + if(conIter->second.flags.test(4)) { + c->idMap.erase(conIter->second.id); + } + c->conMap.erase(conIter); + } + c->addrConMap.erase(addrConIter); + return 1; + } + + return 0; +} + uint32_t UDPC_set_protocol_id(void *ctx, uint32_t id) { UDPC::Context *c = UDPC::verifyContext(ctx); if(!c) { diff --git a/cpp_impl/src/UDPConnection.h b/cpp_impl/src/UDPConnection.h index 31af2b4..4dc10b2 100644 --- a/cpp_impl/src/UDPConnection.h +++ b/cpp_impl/src/UDPConnection.h @@ -86,6 +86,9 @@ int UDPC_set_accept_new_connections(void *ctx, int isAccepting); /// addr must be in network byte order (big-endian), port must be in native byte order int UDPC_drop_connection(void *ctx, uint32_t addr, uint16_t port); +/// addr must be in network byte order, drops all connections to specified addr +int UDPC_drop_connection_addr(void *ctx, uint32_t addr); + uint32_t UDPC_set_protocol_id(void *ctx, uint32_t id); UDPC_LoggingType set_logging_type(void *ctx, UDPC_LoggingType loggingType);