Program Listing for File ActiveEndpoint.cpp

Return to documentation for file (BackendLibfabric/ActiveEndpoint.cpp)

#include "ActiveEndpoint.hpp"

#include <memory>

#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>

#include <tracy/Tracy.hpp>

#include "Issues.hpp"
#include "Helpers.hpp"
#include "netio3-backend/Issues.hpp"
#include "netio3-backend/Netio3Backend.hpp"

netio3::libfabric::ActiveEndpoint::ActiveEndpoint(EndPointAddress address,
                                                  NetworkMode mode,
                                                  EndpointCapabilities capabilities,
                                                  fid_fabric* fabric,
                                                  fid_domain* domain,
                                                  FiInfoUniquePtr&& info,
                                                  fid_ep* shx_ctx,
                                                  fid_av* av) :
  m_address{std::move(address)},
  m_capabilities{std::move(capabilities)},
  m_ep{create_endpoint(fabric, domain, mode, std::move(info), 0, shx_ctx, av)}
{}

netio3::libfabric::ActiveEndpoint::ActiveEndpoint(EndPointAddress address,
                                                  NetworkMode mode,
                                                  EndpointCapabilities capabilities,
                                                  fid_fabric* fabric,
                                                  fid_domain* domain,
                                                  std::uint64_t info_flags,
                                                  fid_ep* shx_ctx,
                                                  fid_av* av) :
  m_address{std::move(address)},
  m_capabilities{std::move(capabilities)},
  m_ep{create_endpoint(fabric, domain, mode, nullptr, info_flags, shx_ctx, av)}
{}

void netio3::libfabric::ActiveEndpoint::complete_connection(ConnectionMode mode, fid_pep* pep)
{
  ZoneScoped;
  if (mode == ConnectionMode::Connect) {
    // Connect to remote passive endpoint
    if (const auto ret = fi_connect(m_ep.ep.get(), m_ep.fi->dest_addr, nullptr, 0)) {
      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Connection to remote failed, error {} - {}", ret, fi_strerror(-ret)));
    }
  } else {
    // Accept connection from local passive endpoint
    const int ret = fi_accept(m_ep.ep.get(), nullptr, 0);
    if (ret != 0) {
      fi_reject(pep, m_ep.fi->handle, nullptr, 0);

      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Listen endpoint, connection rejected, error {} - {}", ret, fi_strerror(-ret)));
    }
  }
}

void netio3::libfabric::ActiveEndpoint::update_addresses()
{
  ZoneScoped;
  m_address = peer_address(m_ep.ep.get());
  m_local_address = local_address(m_ep.ep.get());
}

netio3::libfabric::Endpoint netio3::libfabric::ActiveEndpoint::create_endpoint(
  fid_fabric* fabric,
  fid_domain* domain,
  NetworkMode mode,
  FiInfoUniquePtr&& info,
  std::uint64_t info_flags,
  fid_ep* shx_ctx,
  fid_av* av)
{
  ZoneScoped;
  auto ep = Endpoint{};
  if (info) {
    ep.fi = std::move(info);
  } else {
    ep.fi = get_info(m_address, mode, info_flags);
  }
  open_endpoint(ep, fabric, domain, ep.fi.get());
  bind_av(ep, av);
  open_cq(ep, domain);
  bind_srx_context(ep, shx_ctx);
  enable_endpoint(ep);
  return ep;
}

void netio3::libfabric::ActiveEndpoint::open_endpoint(Endpoint& ep,
                                                      fid_fabric* fabric,
                                                      fid_domain* domain,
                                                      fi_info* info)
{
  ZoneScoped;
  fi_eq_attr eq_attr{};
  eq_attr.wait_obj = FI_WAIT_FD;
  fid_eq* eq = nullptr;

  if (auto ret = fi_eq_open(fabric, &eq_attr, &eq, nullptr)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format(
        "Failed to open Event Queue for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.eq = FiCloseEndpointUniquePtr<fid_eq>(
    eq, FiCloseEndpointDeleter<fid_eq>(m_address, "Failed to close Event Queue"));

  fid_ep* ep_pointer = nullptr;
  if (auto ret = fi_endpoint(domain, info, &ep_pointer, nullptr)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format(
        "Failed to open Endpoint for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.ep = FiCloseEndpointUniquePtr<fid_ep>(
    ep_pointer, FiCloseEndpointDeleter<fid_ep>(m_address, "Failed to close enpoint"));

  if (auto ret = fi_ep_bind(ep.ep.get(), &ep.eq->fid, 0)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to bind endpoint, error {} - {}", ret, fi_strerror(-ret)));
  }
}

void netio3::libfabric::ActiveEndpoint::bind_av(Endpoint& ep, fid_av* av) const
{
  ZoneScoped;
  if (av != nullptr) {
    if (const auto ret = fi_ep_bind(ep.ep.get(), &av->fid, 0)) {
      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Failed to bind address vector to endpoint, error {} - {}", ret, fi_strerror(-ret)));
    }
  }
}

void netio3::libfabric::ActiveEndpoint::enable_endpoint(Endpoint& ep)
{
  ZoneScoped;
  if (const auto ret = fi_enable(ep.ep.get())) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format(
        "Failed to enable endpoint for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
  }

  if (const auto ret = fi_control(&ep.eq->fid, FI_GETWAIT, &ep.eqfd)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to retrive active endpoint Event Queue wait object, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
}

fi_cq_attr netio3::libfabric::ActiveEndpoint::prepare_cq_attr()
{
  return {
    .size = MAX_CQ_ENTRIES,       // # entries for CQ
    .flags = 0,                   // operation flags
    .format = FI_CQ_FORMAT_DATA,  // FI_CQ_FORMAT_CONTEXT;    /* completion format */
    .wait_obj = FI_WAIT_FD,       // requested wait object
    .signaling_vector = 0,        // interrupt affinity
    .wait_cond = FI_CQ_COND_NONE,
    /* wait condition format */  // The threshold indicates the number of entries that are to be
                                 // queued before at the CQ before the wait is satisfied.
    .wait_set = nullptr,         // optional wait set
  };
}

void netio3::libfabric::ActiveEndpoint::open_cq(Endpoint& ep, fid_domain* domain)
{
  ZoneScoped;
  const auto base_cq_attr = prepare_cq_attr();

  // FI_TRANSMIT CQ
  auto cq_attr = base_cq_attr;
  if (not m_capabilities.send_buffered and not m_capabilities.send_zero_copy) {
    cq_attr.format = FI_CQ_FORMAT_UNSPEC;
    cq_attr.wait_obj = FI_WAIT_NONE;
  }
  fid_cq* cq = nullptr;
  if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to open Completion Queue for active endpoint, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
  ep.cq = FiCloseEndpointUniquePtr<fid_cq>(
    cq, FiCloseEndpointDeleter<fid_cq>(m_address, "Failed to close active endpoint Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to bind Completion Queue for active endpoint, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }

  if (m_capabilities.send_buffered or m_capabilities.send_zero_copy) {
    if (const auto ret = fi_control(&ep.cq->fid, FI_GETWAIT, &ep.cqfd)) {
      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Failed to retrieve wait object for send Completion Queue, error {} - {}",
                    ret,
                    fi_strerror(-ret)));
    }
  }

  // FI_RECV CQ
  auto rcq_attr = base_cq_attr;
  if (not m_capabilities.receive) {
    rcq_attr.format = FI_CQ_FORMAT_UNSPEC;
    rcq_attr.wait_obj = FI_WAIT_NONE;
  }
  fid_cq* rcq = nullptr;
  if (const auto ret = fi_cq_open(domain, &rcq_attr, &rcq, nullptr)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to open Completion Queue for active endpoint, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
  ep.rcq = FiCloseEndpointUniquePtr<fid_cq>(
    rcq,
    FiCloseEndpointDeleter<fid_cq>(m_address, "Failed to close FI_RECV active endpoint Completion Queue"));

  if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
    throw FailedOpenActiveEndpoint(
      ERS_HERE,
      m_address.address(),
      m_address.port(),
      std::format("Failed to bind Completion Queue for active endpoint, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }

  if (m_capabilities.receive) {
    if (const auto ret = fi_control(&ep.rcq->fid, FI_GETWAIT, &ep.rcqfd)) {
      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Failed to retrieve wait object for receive Completion Queue, error {} - {}",
                    ret,
                    fi_strerror(-ret)));
    }
  }
}

void netio3::libfabric::ActiveEndpoint::bind_srx_context(Endpoint& ep, fid_ep* shx_ctx) const
{
  ZoneScoped;
  if (shx_ctx != nullptr) {
    if (const auto ret = fi_ep_bind(ep.ep.get(), &shx_ctx->fid, 0)) {
      throw FailedOpenActiveEndpoint(
        ERS_HERE,
        m_address.address(),
        m_address.port(),
        std::format("Failed to bind shared receive context, error {} - {}",
                    ret,
                    fi_strerror(-ret)));
    }
  }
}

netio3::libfabric::FiInfoUniquePtr netio3::libfabric::ActiveEndpoint::get_info(
  const EndPointAddress& address,
  NetworkMode mode,
  std::uint64_t info_flags)
{
  ZoneScoped;
  try {
    return FiInfoUniquePtr{get_fi_info(address, mode, info_flags)};
  } catch (const LibfabricFiInfoError& e) {
    throw FailedOpenActiveEndpoint(
      address.address(),
      address.port(),
      std::format("Failed to get info on local interface {}", e.what()));
  }
}