Program Listing for File ConnectionManager.cpp

Return to documentation for file (BackendLibfabric/ConnectionManager.cpp)

#include "ConnectionManager.hpp"

#include <algorithm>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <tuple>
#include <utility>

#include <tracy/Tracy.hpp>

#include <rdma/fabric.h>
#include <rdma/fi_eq.h>

#include "ActiveEndpoint.hpp"
#include "Issues.hpp"
#include "netio3-backend/EventLoop/BaseEventLoop.hpp"
#include "netio3-backend/Netio3Backend.hpp"
#include "netio3-backend/Issues.hpp"
#include "SendEndpointBuffered.hpp"

netio3::libfabric::ConnectionManager::ConnectionManager(BaseEventLoop* event_loop,
                                                        NetworkConfig config) :
  m_config{std::move(config)},
  m_event_loop{event_loop},
  m_close_signal{m_event_loop, [this] () { handle_close_requests(); }, false },
  m_send_sendpoints_buffered{config.thread_safety == ThreadSafetyModel::SAFE},
  m_send_endpoints_zero_copy{config.thread_safety == ThreadSafetyModel::SAFE}
{}

netio3::libfabric::ConnectionManager::~ConnectionManager()
{
  if (m_event_loop->get_thread_id() == std::this_thread::get_id()) {
    unregister_all();
  } else if (m_event_loop->is_running()) {
    std::mutex mutex;
    std::condition_variable cv;
    std::atomic_bool unregistered{false};
    const auto signal = m_event_loop->create_signal(
      [this, &unregistered, &cv](int) {
        unregister_all();
        unregistered = true;
        cv.notify_one();
      },
      false);
    signal.fire();
    std::unique_lock<std::mutex> lock(mutex);
    while (m_event_loop->is_running() and !unregistered.load()) {
      constexpr static auto timeout = std::chrono::seconds(1);
      if (cv.wait_for(lock, timeout, [&unregistered] { return unregistered.load(); })) {
        break;
      }
    }
  }
}

netio3::EndPointAddress netio3::libfabric::ConnectionManager::open_listen_endpoint(
  const EndPointAddress& address,
  ConnectionParameters conn_params)
{
  try {
    validate_capabilities(conn_params);
  } catch (const FailedOpenActiveEndpoint& e) {
    throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  std::lock_guard lock(m_listen_endpoint_mutex);
  if (m_listen_endpoints.contains(address)) {
    throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  try {
    check_and_init(address, FI_SOURCE);
  } catch (const LibfabricDomainError& e) {
    throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  if (m_shared_buffer_manager->get_receive_context_manager().has_value()) {
    ERS_INFO(std::format("Using shared receive context manager for listen endpoint at {}:{}. "
                         "Ignoring connection parameters passed to open_listen_endpoint",
                         address.address(),
                         address.port()));
  }
  auto lendpoint_to_insert =
    ListenSocket(address, conn_params, m_config.mode, *m_domain_manager);
  const auto pair =
    m_listen_endpoints.try_emplace(lendpoint_to_insert.get_address(), std::move(lendpoint_to_insert));
  auto& lendpoint = pair.first->second;
  const auto eq_ev_ctx = EventContext{
    lendpoint.get_endpoint().eqfd, [this, &lendpoint](int) { on_listen_endpoint_cm_event(lendpoint); }};
  m_event_loop->register_fd(eq_ev_ctx);
  return lendpoint.get_address();
}

void netio3::libfabric::ConnectionManager::open_active_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& conn_params)
{
  validate_capabilities(conn_params);
  std::lock_guard lock(m_active_endpoint_mutex);
  create_active_endpoint(address, conn_params);
  std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
  enable_capabilities(address, conn_params);
  m_active_endpoints.at(address).complete_connection(ActiveEndpoint::ConnectionMode::Connect);
  const auto ctx = EventContext{m_active_endpoints.at(address).get_endpoint().eqfd,
                                [this, eqfd = m_active_endpoints.at(address).get_endpoint().eqfd](
                                  int) { on_active_endpoint_cm_event(eqfd); }};
  m_active_endpoints.at(address).set_eq_ev_ctx(ctx);
  m_event_loop->register_fd(ctx);
}

void netio3::libfabric::ConnectionManager::close_listen_endpoint(const EndPointAddress& address)
{
  std::lock_guard lock(m_listen_endpoint_mutex);
  if (not m_listen_endpoints.contains(address)) {
    throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_close_queue.push({address, CloseQueueItem::Type::listen});
  m_close_signal.fire();
}

void netio3::libfabric::ConnectionManager::close_active_endpoint(const EndPointAddress& address)
{
  {
    std::lock_guard lock(m_active_endpoint_mutex);
    if (not m_active_endpoints.contains(address)) {
      throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
    }
  }
  m_close_queue.push({address, CloseQueueItem::Type::active});
  m_close_signal.fire();
}

std::size_t netio3::libfabric::ConnectionManager::get_num_available_buffers(
  const EndPointAddress& address)
{
  if (m_send_sendpoints_buffered.contains(address)) {
    return m_send_sendpoints_buffered.apply(
      address, [](std::unique_ptr<SendEndpointBuffered>& sendpoint) { return sendpoint->get_num_available_buffers(); });
  }
  if (m_send_endpoints_zero_copy.contains(address)) {
    return m_send_endpoints_zero_copy.apply(
      address, [](std::unique_ptr<SendEndpointZeroCopy>& sendpoint) { return sendpoint->get_num_available_buffers(); });
  }
  throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}

void netio3::libfabric::ConnectionManager::check_and_init(const EndPointAddress& address,
                                                          const std::uint64_t flags)
{
  if (m_domain_manager == nullptr) {
    m_domain_manager = std::make_unique<DomainManager>(m_config, address, flags);
  }
  if (m_cq_reactor == nullptr) {
    m_cq_reactor =
      std::make_unique<CqReactor>(m_domain_manager->get_fabric(), m_config.callbacks.on_data_cb);
  }
  if (m_shared_buffer_manager == nullptr) {
    m_shared_buffer_manager = std::make_unique<SharedBufferManager>(
      m_config, *m_domain_manager, m_event_loop);
  }
}

void netio3::libfabric::ConnectionManager::create_active_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  if (m_active_endpoints.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  try {
    check_and_init(address, 0);
  } catch (const LibfabricDomainError& e) {
    throw FailedOpenActiveEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  auto* srx_context = std::invoke([this]() -> fid_ep* {
    if (m_shared_buffer_manager->get_receive_context_manager().has_value()) {
      return m_shared_buffer_manager->get_receive_context_manager()->get_srx_context();
    }
    return nullptr;
  });
  m_active_endpoints.try_emplace(address,
                                 address,
                                 m_config.mode,
                                 get_endpoint_capabilities(connection_params),
                                 m_domain_manager->get_fabric(),
                                 m_domain_manager->get_domain(),
                                 0,
                                 srx_context);
}

void netio3::libfabric::ConnectionManager::validate_capabilities(
  const ConnectionParameters& connection_params) const
{
  if (connection_params.recv_params.use_shared_receive_buffers and
      not m_config.conn_params.recv_params.use_shared_receive_buffers) {
    throw InvalidConnectionParameters(
      "Shared receive buffers are not enabled in the backend configuration");
  }
  if (connection_params.send_zero_copy_params.use_shared_send_buffers and
      not m_config.conn_params.send_zero_copy_params.use_shared_send_buffers) {
    throw InvalidConnectionParameters(
      "Shared zero-copy send buffers are not enabled in the backend configuration");
  }
  if (connection_params.send_buffered_params.use_shared_send_buffers and
      not m_config.conn_params.send_buffered_params.use_shared_send_buffers) {
    throw InvalidConnectionParameters(
      "Shared buffered send buffers are not enabled in the backend configuration");
  }
  if ((connection_params.send_buffered_params.num_buf > 0 or
       connection_params.send_buffered_params.use_shared_send_buffers) and
      (connection_params.send_zero_copy_params.mr_start != nullptr or
       connection_params.send_zero_copy_params.use_shared_send_buffers)) {
    throw InvalidConnectionParameters(
      "Libfabric does not support buffered and zero-copy sending on the same endpoint");
  }
  if (connection_params.recv_params.use_shared_receive_buffers and
      connection_params.recv_params.num_buf > 0) {
    ers::warning(
      InvalidConnectionParameters("Shared receive buffers requested, but the number of receive "
                                  "buffers is set to a non-zero value. Value ignored."));
  }
  if (connection_params.send_buffered_params.use_shared_send_buffers and
      connection_params.send_buffered_params.num_buf > 0) {
    ers::warning(InvalidConnectionParameters(
      "Shared buffered send buffers requested, but the number of buffered send buffers is set to a "
      "non-zero value. Value ignored."));
  }
  if (connection_params.send_zero_copy_params.use_shared_send_buffers and
      connection_params.send_zero_copy_params.mr_start != nullptr) {
    ers::warning(InvalidConnectionParameters(
      "Shared zero-copy send buffers requested, but the zero-copy send memory region is set to a "
      "non-null value. Value ignored."));
  }
}

void netio3::libfabric::ConnectionManager::enable_capabilities(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  if (connection_params.send_buffered_params.num_buf > 0 or
      connection_params.send_buffered_params.use_shared_send_buffers) {
    enable_buffered_sending(address, connection_params.send_buffered_params);
  }
  if (connection_params.send_zero_copy_params.mr_start != nullptr) {
    enable_zero_copy_sending(address, connection_params.send_zero_copy_params);
  }
  if (connection_params.recv_params.num_buf > 0 or
      connection_params.recv_params.use_shared_receive_buffers) {
    enable_receiving(address, connection_params.recv_params);
  }
}

void netio3::libfabric::ConnectionManager::enable_buffered_sending(
  const EndPointAddress& address,
  ConnectionParametersSendBuffered conn_params)
{
  if (m_send_sendpoints_buffered.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  if (not m_active_endpoints.contains(address)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE, address.address(), address.port(), "No active endpoint found");
  }
  m_send_sendpoints_buffered.try_emplace(
    address,
    std::make_unique<SendEndpointBuffered>(m_active_endpoints.at(address),
                                           conn_params,
                                           m_shared_buffer_manager->get_send_buffer_manager(),
                                           *m_domain_manager.get()));
}

void netio3::libfabric::ConnectionManager::enable_zero_copy_sending(
  const EndPointAddress& address,
  ConnectionParametersSendZeroCopy conn_params)
{
  if (m_send_endpoints_zero_copy.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  if (not m_active_endpoints.contains(address)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE, address.address(), address.port(), "No active endpoint found");
  }
  m_send_endpoints_zero_copy.try_emplace(
    address,
    std::make_unique<SendEndpointZeroCopy>(m_active_endpoints.at(address),
                                           conn_params,
                                           m_shared_buffer_manager->get_zero_copy_buffer_manager(),
                                           *m_domain_manager));
}

void netio3::libfabric::ConnectionManager::enable_receiving(const EndPointAddress& address,
                                                            ConnectionParametersRecv conn_params)
{
  if (m_receive_endpoints.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  if (not m_active_endpoints.contains(address)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE, address.address(), address.port(), "No active endpoint found");
  }
  m_receive_endpoints.try_emplace(address,
                                  m_active_endpoints.at(address),
                                  conn_params,
                                  m_shared_buffer_manager->get_receive_context_manager(),
                                  *m_domain_manager,
                                  m_event_loop);
}

netio3::libfabric::ConnectionManager::CmEvent netio3::libfabric::ConnectionManager::read_cm_event(
  fid_eq* eq)
{
  ZoneScoped;
  std::uint32_t event{};
  fi_eq_cm_entry entry{};
  fi_info* info{nullptr};
  fi_eq_err_entry err_entry{};

  const auto rd = fi_eq_sread(eq, &event, &entry, sizeof(entry), 0, 0);
  if (rd < 0) {
    if (rd == -FI_EAGAIN) {
      return {rd, err_entry, FiInfoUniquePtr{info}};
    }
    if (rd == -FI_EAVAIL) {
      const auto r = fi_eq_readerr(eq, &err_entry, 0);
      if (r < 0) {
        ers::error(LibFabricCmError(
          ERS_HERE, std::format("Failed to retrieve details on Event Queue error {}", r)));
      }
      ers::error(LibFabricCmError(
        ERS_HERE,
        std::format("Event Queue error: {} (code: {}),  provider specific: {} (code: {})",
                    fi_strerror(err_entry.err),
                    err_entry.err,
                    fi_eq_strerror(eq, err_entry.prov_errno, err_entry.err_data, nullptr, 0),
                    err_entry.prov_errno)));
      return {rd, err_entry, FiInfoUniquePtr{info}};
    }
  }
  if (rd != sizeof(entry)) {
    ers::error(LibFabricCmError(ERS_HERE, std::format("Failed to read from Event Queue: {}", rd)));
  }

  return {event, err_entry, FiInfoUniquePtr{entry.info}};
}

void netio3::libfabric::ConnectionManager::on_listen_endpoint_cm_event(ListenSocket& lendpoint)
{
  ZoneScoped;
  ERS_DEBUG(1, "listen endpoint: connection event");
  auto event = read_cm_event(lendpoint.get_endpoint().eq.get());

  switch (event.event) {
  case FI_CONNREQ: {
    ERS_DEBUG(2, ": FI_CONNREQ");
    handle_connection_request(lendpoint, std::move(event.info));
  } break;

  case FI_CONNECTED:
    throw LibFabricCmError(ERS_HERE, "FI_CONNECTED received on listen endpoint");

  case FI_SHUTDOWN:
    throw LibFabricCmError(ERS_HERE, "FI_SHUTDOWN received on listen endpoint");

  case -FI_EAGAIN: {
    auto* fp = &lendpoint.get_endpoint().eq->fid;
    fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
    break;
  }

  case -FI_EAVAIL:
    ers::error(LibFabricCmError(
      ERS_HERE,
      std::format("Unhandled error in listen endpoint EQ code: {} provider specific code: {}",
                  event.err_entry.err,
                  event.err_entry.prov_errno)));
    break;

  default:
    ers::warning(LibFabricCmError(
      ERS_HERE, std::format("Unexpected event {} in listen endpoint Event Queue", event.event)));
    break;
  }
}

void netio3::libfabric::ConnectionManager::on_active_endpoint_cm_event(const int eqfd)
{
  ZoneScoped;
  ERS_DEBUG(1, "active endpoint: connection event");
  const auto [ep, is_pending] = std::invoke([this, eqfd]() -> std::tuple<ActiveEndpoint&, bool> {
    std::lock_guard lock(m_active_endpoint_mutex);
    const auto it_emplaced_ep = std::ranges::find_if(
      m_active_endpoints, [eqfd](const auto& pair) { return pair.second.get_endpoint().eqfd == eqfd; });
    if (it_emplaced_ep != m_active_endpoints.end()) {
      return {it_emplaced_ep->second, false};
    }
    const auto it_pending_ep = std::ranges::find_if(
      m_pending_active_endpoints, [eqfd](const auto& pair) { return pair.endpoint.get_endpoint().eqfd == eqfd; });
    if (it_pending_ep != m_pending_active_endpoints.end()) {
      return {it_pending_ep->endpoint, true};
    }
    throw LibFabricCmError(ERS_HERE, std::format("No active endpoint found for EQ fd {}", eqfd));
  });
  {

    const auto event = std::invoke([this, &ep] {
      std::lock_guard lock(m_active_endpoint_mutex);
      return read_cm_event(ep.get_endpoint().eq.get());
    });

    switch (event.event) {
    case FI_SHUTDOWN: {
      const auto address = ep.get_address();
      ERS_INFO(std::format("Closed connection to {}:{} because received FI_SHUTDOWN",
                           address.address(),
                           address.port()));
      do_close_active_endpoint(address);
      return;
    }

    case FI_CONNECTED: {
      ep.update_addresses();
      auto* active_ep = &ep;
      if (is_pending) {
        auto& moved_ep = handle_pending_active_endpoint(eqfd);
        active_ep = &moved_ep;
      }
      std::unique_lock lock(m_active_endpoint_mutex);
      const auto address = active_ep->get_address();
      if (m_send_sendpoints_buffered.contains(address)) {
        m_active_endpoints.at(address).set_cq_ev_ctx(
          {m_active_endpoints.at(address).get_endpoint().cqfd, [this, address](int) {
             const auto keys =
               m_send_sendpoints_buffered.apply(address, [this](std::unique_ptr<SendEndpointBuffered>& sendpoint_cb) {
                 return m_cq_reactor->on_send_cq_event(*sendpoint_cb);
               });
             if (m_config.callbacks.on_send_completed_cb != nullptr) {
               for (const auto key : keys) {
                 m_config.callbacks.on_send_completed_cb(address, key);
               }
             }
           }});
      } else if (m_send_endpoints_zero_copy.contains(address)) {
        m_active_endpoints.at(address).set_cq_ev_ctx(
          {m_active_endpoints.at(address).get_endpoint().cqfd, [this, address](int) {
             const auto keys =
               m_send_endpoints_zero_copy.apply(address, [this](std::unique_ptr<SendEndpointZeroCopy>& sendpoint_cb) {
                 return m_cq_reactor->on_send_cq_event(*sendpoint_cb);
               });
             if (m_config.callbacks.on_send_completed_cb != nullptr) {
               for (const auto key : keys) {
                 m_config.callbacks.on_send_completed_cb(address, key);
               }
             }
           }});
      }
      {
        std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
        if (m_receive_endpoints.contains(address)) {
          m_active_endpoints.at(address).set_rcq_ev_ctx(
            {m_active_endpoints.at(address).get_endpoint().rcqfd, [this, address](int) {
               m_cq_reactor->on_recv_cq_event(m_receive_endpoints.at(address));
             }});
          m_event_loop->register_fd(m_active_endpoints.at(address).get_rcq_ev_ctx());
        }
      }
      ERS_DEBUG(1,
                std::format("Active endpoint: EQ fd {} connected, CQ fd {}",
                            active_ep->get_endpoint().eqfd,
                            active_ep->get_endpoint().cqfd));
      ERS_INFO(std::format("Opened send connection to {}:{}", address.address(), address.port()));
      const auto ctx = m_active_endpoints.at(address).get_cq_ev_ctx();
      if (ctx.fd >= 0) {
        m_event_loop->register_fd(ctx);
      }
      if (m_config.callbacks.on_connection_established_cb != nullptr) {
        const auto local_address = m_active_endpoints.at(address).get_local_address();
        const auto capabilities = m_active_endpoints.at(address).get_capabilities();
        lock.unlock(); // Unlock before calling the callback to avoid deadlock
        m_config.callbacks.on_connection_established_cb(
          address,
          local_address,
          capabilities);
        lock.lock(); // Re-lock after the callback
      }
      return;
    }

    case FI_MR_COMPLETE:
    case FI_AV_COMPLETE:
    case FI_JOIN_COMPLETE:
      // Not implemented
      break;

    case -FI_EAVAIL:
      switch (event.err_entry.err) {
      case FI_ECONNREFUSED: {
        ERS_DEBUG(1, "Connection refused (FI_ECONNREFUSED). Deallocating active endpoint resources");
        std::scoped_lock lock(m_active_endpoint_mutex, m_receive_endpoint_mutex);
        const auto address = ep.get_address();
        m_event_loop->remove_fd(ep.get_endpoint().eqfd);
        m_send_endpoints_zero_copy.erase(address);
        m_send_sendpoints_buffered.erase(address);
        m_receive_endpoints.erase(address);
        m_active_endpoints.erase(address);
        if (m_config.callbacks.on_connection_refused_cb != nullptr) {
          m_config.callbacks.on_connection_refused_cb(address);
        }
        break;
      }
      case FI_ETIMEDOUT:
        ers::warning(
          LibFabricCmError(ERS_HERE, "Active endpoint CM event error: FI_ETIMEDOUT"));
        break;

      default: {
        std::lock_guard lock(m_active_endpoint_mutex);
        ers::error(LibFabricCmError(
          ERS_HERE,
          std::format(
            "Unhandled error in the Event Queue: {} (code: {}), provider specific: {} (code: {})",
            fi_strerror(event.err_entry.err),
            event.err_entry.err,
            fi_eq_strerror(ep.get_endpoint().eq.get(),
                           event.err_entry.prov_errno,
                           event.err_entry.err_data,
                           nullptr,
                           0),
            event.err_entry.prov_errno)));
      }
      }
      return;

    case -FI_EAGAIN: {
      std::lock_guard lock(m_active_endpoint_mutex);
      auto* fp = &ep.get_endpoint().eq->fid;
      fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
      break;
    }
    default:
      throw LibFabricCmError(
        ERS_HERE, std::format("Unexpected event {} in active endpoint Event Queue", event.event));
    }
  }
}

void netio3::libfabric::ConnectionManager::handle_connection_request(ListenSocket& lendpoint,
                                                                     FiInfoUniquePtr&& info)
{
  ZoneScoped;
  // need to spawn new endpoint
  ERS_DEBUG(1, "Received connection request");
  auto active_endpoint =
    ActiveEndpoint{lendpoint.get_address(),
                   m_config.mode,
                   get_endpoint_capabilities(lendpoint.get_connection_parameters()),
                   m_domain_manager->get_fabric(),
                   m_domain_manager->get_domain(),
                   std::move(info),
                   m_shared_buffer_manager->get_receive_context_manager().has_value()
                     ? m_shared_buffer_manager->get_receive_context_manager()->get_srx_context()
                     : nullptr};
  active_endpoint.complete_connection(ActiveEndpoint::ConnectionMode::Accept, lendpoint.get_pep());
  const auto eq_ev_ctx =
    EventContext{active_endpoint.get_endpoint().eqfd,
                 [this, eqfd = active_endpoint.get_endpoint().eqfd](int) { on_active_endpoint_cm_event(eqfd); }};
  active_endpoint.set_eq_ev_ctx(eq_ev_ctx);
  m_event_loop->register_fd(eq_ev_ctx);
  m_pending_active_endpoints.emplace_back(std::move(active_endpoint), lendpoint.get_connection_parameters());
}

netio3::libfabric::ActiveEndpoint& netio3::libfabric::ConnectionManager::handle_pending_active_endpoint(const int eqfd)
{
  ZoneScoped;
  const auto it = std::ranges::find_if(
    m_pending_active_endpoints, [eqfd](const auto& pair) { return pair.endpoint.get_endpoint().eqfd == eqfd; });
  if (it == m_pending_active_endpoints.end()) {
    throw LibFabricCmError(ERS_HERE, std::format("No pending active endpoint found for EQ fd {}", eqfd));
  }
  auto& pending_endpoint = it->endpoint;
  const auto remote_address = pending_endpoint.get_address();
  std::scoped_lock lock(
    m_active_endpoint_mutex, m_active_endpoint_by_passive_mutex, m_receive_endpoint_mutex);
  const auto [pair, inserted] = m_active_endpoints.try_emplace(remote_address, std::move(pending_endpoint));
  m_active_endpoints_by_passive_address.emplace(
    std::piecewise_construct,
    std::forward_as_tuple(remote_address),
    std::forward_as_tuple(m_active_endpoints.at(remote_address)));
  ERS_DEBUG(1,
            std::format("Created and connected endpoint. Eqfd: {}",
                        m_active_endpoints.at(remote_address).get_endpoint().eqfd));
  enable_capabilities(remote_address, it->connection_params);
  m_pending_active_endpoints.erase(it);
  return pair->second;
}

void netio3::libfabric::ConnectionManager::unregister_endpoint(const CqCmFds& fds)
{
  m_event_loop->remove_fd(fds.cm_fd);
  if (fds.cq_fd >= 0) {
    m_event_loop->remove_fd(fds.cq_fd);
  }
  if (fds.rcq_fd >= 0) {
    m_event_loop->remove_fd(fds.rcq_fd);
  }
}

void netio3::libfabric::ConnectionManager::handle_close_requests()
{
  ZoneScoped;
  CloseQueueItem item;
  while (m_close_queue.try_pop(item)) {
    switch (item.type) {
    case CloseQueueItem::Type::listen:
      do_close_listen_endpoint(item.address);
      break;
    case CloseQueueItem::Type::active:
      do_close_active_endpoint(item.address);
      break;
    }
  }
}

void netio3::libfabric::ConnectionManager::do_close_listen_endpoint(const EndPointAddress& address)
{
  ZoneScoped;
  std::scoped_lock lock(
    m_active_endpoint_by_passive_mutex, m_listen_endpoint_mutex, m_active_endpoint_mutex);
  if (not m_listen_endpoints.contains(address)) {
    return;
  }
  // Listen endpoint has no CQ
  m_event_loop->remove_fd(m_listen_endpoints.at(address).get_cq_cm_fds().cm_fd);
  const auto [begin, end] = m_active_endpoints_by_passive_address.equal_range(address);
  for (auto it = begin; it != end; ++it) {
    unregister_endpoint(it->second.get_cq_cm_fds());
    m_active_endpoints.erase(it->second.get_address());
  }
  m_listen_endpoints.erase(address);
  m_active_endpoints_by_passive_address.erase(address);
}

void netio3::libfabric::ConnectionManager::do_close_active_endpoint(const EndPointAddress& address)
{
  ZoneScoped;
  auto pending_sends = std::invoke([this, &address]() -> std::vector<std::uint64_t> {
    if (m_send_endpoints_zero_copy.contains(address)) {
      return m_send_endpoints_zero_copy.apply(
        address, [](std::unique_ptr<SendEndpointZeroCopy>& sendpoint) { return sendpoint->get_pending_sends(); });
    }
    return {};
  });
  m_send_endpoints_zero_copy.erase(address);
  m_send_sendpoints_buffered.erase(address);
  {
    std::lock_guard lock_receive_endpoint(m_receive_endpoint_mutex);
    m_receive_endpoints.erase(address);
  }
  {
    std::lock_guard lock(m_active_endpoint_mutex);
    unregister_endpoint(m_active_endpoints.at(address).get_cq_cm_fds());
    m_active_endpoints.erase(address);
  }
  if (m_config.callbacks.on_connection_closed_cb != nullptr) {
    m_config.callbacks.on_connection_closed_cb(address, pending_sends);
  }
}

void netio3::libfabric::ConnectionManager::unregister_all()
{
  ZoneScoped;
  std::scoped_lock lock(m_active_endpoint_mutex, m_listen_endpoint_mutex);
  for (const auto& [ep, lendpoint] : m_listen_endpoints) {
    unregister_endpoint(lendpoint.get_cq_cm_fds());
  }
  for (const auto& [ep, endpoint] : m_active_endpoints) {
    unregister_endpoint(endpoint.get_cq_cm_fds());
  }
}