Program Listing for File ReceiveSocket.cpp

Return to documentation for file (BackendLibfabric/ReceiveSocket.cpp)

#include "ReceiveSocket.hpp"

#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>

#include <tracy/Tracy.hpp>

#include "BaseSocket.hpp"
#include "DomainManager.hpp"
#include "Helpers.hpp"
#include "Issues.hpp"
#include "netio3-backend/Issues.hpp"

netio3::libfabric::ReceiveSocket::ReceiveSocket(ListenSocket& lsocket,
                                                fid_fabric* fabric,
                                                DomainContext& domain,
                                                FiInfoUniquePtr&& info,
                                                BaseEventLoop* event_loop) :
  BaseSocket{create_endpoint(lsocket.get_address(), fabric, domain.get_domain(), std::move(info))},
  m_conn_params{prepare_connection_parameters(lsocket.get_connection_parameters())},
  m_address{peer_address(get_endpoint().ep.get())},
  m_retry_post_signal{event_loop->create_signal(
    [this](int) { retry_post_buffers(); }, false)}
{
  init_buffers(domain);
  ERS_DEBUG(1,
            std::format("connection accepted. Lsocket address: {}:{}, rsocket EQ {}",
                        lsocket.get_address().address(),
                        lsocket.get_address().port(),
                        get_eq_ev_ctx().fd));
  ERS_DEBUG(2, std::format("Allocating memory regions number of buffers: {}", m_buffers.size()));
  post_buffers();
  m_retry_post_buffers.reserve(m_conn_params.num_buf);

  const int ret = fi_accept(get_endpoint().ep.get(), nullptr, 0);
  if (ret != 0) {
    fi_reject(lsocket.get_pep(), info->handle, nullptr, 0);

    for (const auto& buf : m_buffers) {
      close_buffer(buf);
    }

    throw FailedOpenReceiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Listen socket, connection rejected, error {} - {}", ret, fi_strerror(-ret)));
  }

  ERS_DEBUG(2, std::format("Connection done, fd: {}", get_eq_ev_ctx().fd));
}

netio3::libfabric::ReceiveSocket::~ReceiveSocket()
{
  ERS_DEBUG(2, "Entered");
  for (const auto& buf : m_buffers) {
    close_buffer(buf);
  }
}

netio3::libfabric::Endpoint netio3::libfabric::ReceiveSocket::create_endpoint(
  const EndPointAddress& address,
  fid_fabric* fabric,
  fid_domain* domain,
  FiInfoUniquePtr&& info)
{
  ZoneScoped;
  auto ep = Endpoint{};
  ep.fi = std::move(info);
  try {
    open_endpoint(address, ep, fabric, domain, ep.fi.get());
    open_cq(address, ep, domain);
    enable_endpoint(address, ep);
  } catch (const FailedOpenEndpoint& e) {
    throw FailedOpenReceiveEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  return ep;
}

void netio3::libfabric::ReceiveSocket::open_cq(const EndPointAddress& address,
                                               Endpoint& ep,
                                               fid_domain* domain)
{
  ZoneScoped;
  auto cq_attr = prepare_cq_attr();

  fid_cq* rcq = nullptr;
  if (const auto ret = fi_cq_open(domain, &cq_attr, &rcq, nullptr)) {
    throw FailedOpenReceiveEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format("Failed to open Completion Queue for receive socket, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
  ep.rcq = FiCloseUniquePtr<fid_cq>(
    rcq, FiCloseDeleter<fid_cq>(address, "Failed to close receive socket Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
    throw FailedOpenReceiveEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format("Failed to bind Completion Queue for receive socket, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }

  cq_attr.format = FI_CQ_FORMAT_UNSPEC;
  cq_attr.wait_obj = FI_WAIT_NONE;

  // FI_TRANSMIT CQ
  fid_cq* cq = nullptr;
  if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
    throw FailedOpenReceiveEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.cq = FiCloseUniquePtr<fid_cq>(
    cq, FiCloseDeleter<fid_cq>(address, "Failed to close FI_RECV receive socket Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
    throw FailedOpenReceiveEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
}

void netio3::libfabric::ReceiveSocket::init_buffers(DomainContext& domain)
{
  m_buffers.reserve(m_conn_params.num_buf);
  for (uint64_t i = 0; i < m_conn_params.num_buf; i++) {
    // Key does not matter for receive buffers
    m_buffers.emplace_back(domain, m_conn_params.buf_size, 0);
    try {
      register_buffer(m_buffers.back(), domain, FI_RECV);
    } catch (const LibFabricBufferError& e) {
      throw FailedOpenReceiveEndpoint(ERS_HERE, m_address.address(), m_address.port(), e.message());
    }
  }
}

void netio3::libfabric::ReceiveSocket::post_buffers()
{
  ZoneScoped;
  for (auto& buf : m_buffers) {
    post_buffer(&buf);
  }
}

void netio3::libfabric::ReceiveSocket::post_buffer(Buffer* buf)
{
  ZoneScoped;
  const auto ret = do_post_buffer(buf);
  if (ret == -FI_EAGAIN) {
    m_retry_post_buffers.push_back(buf);
    m_retry_post_signal.fire();
    return;
  }
  if (ret != 0) {
    ers::error(LibFabricBufferError(
      ERS_HERE,
      std::format("Failed to post a buffer to receive inbound messages, error {} - {}",
                  ret,
                  fi_strerror(-ret))));
  }
}

void netio3::libfabric::ReceiveSocket::retry_post_buffers()
{
  ZoneScoped;
  auto still_needs_retry = std::vector<Buffer*>{};
  still_needs_retry.reserve(m_retry_post_buffers.size());
  for (auto* buf : m_retry_post_buffers) {
    const auto ret = do_post_buffer(buf);
    if (ret == -FI_EAGAIN) {
      still_needs_retry.push_back(buf);
    } else if (ret != 0) [[unlikely]] {
      ers::error(LibFabricBufferError(
        ERS_HERE,
        std::format("Failed to post a buffer to receive inbound messages, error {} - {}",
                    ret,
                    fi_strerror(-ret))));
    }
  }

  if (not still_needs_retry.empty()) {
    m_retry_post_buffers.swap(still_needs_retry);
    m_retry_post_signal.fire();
  } else {
    m_retry_post_buffers.clear();
  }
}

int netio3::libfabric::ReceiveSocket::do_post_buffer(Buffer* buf) const
{
  ZoneScoped;
  ERS_DEBUG(2, "Entered");
  iovec iov{buf->data().data(), buf->size()};
  void* desc = fi_mr_desc(buf->mr);
  fi_msg msg{};
  msg.msg_iov = &iov; /* scatter-gather array */
  msg.desc = &desc;
  msg.iov_count = 1;
  msg.addr = 0;
  msg.context = buf;
  msg.data = 0;

  std::uint64_t flags = FI_REMOTE_CQ_DATA;  // FI_MULTI_RECV;
  return fi_recvmsg(get_endpoint().ep.get(), &msg, flags);
}

netio3::ConnectionParametersRecv netio3::libfabric::ReceiveSocket::prepare_connection_parameters(
  const ConnectionParametersRecv& requested) const
{
  auto params = requested;
  const auto& info = get_endpoint().fi;
  if (params.num_buf > info->rx_attr->size) {
    ers::warning(TooManyBuffersRequested(params.num_buf, info->rx_attr->size));
    params.num_buf = info->rx_attr->size;
  }
  return params;
}