Program Listing for File BaseSocket.cpp

Return to documentation for file (BackendLibfabric/BaseSocket.cpp)

#include "BaseSocket.hpp"

#include <memory>

#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>

#include <tracy/Tracy.hpp>

#include "Buffer.hpp"
#include "Helpers.hpp"
#include "Issues.hpp"
#include "netio3-backend/Issues.hpp"

void netio3::libfabric::BaseSocket::open_endpoint(const EndPointAddress& address,
                                                  Endpoint& ep,
                                                  fid_fabric* fabric,
                                                  fid_domain* domain,
                                                  fi_info* info)
{
  ZoneScoped;
  fi_eq_attr eq_attr{};
  eq_attr.wait_obj = FI_WAIT_FD;
  fid_eq* eq = nullptr;

  if (auto ret = fi_eq_open(fabric, &eq_attr, &eq, nullptr)) {
    throw FailedOpenEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format(
        "Failed to open Event Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
  }
  ep.eq = FiCloseUniquePtr<fid_eq>(
    eq, FiCloseDeleter<fid_eq>(address, "Failed to close socket Event Queue"));

  fid_ep* ep_pointer = nullptr;
  if (auto ret = fi_endpoint(domain, info, &ep_pointer, nullptr)) {
    throw FailedOpenEndpoint(ERS_HERE,
                             address.address(),
                             address.port(),
                             std::format("Failed to open Endpoint for send socket, error {} - {}",
                                         ret,
                                         fi_strerror(-ret)));
  }
  ep.ep = FiCloseUniquePtr<fid_ep>(
    ep_pointer, FiCloseDeleter<fid_ep>(address, "Failed to close socket endpoint"));

  if (auto ret = fi_ep_bind(ep.ep.get(), &ep.eq->fid, 0)) {
    throw FailedOpenEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format("Failed to bind endpoint, error {} - {}", ret, fi_strerror(-ret)));
  }
}

void netio3::libfabric::BaseSocket::enable_endpoint(const EndPointAddress& address, Endpoint& ep)
{
  ZoneScoped;
  if (const auto ret = fi_enable(ep.ep.get())) {
    throw FailedOpenEndpoint(ERS_HERE,
                             address.address(),
                             address.port(),
                             std::format("Failed to enable endpoint for send socket, error {} - {}",
                                         ret,
                                         fi_strerror(-ret)));
  }

  if (const auto ret = fi_control(&ep.eq->fid, FI_GETWAIT, &ep.eqfd)) {
    throw FailedOpenEndpoint(
      ERS_HERE,
      address.address(),
      address.port(),
      std::format("Failed to retrive send socket Event Queue wait object, error {} - {}",
                  ret,
                  fi_strerror(-ret)));
  }
}

fi_cq_attr netio3::libfabric::BaseSocket::prepare_cq_attr()
{
  return {
    .size = MAX_CQ_ENTRIES,       // # entries for CQ
    .flags = 0,                   // operation flags
    .format = FI_CQ_FORMAT_DATA,  // FI_CQ_FORMAT_CONTEXT;    /* completion format */
    .wait_obj = FI_WAIT_FD,       // requested wait object
    .signaling_vector = 0,        // interrupt affinity
    .wait_cond = FI_CQ_COND_NONE,
    /* wait condition format */  // The threshold indicates the number of entries that are to be
                                 // queued before at the CQ before the wait is satisfied.
    .wait_set = nullptr,         // optional wait set
  };
}

template<netio3::libfabric::BufferConcept BufferType>
void netio3::libfabric::BaseSocket::register_buffer(BufferType& buf,
                                                    DomainContext& domain,
                                                    const int access_flag)
{
  ZoneScoped;
  ERS_DEBUG(1,
            std::format("Registering buffer of size {} for domain: {}",
                        buf.get_size(),
                        static_cast<void*>(domain.get_domain())));
  const auto ret = fi_mr_reg(
    domain.get_domain(), buf.get_buffer(), buf.get_size(), access_flag, 0, 0, 0, &buf.mr, nullptr);
  if (ret != 0) {
    throw LibFabricBufferError(
      ERS_HERE,
      std::format("Failed to register buffer failed. Error {} - {}", ret, fi_strerror(-ret)));
  }
}

template<netio3::libfabric::BufferConcept BufferType>
void netio3::libfabric::BaseSocket::close_buffer(const BufferType& buffer)
{
  const auto ret = fi_close(&buffer.mr->fid);
  if (ret != 0) {
    throw LibFabricBufferError(
      ERS_HERE,
      std::format("Failed to close buffer failed. Error {} - {}", ret, fi_strerror(-ret)));
  }
}

template void netio3::libfabric::BaseSocket::register_buffer<netio3::libfabric::Buffer>(
  netio3::libfabric::Buffer&,
  DomainContext&,
  int);
template void netio3::libfabric::BaseSocket::register_buffer<netio3::libfabric::HeaderBuffer>(
  netio3::libfabric::HeaderBuffer&,
  DomainContext&,
  int);

template void netio3::libfabric::BaseSocket::close_buffer<netio3::libfabric::Buffer>(
  const netio3::libfabric::Buffer&);
template void netio3::libfabric::BaseSocket::close_buffer<netio3::libfabric::HeaderBuffer>(
  const netio3::libfabric::HeaderBuffer&);