Program Listing for File SendSocket.cpp

Return to documentation for file (BackendLibfabric/SendSocket.cpp)

#include "SendSocket.hpp"

#include <utility>

#include <tracy/Tracy.hpp>

#include <rdma/fabric.h>
#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_eq.h>

#include "BaseSocket.hpp"
#include "Helpers.hpp"
#include "Issues.hpp"

netio3::libfabric::SendSocket::SendSocket(EndPointAddress address,
                                          NetworkMode mode,
                                          fid_fabric* fabric,
                                          fid_domain* domain) :
  BaseSocket{create_endpoint(address, mode, fabric, domain)}, m_addr{std::move(address)}
{}

void netio3::libfabric::SendSocket::init()
{
  ZoneScoped;
  /* Connect to server */
  if (const auto ret =
        fi_connect(get_endpoint().ep.get(), get_endpoint().fi->dest_addr, nullptr, 0)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      m_addr.address(),
      m_addr.port(),
      std::format("Connection to remote failed, error {} - {}", ret, fi_strerror(-ret)));
  }
  if (const auto ret = fi_control(&get_endpoint().eq->fid, FI_GETWAIT, &get_endpoint().eqfd)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      m_addr.address(),
      m_addr.port(),
      std::format("Cannot retrieve the Event Queue wait object of send socket, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
  ERS_DEBUG(1, std::format("EV context with FD: {}", get_eq_ev_ctx().fd));
}

netio3::NetioStatus netio3::libfabric::SendSocket::send_data(const std::span<const iovec> data,
                                                             const std::span<fid_mr*> mrs,
                                                             std::uint64_t key) const
{
  ZoneScoped;
  if (data.size() != mrs.size()) {
    ers::error(FailedSend(ERS_HERE,
                          get_address().address(),
                          get_address().port(),
                          "Failed sending message because of mismatched data and memory regions."));
    return NetioStatus::FAILED;
  }
  std::vector<void*> descs{};
  descs.reserve(mrs.size());
  std::transform(
    mrs.begin(), mrs.end(), std::back_inserter(descs), [](auto* mr) { return fi_mr_desc(mr); });
  fi_msg msg{};
  msg.msg_iov = data.data(); /* scatter-gather array */
  msg.desc = descs.data();
  msg.iov_count = data.size();
  msg.addr = 0;
  msg.context = reinterpret_cast<void*>(key);
  msg.data = 0;

  if (get_endpoint().ep == nullptr || get_endpoint().ep->msg == nullptr) {
    ers::error(FailedSend(ERS_HERE,
                          get_address().address(),
                          get_address().port(),
                          "Failed sending message because of null message or null endpoint."));
    return NetioStatus::FAILED;
  }

  ERS_DEBUG(2, std::format("sending iov message with key {}", msg.context));
  const uint64_t flags = FI_INJECT_COMPLETE;  // | FI_INJECT;
  const auto ret = fi_sendmsg(get_endpoint().ep.get(), &msg, flags);
  if (ret == -FI_EAGAIN) {
    ERS_DEBUG(1, "Send failed with result: EAGAIN");
    return NetioStatus::NO_RESOURCES;
  }
  ERS_DEBUG(1, std::format("Send completed with result: {}", ret));
  if (ret != 0) {
    ers::error(FailedSend(ERS_HERE,
                          get_address().address(),
                          get_address().port(),
                          std::format("Failed to send message error (IOV count 1, key {}) - {}",
                                      key,
                                      fi_strerror(-ret))));
    return NetioStatus::FAILED;
  }
  return NetioStatus::OK;
}

netio3::libfabric::Endpoint netio3::libfabric::SendSocket::create_endpoint(
  const EndPointAddress& address,
  NetworkMode mode,
  fid_fabric* fabric,
  fid_domain* domain)
{
  ZoneScoped;
  auto ep = Endpoint{};
  ep.fi = get_info(address, mode);
  try {
    open_endpoint(address, ep, fabric, domain, ep.fi.get());
    open_cq(address, ep, domain);
    enable_endpoint(address, ep);
  } catch (const FailedOpenEndpoint& e) {
    throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  return ep;
}

netio3::libfabric::FiInfoUniquePtr netio3::libfabric::SendSocket::get_info(
  const EndPointAddress& address,
  NetworkMode mode)
{
  ZoneScoped;
  auto hints = FiInfoWrapper{};
  hints.get()->addr_format = FI_FORMAT_UNSPEC;
  hints.get()->ep_attr->type = FI_EP_MSG;
  hints.get()->caps = FI_MSG;
  hints.get()->mode = FI_LOCAL_MR;
  // As of libfabric 1.10, the tcp provider only support FI_PROGRESS_MANUAL
  // So the following will not allow the tcp provider to be used
  hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
  hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
  try {
    hints.get()->fabric_attr->prov_name = strdup(get_provider(mode).c_str());
  } catch (const LibFabricError& e) {
    throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
  }
  std::uint64_t flags = 0;
  const auto port_str = std::to_string(address.port());
  fi_info* info{nullptr};
  const auto ret = fi_getinfo(
    FI_VERSION(1, 1), address.address().c_str(), port_str.c_str(), flags, hints.get(), &info);
  if (ret != 0) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format("Failed to initialise socket, error {} - {}", ret, fi_strerror(-ret)));
  }
  return FiInfoUniquePtr{info};
}

void netio3::libfabric::SendSocket::open_cq(const EndPointAddress& address,
                                            Endpoint& ep,
                                            fid_domain* domain)
{
  ZoneScoped;
  auto cq_attr = prepare_cq_attr();
  // FI_TRANSMIT CQ

  fid_cq* cq = nullptr;
  if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.cq = FiCloseUniquePtr<fid_cq>(
    cq, FiCloseDeleter<fid_cq>(address, "Failed to close send socket Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }

  cq_attr.format = FI_CQ_FORMAT_UNSPEC;
  cq_attr.wait_obj = FI_WAIT_NONE;

  // FI_RECV CQ
  fid_cq* rcq = nullptr;
  if (const auto ret = fi_cq_open(domain, &cq_attr, &rcq, nullptr)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.rcq = FiCloseUniquePtr<fid_cq>(
    rcq, FiCloseDeleter<fid_cq>(address, "Failed to close FI_RECV send socket Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
    throw FailedOpenSendEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
}