Program Listing for File ConnectionlessEndpointManager.cpp

Return to documentation for file (BackendLibfabric/ConnectionlessEndpointManager.cpp)

#include "ConnectionlessEndpointManager.hpp"

#include <mutex>

#include <rdma/fabric.h>

#include "Helpers.hpp"
#include "netio3-backend/Netio3Backend.hpp"

netio3::libfabric::ConnectionlessEndpointManager::ConnectionlessEndpointManager(
  NetworkConfig config,
  const EndPointAddress& address,
  BaseEventLoop* evloop,
  std::uint64_t flags) :
  m_config{std::move(config)},
  m_event_loop{evloop},
  m_domain_manager{m_config, address, flags},
  m_cq_reactor{m_domain_manager.get_fabric(), m_config.callbacks.on_data_cb},
  m_shared_buffer_manager{m_config, m_domain_manager, m_event_loop},
  m_address_vector_manager{m_domain_manager.get_domain(), config.mode},
  m_send_endpoint_buffered_addresses{m_config.thread_safety == ThreadSafetyModel::SAFE},
  m_send_endpoint_zero_copy_addresses{m_config.thread_safety == ThreadSafetyModel::SAFE}
{}

std::shared_ptr<netio3::libfabric::ConnectionlessEndpointManager>
netio3::libfabric::ConnectionlessEndpointManager::create(const NetworkConfig& config,
                                                         EndPointAddress address,
                                                         BaseEventLoop* evloop,
                                                         std::uint64_t flags)
{
  auto instance = std::make_shared<ConnectionlessEndpointManager>(config, address, evloop, flags);
  instance->init();
  return instance;
}

void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  ZoneScoped;
  std::lock_guard lock(m_mutex);
  validate_capabilities(connection_params);
  if (connection_params.send_buffered_params.use_shared_send_buffers) {
    open_send_endpoint_buffered(address);
  }
  if (connection_params.send_zero_copy_params.use_shared_send_buffers) {
    open_send_endpoint_zero_copy(address);
  }
}

void netio3::libfabric::ConnectionlessEndpointManager::open_receive_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  ZoneScoped;
  std::lock_guard lock(m_mutex);
  if (m_receive_endpoints.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  validate_capabilities(connection_params);
  init_receive_endpoint(address, connection_params.recv_params);
  m_open_queue.push({address, {.receive = true}});
  m_open_signal->fire();
}

void netio3::libfabric::ConnectionlessEndpointManager::close_send_endpoint(
  const EndPointAddress& address)
{
  ZoneScoped;
  std::lock_guard lock(m_mutex);
  if (not m_send_endpoint_buffered_addresses.contains(address) and
      not m_send_endpoint_zero_copy_addresses.contains(address)) {
    throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_close_queue.push(address);
  m_close_signal->fire();
}

void netio3::libfabric::ConnectionlessEndpointManager::close_receive_endpoint(
  const EndPointAddress& address)
{
  ZoneScoped;
  std::lock_guard lock(m_mutex);
  if (not m_receive_endpoints.contains(address)) {
    throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_close_queue.push(address);
  m_close_signal->fire();
}

std::size_t netio3::libfabric::ConnectionlessEndpointManager::get_num_available_buffers(
  const EndPointAddress& address)
{
  std::lock_guard lock(m_mutex);
  if (m_send_endpoint_buffered_addresses.contains(address)) {
    return m_send_endpoint_buffered->get_num_available_buffers();
  }
  if (m_send_endpoint_zero_copy_addresses.contains(address)) {
    return m_send_endpoint_zero_copy->get_num_available_buffers();
  }
  throw UnknownActiveEndpoint(ERS_HERE, address.address(), address.port());
}

void netio3::libfabric::ConnectionlessEndpointManager::init()
{
  m_close_signal = m_event_loop->create_signal(
    [weak_this = weak_from_this()](int) {
      if (auto shared_this = weak_this.lock()) {
        shared_this->handle_close_requests();
      }
    },
    false);
  m_open_signal = m_event_loop->create_signal(
    [weak_this = weak_from_this()](int) {
      if (auto shared_this = weak_this.lock()) {
        shared_this->handle_open_requests();
      }
    },
    false);
}

void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint_buffered(
  const EndPointAddress& address)
{
  ZoneScoped;
  if (m_send_endpoint_buffered_addresses.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  if (not m_active_endpoint_send_buffered) {
    init_buffered_send_endpoint(address);
  }
  const auto fi_addr = m_address_vector_manager.add_address(address);
  m_send_endpoint_buffered_addresses.try_emplace(address, fi_addr);
  m_open_queue.push({address, {.send_buffered = true}});
  m_open_signal->fire();
}

void netio3::libfabric::ConnectionlessEndpointManager::open_send_endpoint_zero_copy(
  const EndPointAddress& address)
{
  ZoneScoped;
  if (m_send_endpoint_zero_copy_addresses.contains(address)) {
    throw ActiveEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  if (not m_active_endpoint_send_zero_copy) {
    init_zero_copy_send_endpoint(address);
  }
  const auto fi_addr = m_address_vector_manager.add_address(address);
  m_send_endpoint_zero_copy_addresses.try_emplace(address, fi_addr);
  m_open_queue.push({address, {.send_zero_copy = true}});
  m_open_signal->fire();
}

void netio3::libfabric::ConnectionlessEndpointManager::init_buffered_send_endpoint(
  const EndPointAddress& address)
{
  m_active_endpoint_send_buffered =
    std::make_unique<ActiveEndpoint>(address,
                                      m_config.mode,
                                      EndpointCapabilities{.send_buffered = true},
                                      m_domain_manager.get_fabric(),
                                      m_domain_manager.get_domain(),
                                      0,
                                      nullptr,
                                      m_address_vector_manager.get_av());
  m_send_endpoint_buffered =
    std::make_unique<SendEndpointBuffered>(*m_active_endpoint_send_buffered,
                                           m_config.conn_params.send_buffered_params,
                                           m_shared_buffer_manager.get_send_buffer_manager(),
                                           m_domain_manager);
  m_active_endpoint_send_buffered->set_cq_ev_ctx(
    {m_active_endpoint_send_buffered->get_endpoint().cqfd, [this, address](int) {
       const auto keys = m_cq_reactor.on_send_cq_event(*m_send_endpoint_buffered);
       if (m_config.callbacks.on_send_completed_cb != nullptr) {
         for (const auto key : keys) {
           m_config.callbacks.on_send_completed_cb(address, key);
         }
       }
     }});
  m_event_loop->register_fd(m_active_endpoint_send_buffered->get_cq_ev_ctx());
}

void netio3::libfabric::ConnectionlessEndpointManager::init_zero_copy_send_endpoint(
  const EndPointAddress& address)
{
  m_active_endpoint_send_zero_copy =
    std::make_unique<ActiveEndpoint>(address,
                                      m_config.mode,
                                      EndpointCapabilities{.send_zero_copy = true},
                                      m_domain_manager.get_fabric(),
                                      m_domain_manager.get_domain(),
                                      0,
                                      nullptr,
                                      m_address_vector_manager.get_av());
  m_send_endpoint_zero_copy =
    std::make_unique<SendEndpointZeroCopy>(*m_active_endpoint_send_zero_copy,
                                           m_config.conn_params.send_zero_copy_params,
                                           m_shared_buffer_manager.get_zero_copy_buffer_manager(),
                                           m_domain_manager);
  m_active_endpoint_send_zero_copy->set_cq_ev_ctx(
    {m_active_endpoint_send_zero_copy->get_endpoint().cqfd, [this, address](int) {
       const auto keys = m_cq_reactor.on_send_cq_event(*m_send_endpoint_zero_copy);
       if (m_config.callbacks.on_send_completed_cb != nullptr) {
         for (const auto key : keys) {
           m_config.callbacks.on_send_completed_cb(address, key);
         }
       }
     }});
  m_event_loop->register_fd(m_active_endpoint_send_buffered->get_cq_ev_ctx());
}

void netio3::libfabric::ConnectionlessEndpointManager::init_receive_endpoint(
  const EndPointAddress& address,
  const ConnectionParametersRecv& connection_params)
{
  m_active_endpoint_receive.try_emplace(address,
                                        address,
                                        m_config.mode,
                                        EndpointCapabilities{.receive = true},
                                        m_domain_manager.get_fabric(),
                                        m_domain_manager.get_domain(),
                                        FI_SOURCE,
                                        nullptr,
                                        m_address_vector_manager.get_av());
  m_receive_endpoints.try_emplace(address,
                                  m_active_endpoint_receive.at(address),
                                  connection_params,
                                  m_shared_buffer_manager.get_receive_context_manager(),
                                  m_domain_manager,
                                  m_event_loop);
  m_active_endpoint_receive.at(address).set_rcq_ev_ctx(
    {m_active_endpoint_receive.at(address).get_endpoint().rcqfd, [this, address](int) {
        m_cq_reactor.on_recv_cq_event(m_receive_endpoints.at(address));
      }});
  m_event_loop->register_fd(m_active_endpoint_receive.at(address).get_rcq_ev_ctx());
}

void netio3::libfabric::ConnectionlessEndpointManager::handle_close_requests()
{
  ZoneScoped;
  EndPointAddress address;
  while (m_close_queue.try_pop(address)) {
    do_close_endpoint(address);
  }
}

void netio3::libfabric::ConnectionlessEndpointManager::do_close_endpoint(
  const EndPointAddress& address)
{
  ZoneScoped;
  std::lock_guard lock(m_mutex);
  auto pending_sends = std::invoke([this, &address]() -> std::vector<std::uint64_t> {
    if (m_send_endpoint_zero_copy_addresses.contains(address)) {
      return m_send_endpoint_zero_copy->get_pending_sends();
    }
    return {};
  });
  if (m_receive_endpoints.contains(address)) {
    m_event_loop->remove_fd(m_active_endpoint_receive.at(address).get_rcq_ev_ctx().fd);
    m_receive_endpoints.erase(address);
  }
  if (m_send_endpoint_buffered_addresses.contains(address)) {
    m_address_vector_manager.remove_address(address);
    m_send_endpoint_buffered_addresses.erase(address);
  }
  if (m_send_endpoint_zero_copy_addresses.contains(address)) {
    m_address_vector_manager.remove_address(address);
    m_send_endpoint_zero_copy_addresses.erase(address);
  }
  if (m_config.callbacks.on_connection_closed_cb != nullptr) {
    m_config.callbacks.on_connection_closed_cb(address, pending_sends);
  }
}

void netio3::libfabric::ConnectionlessEndpointManager::handle_open_requests()
{
  ZoneScoped;
  OpenQueueItem item;
  while (m_open_queue.try_pop(item)) {
    if (m_config.callbacks.on_connection_established_cb != nullptr) {
      if (item.capabilities.receive) {
        m_config.callbacks.on_connection_established_cb({}, item.address, item.capabilities);
      } else {
        m_config.callbacks.on_connection_established_cb(item.address, {}, item.capabilities);
      }
    }
  }
}

void netio3::libfabric::ConnectionlessEndpointManager::validate_capabilities(
  const ConnectionParameters& connection_params) const
{
  if (connection_params.recv_params.use_shared_receive_buffers) {
    throw InvalidConnectionParameters(
      "Shared receive buffers are not supported in the connectionless libfabric backend");
  }
  if (not connection_params.send_zero_copy_params.use_shared_send_buffers and
      connection_params.send_zero_copy_params.mr_start != nullptr) {
    throw InvalidConnectionParameters(
      "Only shared zero-copy send buffers are supported in the connectionless libfabric backend");
  }
  if (not connection_params.send_buffered_params.use_shared_send_buffers and
      (connection_params.send_buffered_params.num_buf > 0)) {
    throw InvalidConnectionParameters(
      "Only shared buffered send buffers are supported in the connectionless libfabric backend");
  }
  if ((connection_params.send_buffered_params.num_buf > 0 or
       connection_params.send_buffered_params.use_shared_send_buffers) and
      (connection_params.send_zero_copy_params.mr_start != nullptr or
       connection_params.send_zero_copy_params.use_shared_send_buffers)) {
    throw InvalidConnectionParameters(
      "Libfabric does not support buffered and zero-copy sending on the same endpoint");
  }
  if (connection_params.recv_params.use_shared_receive_buffers and
      connection_params.recv_params.num_buf > 0) {
    ers::warning(
      InvalidConnectionParameters("Shared receive buffers requested, but the number of receive "
                                  "buffers is set to a non-zero value. Value ignored."));
  }
  if (connection_params.send_buffered_params.use_shared_send_buffers and
      connection_params.send_buffered_params.num_buf > 0) {
    ers::warning(InvalidConnectionParameters(
      "Shared buffered send buffers requested, but the number of buffered send buffers is set to a "
      "non-zero value. Value ignored."));
  }
  if (connection_params.send_zero_copy_params.use_shared_send_buffers and
      connection_params.send_zero_copy_params.mr_start != nullptr) {
    ers::warning(InvalidConnectionParameters(
      "Shared zero-copy send buffers requested, but the zero-copy send memory region is set to a "
      "non-null value. Value ignored."));
  }
}