Program Listing for File Helpers.cpp

Return to documentation for file (BackendLibfabric/Helpers.cpp)

#include "Helpers.hpp"

#include <cstring>
#include <optional>
#include <string>
#include <tracy/Tracy.hpp>

#include <ers/ers.h>

#include <rdma/fi_cm.h>

#include "Issues.hpp"

std::string netio3::libfabric::get_provider(NetworkMode mode)
{
  switch (mode) {
  case NetworkMode::TCP:
    return "sockets";
  case NetworkMode::RDMA:
    return "verbs";
  case NetworkMode::RDM:
    return "verbs;ofi_rxm";
  default:
    throw LibFabricError(ERS_HERE, "Unknown network mode");
  }
}

int netio3::libfabric::get_address_format(const EndPointAddress& address)
{
  ZoneScoped;
  if (address.is_ipv4()) {
    return FI_SOCKADDR_IN;
  } else {
    return FI_SOCKADDR_IN6;
  }
}

fi_ep_type netio3::libfabric::get_ep_type(NetworkMode mode)
{
  ZoneScoped;
  if (mode == NetworkMode::RDM) {
    return FI_EP_RDM;
  }
  return FI_EP_MSG;
}

netio3::EndPointAddress netio3::libfabric::peer_address(fid_ep* ep)
{
  ZoneScoped;
  std::array<char, BUFSIZ> address{};
  auto addrlen = address.size();
  const auto ret = fi_getpeer(ep, address.data(), &addrlen);
  ERS_DEBUG(2, std::format("ret={} Peer address length={}", ret, addrlen));
  for (size_t ch = 0; ch < addrlen; ++ch) {
    ERS_DEBUG(3, std::format("Peer address val={:#x}", address.at(ch)));
  }
  if (ret != 0) {
    ers::error(LibFabricError(ERS_HERE, std::format("fi_getpeer failed: {}", fi_strerror(-ret))));
  }
  try {
    return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
  } catch (const std::exception& e) {
    ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
  }
  return {};
}

[[nodiscard]] netio3::EndPointAddress netio3::libfabric::local_address(fid_ep* ep)
{
  ZoneScoped;
  std::array<char, BUFSIZ> address{};
  auto addrlen = address.size();
  const auto ret = fi_getname(&ep->fid, address.data(), &addrlen);
  ERS_DEBUG(2, std::format("ret={} Local address length={}", ret, addrlen));
  for (size_t ch = 0; ch < addrlen; ++ch) {
    ERS_DEBUG(3, std::format("Local address val={:#x}", address.at(ch)));
  }
  if (ret != 0) {
    ers::error(LibFabricError(ERS_HERE, std::format("fi_getname failed: {}", fi_strerror(-ret))));
  }
  try {
    return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
  } catch (const std::exception& e) {
    ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
  }
  return {};
}

namespace {
  netio3::libfabric::FiInfoWrapper make_hints(const netio3::EndPointAddress& addr,
                                              netio3::NetworkMode mode)
  {
    auto hints = netio3::libfabric::FiInfoWrapper{};
    hints.get()->addr_format = netio3::libfabric::get_address_format(addr);
    hints.get()->ep_attr->type = netio3::libfabric::get_ep_type(mode);
    hints.get()->caps = FI_MSG;
    hints.get()->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ALLOCATED;
    hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
    hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
    hints.get()->fabric_attr->prov_name = strdup(netio3::libfabric::get_provider(mode).c_str());
    return hints;
  }
}  // namespace

fi_info* netio3::libfabric::get_fi_info_active(const EndPointAddress& remote_addr,
                                               NetworkMode mode,
                                               std::optional<EndPointAddress> local_addr)
{
  ZoneScoped;
  auto hints = make_hints(remote_addr, mode);

  if (local_addr.has_value()) {
    const auto src_storage = local_addr->to_sockaddr_storage();
    const std::size_t addrlen =
      local_addr->is_ipv4() ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
    // fi_freeinfo() calls free() on src_addr, so it must be malloc'd — new[] would be UB here.
    hints.get()->src_addr = std::malloc(addrlen);
    if (hints.get()->src_addr != nullptr) {
      std::memcpy(hints.get()->src_addr, &src_storage, addrlen);
      hints.get()->src_addrlen = addrlen;
    }
  }

  fi_info* fi_info = nullptr;
  if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
                            remote_addr.address().c_str(),
                            std::to_string(remote_addr.port()).c_str(),
                            0,
                            hints.get(),
                            &fi_info)) {
    throw LibfabricFiInfoError(ret, fi_strerror(-ret));
  }
  return fi_info;
}

fi_info* netio3::libfabric::get_fi_info_passive(const EndPointAddress& local_addr,
                                                NetworkMode mode)
{
  ZoneScoped;
  auto hints = make_hints(local_addr, mode);

  fi_info* fi_info = nullptr;
  if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
                            local_addr.address().c_str(),
                            std::to_string(local_addr.port()).c_str(),
                            FI_SOURCE,
                            hints.get(),
                            &fi_info)) {
    throw LibfabricFiInfoError(ret, fi_strerror(-ret));
  }
  return fi_info;
}

netio3::EndpointCapabilities netio3::libfabric::get_endpoint_capabilities(
  const ConnectionParameters& connection_params)
{
  return {connection_params.send_buffered_params.num_buf > 0 or
            connection_params.send_buffered_params.use_shared_send_buffers,
          false,
          connection_params.send_zero_copy_params.mr_start != nullptr or
            connection_params.send_zero_copy_params.use_shared_send_buffers,
          connection_params.recv_params.num_buf > 0 or
            connection_params.recv_params.use_shared_receive_buffers};
}