Program Listing for File ActiveEndpoint.cpp
↰ Return to documentation for file (BackendLibfabric/ActiveEndpoint.cpp)
#include "ActiveEndpoint.hpp"
#include <memory>
#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>
#include <tracy/Tracy.hpp>
#include "Issues.hpp"
#include "Helpers.hpp"
#include "netio3-backend/Issues.hpp"
#include "netio3-backend/Netio3Backend.hpp"
netio3::libfabric::ActiveEndpoint::ActiveEndpoint(EndPointAddress address,
NetworkMode mode,
EndpointCapabilities capabilities,
fid_fabric* fabric,
fid_domain* domain,
FiInfoUniquePtr&& info,
fid_ep* shx_ctx,
fid_av* av) :
m_address{std::move(address)},
m_capabilities{std::move(capabilities)},
m_ep{create_endpoint(fabric, domain, mode, std::move(info), 0, shx_ctx, av)}
{}
netio3::libfabric::ActiveEndpoint::ActiveEndpoint(EndPointAddress address,
NetworkMode mode,
EndpointCapabilities capabilities,
fid_fabric* fabric,
fid_domain* domain,
std::uint64_t info_flags,
fid_ep* shx_ctx,
fid_av* av) :
m_address{std::move(address)},
m_capabilities{std::move(capabilities)},
m_ep{create_endpoint(fabric, domain, mode, nullptr, info_flags, shx_ctx, av)}
{}
void netio3::libfabric::ActiveEndpoint::complete_connection(ConnectionMode mode, fid_pep* pep)
{
ZoneScoped;
if (mode == ConnectionMode::Connect) {
// Connect to remote passive endpoint
if (const auto ret = fi_connect(m_ep.ep.get(), m_ep.fi->dest_addr, nullptr, 0)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Connection to remote failed, error {} - {}", ret, fi_strerror(-ret)));
}
} else {
// Accept connection from local passive endpoint
const int ret = fi_accept(m_ep.ep.get(), nullptr, 0);
if (ret != 0) {
fi_reject(pep, m_ep.fi->handle, nullptr, 0);
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Listen endpoint, connection rejected, error {} - {}", ret, fi_strerror(-ret)));
}
}
}
void netio3::libfabric::ActiveEndpoint::update_addresses()
{
ZoneScoped;
m_address = peer_address(m_ep.ep.get());
m_local_address = local_address(m_ep.ep.get());
}
netio3::libfabric::Endpoint netio3::libfabric::ActiveEndpoint::create_endpoint(
fid_fabric* fabric,
fid_domain* domain,
NetworkMode mode,
FiInfoUniquePtr&& info,
std::uint64_t info_flags,
fid_ep* shx_ctx,
fid_av* av)
{
ZoneScoped;
auto ep = Endpoint{};
if (info) {
ep.fi = std::move(info);
} else {
ep.fi = get_info(m_address, mode, info_flags);
}
open_endpoint(ep, fabric, domain, ep.fi.get());
bind_av(ep, av);
open_cq(ep, domain);
bind_srx_context(ep, shx_ctx);
enable_endpoint(ep);
return ep;
}
void netio3::libfabric::ActiveEndpoint::open_endpoint(Endpoint& ep,
fid_fabric* fabric,
fid_domain* domain,
fi_info* info)
{
ZoneScoped;
fi_eq_attr eq_attr{};
eq_attr.wait_obj = FI_WAIT_FD;
fid_eq* eq = nullptr;
if (auto ret = fi_eq_open(fabric, &eq_attr, &eq, nullptr)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format(
"Failed to open Event Queue for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
}
ep.eq = FiCloseEndpointUniquePtr<fid_eq>(
eq, FiCloseEndpointDeleter<fid_eq>(m_address, "Failed to close Event Queue"));
fid_ep* ep_pointer = nullptr;
if (auto ret = fi_endpoint(domain, info, &ep_pointer, nullptr)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format(
"Failed to open Endpoint for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
}
ep.ep = FiCloseEndpointUniquePtr<fid_ep>(
ep_pointer, FiCloseEndpointDeleter<fid_ep>(m_address, "Failed to close enpoint"));
if (auto ret = fi_ep_bind(ep.ep.get(), &ep.eq->fid, 0)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to bind endpoint, error {} - {}", ret, fi_strerror(-ret)));
}
}
void netio3::libfabric::ActiveEndpoint::bind_av(Endpoint& ep, fid_av* av) const
{
ZoneScoped;
if (av != nullptr) {
if (const auto ret = fi_ep_bind(ep.ep.get(), &av->fid, 0)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to bind address vector to endpoint, error {} - {}", ret, fi_strerror(-ret)));
}
}
}
void netio3::libfabric::ActiveEndpoint::enable_endpoint(Endpoint& ep)
{
ZoneScoped;
if (const auto ret = fi_enable(ep.ep.get())) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format(
"Failed to enable endpoint for active endpoint, error {} - {}", ret, fi_strerror(-ret)));
}
if (const auto ret = fi_control(&ep.eq->fid, FI_GETWAIT, &ep.eqfd)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to retrive active endpoint Event Queue wait object, error {} - {}",
ret,
fi_strerror(-ret)));
}
}
fi_cq_attr netio3::libfabric::ActiveEndpoint::prepare_cq_attr()
{
return {
.size = MAX_CQ_ENTRIES, // # entries for CQ
.flags = 0, // operation flags
.format = FI_CQ_FORMAT_DATA, // FI_CQ_FORMAT_CONTEXT; /* completion format */
.wait_obj = FI_WAIT_FD, // requested wait object
.signaling_vector = 0, // interrupt affinity
.wait_cond = FI_CQ_COND_NONE,
/* wait condition format */ // The threshold indicates the number of entries that are to be
// queued before at the CQ before the wait is satisfied.
.wait_set = nullptr, // optional wait set
};
}
void netio3::libfabric::ActiveEndpoint::open_cq(Endpoint& ep, fid_domain* domain)
{
ZoneScoped;
const auto base_cq_attr = prepare_cq_attr();
// FI_TRANSMIT CQ
auto cq_attr = base_cq_attr;
if (not m_capabilities.send_buffered and not m_capabilities.send_zero_copy) {
cq_attr.format = FI_CQ_FORMAT_UNSPEC;
cq_attr.wait_obj = FI_WAIT_NONE;
}
fid_cq* cq = nullptr;
if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to open Completion Queue for active endpoint, error {} - {}",
ret,
fi_strerror(-ret)));
}
ep.cq = FiCloseEndpointUniquePtr<fid_cq>(
cq, FiCloseEndpointDeleter<fid_cq>(m_address, "Failed to close active endpoint Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to bind Completion Queue for active endpoint, error {} - {}",
ret,
fi_strerror(-ret)));
}
if (m_capabilities.send_buffered or m_capabilities.send_zero_copy) {
if (const auto ret = fi_control(&ep.cq->fid, FI_GETWAIT, &ep.cqfd)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to retrieve wait object for send Completion Queue, error {} - {}",
ret,
fi_strerror(-ret)));
}
}
// FI_RECV CQ
auto rcq_attr = base_cq_attr;
if (not m_capabilities.receive) {
rcq_attr.format = FI_CQ_FORMAT_UNSPEC;
rcq_attr.wait_obj = FI_WAIT_NONE;
}
fid_cq* rcq = nullptr;
if (const auto ret = fi_cq_open(domain, &rcq_attr, &rcq, nullptr)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to open Completion Queue for active endpoint, error {} - {}",
ret,
fi_strerror(-ret)));
}
ep.rcq = FiCloseEndpointUniquePtr<fid_cq>(
rcq,
FiCloseEndpointDeleter<fid_cq>(m_address, "Failed to close FI_RECV active endpoint Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to bind Completion Queue for active endpoint, error {} - {}",
ret,
fi_strerror(-ret)));
}
if (m_capabilities.receive) {
if (const auto ret = fi_control(&ep.rcq->fid, FI_GETWAIT, &ep.rcqfd)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to retrieve wait object for receive Completion Queue, error {} - {}",
ret,
fi_strerror(-ret)));
}
}
}
void netio3::libfabric::ActiveEndpoint::bind_srx_context(Endpoint& ep, fid_ep* shx_ctx) const
{
ZoneScoped;
if (shx_ctx != nullptr) {
if (const auto ret = fi_ep_bind(ep.ep.get(), &shx_ctx->fid, 0)) {
throw FailedOpenActiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Failed to bind shared receive context, error {} - {}",
ret,
fi_strerror(-ret)));
}
}
}
netio3::libfabric::FiInfoUniquePtr netio3::libfabric::ActiveEndpoint::get_info(
const EndPointAddress& address,
NetworkMode mode,
std::uint64_t info_flags)
{
ZoneScoped;
try {
return FiInfoUniquePtr{get_fi_info(address, mode, info_flags)};
} catch (const LibfabricFiInfoError& e) {
throw FailedOpenActiveEndpoint(
address.address(),
address.port(),
std::format("Failed to get info on local interface {}", e.what()));
}
}