Program Listing for File Helpers.cpp
↰ Return to documentation for file (BackendLibfabric/Helpers.cpp)
#include "Helpers.hpp"
#include <string>
#include <tracy/Tracy.hpp>
#include <ers/ers.h>
#include <rdma/fi_cm.h>
#include "Issues.hpp"
std::string netio3::libfabric::get_provider(NetworkMode mode)
{
switch (mode) {
case NetworkMode::TCP:
return "sockets";
case NetworkMode::RDMA:
return "verbs";
case NetworkMode::RDM:
return "verbs;ofi_rxm";
default:
throw LibFabricError(ERS_HERE, "Unknown network mode");
}
}
int netio3::libfabric::get_address_format(const EndPointAddress& address)
{
ZoneScoped;
if (address.is_ipv4()) {
return FI_SOCKADDR_IN;
} else {
return FI_SOCKADDR_IN6;
}
}
fi_ep_type netio3::libfabric::get_ep_type(NetworkMode mode)
{
ZoneScoped;
if (mode == NetworkMode::RDM) {
return FI_EP_RDM;
}
return FI_EP_MSG;
}
netio3::EndPointAddress netio3::libfabric::peer_address(fid_ep* ep)
{
ZoneScoped;
std::array<char, BUFSIZ> address{};
auto addrlen = address.size();
const auto ret = fi_getpeer(ep, address.data(), &addrlen);
ERS_DEBUG(2, std::format("ret={} Peer address length={}", ret, addrlen));
for (size_t ch = 0; ch < addrlen; ++ch) {
ERS_DEBUG(3, std::format("Peer address val={:#x}", address.at(ch)));
}
if (ret != 0) {
ers::error(LibFabricError(ERS_HERE, std::format("fi_getpeer failed: {}", fi_strerror(-ret))));
}
try {
return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
} catch (const std::exception& e) {
ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
}
return {};
}
[[nodiscard]] netio3::EndPointAddress netio3::libfabric::local_address(fid_ep* ep)
{
ZoneScoped;
std::array<char, BUFSIZ> address{};
auto addrlen = address.size();
const auto ret = fi_getname(&ep->fid, address.data(), &addrlen);
ERS_DEBUG(2, std::format("ret={} Local address length={}", ret, addrlen));
for (size_t ch = 0; ch < addrlen; ++ch) {
ERS_DEBUG(3, std::format("Local address val={:#x}", address.at(ch)));
}
if (ret != 0) {
ers::error(LibFabricError(ERS_HERE, std::format("fi_getname failed: {}", fi_strerror(-ret))));
}
try {
return netio3::EndPointAddress{reinterpret_cast<const sockaddr*>(address.data()), addrlen};
} catch (const std::exception& e) {
ers::error(LibFabricError(ERS_HERE, std::format("Failed to convert sockaddr: {}", e.what())));
}
return {};
}
fi_info* netio3::libfabric::get_fi_info(const EndPointAddress& address,
NetworkMode mode,
std::uint64_t info_flags)
{
ZoneScoped;
auto hints = FiInfoWrapper{};
hints.get()->addr_format = get_address_format(address);
hints.get()->ep_attr->type = get_ep_type(mode);
hints.get()->caps = FI_MSG;
hints.get()->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ALLOCATED;
hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
hints.get()->fabric_attr->prov_name = strdup(get_provider(mode).c_str());
fi_info* fi_info = nullptr;
if (auto ret = fi_getinfo(FI_VERSION(LIBFABRIC_MAJOR_VERSION, LIBFABRIC_MINOR_VERSION),
address.address().c_str(),
std::to_string(address.port()).c_str(),
info_flags,
hints.get(),
&fi_info)) {
throw LibfabricFiInfoError(ret, fi_strerror(-ret));
}
return fi_info;
}
netio3::EndpointCapabilities netio3::libfabric::get_endpoint_capabilities(
const ConnectionParameters& connection_params)
{
return {connection_params.send_buffered_params.num_buf > 0 or
connection_params.send_buffered_params.use_shared_send_buffers,
false,
connection_params.send_zero_copy_params.mr_start != nullptr or
connection_params.send_zero_copy_params.use_shared_send_buffers,
connection_params.recv_params.num_buf > 0 or
connection_params.recv_params.use_shared_receive_buffers};
}