Program Listing for File SendSocket.cpp
↰ Return to documentation for file (BackendLibfabric/SendSocket.cpp
)
#include "SendSocket.hpp"
#include <utility>
#include <tracy/Tracy.hpp>
#include <rdma/fabric.h>
#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_eq.h>
#include "BaseSocket.hpp"
#include "Helpers.hpp"
#include "Issues.hpp"
netio3::libfabric::SendSocket::SendSocket(EndPointAddress address,
NetworkMode mode,
fid_fabric* fabric,
fid_domain* domain) :
BaseSocket{create_endpoint(address, mode, fabric, domain)}, m_addr{std::move(address)}
{}
void netio3::libfabric::SendSocket::init()
{
ZoneScoped;
/* Connect to server */
if (const auto ret =
fi_connect(get_endpoint().ep.get(), get_endpoint().fi->dest_addr, nullptr, 0)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
m_addr.address(),
m_addr.port(),
std::format("Connection to remote failed, error {} - {}", ret, fi_strerror(-ret)));
}
if (const auto ret = fi_control(&get_endpoint().eq->fid, FI_GETWAIT, &get_endpoint().eqfd)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
m_addr.address(),
m_addr.port(),
std::format("Cannot retrieve the Event Queue wait object of send socket, error {} - {}",
ret,
fi_strerror(-ret)));
}
ERS_DEBUG(1, std::format("EV context with FD: {}", get_eq_ev_ctx().fd));
}
netio3::NetioStatus netio3::libfabric::SendSocket::send_data(const std::span<const iovec> data,
const std::span<fid_mr*> mrs,
std::uint64_t key) const
{
ZoneScoped;
if (data.size() != mrs.size()) {
ers::error(FailedSend(ERS_HERE,
get_address().address(),
get_address().port(),
"Failed sending message because of mismatched data and memory regions."));
return NetioStatus::FAILED;
}
std::vector<void*> descs{};
descs.reserve(mrs.size());
std::transform(
mrs.begin(), mrs.end(), std::back_inserter(descs), [](auto* mr) { return fi_mr_desc(mr); });
fi_msg msg{};
msg.msg_iov = data.data(); /* scatter-gather array */
msg.desc = descs.data();
msg.iov_count = data.size();
msg.addr = 0;
msg.context = reinterpret_cast<void*>(key);
msg.data = 0;
if (get_endpoint().ep == nullptr || get_endpoint().ep->msg == nullptr) {
ers::error(FailedSend(ERS_HERE,
get_address().address(),
get_address().port(),
"Failed sending message because of null message or null endpoint."));
return NetioStatus::FAILED;
}
ERS_DEBUG(2, std::format("sending iov message with key {}", msg.context));
const uint64_t flags = FI_INJECT_COMPLETE; // | FI_INJECT;
const auto ret = fi_sendmsg(get_endpoint().ep.get(), &msg, flags);
if (ret == -FI_EAGAIN) {
ERS_DEBUG(1, "Send failed with result: EAGAIN");
return NetioStatus::NO_RESOURCES;
}
ERS_DEBUG(1, std::format("Send completed with result: {}", ret));
if (ret != 0) {
ers::error(FailedSend(ERS_HERE,
get_address().address(),
get_address().port(),
std::format("Failed to send message error (IOV count 1, key {}) - {}",
key,
fi_strerror(-ret))));
return NetioStatus::FAILED;
}
return NetioStatus::OK;
}
netio3::libfabric::Endpoint netio3::libfabric::SendSocket::create_endpoint(
const EndPointAddress& address,
NetworkMode mode,
fid_fabric* fabric,
fid_domain* domain)
{
ZoneScoped;
auto ep = Endpoint{};
ep.fi = get_info(address, mode);
try {
open_endpoint(address, ep, fabric, domain, ep.fi.get());
open_cq(address, ep, domain);
enable_endpoint(address, ep);
} catch (const FailedOpenEndpoint& e) {
throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
return ep;
}
netio3::libfabric::FiInfoUniquePtr netio3::libfabric::SendSocket::get_info(
const EndPointAddress& address,
NetworkMode mode)
{
ZoneScoped;
auto hints = FiInfoWrapper{};
hints.get()->addr_format = FI_FORMAT_UNSPEC;
hints.get()->ep_attr->type = FI_EP_MSG;
hints.get()->caps = FI_MSG;
hints.get()->mode = FI_LOCAL_MR;
// As of libfabric 1.10, the tcp provider only support FI_PROGRESS_MANUAL
// So the following will not allow the tcp provider to be used
hints.get()->domain_attr->data_progress = FI_PROGRESS_AUTO;
hints.get()->domain_attr->resource_mgmt = FI_RM_ENABLED;
try {
hints.get()->fabric_attr->prov_name = strdup(get_provider(mode).c_str());
} catch (const LibFabricError& e) {
throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
std::uint64_t flags = 0;
const auto port_str = std::to_string(address.port());
fi_info* info{nullptr};
const auto ret = fi_getinfo(
FI_VERSION(1, 1), address.address().c_str(), port_str.c_str(), flags, hints.get(), &info);
if (ret != 0) {
throw FailedOpenSendEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format("Failed to initialise socket, error {} - {}", ret, fi_strerror(-ret)));
}
return FiInfoUniquePtr{info};
}
void netio3::libfabric::SendSocket::open_cq(const EndPointAddress& address,
Endpoint& ep,
fid_domain* domain)
{
ZoneScoped;
auto cq_attr = prepare_cq_attr();
// FI_TRANSMIT CQ
fid_cq* cq = nullptr;
if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
ep.cq = FiCloseUniquePtr<fid_cq>(
cq, FiCloseDeleter<fid_cq>(address, "Failed to close send socket Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
cq_attr.format = FI_CQ_FORMAT_UNSPEC;
cq_attr.wait_obj = FI_WAIT_NONE;
// FI_RECV CQ
fid_cq* rcq = nullptr;
if (const auto ret = fi_cq_open(domain, &cq_attr, &rcq, nullptr)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
ep.rcq = FiCloseUniquePtr<fid_cq>(
rcq, FiCloseDeleter<fid_cq>(address, "Failed to close FI_RECV send socket Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
throw FailedOpenSendEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
}