Program Listing for File ConnectionManager.cpp
↰ Return to documentation for file (BackendLibfabric/ConnectionManager.cpp)
#include "ConnectionManager.hpp"
#include <algorithm>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <tuple>
#include <utility>
#include <tracy/Tracy.hpp>
#include <rdma/fabric.h>
#include <rdma/fi_eq.h>
#include "ActiveEndpoint.hpp"
#include "Issues.hpp"
#include "netio3-backend/EventLoop/BaseEventLoop.hpp"
#include "netio3-backend/Netio3Backend.hpp"
#include "netio3-backend/Issues.hpp"
#include "SendEndpointBuffered.hpp"
netio3::libfabric::ConnectionManager::ConnectionManager(BaseEventLoop* event_loop,
NetworkConfig config) :
m_config{std::move(config)},
m_event_loop{event_loop},
m_close_signal{m_event_loop, [this] () { handle_close_requests(); }, false },
m_send_sendpoints_buffered{config.thread_safety == ThreadSafetyModel::SAFE},
m_send_endpoints_zero_copy{config.thread_safety == ThreadSafetyModel::SAFE}
{}
netio3::libfabric::ConnectionManager::~ConnectionManager()
{
if (m_event_loop->get_thread_id() == std::this_thread::get_id()) {
unregister_all();
} else if (m_event_loop->is_running()) {
std::mutex mutex;
std::condition_variable cv;
std::atomic_bool unregistered{false};
const auto signal = m_event_loop->create_signal(
[this, &unregistered, &cv](int) {
unregister_all();
unregistered = true;
cv.notify_one();
},
false);
signal.fire();
std::unique_lock<std::mutex> lock(mutex);
while (m_event_loop->is_running() and !unregistered.load()) {
constexpr static auto timeout = std::chrono::seconds(1);
if (cv.wait_for(lock, timeout, [&unregistered] { return unregistered.load(); })) {
break;
}
}
}
}
netio3::EndPointAddress netio3::libfabric::ConnectionManager::open_listen_endpoint(
const EndPointAddress& address,
ConnectionParameters conn_params)
{
try {
validate_capabilities(conn_params);
} catch (const FailedOpenActiveEndpoint& e) {
throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
std::lock_guard lock(m_listen_endpoint_mutex);
if (m_listen_endpoints.contains(address)) {
throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
try {
check_and_init(address, FI_SOURCE);
} catch (const LibfabricDomainError& e) {
throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
if (m_shared_buffer_manager->get_receive_context_manager().has_value()) {
ERS_INFO(std::format("Using shared receive context manager for listen endpoint at {}:{}. "
"Ignoring connection parameters passed to open_listen_endpoint",
address.address(),
address.port()));
}
auto lendpoint_to_insert =
ListenSocket(address, conn_params, m_config.mode, *m_domain_manager);
const auto pair =
m_listen_endpoints.try_emplace(lendpoint_to_insert.get_address(), std::move(lendpoint_to_insert));
auto& lendpoint = pair.first->second;
const auto eq_ev_ctx = EventContext{
lendpoint.get_endpoint().eqfd, [this, &lendpoint](int) { on_listen_endpoint_cm_event(lendpoint); }};
m_event_loop->register_fd(eq_ev_ctx);
return lendpoint.get_address();
}
void netio3::libfabric::ConnectionManager::open_active_endpoint(
const EndPointAddress& address,
const ConnectionParameters& conn_params)
{
validate_capabilities(conn_params);
std::lock_guard lock(m_active_endpoint_mutex);
create_active_endpoint(address, conn_params);
std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
enable_capabilities(address, conn_params);
m_active_endpoints.at(address).complete_connection(ActiveEndpoint::ConnectionMode::Connect);
const auto ctx = EventContext{m_active_endpoints.at(address).get_endpoint().eqfd,
[this, eqfd = m_active_endpoints.at(address).get_endpoint().eqfd](
int) { on_active_endpoint_cm_event(eqfd); }};
m_active_endpoints.at(address).set_eq_ev_ctx(ctx);
m_event_loop->register_fd(ctx);
}
void netio3::libfabric::ConnectionManager::close_listen_endpoint(const EndPointAddress& address)
{
std::lock_guard lock(m_listen_endpoint_mutex);
if (not m_listen_endpoints.contains(address)) {
throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port());
}
m_close_queue.push({address, CloseQueueItem::Type::listen});
m_close_signal.fire();
}
void netio3::libfabric::ConnectionManager::close_active_endpoint(const EndPointAddress& address)
{
{
std::lock_guard lock(m_active_endpoint_mutex);
if (not m_active_endpoints.contains(address)) {
throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}
}
m_close_queue.push({address, CloseQueueItem::Type::active});
m_close_signal.fire();
}
std::size_t netio3::libfabric::ConnectionManager::get_num_available_buffers(
const EndPointAddress& address)
{
if (m_send_sendpoints_buffered.contains(address)) {
return m_send_sendpoints_buffered.apply(
address, [](std::unique_ptr<SendEndpointBuffered>& sendpoint) { return sendpoint->get_num_available_buffers(); });
}
if (m_send_endpoints_zero_copy.contains(address)) {
return m_send_endpoints_zero_copy.apply(
address, [](std::unique_ptr<SendEndpointZeroCopy>& sendpoint) { return sendpoint->get_num_available_buffers(); });
}
throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}
void netio3::libfabric::ConnectionManager::check_and_init(const EndPointAddress& address,
const std::uint64_t flags)
{
if (m_domain_manager == nullptr) {
m_domain_manager = std::make_unique<DomainManager>(m_config, address, flags);
}
if (m_cq_reactor == nullptr) {
m_cq_reactor =
std::make_unique<CqReactor>(m_domain_manager->get_fabric(), m_config.callbacks.on_data_cb);
}
if (m_shared_buffer_manager == nullptr) {
m_shared_buffer_manager = std::make_unique<SharedBufferManager>(
m_config, *m_domain_manager, m_event_loop);
}
}
void netio3::libfabric::ConnectionManager::create_active_endpoint(
const EndPointAddress& address,
const ConnectionParameters& connection_params)
{
if (m_active_endpoints.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
try {
check_and_init(address, 0);
} catch (const LibfabricDomainError& e) {
throw FailedOpenActiveEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
auto* srx_context = std::invoke([this]() -> fid_ep* {
if (m_shared_buffer_manager->get_receive_context_manager().has_value()) {
return m_shared_buffer_manager->get_receive_context_manager()->get_srx_context();
}
return nullptr;
});
m_active_endpoints.try_emplace(address,
address,
m_config.mode,
get_endpoint_capabilities(connection_params),
m_domain_manager->get_fabric(),
m_domain_manager->get_domain(),
0,
srx_context);
}
void netio3::libfabric::ConnectionManager::validate_capabilities(
const ConnectionParameters& connection_params) const
{
if (connection_params.recv_params.use_shared_receive_buffers and
not m_config.conn_params.recv_params.use_shared_receive_buffers) {
throw InvalidConnectionParameters(
"Shared receive buffers are not enabled in the backend configuration");
}
if (connection_params.send_zero_copy_params.use_shared_send_buffers and
not m_config.conn_params.send_zero_copy_params.use_shared_send_buffers) {
throw InvalidConnectionParameters(
"Shared zero-copy send buffers are not enabled in the backend configuration");
}
if (connection_params.send_buffered_params.use_shared_send_buffers and
not m_config.conn_params.send_buffered_params.use_shared_send_buffers) {
throw InvalidConnectionParameters(
"Shared buffered send buffers are not enabled in the backend configuration");
}
if ((connection_params.send_buffered_params.num_buf > 0 or
connection_params.send_buffered_params.use_shared_send_buffers) and
(connection_params.send_zero_copy_params.mr_start != nullptr or
connection_params.send_zero_copy_params.use_shared_send_buffers)) {
throw InvalidConnectionParameters(
"Libfabric does not support buffered and zero-copy sending on the same endpoint");
}
if (connection_params.recv_params.use_shared_receive_buffers and
connection_params.recv_params.num_buf > 0) {
ers::warning(
InvalidConnectionParameters("Shared receive buffers requested, but the number of receive "
"buffers is set to a non-zero value. Value ignored."));
}
if (connection_params.send_buffered_params.use_shared_send_buffers and
connection_params.send_buffered_params.num_buf > 0) {
ers::warning(InvalidConnectionParameters(
"Shared buffered send buffers requested, but the number of buffered send buffers is set to a "
"non-zero value. Value ignored."));
}
if (connection_params.send_zero_copy_params.use_shared_send_buffers and
connection_params.send_zero_copy_params.mr_start != nullptr) {
ers::warning(InvalidConnectionParameters(
"Shared zero-copy send buffers requested, but the zero-copy send memory region is set to a "
"non-null value. Value ignored."));
}
}
void netio3::libfabric::ConnectionManager::enable_capabilities(
const EndPointAddress& address,
const ConnectionParameters& connection_params)
{
if (connection_params.send_buffered_params.num_buf > 0 or
connection_params.send_buffered_params.use_shared_send_buffers) {
enable_buffered_sending(address, connection_params.send_buffered_params);
}
if (connection_params.send_zero_copy_params.mr_start != nullptr) {
enable_zero_copy_sending(address, connection_params.send_zero_copy_params);
}
if (connection_params.recv_params.num_buf > 0 or
connection_params.recv_params.use_shared_receive_buffers) {
enable_receiving(address, connection_params.recv_params);
}
}
void netio3::libfabric::ConnectionManager::enable_buffered_sending(
const EndPointAddress& address,
ConnectionParametersSendBuffered conn_params)
{
if (m_send_sendpoints_buffered.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
if (not m_active_endpoints.contains(address)) {
throw FailedOpenActiveEndpoint(
ERS_HERE, address.address(), address.port(), "No active endpoint found");
}
m_send_sendpoints_buffered.try_emplace(
address,
std::make_unique<SendEndpointBuffered>(m_active_endpoints.at(address),
conn_params,
m_shared_buffer_manager->get_send_buffer_manager(),
*m_domain_manager.get()));
}
void netio3::libfabric::ConnectionManager::enable_zero_copy_sending(
const EndPointAddress& address,
ConnectionParametersSendZeroCopy conn_params)
{
if (m_send_endpoints_zero_copy.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
if (not m_active_endpoints.contains(address)) {
throw FailedOpenActiveEndpoint(
ERS_HERE, address.address(), address.port(), "No active endpoint found");
}
m_send_endpoints_zero_copy.try_emplace(
address,
std::make_unique<SendEndpointZeroCopy>(m_active_endpoints.at(address),
conn_params,
m_shared_buffer_manager->get_zero_copy_buffer_manager(),
*m_domain_manager));
}
void netio3::libfabric::ConnectionManager::enable_receiving(const EndPointAddress& address,
ConnectionParametersRecv conn_params)
{
if (m_receive_endpoints.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
if (not m_active_endpoints.contains(address)) {
throw FailedOpenActiveEndpoint(
ERS_HERE, address.address(), address.port(), "No active endpoint found");
}
m_receive_endpoints.try_emplace(address,
m_active_endpoints.at(address),
conn_params,
m_shared_buffer_manager->get_receive_context_manager(),
*m_domain_manager,
m_event_loop);
}
netio3::libfabric::ConnectionManager::CmEvent netio3::libfabric::ConnectionManager::read_cm_event(
fid_eq* eq)
{
ZoneScoped;
std::uint32_t event{};
fi_eq_cm_entry entry{};
fi_info* info{nullptr};
fi_eq_err_entry err_entry{};
const auto rd = fi_eq_sread(eq, &event, &entry, sizeof(entry), 0, 0);
if (rd < 0) {
if (rd == -FI_EAGAIN) {
return {rd, err_entry, FiInfoUniquePtr{info}};
}
if (rd == -FI_EAVAIL) {
const auto r = fi_eq_readerr(eq, &err_entry, 0);
if (r < 0) {
ers::error(LibFabricCmError(
ERS_HERE, std::format("Failed to retrieve details on Event Queue error {}", r)));
}
ers::error(LibFabricCmError(
ERS_HERE,
std::format("Event Queue error: {} (code: {}), provider specific: {} (code: {})",
fi_strerror(err_entry.err),
err_entry.err,
fi_eq_strerror(eq, err_entry.prov_errno, err_entry.err_data, nullptr, 0),
err_entry.prov_errno)));
return {rd, err_entry, FiInfoUniquePtr{info}};
}
}
if (rd != sizeof(entry)) {
ers::error(LibFabricCmError(ERS_HERE, std::format("Failed to read from Event Queue: {}", rd)));
}
return {event, err_entry, FiInfoUniquePtr{entry.info}};
}
void netio3::libfabric::ConnectionManager::on_listen_endpoint_cm_event(ListenSocket& lendpoint)
{
ZoneScoped;
ERS_DEBUG(1, "listen endpoint: connection event");
auto event = read_cm_event(lendpoint.get_endpoint().eq.get());
switch (event.event) {
case FI_CONNREQ: {
ERS_DEBUG(2, ": FI_CONNREQ");
handle_connection_request(lendpoint, std::move(event.info));
} break;
case FI_CONNECTED:
throw LibFabricCmError(ERS_HERE, "FI_CONNECTED received on listen endpoint");
case FI_SHUTDOWN:
throw LibFabricCmError(ERS_HERE, "FI_SHUTDOWN received on listen endpoint");
case -FI_EAGAIN: {
auto* fp = &lendpoint.get_endpoint().eq->fid;
fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
break;
}
case -FI_EAVAIL:
ers::error(LibFabricCmError(
ERS_HERE,
std::format("Unhandled error in listen endpoint EQ code: {} provider specific code: {}",
event.err_entry.err,
event.err_entry.prov_errno)));
break;
default:
ers::warning(LibFabricCmError(
ERS_HERE, std::format("Unexpected event {} in listen endpoint Event Queue", event.event)));
break;
}
}
void netio3::libfabric::ConnectionManager::on_active_endpoint_cm_event(const int eqfd)
{
ZoneScoped;
ERS_DEBUG(1, "active endpoint: connection event");
const auto [ep, is_pending] = std::invoke([this, eqfd]() -> std::tuple<ActiveEndpoint&, bool> {
std::lock_guard lock(m_active_endpoint_mutex);
const auto it_emplaced_ep = std::ranges::find_if(
m_active_endpoints, [eqfd](const auto& pair) { return pair.second.get_endpoint().eqfd == eqfd; });
if (it_emplaced_ep != m_active_endpoints.end()) {
return {it_emplaced_ep->second, false};
}
const auto it_pending_ep = std::ranges::find_if(
m_pending_active_endpoints, [eqfd](const auto& pair) { return pair.endpoint.get_endpoint().eqfd == eqfd; });
if (it_pending_ep != m_pending_active_endpoints.end()) {
return {it_pending_ep->endpoint, true};
}
throw LibFabricCmError(ERS_HERE, std::format("No active endpoint found for EQ fd {}", eqfd));
});
{
const auto event = std::invoke([this, &ep] {
std::lock_guard lock(m_active_endpoint_mutex);
return read_cm_event(ep.get_endpoint().eq.get());
});
switch (event.event) {
case FI_SHUTDOWN: {
const auto address = ep.get_address();
ERS_INFO(std::format("Closed connection to {}:{} because received FI_SHUTDOWN",
address.address(),
address.port()));
do_close_active_endpoint(address);
return;
}
case FI_CONNECTED: {
ep.update_addresses();
auto* active_ep = &ep;
if (is_pending) {
auto& moved_ep = handle_pending_active_endpoint(eqfd);
active_ep = &moved_ep;
}
std::unique_lock lock(m_active_endpoint_mutex);
const auto address = active_ep->get_address();
if (m_send_sendpoints_buffered.contains(address)) {
m_active_endpoints.at(address).set_cq_ev_ctx(
{m_active_endpoints.at(address).get_endpoint().cqfd, [this, address](int) {
const auto keys =
m_send_sendpoints_buffered.apply(address, [this](std::unique_ptr<SendEndpointBuffered>& sendpoint_cb) {
return m_cq_reactor->on_send_cq_event(*sendpoint_cb);
});
if (m_config.callbacks.on_send_completed_cb != nullptr) {
for (const auto key : keys) {
m_config.callbacks.on_send_completed_cb(address, key);
}
}
}});
} else if (m_send_endpoints_zero_copy.contains(address)) {
m_active_endpoints.at(address).set_cq_ev_ctx(
{m_active_endpoints.at(address).get_endpoint().cqfd, [this, address](int) {
const auto keys =
m_send_endpoints_zero_copy.apply(address, [this](std::unique_ptr<SendEndpointZeroCopy>& sendpoint_cb) {
return m_cq_reactor->on_send_cq_event(*sendpoint_cb);
});
if (m_config.callbacks.on_send_completed_cb != nullptr) {
for (const auto key : keys) {
m_config.callbacks.on_send_completed_cb(address, key);
}
}
}});
}
{
std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
if (m_receive_endpoints.contains(address)) {
m_active_endpoints.at(address).set_rcq_ev_ctx(
{m_active_endpoints.at(address).get_endpoint().rcqfd, [this, address](int) {
m_cq_reactor->on_recv_cq_event(m_receive_endpoints.at(address));
}});
m_event_loop->register_fd(m_active_endpoints.at(address).get_rcq_ev_ctx());
}
}
ERS_DEBUG(1,
std::format("Active endpoint: EQ fd {} connected, CQ fd {}",
active_ep->get_endpoint().eqfd,
active_ep->get_endpoint().cqfd));
ERS_INFO(std::format("Opened send connection to {}:{}", address.address(), address.port()));
const auto ctx = m_active_endpoints.at(address).get_cq_ev_ctx();
if (ctx.fd >= 0) {
m_event_loop->register_fd(ctx);
}
if (m_config.callbacks.on_connection_established_cb != nullptr) {
const auto local_address = m_active_endpoints.at(address).get_local_address();
const auto capabilities = m_active_endpoints.at(address).get_capabilities();
lock.unlock(); // Unlock before calling the callback to avoid deadlock
m_config.callbacks.on_connection_established_cb(
address,
local_address,
capabilities);
lock.lock(); // Re-lock after the callback
}
return;
}
case FI_MR_COMPLETE:
case FI_AV_COMPLETE:
case FI_JOIN_COMPLETE:
// Not implemented
break;
case -FI_EAVAIL:
switch (event.err_entry.err) {
case FI_ECONNREFUSED: {
ERS_DEBUG(1, "Connection refused (FI_ECONNREFUSED). Deallocating active endpoint resources");
std::scoped_lock lock(m_active_endpoint_mutex, m_receive_endpoint_mutex);
const auto address = ep.get_address();
m_event_loop->remove_fd(ep.get_endpoint().eqfd);
m_send_endpoints_zero_copy.erase(address);
m_send_sendpoints_buffered.erase(address);
m_receive_endpoints.erase(address);
m_active_endpoints.erase(address);
if (m_config.callbacks.on_connection_refused_cb != nullptr) {
m_config.callbacks.on_connection_refused_cb(address);
}
break;
}
case FI_ETIMEDOUT:
ers::warning(
LibFabricCmError(ERS_HERE, "Active endpoint CM event error: FI_ETIMEDOUT"));
break;
default: {
std::lock_guard lock(m_active_endpoint_mutex);
ers::error(LibFabricCmError(
ERS_HERE,
std::format(
"Unhandled error in the Event Queue: {} (code: {}), provider specific: {} (code: {})",
fi_strerror(event.err_entry.err),
event.err_entry.err,
fi_eq_strerror(ep.get_endpoint().eq.get(),
event.err_entry.prov_errno,
event.err_entry.err_data,
nullptr,
0),
event.err_entry.prov_errno)));
}
}
return;
case -FI_EAGAIN: {
std::lock_guard lock(m_active_endpoint_mutex);
auto* fp = &ep.get_endpoint().eq->fid;
fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
break;
}
default:
throw LibFabricCmError(
ERS_HERE, std::format("Unexpected event {} in active endpoint Event Queue", event.event));
}
}
}
void netio3::libfabric::ConnectionManager::handle_connection_request(ListenSocket& lendpoint,
FiInfoUniquePtr&& info)
{
ZoneScoped;
// need to spawn new endpoint
ERS_DEBUG(1, "Received connection request");
auto active_endpoint =
ActiveEndpoint{lendpoint.get_address(),
m_config.mode,
get_endpoint_capabilities(lendpoint.get_connection_parameters()),
m_domain_manager->get_fabric(),
m_domain_manager->get_domain(),
std::move(info),
m_shared_buffer_manager->get_receive_context_manager().has_value()
? m_shared_buffer_manager->get_receive_context_manager()->get_srx_context()
: nullptr};
active_endpoint.complete_connection(ActiveEndpoint::ConnectionMode::Accept, lendpoint.get_pep());
const auto eq_ev_ctx =
EventContext{active_endpoint.get_endpoint().eqfd,
[this, eqfd = active_endpoint.get_endpoint().eqfd](int) { on_active_endpoint_cm_event(eqfd); }};
active_endpoint.set_eq_ev_ctx(eq_ev_ctx);
m_event_loop->register_fd(eq_ev_ctx);
m_pending_active_endpoints.emplace_back(std::move(active_endpoint), lendpoint.get_connection_parameters());
}
netio3::libfabric::ActiveEndpoint& netio3::libfabric::ConnectionManager::handle_pending_active_endpoint(const int eqfd)
{
ZoneScoped;
const auto it = std::ranges::find_if(
m_pending_active_endpoints, [eqfd](const auto& pair) { return pair.endpoint.get_endpoint().eqfd == eqfd; });
if (it == m_pending_active_endpoints.end()) {
throw LibFabricCmError(ERS_HERE, std::format("No pending active endpoint found for EQ fd {}", eqfd));
}
auto& pending_endpoint = it->endpoint;
const auto remote_address = pending_endpoint.get_address();
std::scoped_lock lock(
m_active_endpoint_mutex, m_active_endpoint_by_passive_mutex, m_receive_endpoint_mutex);
const auto [pair, inserted] = m_active_endpoints.try_emplace(remote_address, std::move(pending_endpoint));
m_active_endpoints_by_passive_address.emplace(
std::piecewise_construct,
std::forward_as_tuple(remote_address),
std::forward_as_tuple(m_active_endpoints.at(remote_address)));
ERS_DEBUG(1,
std::format("Created and connected endpoint. Eqfd: {}",
m_active_endpoints.at(remote_address).get_endpoint().eqfd));
enable_capabilities(remote_address, it->connection_params);
m_pending_active_endpoints.erase(it);
return pair->second;
}
void netio3::libfabric::ConnectionManager::unregister_endpoint(const CqCmFds& fds)
{
m_event_loop->remove_fd(fds.cm_fd);
if (fds.cq_fd >= 0) {
m_event_loop->remove_fd(fds.cq_fd);
}
if (fds.rcq_fd >= 0) {
m_event_loop->remove_fd(fds.rcq_fd);
}
}
void netio3::libfabric::ConnectionManager::handle_close_requests()
{
ZoneScoped;
CloseQueueItem item;
while (m_close_queue.try_pop(item)) {
switch (item.type) {
case CloseQueueItem::Type::listen:
do_close_listen_endpoint(item.address);
break;
case CloseQueueItem::Type::active:
do_close_active_endpoint(item.address);
break;
}
}
}
void netio3::libfabric::ConnectionManager::do_close_listen_endpoint(const EndPointAddress& address)
{
ZoneScoped;
std::scoped_lock lock(
m_active_endpoint_by_passive_mutex, m_listen_endpoint_mutex, m_active_endpoint_mutex);
if (not m_listen_endpoints.contains(address)) {
return;
}
// Listen endpoint has no CQ
m_event_loop->remove_fd(m_listen_endpoints.at(address).get_cq_cm_fds().cm_fd);
const auto [begin, end] = m_active_endpoints_by_passive_address.equal_range(address);
for (auto it = begin; it != end; ++it) {
unregister_endpoint(it->second.get_cq_cm_fds());
m_active_endpoints.erase(it->second.get_address());
}
m_listen_endpoints.erase(address);
m_active_endpoints_by_passive_address.erase(address);
}
void netio3::libfabric::ConnectionManager::do_close_active_endpoint(const EndPointAddress& address)
{
ZoneScoped;
auto pending_sends = std::invoke([this, &address]() -> std::vector<std::uint64_t> {
if (m_send_endpoints_zero_copy.contains(address)) {
return m_send_endpoints_zero_copy.apply(
address, [](std::unique_ptr<SendEndpointZeroCopy>& sendpoint) { return sendpoint->get_pending_sends(); });
}
return {};
});
m_send_endpoints_zero_copy.erase(address);
m_send_sendpoints_buffered.erase(address);
{
std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
m_receive_endpoints.erase(address);
}
{
std::lock_guard lock(m_active_endpoint_mutex);
unregister_endpoint(m_active_endpoints.at(address).get_cq_cm_fds());
m_active_endpoints.erase(address);
}
if (m_config.callbacks.on_connection_closed_cb != nullptr) {
m_config.callbacks.on_connection_closed_cb(address, pending_sends);
}
}
void netio3::libfabric::ConnectionManager::unregister_all()
{
ZoneScoped;
std::scoped_lock lock(m_active_endpoint_mutex, m_listen_endpoint_mutex);
for (const auto& [ep, lendpoint] : m_listen_endpoints) {
unregister_endpoint(lendpoint.get_cq_cm_fds());
}
for (const auto& [ep, endpoint] : m_active_endpoints) {
unregister_endpoint(endpoint.get_cq_cm_fds());
}
}