Program Listing for File ConnectionManager.cpp
↰ Return to documentation for file (BackendLibfabric/ConnectionManager.cpp
)
#include "ConnectionManager.hpp"
#include <algorithm>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <stdexcept>
#include <utility>
#include <tracy/Tracy.hpp>
#include <rdma/fabric.h>
#include <rdma/fi_eq.h>
#include "Issues.hpp"
#include "ReceiveSocket.hpp"
#include "SendSocketZeroCopy.hpp"
#include "netio3-backend/EventLoop/BaseEventLoop.hpp"
#include "netio3-backend/Netio3Backend.hpp"
#include "netio3-backend/Issues.hpp"
netio3::libfabric::ConnectionManager::ConnectionManager(BaseEventLoop* event_loop,
NetworkConfig config) :
m_config{std::move(config)},
m_event_loop{event_loop},
m_ssockets_buffered{config.thread_safety == ThreadSafetyModel::SAFE},
m_ssockets_zero_copy{config.thread_safety == ThreadSafetyModel::SAFE}
{}
netio3::libfabric::ConnectionManager::~ConnectionManager()
{
if (m_event_loop->get_thread_id() == std::this_thread::get_id()) {
unregister_all();
} else if (m_event_loop->is_running()) {
std::mutex mutex;
std::condition_variable cv;
std::atomic_bool unregistered{false};
const auto signal = m_event_loop->create_signal(
[this, &unregistered, &cv](int) {
unregister_all();
unregistered = true;
cv.notify_one();
},
false);
signal.fire();
std::unique_lock<std::mutex> lock(mutex);
while (m_event_loop->is_running() and !unregistered.load()) {
constexpr static auto timeout = std::chrono::seconds(1);
if (cv.wait_for(lock, timeout, [&unregistered] { return unregistered.load(); })) {
break;
}
}
}
}
std::shared_ptr<netio3::libfabric::ConnectionManager> netio3::libfabric::ConnectionManager::create(
BaseEventLoop* event_loop,
NetworkConfig config)
{
auto instance = std::make_shared<ConnectionManager>(event_loop, config);
instance->init();
return instance;
}
void netio3::libfabric::ConnectionManager::init()
{
m_close_signal = m_event_loop->create_signal(
[weak_this = weak_from_this()](int) {
if (auto shared_this = weak_this.lock()) {
shared_this->handle_close_requests();
}
},
false);
}
netio3::EndPointAddress netio3::libfabric::ConnectionManager::open_listen_endpoint(
const EndPointAddress& address,
ConnectionParametersRecv conn_params)
{
if (m_lsockets.contains(address)) {
throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
try {
check_and_init(address, FI_RECV);
} catch (const LibfabricDomainError& e) {
throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
auto lsocket_to_insert =
ListenSocket(address, conn_params, m_config.mode, m_domain_manager->get_fabric());
const auto pair =
m_lsockets.try_emplace(lsocket_to_insert.get_address(), std::move(lsocket_to_insert));
auto& lsocket = pair.first->second;
const auto eq_ev_ctx = EventContext{
lsocket.get_endpoint().eqfd, [this, &lsocket](int) { on_listen_socket_cm_event(lsocket); }};
lsocket.set_eq_ev_ctx(eq_ev_ctx);
m_event_loop->register_fd(eq_ev_ctx);
return lsocket.get_address();
}
void netio3::libfabric::ConnectionManager::open_send_endpoint_buffered(
const EndPointAddress& address,
ConnectionParameters conn_params)
{
if (m_ssockets_buffered.contains(address)) {
throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
try {
check_and_init(address, FI_SEND);
} catch (const LibfabricDomainError& e) {
throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
m_ssockets_buffered.try_emplace(address,
address,
conn_params,
m_config.mode,
m_domain_manager->get_fabric(),
m_domain_manager->get_send_domain());
const auto ctx =
m_ssockets_buffered.apply(address, [this, &address](SendSocketBuffered& ssocket) {
const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) {
on_send_socket_cm_event<SendSocketBuffered>(address);
}};
ssocket.set_eq_ev_ctx(eq_ev_ctx);
return eq_ev_ctx;
});
m_event_loop->register_fd(ctx);
}
void netio3::libfabric::ConnectionManager::open_send_endpoint_zero_copy(
const EndPointAddress& address,
ConnectionParameters conn_params)
{
if (m_ssockets_zero_copy.contains(address)) {
throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
try {
check_and_init(address, FI_SEND);
} catch (const LibfabricDomainError& e) {
throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message());
}
m_ssockets_zero_copy.try_emplace(address,
address,
conn_params,
m_config.mode,
m_domain_manager->get_fabric(),
m_domain_manager->get_send_domain());
const auto ctx =
m_ssockets_zero_copy.apply(address, [this, &address](SendSocketZeroCopy& ssocket) {
const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) {
on_send_socket_cm_event<SendSocketZeroCopy>(address);
}};
ssocket.set_eq_ev_ctx(eq_ev_ctx);
return eq_ev_ctx;
});
m_event_loop->register_fd(ctx);
}
void netio3::libfabric::ConnectionManager::close_listen_endpoint(const EndPointAddress& address)
{
if (not m_lsockets.contains(address)) {
throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port());
}
m_close_queue.push({address, CloseQueueItem::Type::listen});
m_close_signal->fire();
}
void netio3::libfabric::ConnectionManager::close_send_endpoint(const EndPointAddress& address)
{
if (not m_ssockets_buffered.contains(address) and not m_ssockets_zero_copy.contains(address)) {
throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
}
m_close_queue.push({address, CloseQueueItem::Type::send});
m_close_signal->fire();
}
std::size_t netio3::libfabric::ConnectionManager::get_num_available_buffers(
const EndPointAddress& address)
{
if (m_ssockets_buffered.contains(address)) {
return m_ssockets_buffered.apply(
address, [](SendSocketBuffered& ssocket) { return ssocket.get_num_available_buffers(); });
}
if (m_ssockets_zero_copy.contains(address)) {
return m_ssockets_zero_copy.apply(
address, [](SendSocketZeroCopy& ssocket) { return ssocket.get_num_available_buffers(); });
}
throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
}
void netio3::libfabric::ConnectionManager::check_and_init(const EndPointAddress& address,
const std::uint64_t flags)
{
if (m_domain_manager == nullptr) {
m_domain_manager = std::make_unique<DomainManager>(m_config.mode, address, flags);
}
if (m_cq_reactor == nullptr) {
m_cq_reactor = std::make_unique<CqReactor>(m_domain_manager->get_fabric(), m_config.on_data_cb);
}
}
netio3::libfabric::ConnectionManager::CmEvent netio3::libfabric::ConnectionManager::read_cm_event(
fid_eq* eq)
{
ZoneScoped;
std::uint32_t event{};
fi_eq_cm_entry entry{};
fi_info* info{nullptr};
fi_eq_err_entry err_entry{};
const auto rd = fi_eq_sread(eq, &event, &entry, sizeof(entry), 0, 0);
if (rd < 0) {
if (rd == -FI_EAGAIN) {
return {rd, err_entry, FiInfoUniquePtr{info}};
}
if (rd == -FI_EAVAIL) {
const auto r = fi_eq_readerr(eq, &err_entry, 0);
if (r < 0) {
ers::error(LibFabricCmError(
ERS_HERE, std::format("Failed to retrieve details on Event Queue error {}", r)));
}
ers::error(LibFabricCmError(
ERS_HERE,
std::format("Event Queue error: {} (code: {}), provider specific: {} (code: {})",
fi_strerror(err_entry.err),
err_entry.err,
fi_eq_strerror(eq, err_entry.prov_errno, err_entry.err_data, nullptr, 0),
err_entry.prov_errno)));
return {rd, err_entry, FiInfoUniquePtr{info}};
}
}
if (rd != sizeof(entry)) {
ers::error(LibFabricCmError(ERS_HERE, std::format("Failed to read from Event Queue: {}", rd)));
}
return {event, err_entry, FiInfoUniquePtr{entry.info}};
}
void netio3::libfabric::ConnectionManager::on_listen_socket_cm_event(ListenSocket& lsocket)
{
ZoneScoped;
ERS_DEBUG(1, "listen socket: connection event");
auto event = read_cm_event(lsocket.get_endpoint().eq.get());
switch (event.event) {
case FI_CONNREQ: {
ERS_DEBUG(2, ": FI_CONNREQ");
handle_connection_request(lsocket, std::move(event.info));
} break;
case FI_CONNECTED:
throw LibFabricCmError(ERS_HERE, "FI_CONNECTED received on listen socket");
case FI_SHUTDOWN:
throw LibFabricCmError(ERS_HERE, "FI_SHUTDOWN received on listen socket");
case -FI_EAGAIN: {
auto* fp = &lsocket.get_endpoint().eq->fid;
fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
break;
}
case -FI_EAVAIL:
ers::error(LibFabricCmError(
ERS_HERE,
std::format("Unhandled error in listen socket EQ code: {} provider specific code: {}",
event.err_entry.err,
event.err_entry.prov_errno)));
break;
default:
ers::warning(LibFabricCmError(
ERS_HERE, std::format("Unexpected event {} in listen socket Event Queue", event.event)));
break;
}
}
void netio3::libfabric::ConnectionManager::on_recv_socket_cm_event(ReceiveSocket& rsocket)
{
ZoneScoped;
ERS_DEBUG(2, "Entered");
const auto event = read_cm_event(rsocket.get_endpoint().eq.get());
switch (event.event) {
case FI_CONNECTED: {
ERS_DEBUG(2, "FI_CONNECTED");
if (const auto ret =
fi_control(&rsocket.get_endpoint().rcq->fid, FI_GETWAIT, &rsocket.get_endpoint().cqfd)) {
ers::error(LibFabricError(
ERS_HERE,
std::format(
"Failed to retrieve recv socket Completion Queue wait object: error {} cqfd: {}",
ret,
rsocket.get_endpoint().cqfd)));
}
rsocket.set_cq_ev_ctx({rsocket.get_endpoint().cqfd, [this, &rsocket](int) {
m_cq_reactor->on_recv_socket_cq_event(rsocket);
}});
ERS_DEBUG(1,
std::format("recv_socket: EQ fd {} connected, CQ fd {}",
rsocket.get_endpoint().eqfd,
rsocket.get_endpoint().cqfd));
m_event_loop->register_fd(rsocket.get_cq_ev_ctx());
ERS_DEBUG(1,
std::format("Adding recv CQ polled fid {} {}",
rsocket.get_endpoint().cqfd,
static_cast<void*>(&rsocket.get_endpoint().cq->fid)));
ERS_DEBUG(1,
std::format("recv_socket: EQ fd {} CQ fd {} connected",
rsocket.get_endpoint().eqfd,
rsocket.get_endpoint().cqfd));
const auto address = peer_address(rsocket.get_endpoint().ep.get());
ERS_INFO(std::format("Opened receive connection to {}:{}", address.address(), address.port()));
if (m_config.on_connection_established_cb != nullptr) {
m_config.on_connection_established_cb(address);
}
break;
}
case FI_SHUTDOWN: {
ERS_DEBUG(2, "FI_SHUTDOWN");
const auto address = peer_address(rsocket.get_endpoint().ep.get());
ERS_INFO(std::format("Closed receive connection to {}:{} because received FI_SHUTDOWN",
address.address(),
address.port()));
if (m_config.on_connection_closed_cb != nullptr) {
m_config.on_connection_closed_cb(address, {});
}
std::lock_guard lock(m_rsocket_mutex);
const auto entry = std::ranges::find_if(m_rsockets, [&rsocket](const auto& pair) {
return pair.second.get_endpoint().eqfd == rsocket.get_endpoint().eqfd;
});
if (entry == std::cend(m_rsockets)) {
ers::error(LibFabricError(ERS_HERE, "Failed to find ReceiveSocket in ConnectionManager"));
} else {
unregister_endpoint(rsocket.get_cq_cm_fds());
m_rsockets.erase(entry);
}
break;
}
case FI_MR_COMPLETE:
case FI_AV_COMPLETE:
case FI_JOIN_COMPLETE:
// Not implemented
break;
case -FI_EAGAIN: {
auto* fp = &rsocket.get_endpoint().eq->fid;
fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
break;
}
case -FI_EAVAIL:
ers::error(LibFabricCmError(
ERS_HERE,
std::format("Unhandled error in recv socket EQ code: {} provider specific code: {}",
event.err_entry.err,
event.err_entry.prov_errno)));
break;
default:
throw LibFabricCmError(
ERS_HERE, std::format("Unexpected event {} in recv socket Event Queue", event.event));
}
}
template<netio3::libfabric::SendSocketConcept SocketType>
void netio3::libfabric::ConnectionManager::on_send_socket_cm_event(const EndPointAddress& address)
{
ZoneScoped;
ERS_DEBUG(1, "send socket: connection event");
auto& map = [this]() -> ThreadSafeMap<EndPointAddress, SocketType>& {
if constexpr (std::is_same_v<SocketType, SendSocketZeroCopy>) {
return m_ssockets_zero_copy;
} else if (std::is_same_v<SocketType, SendSocketBuffered>) {
return m_ssockets_buffered;
} else {
throw LibFabricLogicError(ERS_HERE, "Unknown socket type");
}
}();
const auto event = map.apply(
address, [](SocketType& ssocket) { return read_cm_event(ssocket.get_endpoint().eq.get()); });
switch (event.event) {
case FI_SHUTDOWN: {
ERS_INFO(std::format("Closed send connection to {}:{} because received FI_SHUTDOWN",
address.address(),
address.port()));
const auto fds =
map.apply(address, [](SocketType& ssocket) { return ssocket.get_cq_cm_fds(); });
unregister_endpoint(fds);
const auto pending_sends =
map.apply(address, [](SocketType& ssocket) { return ssocket.get_pending_sends(); });
map.erase(address);
if (m_config.on_connection_closed_cb != nullptr) {
m_config.on_connection_closed_cb(address, pending_sends);
}
return;
}
case FI_CONNECTED: {
const auto ctx = map.apply(address, [this, &map, &address](SocketType& ssocket) {
if (const auto ret =
fi_control(&ssocket.get_endpoint().cq->fid, FI_GETWAIT, &ssocket.get_endpoint().cqfd)) {
ers::error(LibFabricError(
ERS_HERE,
std::format(
"Failed to retrieve wait object for send socket Completion Queue, error {} - {}",
ret,
fi_strerror(-ret))));
}
ssocket.set_cq_ev_ctx({ssocket.get_endpoint().cqfd, [this, &map, address](int) {
try {
const auto keys =
map.apply(address, [this](SocketType& ssocket_cb) {
return m_cq_reactor->on_send_socket_cq_event(ssocket_cb);
});
if (m_config.on_send_completed_cb != nullptr) {
for (const auto key : keys) {
m_config.on_send_completed_cb(address, key);
}
}
} catch (const std::out_of_range& e) {
return;
}
}});
ERS_DEBUG(1,
std::format("send_socket: EQ fd {} connected, CQ fd {}",
ssocket.get_endpoint().eqfd,
ssocket.get_endpoint().cqfd));
return ssocket.get_cq_ev_ctx();
});
ERS_INFO(std::format("Opened send connection to {}:{}", address.address(), address.port()));
m_event_loop->register_fd(ctx);
if (m_config.on_connection_established_cb != nullptr) {
m_config.on_connection_established_cb(address);
}
return;
}
case FI_MR_COMPLETE:
case FI_AV_COMPLETE:
case FI_JOIN_COMPLETE:
// Not implemented
break;
case -FI_EAVAIL:
switch (event.err_entry.err) {
case FI_ECONNREFUSED:
ERS_DEBUG(1, "Connection refused (FI_ECONNREFUSED). Deallocating send_socket resources");
m_event_loop->remove_fd(
map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eqfd; }));
map.erase(address);
if (m_config.on_connection_refused_cb != nullptr) {
m_config.on_connection_refused_cb(address);
}
break;
case FI_ETIMEDOUT:
ers::warning(
LibFabricCmError(ERS_HERE, "fi_verbs_process_send_socket_cm_event: FI_ETIMEDOUT"));
// if (socket->eqfd < 0 ){
// log_info("Ignoring FI_SHUTDOWN on send_socket, invalid eqfd (socket already closed)");
// break;
// }
// // Need to take care of receive socket as well
// if (socket->recv_socket != NULL){
// if (socket->recv_socket->eqfd < 0 ){
// log_info("Ignoring FI_ETIMEDOUT on recv_socket, invalid eqfd (socket already
// closed)"); return;
// }
// log_info("Shutting down receive socket on FI_ETIMEDOUT");
// handle_recv_socket_shutdown(socket->recv_socket);
// if(socket->recv_socket->lsocket->cb_connection_closed) {
// socket->recv_socket->lsocket->cb_connection_closed(socket->recv_socket);
// }
// }
// if(socket->cqfd < 0){ //cq not initalized yet
// handle_send_socket_shutdown_on_connetion_refused(socket);
// } else {
// if(socket->cb_internal_connection_closed){
// socket->cb_internal_connection_closed(socket);
// }
// if(socket->cb_connection_closed) {
// socket->cb_connection_closed(socket);
// }
// }
break;
default:
ers::error(LibFabricError(
ERS_HERE,
std::format(
"Unhandled error in the Event Queue: {} (code: {}), provider specific: {} (code: {})",
fi_strerror(event.err_entry.err),
event.err_entry.err,
fi_eq_strerror(
map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eq.get(); }),
event.err_entry.prov_errno,
event.err_entry.err_data,
nullptr,
0),
event.err_entry.prov_errno)));
}
return;
case -FI_EAGAIN: {
auto* fp =
map.apply(address, [](SocketType& ssocket) { return &ssocket.get_endpoint().eq->fid; });
fi_trywait(m_domain_manager->get_fabric(), &fp, 1);
break;
}
default:
throw LibFabricCmError(
ERS_HERE, std::format("Unexpected event {} in send socket Event Queue", event.event));
}
}
void netio3::libfabric::ConnectionManager::handle_connection_request(ListenSocket& lsocket,
FiInfoUniquePtr&& info)
{
ZoneScoped;
// need to spawn new endpoint
ERS_DEBUG(1, "Received connection request");
std::lock_guard lock(m_rsocket_mutex);
const auto pair = m_rsockets.emplace(
std::piecewise_construct,
std::forward_as_tuple(lsocket.get_address()),
std::forward_as_tuple(
lsocket, m_domain_manager->get_fabric(), m_domain_manager->get_listen_domain(), std::move(info), m_event_loop));
auto& rsocket = pair->second;
ERS_DEBUG(1,
std::format("Created and connected endpoint. Eqfd: {}", rsocket.get_endpoint().eqfd));
const auto eq_ev_ctx = EventContext{rsocket.get_endpoint().eqfd,
[this, &rsocket](int) { on_recv_socket_cm_event(rsocket); }};
rsocket.set_eq_ev_ctx(eq_ev_ctx);
m_event_loop->register_fd(eq_ev_ctx);
}
void netio3::libfabric::ConnectionManager::unregister_endpoint(const CqCmFds& fds)
{
m_event_loop->remove_fd(fds.cm_fd);
m_event_loop->remove_fd(fds.cq_fd);
}
void netio3::libfabric::ConnectionManager::handle_close_requests()
{
ZoneScoped;
CloseQueueItem item;
while (m_close_queue.try_pop(item)) {
switch (item.type) {
case CloseQueueItem::Type::listen:
do_close_listen_endpoint(item.address);
break;
case CloseQueueItem::Type::send:
do_close_send_endpoint(item.address);
break;
}
}
}
void netio3::libfabric::ConnectionManager::do_close_listen_endpoint(const EndPointAddress& address)
{
ZoneScoped;
if (not m_lsockets.contains(address)) {
return;
}
// Listen socket has no CQ
m_event_loop->remove_fd(m_lsockets.at(address).get_cq_cm_fds().cm_fd);
std::lock_guard lock(m_rsocket_mutex);
const auto [begin, end] = m_rsockets.equal_range(address);
for (auto it = begin; it != end; ++it) {
unregister_endpoint(it->second.get_cq_cm_fds());
}
m_lsockets.erase(address);
m_rsockets.erase(address);
}
void netio3::libfabric::ConnectionManager::do_close_send_endpoint(const EndPointAddress& address)
{
ZoneScoped;
std::vector<std::uint64_t> pending_sends{};
if (m_ssockets_buffered.contains(address)) {
const auto fds =
m_ssockets_buffered.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); });
unregister_endpoint(fds);
pending_sends = m_ssockets_buffered.apply(address, [](auto& socket) {
return socket.get_pending_sends();
});
m_ssockets_buffered.erase(address);
} else if (m_ssockets_zero_copy.contains(address)) {
const auto fds =
m_ssockets_zero_copy.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); });
unregister_endpoint(fds);
pending_sends = m_ssockets_zero_copy.apply(address, [](SendSocketZeroCopy& socket) {
return socket.get_pending_sends();
});
m_ssockets_zero_copy.erase(address);
}
if (m_config.on_connection_closed_cb != nullptr) {
m_config.on_connection_closed_cb(address, pending_sends);
}
}
void netio3::libfabric::ConnectionManager::unregister_all()
{
ZoneScoped;
for (const auto& [ep, rsocket] : m_rsockets) {
unregister_endpoint(rsocket.get_cq_cm_fds());
}
for (const auto& [ep, lsocket] : m_lsockets) {
unregister_endpoint(lsocket.get_cq_cm_fds());
}
m_ssockets_buffered.apply_all([this] (const auto& ssocket) {
unregister_endpoint(ssocket.get_cq_cm_fds());
});
m_ssockets_zero_copy.apply_all([this] (const auto& ssocket) {
unregister_endpoint(ssocket.get_cq_cm_fds());
});
}
template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event<
netio3::libfabric::SendSocketBuffered>(const EndPointAddress&);
template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event<
netio3::libfabric::SendSocketZeroCopy>(const EndPointAddress&);