Program Listing for File Helpers.cpp
↰ Return to documentation for file (BackendLibfabric/Helpers.cpp)
#include "Helpers.hpp"
#include <cstring>
#include <optional>
#include <string>
#include <tracy/Tracy.hpp>
#include <ers/ers.h>
#include <rdma/fi_cm.h>
#include "Issues.hpp"
std::string netio3::libfabric::get_provider(NetworkMode mode)
{
switch (mode) {
case NetworkMode::TCP:
return "sockets";
case NetworkMode::RDMA:
return "verbs";
case NetworkMode::RDM:
return "verbs;ofi_rxm";
default:
throw LibFabricError(ERS_HERE, "Unknown network mode");
}
}
int netio3::libfabric::get_address_format(const EndPointAddress& address)
{
ZoneScoped;
if (address.is_ipv4()) {
return FI_SOCKADDR_IN;
} else {
return FI_SOCKADDR_IN6;
}
}
fi_ep_type netio3::libfabric::get_ep_type(NetworkMode mode)
{
ZoneScoped;
if (mode == NetworkMode::RDM) {
return FI_EP_RDM;
}
return FI_EP_MSG;
}
netio3::EndPointAddress netio3::libfabric::peer_address(fid_ep* ep)
{
ZoneScoped;
std::array<char, BUFSIZ> address{};
auto addrlen = address.size();
const auto ret = fi_getpeer(ep, address.data(), &addrlen);
ERS_DEBUG(2, std::format("ret={} Peer address length={}", ret, addrlen));
for (size_t ch = 0; ch < addrlen; ++ch) {
ERS_DEBUG(3, std::format("Peer address val={:#x}", address.at(ch)));
}
if (ret != 0) {
ers::error(LibFabricError(ERS_HERE, std::format("fi_getpeer failed: {}", fi_strerror(-ret))));
}
try {
return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
} catch (const std::exception& e) {
ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
}
return {};
}
[[nodiscard]] netio3::EndPointAddress netio3::libfabric::local_address(fid_ep* ep)
{
ZoneScoped;
std::array<char, BUFSIZ> address{};
auto addrlen = address.size();
const auto ret = fi_getname(&ep->fid, address.data(), &addrlen);
ERS_DEBUG(2, std::format("ret={} Local address length={}", ret, addrlen));
for (size_t ch = 0; ch < addrlen; ++ch) {
ERS_DEBUG(3, std::format("Local address val={:#x}", address.at(ch)));
}
if (ret != 0) {
ers::error(LibFabricError(ERS_HERE, std::format("fi_getname failed: {}", fi_strerror(-ret))));
}
try {
return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
} catch (const std::exception& e) {
ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
}
return {};
}
namespace {
netio3::libfabric::FiInfoWrapper make_hints(const netio3::EndPointAddress& addr,
netio3::NetworkMode mode)
{
auto hints = netio3::libfabric::FiInfoWrapper{};
hints.get()->addr_format = netio3::libfabric::get_address_format(addr);
hints.get()->ep_attr->type = netio3::libfabric::get_ep_type(mode);
hints.get()->caps = FI_MSG;
hints.get()->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ALLOCATED;
hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
hints.get()->fabric_attr->prov_name = strdup(netio3::libfabric::get_provider(mode).c_str());
return hints;
}
} // namespace
fi_info* netio3::libfabric::get_fi_info_active(const EndPointAddress& remote_addr,
NetworkMode mode,
std::optional<EndPointAddress> local_addr)
{
ZoneScoped;
auto hints = make_hints(remote_addr, mode);
if (local_addr.has_value()) {
const auto src_storage = local_addr->to_sockaddr_storage();
const std::size_t addrlen =
local_addr->is_ipv4() ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
// fi_freeinfo() calls free() on src_addr, so it must be malloc'd — new[] would be UB here.
hints.get()->src_addr = std::malloc(addrlen);
if (hints.get()->src_addr != nullptr) {
std::memcpy(hints.get()->src_addr, &src_storage, addrlen);
hints.get()->src_addrlen = addrlen;
}
}
fi_info* fi_info = nullptr;
if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
remote_addr.address().c_str(),
std::to_string(remote_addr.port()).c_str(),
0,
hints.get(),
&fi_info)) {
throw LibfabricFiInfoError(ret, fi_strerror(-ret));
}
return fi_info;
}
fi_info* netio3::libfabric::get_fi_info_passive(const EndPointAddress& local_addr,
NetworkMode mode)
{
ZoneScoped;
auto hints = make_hints(local_addr, mode);
fi_info* fi_info = nullptr;
if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
local_addr.address().c_str(),
std::to_string(local_addr.port()).c_str(),
FI_SOURCE,
hints.get(),
&fi_info)) {
throw LibfabricFiInfoError(ret, fi_strerror(-ret));
}
return fi_info;
}
netio3::EndpointCapabilities netio3::libfabric::get_endpoint_capabilities(
const ConnectionParameters& connection_params)
{
return {connection_params.send_buffered_params.num_buf > 0 or
connection_params.send_buffered_params.use_shared_send_buffers,
false,
connection_params.send_zero_copy_params.mr_start != nullptr or
connection_params.send_zero_copy_params.use_shared_send_buffers,
connection_params.recv_params.num_buf > 0 or
connection_params.recv_params.use_shared_receive_buffers};
}