Program Listing for File ConnectionManager.cpp

Return to documentation for file (BackendLibfabric/ConnectionManager.cpp)

#include "ConnectionManager.hpp"

#include <algorithm>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <stdexcept>
#include <utility>

#include <tracy/Tracy.hpp>

#include <rdma/fabric.h>
#include <rdma/fi_eq.h>

#include "Issues.hpp"
#include "ReceiveSocket.hpp"
#include "SendSocketZeroCopy.hpp"
#include "netio3-backend/EventLoop/BaseEventLoop.hpp"
#include "netio3-backend/Netio3Backend.hpp"
#include "netio3-backend/Issues.hpp"

netio3::libfabric::ConnectionManager::ConnectionManager(BaseEventLoop* event_loop,
                                                        NetworkConfig config) :
  m_config{std::move(config)},
  m_event_loop{event_loop},
  m_ssockets_buffered{config.thread_safety == ThreadSafetyModel::SAFE},
  m_ssockets_zero_copy{config.thread_safety == ThreadSafetyModel::SAFE}
{}

netio3::libfabric::ConnectionManager::~ConnectionManager()
{
  if (m_event_loop->get_thread_id() == std::this_thread::get_id()) {
    unregister_all();
  } else if (m_event_loop->is_running()) {
    std::mutex mutex;
    std::condition_variable cv;
    std::atomic_bool unregistered{false};
    const auto signal = m_event_loop->create_signal(
      [this, &unregistered, &cv](int) {
        unregister_all();
        unregistered = true;
        cv.notify_one();
      },
      false);
    signal.fire();
    std::unique_lock<std::mutex> lock(mutex);
    while (m_event_loop->is_running() and !unregistered.load()) {
      constexpr static auto timeout = std::chrono::seconds(1);
      if (cv.wait_for(lock, timeout, [&unregistered] { return unregistered.load(); })) {
        break;
      }
    }
  }
}

std::shared_ptr<netio3::libfabric::ConnectionManager> netio3::libfabric::ConnectionManager::create(
  BaseEventLoop* event_loop,
  NetworkConfig config)
{
  auto instance = std::make_shared<ConnectionManager>(event_loop, config);
  instance->init();
  return instance;
}

void netio3::libfabric::ConnectionManager::init()
{
  m_close_signal = m_event_loop->create_signal(
    [weak_this = weak_from_this()](int) {
      if (auto shared_this = weak_this.lock()) {
        shared_this->handle_close_requests();
      }
    },
    false);
}

netio3::EndPointAddress netio3::libfabric::ConnectionManager::open_listen_endpoint(
  const EndPointAddress& address,
  ConnectionParametersRecv conn_params)
{
  if (m_lsockets.contains(address)) {
    throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  try {
    check_and_init(address, FI_RECV);
  } catch (const LibfabricDomainError& e) {
    throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  auto lsocket_to_insert =
    ListenSocket(address, conn_params, m_config.mode, m_domain_manager->get_fabric());
  const auto pair =
    m_lsockets.try_emplace(lsocket_to_insert.get_address(), std::move(lsocket_to_insert));
  auto& lsocket = pair.first->second;
  const auto eq_ev_ctx = EventContext{
    lsocket.get_endpoint().eqfd, [this, &lsocket](int) { on_listen_socket_cm_event(lsocket); }};
  lsocket.set_eq_ev_ctx(eq_ev_ctx);
  m_event_loop->register_fd(eq_ev_ctx);
  return lsocket.get_address();
}

void netio3::libfabric::ConnectionManager::open_send_endpoint_buffered(
  const EndPointAddress& address,
  ConnectionParameters conn_params)
{
  if (m_ssockets_buffered.contains(address)) {
    throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  try {
    check_and_init(address, FI_SEND);
  } catch (const LibfabricDomainError& e) {
    throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  m_ssockets_buffered.try_emplace(address,
                                  address,
                                  conn_params,
                                  m_config.mode,
                                  m_domain_manager->get_fabric(),
                                  m_domain_manager->get_send_domain());
  const auto ctx =
    m_ssockets_buffered.apply(address, [this, &address](SendSocketBuffered& ssocket) {
      const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) {
                                            on_send_socket_cm_event<SendSocketBuffered>(address);
                                          }};
      ssocket.set_eq_ev_ctx(eq_ev_ctx);
      return eq_ev_ctx;
    });
  m_event_loop->register_fd(ctx);
}

void netio3::libfabric::ConnectionManager::open_send_endpoint_zero_copy(
  const EndPointAddress& address,
  ConnectionParameters conn_params)
{
  if (m_ssockets_zero_copy.contains(address)) {
    throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  try {
    check_and_init(address, FI_SEND);
  } catch (const LibfabricDomainError& e) {
    throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  m_ssockets_zero_copy.try_emplace(address,
                                   address,
                                   conn_params,
                                   m_config.mode,
                                   m_domain_manager->get_fabric(),
                                   m_domain_manager->get_send_domain());
  const auto ctx =
    m_ssockets_zero_copy.apply(address, [this, &address](SendSocketZeroCopy& ssocket) {
      const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) {
                                            on_send_socket_cm_event<SendSocketZeroCopy>(address);
                                          }};
      ssocket.set_eq_ev_ctx(eq_ev_ctx);
      return eq_ev_ctx;
    });
  m_event_loop->register_fd(ctx);
}

void netio3::libfabric::ConnectionManager::close_listen_endpoint(const EndPointAddress& address)
{
  if (not m_lsockets.contains(address)) {
    throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_close_queue.push({address, CloseQueueItem::Type::listen});
  m_close_signal->fire();
}

void netio3::libfabric::ConnectionManager::close_send_endpoint(const EndPointAddress& address)
{
  if (not m_ssockets_buffered.contains(address) and not m_ssockets_zero_copy.contains(address)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_close_queue.push({address, CloseQueueItem::Type::send});
  m_close_signal->fire();
}

std::size_t netio3::libfabric::ConnectionManager::get_num_available_buffers(
  const EndPointAddress& address)
{
  if (m_ssockets_buffered.contains(address)) {
    return m_ssockets_buffered.apply(
      address, [](SendSocketBuffered& ssocket) { return ssocket.get_num_available_buffers(); });
  }
  if (m_ssockets_zero_copy.contains(address)) {
    return m_ssockets_zero_copy.apply(
      address, [](SendSocketZeroCopy& ssocket) { return ssocket.get_num_available_buffers(); });
  }
  throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
}

void netio3::libfabric::ConnectionManager::check_and_init(const EndPointAddress& address,
                                                          const std::uint64_t flags)
{
  if (m_domain_manager == nullptr) {
    m_domain_manager = std::make_unique<DomainManager>(m_config.mode, address, flags);
  }
  if (m_cq_reactor == nullptr) {
    m_cq_reactor = std::make_unique<CqReactor>(m_domain_manager->get_fabric(), m_config.on_data_cb);
  }
}

netio3::libfabric::ConnectionManager::CmEvent netio3::libfabric::ConnectionManager::read_cm_event(
  fid_eq* eq)
{
  ZoneScoped;
  std::uint32_t event{};
  fi_eq_cm_entry entry{};
  fi_info* info{nullptr};
  fi_eq_err_entry err_entry{};

  const auto rd = fi_eq_sread(eq, &event, &entry, sizeof(entry), 0, 0);
  if (rd < 0) {
    if (rd == -FI_EAGAIN) {
      return {rd, err_entry, FiInfoUniquePtr{info}};
    }
    if (rd == -FI_EAVAIL) {
      const auto r = fi_eq_readerr(eq, &err_entry, 0);
      if (r < 0) {
        ers::error(LibFabricCmError(
          ERS_HERE, std::format("Failed to retrieve details on Event Queue error {}", r)));
      }
      ers::error(LibFabricCmError(
        ERS_HERE,
        std::format("Event Queue error: {} (code: {}),  provider specific: {} (code: {})",
                    fi_strerror(err_entry.err),
                    err_entry.err,
                    fi_eq_strerror(eq, err_entry.prov_errno, err_entry.err_data, nullptr, 0),
                    err_entry.prov_errno)));
      return {rd, err_entry, FiInfoUniquePtr{info}};
    }
  }
  if (rd != sizeof(entry)) {
    ers::error(LibFabricCmError(ERS_HERE, std::format("Failed to read from Event Queue: {}", rd)));
  }

  return {event, err_entry, FiInfoUniquePtr{entry.info}};
}

void netio3::libfabric::ConnectionManager::on_listen_socket_cm_event(ListenSocket& lsocket)
{
  ZoneScoped;
  ERS_DEBUG(1, "listen socket: connection event");
  auto event = read_cm_event(lsocket.get_endpoint().eq.get());

  switch (event.event) {
  case FI_CONNREQ: {
    ERS_DEBUG(2, ": FI_CONNREQ");
    handle_connection_request(lsocket, std::move(event.info));
  } break;

  case FI_CONNECTED:
    throw LibFabricCmError(ERS_HERE, "FI_CONNECTED received on listen socket");

  case FI_SHUTDOWN:
    throw LibFabricCmError(ERS_HERE, "FI_SHUTDOWN received on listen socket");

  case -FI_EAGAIN: {
    auto* fp = &lsocket.get_endpoint().eq->fid;
    fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
    break;
  }

  case -FI_EAVAIL:
    ers::error(LibFabricCmError(
      ERS_HERE,
      std::format("Unhandled error in listen socket EQ code: {} provider specific code: {}",
                  event.err_entry.err,
                  event.err_entry.prov_errno)));
    break;

  default:
    ers::warning(LibFabricCmError(
      ERS_HERE, std::format("Unexpected event {} in listen socket Event Queue", event.event)));
    break;
  }
}

void netio3::libfabric::ConnectionManager::on_recv_socket_cm_event(ReceiveSocket& rsocket)
{
  ZoneScoped;
  ERS_DEBUG(2, "Entered");
  const auto event = read_cm_event(rsocket.get_endpoint().eq.get());

  switch (event.event) {
  case FI_CONNECTED: {
    ERS_DEBUG(2, "FI_CONNECTED");
    if (const auto ret =
          fi_control(&rsocket.get_endpoint().rcq->fid, FI_GETWAIT, &rsocket.get_endpoint().cqfd)) {
      ers::error(LibFabricError(
        ERS_HERE,
        std::format(
          "Failed to retrieve recv socket Completion Queue wait object: error {} cqfd: {}",
          ret,
          rsocket.get_endpoint().cqfd)));
    }

    rsocket.set_cq_ev_ctx({rsocket.get_endpoint().cqfd, [this, &rsocket](int) {
                             m_cq_reactor->on_recv_socket_cq_event(rsocket);
                           }});
    ERS_DEBUG(1,
              std::format("recv_socket: EQ fd {} connected, CQ fd {}",
                          rsocket.get_endpoint().eqfd,
                          rsocket.get_endpoint().cqfd));
    m_event_loop->register_fd(rsocket.get_cq_ev_ctx());

    ERS_DEBUG(1,
              std::format("Adding recv CQ polled fid {} {}",
                          rsocket.get_endpoint().cqfd,
                          static_cast<void*>(&rsocket.get_endpoint().cq->fid)));
    ERS_DEBUG(1,
              std::format("recv_socket: EQ fd {} CQ fd {} connected",
                          rsocket.get_endpoint().eqfd,
                          rsocket.get_endpoint().cqfd));
    const auto address = peer_address(rsocket.get_endpoint().ep.get());
    ERS_INFO(std::format("Opened receive connection to {}:{}", address.address(), address.port()));
    if (m_config.on_connection_established_cb != nullptr) {
      m_config.on_connection_established_cb(address);
    }

    break;
  }

  case FI_SHUTDOWN: {
    ERS_DEBUG(2, "FI_SHUTDOWN");
    const auto address = peer_address(rsocket.get_endpoint().ep.get());
    ERS_INFO(std::format("Closed receive connection to {}:{} because received FI_SHUTDOWN",
                         address.address(),
                         address.port()));
    if (m_config.on_connection_closed_cb != nullptr) {
      m_config.on_connection_closed_cb(address, {});
    }
    std::lock_guard lock(m_rsocket_mutex);
    const auto entry = std::ranges::find_if(m_rsockets, [&rsocket](const auto& pair) {
      return pair.second.get_endpoint().eqfd == rsocket.get_endpoint().eqfd;
    });
    if (entry == std::cend(m_rsockets)) {
      ers::error(LibFabricError(ERS_HERE, "Failed to find ReceiveSocket in ConnectionManager"));
    } else {
      unregister_endpoint(rsocket.get_cq_cm_fds());
      m_rsockets.erase(entry);
    }
    break;
  }

  case FI_MR_COMPLETE:
  case FI_AV_COMPLETE:
  case FI_JOIN_COMPLETE:
    // Not implemented
    break;

  case -FI_EAGAIN: {
    auto* fp = &rsocket.get_endpoint().eq->fid;
    fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
    break;
  }

  case -FI_EAVAIL:
    ers::error(LibFabricCmError(
      ERS_HERE,
      std::format("Unhandled error in recv socket EQ code: {} provider specific code: {}",
                  event.err_entry.err,
                  event.err_entry.prov_errno)));
    break;

  default:
    throw LibFabricCmError(
      ERS_HERE, std::format("Unexpected event {} in recv socket Event Queue", event.event));
  }
}

template<netio3::libfabric::SendSocketConcept SocketType>
void netio3::libfabric::ConnectionManager::on_send_socket_cm_event(const EndPointAddress& address)
{
  ZoneScoped;
  ERS_DEBUG(1, "send socket: connection event");
  auto& map = [this]() -> ThreadSafeMap<EndPointAddress, SocketType>& {
    if constexpr (std::is_same_v<SocketType, SendSocketZeroCopy>) {
      return m_ssockets_zero_copy;
    } else if (std::is_same_v<SocketType, SendSocketBuffered>) {
      return m_ssockets_buffered;
    } else {
      throw LibFabricLogicError(ERS_HERE, "Unknown socket type");
    }
  }();
  const auto event = map.apply(
    address, [](SocketType& ssocket) { return read_cm_event(ssocket.get_endpoint().eq.get()); });

  switch (event.event) {
  case FI_SHUTDOWN: {
    ERS_INFO(std::format("Closed send connection to {}:{} because received FI_SHUTDOWN",
                         address.address(),
                         address.port()));
    const auto fds =
      map.apply(address, [](SocketType& ssocket) { return ssocket.get_cq_cm_fds(); });
    unregister_endpoint(fds);
    const auto pending_sends =
      map.apply(address, [](SocketType& ssocket) { return ssocket.get_pending_sends(); });
    map.erase(address);
    if (m_config.on_connection_closed_cb != nullptr) {
      m_config.on_connection_closed_cb(address, pending_sends);
    }
    return;
  }

  case FI_CONNECTED: {
    const auto ctx = map.apply(address, [this, &map, &address](SocketType& ssocket) {
      if (const auto ret =
            fi_control(&ssocket.get_endpoint().cq->fid, FI_GETWAIT, &ssocket.get_endpoint().cqfd)) {
        ers::error(LibFabricError(
          ERS_HERE,
          std::format(
            "Failed to retrieve wait object for send socket Completion Queue, error {} - {}",
            ret,
            fi_strerror(-ret))));
      }
      ssocket.set_cq_ev_ctx({ssocket.get_endpoint().cqfd, [this, &map, address](int) {
                               try {
                                 const auto keys =
                                   map.apply(address, [this](SocketType& ssocket_cb) {
                                     return m_cq_reactor->on_send_socket_cq_event(ssocket_cb);
                                   });
                                 if (m_config.on_send_completed_cb != nullptr) {
                                   for (const auto key : keys) {
                                     m_config.on_send_completed_cb(address, key);
                                   }
                                 }
                               } catch (const std::out_of_range& e) {
                                 return;
                               }
                             }});
      ERS_DEBUG(1,
                std::format("send_socket: EQ fd {} connected, CQ fd {}",
                            ssocket.get_endpoint().eqfd,
                            ssocket.get_endpoint().cqfd));
      return ssocket.get_cq_ev_ctx();
    });
    ERS_INFO(std::format("Opened send connection to {}:{}", address.address(), address.port()));
    m_event_loop->register_fd(ctx);
    if (m_config.on_connection_established_cb != nullptr) {
      m_config.on_connection_established_cb(address);
    }
    return;
  }

  case FI_MR_COMPLETE:
  case FI_AV_COMPLETE:
  case FI_JOIN_COMPLETE:
    // Not implemented
    break;

  case -FI_EAVAIL:
    switch (event.err_entry.err) {
    case FI_ECONNREFUSED:
      ERS_DEBUG(1, "Connection refused (FI_ECONNREFUSED). Deallocating send_socket resources");
      m_event_loop->remove_fd(
        map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eqfd; }));
      map.erase(address);
      if (m_config.on_connection_refused_cb != nullptr) {
        m_config.on_connection_refused_cb(address);
      }
      break;
    case FI_ETIMEDOUT:
      ers::warning(
        LibFabricCmError(ERS_HERE, "fi_verbs_process_send_socket_cm_event: FI_ETIMEDOUT"));
      // if (socket->eqfd < 0 ){
      //     log_info("Ignoring FI_SHUTDOWN on send_socket, invalid eqfd (socket already closed)");
      //     break;
      // }

      // // Need to take care of receive socket as well

      // if (socket->recv_socket != NULL){
      //     if (socket->recv_socket->eqfd < 0 ){
      //         log_info("Ignoring FI_ETIMEDOUT on recv_socket, invalid eqfd (socket already
      //         closed)"); return;
      //     }
      //     log_info("Shutting down receive socket on FI_ETIMEDOUT");
      //     handle_recv_socket_shutdown(socket->recv_socket);
      //     if(socket->recv_socket->lsocket->cb_connection_closed) {
      //         socket->recv_socket->lsocket->cb_connection_closed(socket->recv_socket);
      //     }
      // }

      // if(socket->cqfd < 0){ //cq not initalized yet
      //     handle_send_socket_shutdown_on_connetion_refused(socket);
      // } else {
      //     if(socket->cb_internal_connection_closed){
      //         socket->cb_internal_connection_closed(socket);
      //     }
      //     if(socket->cb_connection_closed) {
      //         socket->cb_connection_closed(socket);
      //     }
      // }
      break;

    default:
      ers::error(LibFabricError(
        ERS_HERE,
        std::format(
          "Unhandled error in the Event Queue: {} (code: {}), provider specific: {} (code: {})",
          fi_strerror(event.err_entry.err),
          event.err_entry.err,
          fi_eq_strerror(
            map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eq.get(); }),
            event.err_entry.prov_errno,
            event.err_entry.err_data,
            nullptr,
            0),
          event.err_entry.prov_errno)));
    }
    return;

  case -FI_EAGAIN: {
    auto* fp =
      map.apply(address, [](SocketType& ssocket) { return &ssocket.get_endpoint().eq->fid; });
    fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
    break;
  }
  default:
    throw LibFabricCmError(
      ERS_HERE, std::format("Unexpected event {} in send socket Event Queue", event.event));
  }
}

void netio3::libfabric::ConnectionManager::handle_connection_request(ListenSocket& lsocket,
                                                                     FiInfoUniquePtr&& info)
{
  ZoneScoped;
  // need to spawn new endpoint
  ERS_DEBUG(1, "Received connection request");
  std::lock_guard lock(m_rsocket_mutex);
  const auto pair = m_rsockets.emplace(
    std::piecewise_construct,
    std::forward_as_tuple(lsocket.get_address()),
    std::forward_as_tuple(
      lsocket, m_domain_manager->get_fabric(), m_domain_manager->get_listen_domain(), std::move(info), m_event_loop));
  auto& rsocket = pair->second;
  ERS_DEBUG(1,
            std::format("Created and connected endpoint. Eqfd: {}", rsocket.get_endpoint().eqfd));
  const auto eq_ev_ctx = EventContext{rsocket.get_endpoint().eqfd,
                                      [this, &rsocket](int) { on_recv_socket_cm_event(rsocket); }};
  rsocket.set_eq_ev_ctx(eq_ev_ctx);
  m_event_loop->register_fd(eq_ev_ctx);
}

void netio3::libfabric::ConnectionManager::unregister_endpoint(const CqCmFds& fds)
{
  m_event_loop->remove_fd(fds.cm_fd);
  m_event_loop->remove_fd(fds.cq_fd);
}

void netio3::libfabric::ConnectionManager::handle_close_requests()
{
  ZoneScoped;
  CloseQueueItem item;
  while (m_close_queue.try_pop(item)) {
    switch (item.type) {
    case CloseQueueItem::Type::listen:
      do_close_listen_endpoint(item.address);
      break;
    case CloseQueueItem::Type::send:
      do_close_send_endpoint(item.address);
      break;
    }
  }
}

void netio3::libfabric::ConnectionManager::do_close_listen_endpoint(const EndPointAddress& address)
{
  ZoneScoped;
  if (not m_lsockets.contains(address)) {
    return;
  }
  // Listen socket has no CQ
  m_event_loop->remove_fd(m_lsockets.at(address).get_cq_cm_fds().cm_fd);
  std::lock_guard lock(m_rsocket_mutex);
  const auto [begin, end] = m_rsockets.equal_range(address);
  for (auto it = begin; it != end; ++it) {
    unregister_endpoint(it->second.get_cq_cm_fds());
  }
  m_lsockets.erase(address);
  m_rsockets.erase(address);
}

void netio3::libfabric::ConnectionManager::do_close_send_endpoint(const EndPointAddress& address)
{
  ZoneScoped;
  std::vector<std::uint64_t> pending_sends{};
  if (m_ssockets_buffered.contains(address)) {
    const auto fds =
      m_ssockets_buffered.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); });
    unregister_endpoint(fds);
    pending_sends = m_ssockets_buffered.apply(address, [](auto& socket) {
      return socket.get_pending_sends();
    });
    m_ssockets_buffered.erase(address);
  } else if (m_ssockets_zero_copy.contains(address)) {
    const auto fds =
      m_ssockets_zero_copy.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); });
    unregister_endpoint(fds);
    pending_sends = m_ssockets_zero_copy.apply(address, [](SendSocketZeroCopy& socket) {
      return socket.get_pending_sends();
    });
    m_ssockets_zero_copy.erase(address);
  }

  if (m_config.on_connection_closed_cb != nullptr) {
    m_config.on_connection_closed_cb(address, pending_sends);
  }
}

void netio3::libfabric::ConnectionManager::unregister_all()
{
  ZoneScoped;
  for (const auto& [ep, rsocket] : m_rsockets) {
    unregister_endpoint(rsocket.get_cq_cm_fds());
  }
  for (const auto& [ep, lsocket] : m_lsockets) {
    unregister_endpoint(lsocket.get_cq_cm_fds());
  }
  m_ssockets_buffered.apply_all([this] (const auto& ssocket) {
    unregister_endpoint(ssocket.get_cq_cm_fds());
  });
  m_ssockets_zero_copy.apply_all([this] (const auto& ssocket) {
    unregister_endpoint(ssocket.get_cq_cm_fds());
  });
}

template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event<
  netio3::libfabric::SendSocketBuffered>(const EndPointAddress&);
template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event<
  netio3::libfabric::SendSocketZeroCopy>(const EndPointAddress&);