Program Listing for File Helpers.cpp

Return to documentation for file (BackendLibfabric/Helpers.cpp)

#include "Helpers.hpp"

#include <string>
#include <tracy/Tracy.hpp>

#include <ers/ers.h>

#include <rdma/fi_cm.h>

#include "Issues.hpp"

std::string netio3::libfabric::get_provider(NetworkMode mode)
{
  switch (mode) {
  case NetworkMode::TCP:
    return "sockets";
  case NetworkMode::RDMA:
    return "verbs";
  case NetworkMode::RDM:
    return "verbs;ofi_rxm";
  default:
    throw LibFabricError(ERS_HERE, "Unknown network mode");
  }
}

int netio3::libfabric::get_address_format(const EndPointAddress& address)
{
  ZoneScoped;
  if (address.is_ipv4()) {
    return FI_SOCKADDR_IN;
  } else {
    return FI_SOCKADDR_IN6;
  }
}

fi_ep_type netio3::libfabric::get_ep_type(NetworkMode mode)
{
  ZoneScoped;
  if (mode == NetworkMode::RDM) {
    return FI_EP_RDM;
  }
  return FI_EP_MSG;
}

netio3::EndPointAddress netio3::libfabric::peer_address(fid_ep* ep)
{
  ZoneScoped;
  std::array<char, BUFSIZ> address{};
  auto addrlen = address.size();
  const auto ret = fi_getpeer(ep, address.data(), &addrlen);
  ERS_DEBUG(2, std::format("ret={} Peer address length={}", ret, addrlen));
  for (size_t ch = 0; ch < addrlen; ++ch) {
    ERS_DEBUG(3, std::format("Peer address val={:#x}", address.at(ch)));
  }
  if (ret != 0) {
    ers::error(LibFabricError(ERS_HERE, std::format("fi_getpeer failed: {}", fi_strerror(-ret))));
  }
  try {
    return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
  } catch (const std::exception& e) {
    ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
  }
  return {};
}

[[nodiscard]] netio3::EndPointAddress netio3::libfabric::local_address(fid_ep* ep)
{
  ZoneScoped;
  std::array<char, BUFSIZ> address{};
  auto addrlen = address.size();
  const auto ret = fi_getname(&ep->fid, address.data(), &addrlen);
  ERS_DEBUG(2, std::format("ret={} Local address length={}", ret, addrlen));
  for (size_t ch = 0; ch < addrlen; ++ch) {
    ERS_DEBUG(3, std::format("Local address val={:#x}", address.at(ch)));
  }
  if (ret != 0) {
    ers::error(LibFabricError(ERS_HERE, std::format("fi_getname failed: {}", fi_strerror(-ret))));
  }
  try {
    return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
  } catch (const std::exception& e) {
    ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
  }
  return {};
}

fi_info* netio3::libfabric::get_fi_info(const EndPointAddress& address,
                                        NetworkMode mode,
                                        std::uint64_t info_flags)
{
  ZoneScoped;
  auto hints = FiInfoWrapper{};
  hints.get()->addr_format = get_address_format(address);
  hints.get()->ep_attr->type = get_ep_type(mode);
  hints.get()->caps = FI_MSG;
  hints.get()->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ALLOCATED;
  hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
  hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
  hints.get()->fabric_attr->prov_name = strdup(get_provider(mode).c_str());

  fi_info* fi_info = nullptr;
  if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
                            address.address().c_str(),
                            std::to_string(address.port()).c_str(),
                            info_flags,
                            hints.get(),
                            &fi_info)) {
    throw LibfabricFiInfoError(ret, fi_strerror(-ret));
  }
  return fi_info;
}

netio3::EndpointCapabilities netio3::libfabric::get_endpoint_capabilities(
  const ConnectionParameters& connection_params)
{
  return {connection_params.send_buffered_params.num_buf > 0 or
            connection_params.send_buffered_params.use_shared_send_buffers,
          false,
          connection_params.send_zero_copy_params.mr_start != nullptr or
            connection_params.send_zero_copy_params.use_shared_send_buffers,
          connection_params.recv_params.num_buf > 0 or
            connection_params.recv_params.use_shared_receive_buffers};
}