Program Listing for File ConnectionlessEndpointManager.cpp
↰ Return to documentation for file (BackendLibfabric/ConnectionlessEndpointManager.cpp)
#include "ConnectionlessEndpointManager.hpp"
#include <mutex>
#include <rdma/fabric.h>
#include "Helpers.hpp"
#include "netio3-backend/Netio3Backend.hpp"
netio3::libfabric::ConnectionlessEndpointManager::ConnectionlessEndpointManager(
NetworkConfig config,
const EndPointAddress& address,
BaseEventLoop* evloop,
std::uint64_t flags) :
m_config{std::move(config)},
m_event_loop{evloop},
m_domain_manager{m_config, address, flags},
m_cq_reactor{m_domain_manager.get_fabric(), m_config.callbacks.on_data_cb},
m_shared_buffer_manager{m_config, m_domain_manager, m_event_loop},
m_address_vector_manager{m_domain_manager.get_domain(), config.mode},
m_send_endpoint_buffered_addresses{m_config.thread_safety == ThreadSafetyModel::SAFE},
m_send_endpoint_zero_copy_addresses{m_config.thread_safety == ThreadSafetyModel::SAFE}
{}
std::shared_ptr<netio3::libfabric::ConnectionlessEndpointManager>
netio3::libfabric::ConnectionlessEndpointManager::create(const NetworkConfig& config,
EndPointAddress address,
BaseEventLoop* evloop,
std::uint64_t flags)
{
auto instance = std::make_shared<ConnectionlessEndpointManager>(config, address, evloop, flags);
instance->init();
return instance;
}
void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint(
const EndPointAddress& address,
const ConnectionParameters& connection_params)
{
ZoneScoped;
std::lock_guard lock(m_mutex);
validate_capabilities(connection_params);
if (connection_params.send_buffered_params.use_shared_send_buffers) {
open_send_endpoint_buffered(address);
}
if (connection_params.send_zero_copy_params.use_shared_send_buffers) {
open_send_endpoint_zero_copy(address);
}
}
void netio3::libfabric::ConnectionlessEndpointManager::open_receive_endpoint(
const EndPointAddress& address,
const ConnectionParameters& connection_params)
{
ZoneScoped;
std::lock_guard lock(m_mutex);
if (m_receive_endpoints.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
validate_capabilities(connection_params);
init_receive_endpoint(address, connection_params.recv_params);
m_open_queue.push({address, {.receive = true}});
m_open_signal->fire();
}
void netio3::libfabric::ConnectionlessEndpointManager::close_send_endpoint(
const EndPointAddress& address)
{
ZoneScoped;
std::lock_guard lock(m_mutex);
if (not m_send_endpoint_buffered_addresses.contains(address) and
not m_send_endpoint_zero_copy_addresses.contains(address)) {
throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}
m_close_queue.push(address);
m_close_signal->fire();
}
void netio3::libfabric::ConnectionlessEndpointManager::close_receive_endpoint(
const EndPointAddress& address)
{
ZoneScoped;
std::lock_guard lock(m_mutex);
if (not m_receive_endpoints.contains(address)) {
throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}
m_close_queue.push(address);
m_close_signal->fire();
}
std::size_t netio3::libfabric::ConnectionlessEndpointManager::get_num_available_buffers(
const EndPointAddress& address)
{
std::lock_guard lock(m_mutex);
if (m_send_endpoint_buffered_addresses.contains(address)) {
return m_send_endpoint_buffered->get_num_available_buffers();
}
if (m_send_endpoint_zero_copy_addresses.contains(address)) {
return m_send_endpoint_zero_copy->get_num_available_buffers();
}
throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}
void netio3::libfabric::ConnectionlessEndpointManager::init()
{
m_close_signal = m_event_loop->create_signal(
[weak_this = weak_from_this()](int) {
if (auto shared_this = weak_this.lock()) {
shared_this->handle_close_requests();
}
},
false);
m_open_signal = m_event_loop->create_signal(
[weak_this = weak_from_this()](int) {
if (auto shared_this = weak_this.lock()) {
shared_this->handle_open_requests();
}
},
false);
}
void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint_buffered(
const EndPointAddress& address)
{
ZoneScoped;
if (m_send_endpoint_buffered_addresses.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
if (not m_active_endpoint_send_buffered) {
init_buffered_send_endpoint(address);
}
const auto fi_addr = m_address_vector_manager.add_address(address);
m_send_endpoint_buffered_addresses.try_emplace(address, fi_addr);
m_open_queue.push({address, {.send_buffered = true}});
m_open_signal->fire();
}
void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint_zero_copy(
const EndPointAddress& address)
{
ZoneScoped;
if (m_send_endpoint_zero_copy_addresses.contains(address)) {
throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
}
if (not m_active_endpoint_send_zero_copy) {
init_zero_copy_send_endpoint(address);
}
const auto fi_addr = m_address_vector_manager.add_address(address);
m_send_endpoint_zero_copy_addresses.try_emplace(address, fi_addr);
m_open_queue.push({address, {.send_zero_copy = true}});
m_open_signal->fire();
}
void netio3::libfabric::ConnectionlessEndpointManager::init_buffered_send_endpoint(
const EndPointAddress& address)
{
m_active_endpoint_send_buffered =
std::make_unique<ActiveEndpoint>(address,
m_config.mode,
EndpointCapabilities{.send_buffered = true},
m_domain_manager.get_fabric(),
m_domain_manager.get_domain(),
0,
nullptr,
m_address_vector_manager.get_av());
m_send_endpoint_buffered =
std::make_unique<SendEndpointBuffered>(*m_active_endpoint_send_buffered,
m_config.conn_params.send_buffered_params,
m_shared_buffer_manager.get_send_buffer_manager(),
m_domain_manager);
m_active_endpoint_send_buffered->set_cq_ev_ctx(
{m_active_endpoint_send_buffered->get_endpoint().cqfd, [this, address](int) {
const auto keys = m_cq_reactor.on_send_cq_event(*m_send_endpoint_buffered);
if (m_config.callbacks.on_send_completed_cb != nullptr) {
for (const auto key : keys) {
m_config.callbacks.on_send_completed_cb(address, key);
}
}
}});
m_event_loop->register_fd(m_active_endpoint_send_buffered->get_cq_ev_ctx());
}
void netio3::libfabric::ConnectionlessEndpointManager::init_zero_copy_send_endpoint(
const EndPointAddress& address)
{
m_active_endpoint_send_zero_copy =
std::make_unique<ActiveEndpoint>(address,
m_config.mode,
EndpointCapabilities{.send_zero_copy = true},
m_domain_manager.get_fabric(),
m_domain_manager.get_domain(),
0,
nullptr,
m_address_vector_manager.get_av());
m_send_endpoint_zero_copy =
std::make_unique<SendEndpointZeroCopy>(*m_active_endpoint_send_zero_copy,
m_config.conn_params.send_zero_copy_params,
m_shared_buffer_manager.get_zero_copy_buffer_manager(),
m_domain_manager);
m_active_endpoint_send_zero_copy->set_cq_ev_ctx(
{m_active_endpoint_send_zero_copy->get_endpoint().cqfd, [this, address](int) {
const auto keys = m_cq_reactor.on_send_cq_event(*m_send_endpoint_zero_copy);
if (m_config.callbacks.on_send_completed_cb != nullptr) {
for (const auto key : keys) {
m_config.callbacks.on_send_completed_cb(address, key);
}
}
}});
m_event_loop->register_fd(m_active_endpoint_send_buffered->get_cq_ev_ctx());
}
void netio3::libfabric::ConnectionlessEndpointManager::init_receive_endpoint(
const EndPointAddress& address,
const ConnectionParametersRecv& connection_params)
{
m_active_endpoint_receive.try_emplace(address,
address,
m_config.mode,
EndpointCapabilities{.receive = true},
m_domain_manager.get_fabric(),
m_domain_manager.get_domain(),
FI_SOURCE,
nullptr,
m_address_vector_manager.get_av());
m_receive_endpoints.try_emplace(address,
m_active_endpoint_receive.at(address),
connection_params,
m_shared_buffer_manager.get_receive_context_manager(),
m_domain_manager,
m_event_loop);
m_active_endpoint_receive.at(address).set_rcq_ev_ctx(
{m_active_endpoint_receive.at(address).get_endpoint().rcqfd, [this, address](int) {
m_cq_reactor.on_recv_cq_event(m_receive_endpoints.at(address));
}});
m_event_loop->register_fd(m_active_endpoint_receive.at(address).get_rcq_ev_ctx());
}
void netio3::libfabric::ConnectionlessEndpointManager::handle_close_requests()
{
ZoneScoped;
EndPointAddress address;
while (m_close_queue.try_pop(address)) {
do_close_endpoint(address);
}
}
void netio3::libfabric::ConnectionlessEndpointManager::do_close_endpoint(
const EndPointAddress& address)
{
ZoneScoped;
std::lock_guard lock(m_mutex);
auto pending_sends = std::invoke([this, &address]() -> std::vector<std::uint64_t> {
if (m_send_endpoint_zero_copy_addresses.contains(address)) {
return m_send_endpoint_zero_copy->get_pending_sends();
}
return {};
});
if (m_receive_endpoints.contains(address)) {
m_event_loop->remove_fd(m_active_endpoint_receive.at(address).get_rcq_ev_ctx().fd);
m_receive_endpoints.erase(address);
}
if (m_send_endpoint_buffered_addresses.contains(address)) {
m_address_vector_manager.remove_address(address);
m_send_endpoint_buffered_addresses.erase(address);
}
if (m_send_endpoint_zero_copy_addresses.contains(address)) {
m_address_vector_manager.remove_address(address);
m_send_endpoint_zero_copy_addresses.erase(address);
}
if (m_config.callbacks.on_connection_closed_cb != nullptr) {
m_config.callbacks.on_connection_closed_cb(address, pending_sends);
}
}
void netio3::libfabric::ConnectionlessEndpointManager::handle_open_requests()
{
ZoneScoped;
OpenQueueItem item;
while (m_open_queue.try_pop(item)) {
if (m_config.callbacks.on_connection_established_cb != nullptr) {
if (item.capabilities.receive) {
m_config.callbacks.on_connection_established_cb({}, item.address, item.capabilities);
} else {
m_config.callbacks.on_connection_established_cb(item.address, {}, item.capabilities);
}
}
}
}
void netio3::libfabric::ConnectionlessEndpointManager::validate_capabilities(
const ConnectionParameters& connection_params) const
{
if (connection_params.recv_params.use_shared_receive_buffers) {
throw InvalidConnectionParameters(
"Shared receive buffers are not supported in the connectionless libfabric backend");
}
if (not connection_params.send_zero_copy_params.use_shared_send_buffers and
connection_params.send_zero_copy_params.mr_start != nullptr) {
throw InvalidConnectionParameters(
"Only shared zero-copy send buffers are supported in the connectionless libfabric backend");
}
if (not connection_params.send_buffered_params.use_shared_send_buffers and
(connection_params.send_buffered_params.num_buf > 0)) {
throw InvalidConnectionParameters(
"Only shared buffered send buffers are supported in the connectionless libfabric backend");
}
if ((connection_params.send_buffered_params.num_buf > 0 or
connection_params.send_buffered_params.use_shared_send_buffers) and
(connection_params.send_zero_copy_params.mr_start != nullptr or
connection_params.send_zero_copy_params.use_shared_send_buffers)) {
throw InvalidConnectionParameters(
"Libfabric does not support buffered and zero-copy sending on the same endpoint");
}
if (connection_params.recv_params.use_shared_receive_buffers and
connection_params.recv_params.num_buf > 0) {
ers::warning(
InvalidConnectionParameters("Shared receive buffers requested, but the number of receive "
"buffers is set to a non-zero value. Value ignored."));
}
if (connection_params.send_buffered_params.use_shared_send_buffers and
connection_params.send_buffered_params.num_buf > 0) {
ers::warning(InvalidConnectionParameters(
"Shared buffered send buffers requested, but the number of buffered send buffers is set to a "
"non-zero value. Value ignored."));
}
if (connection_params.send_zero_copy_params.use_shared_send_buffers and
connection_params.send_zero_copy_params.mr_start != nullptr) {
ers::warning(InvalidConnectionParameters(
"Shared zero-copy send buffers requested, but the zero-copy send memory region is set to a "
"non-null value. Value ignored."));
}
}