Program Listing for File ReceiveSocket.cpp
↰ Return to documentation for file (BackendLibfabric/ReceiveSocket.cpp
)
#include "ReceiveSocket.hpp"
#include <rdma/fi_cm.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>
#include <tracy/Tracy.hpp>
#include "BaseSocket.hpp"
#include "DomainManager.hpp"
#include "Helpers.hpp"
#include "Issues.hpp"
#include "netio3-backend/Issues.hpp"
netio3::libfabric::ReceiveSocket::ReceiveSocket(ListenSocket& lsocket,
fid_fabric* fabric,
DomainContext& domain,
FiInfoUniquePtr&& info,
BaseEventLoop* event_loop) :
BaseSocket{create_endpoint(lsocket.get_address(), fabric, domain.get_domain(), std::move(info))},
m_conn_params{prepare_connection_parameters(lsocket.get_connection_parameters())},
m_address{peer_address(get_endpoint().ep.get())},
m_retry_post_signal{event_loop->create_signal(
[this](int) { retry_post_buffers(); }, false)}
{
init_buffers(domain);
ERS_DEBUG(1,
std::format("connection accepted. Lsocket address: {}:{}, rsocket EQ {}",
lsocket.get_address().address(),
lsocket.get_address().port(),
get_eq_ev_ctx().fd));
ERS_DEBUG(2, std::format("Allocating memory regions number of buffers: {}", m_buffers.size()));
post_buffers();
m_retry_post_buffers.reserve(m_conn_params.num_buf);
const int ret = fi_accept(get_endpoint().ep.get(), nullptr, 0);
if (ret != 0) {
fi_reject(lsocket.get_pep(), info->handle, nullptr, 0);
for (const auto& buf : m_buffers) {
close_buffer(buf);
}
throw FailedOpenReceiveEndpoint(
ERS_HERE,
m_address.address(),
m_address.port(),
std::format("Listen socket, connection rejected, error {} - {}", ret, fi_strerror(-ret)));
}
ERS_DEBUG(2, std::format("Connection done, fd: {}", get_eq_ev_ctx().fd));
}
netio3::libfabric::ReceiveSocket::~ReceiveSocket()
{
ERS_DEBUG(2, "Entered");
for (const auto& buf : m_buffers) {
close_buffer(buf);
}
}
netio3::libfabric::Endpoint netio3::libfabric::ReceiveSocket::create_endpoint(
const EndPointAddress& address,
fid_fabric* fabric,
fid_domain* domain,
FiInfoUniquePtr&& info)
{
ZoneScoped;
auto ep = Endpoint{};
ep.fi = std::move(info);
try {
open_endpoint(address, ep, fabric, domain, ep.fi.get());
open_cq(address, ep, domain);
enable_endpoint(address, ep);
} catch (const FailedOpenEndpoint& e) {
throw FailedOpenReceiveEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
return ep;
}
void netio3::libfabric::ReceiveSocket::open_cq(const EndPointAddress& address,
Endpoint& ep,
fid_domain* domain)
{
ZoneScoped;
auto cq_attr = prepare_cq_attr();
fid_cq* rcq = nullptr;
if (const auto ret = fi_cq_open(domain, &cq_attr, &rcq, nullptr)) {
throw FailedOpenReceiveEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format("Failed to open Completion Queue for receive socket, error {} - {}",
ret,
fi_strerror(-ret)));
}
ep.rcq = FiCloseUniquePtr<fid_cq>(
rcq, FiCloseDeleter<fid_cq>(address, "Failed to close receive socket Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) {
throw FailedOpenReceiveEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format("Failed to bind Completion Queue for receive socket, error {} - {}",
ret,
fi_strerror(-ret)));
}
cq_attr.format = FI_CQ_FORMAT_UNSPEC;
cq_attr.wait_obj = FI_WAIT_NONE;
// FI_TRANSMIT CQ
fid_cq* cq = nullptr;
if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) {
throw FailedOpenReceiveEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
ep.cq = FiCloseUniquePtr<fid_cq>(
cq, FiCloseDeleter<fid_cq>(address, "Failed to close FI_RECV receive socket Completion Queue"));
if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) {
throw FailedOpenReceiveEndpoint(
ERS_HERE,
address.address(),
address.port(),
std::format(
"Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret)));
}
}
void netio3::libfabric::ReceiveSocket::init_buffers(DomainContext& domain)
{
m_buffers.reserve(m_conn_params.num_buf);
for (uint64_t i = 0; i < m_conn_params.num_buf; i++) {
// Key does not matter for receive buffers
m_buffers.emplace_back(domain, m_conn_params.buf_size, 0);
try {
register_buffer(m_buffers.back(), domain, FI_RECV);
} catch (const LibFabricBufferError& e) {
throw FailedOpenReceiveEndpoint(ERS_HERE, m_address.address(), m_address.port(), e.message());
}
}
}
void netio3::libfabric::ReceiveSocket::post_buffers()
{
ZoneScoped;
for (auto& buf : m_buffers) {
post_buffer(&buf);
}
}
void netio3::libfabric::ReceiveSocket::post_buffer(Buffer* buf)
{
ZoneScoped;
const auto ret = do_post_buffer(buf);
if (ret == -FI_EAGAIN) {
m_retry_post_buffers.push_back(buf);
m_retry_post_signal.fire();
return;
}
if (ret != 0) {
ers::error(LibFabricBufferError(
ERS_HERE,
std::format("Failed to post a buffer to receive inbound messages, error {} - {}",
ret,
fi_strerror(-ret))));
}
}
void netio3::libfabric::ReceiveSocket::retry_post_buffers()
{
ZoneScoped;
auto still_needs_retry = std::vector<Buffer*>{};
still_needs_retry.reserve(m_retry_post_buffers.size());
for (auto* buf : m_retry_post_buffers) {
const auto ret = do_post_buffer(buf);
if (ret == -FI_EAGAIN) {
still_needs_retry.push_back(buf);
} else if (ret != 0) [[unlikely]] {
ers::error(LibFabricBufferError(
ERS_HERE,
std::format("Failed to post a buffer to receive inbound messages, error {} - {}",
ret,
fi_strerror(-ret))));
}
}
if (not still_needs_retry.empty()) {
m_retry_post_buffers.swap(still_needs_retry);
m_retry_post_signal.fire();
} else {
m_retry_post_buffers.clear();
}
}
int netio3::libfabric::ReceiveSocket::do_post_buffer(Buffer* buf) const
{
ZoneScoped;
ERS_DEBUG(2, "Entered");
iovec iov{buf->data().data(), buf->size()};
void* desc = fi_mr_desc(buf->mr);
fi_msg msg{};
msg.msg_iov = &iov; /* scatter-gather array */
msg.desc = &desc;
msg.iov_count = 1;
msg.addr = 0;
msg.context = buf;
msg.data = 0;
std::uint64_t flags = FI_REMOTE_CQ_DATA; // FI_MULTI_RECV;
return fi_recvmsg(get_endpoint().ep.get(), &msg, flags);
}
netio3::ConnectionParametersRecv netio3::libfabric::ReceiveSocket::prepare_connection_parameters(
const ConnectionParametersRecv& requested) const
{
auto params = requested;
const auto& info = get_endpoint().fi;
if (params.num_buf > info->rx_attr->size) {
ers::warning(TooManyBuffersRequested(params.num_buf, info->rx_attr->size));
params.num_buf = info->rx_attr->size;
}
return params;
}