Program Listing for File BackendLibfabric.cpp

Return to documentation for file (BackendLibfabric/BackendLibfabric.cpp)

#include "BackendLibfabric.hpp"

#include <ers/ers.h>

#include "ConnectionManager.hpp"
#include "netio3-backend/Issues.hpp"

netio3::libfabric::BackendLibfabric::BackendLibfabric(const NetworkConfig& config,
                                                      std::shared_ptr<BaseEventLoop> evloop) :
  NetworkBackend(config, evloop),
  m_connection_manager{libfabric::ConnectionManager::create(evloop.get(), m_config)}
{}

void netio3::libfabric::BackendLibfabric::open_send_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  if (not check_ip_address(address)) {
    throw InvalidEndpointAddress(ERS_HERE, address.address(), address.port());
  }
  if (connection_params.mr_start == nullptr) {
    m_connection_manager->open_send_endpoint_buffered(address, connection_params);
    ERS_DEBUG(1,
              std::format("Requested to open send endpoint for {}:{} with buffered socket",
                          address.address(),
                          address.port()));
  } else {
    m_connection_manager->open_send_endpoint_zero_copy(address, connection_params);
    ERS_DEBUG(1,
              std::format("Requested to open send endpoint for {}:{} with zero copy socket",
                          address.address(),
                          address.port()));
  }
}

netio3::EndPointAddress netio3::libfabric::BackendLibfabric::open_listen_endpoint(
  const EndPointAddress& address,
  const ConnectionParametersRecv& connection_params)
{
  if (not check_ip_address(address)) {
    throw InvalidEndpointAddress(ERS_HERE, address.address(), address.port());
  }
  ERS_DEBUG(
    1,
    std::format("Requested to open listen endpoint for {}:{}", address.address(), address.port()));
  return m_connection_manager->open_listen_endpoint(address, connection_params);
}

void netio3::libfabric::BackendLibfabric::close_send_endpoint(const EndPointAddress& address)
{
  if (not check_ip_address(address)) {
    throw InvalidEndpointAddress(ERS_HERE, address.address(), address.port());
  }
  m_connection_manager->close_send_endpoint(address);
  ERS_DEBUG(
    1,
    std::format("Requested to close send endpoint for {}:{}", address.address(), address.port()));
}

void netio3::libfabric::BackendLibfabric::close_listen_endpoint(const EndPointAddress& address)
{
  if (not check_ip_address(address)) {
    throw InvalidEndpointAddress(ERS_HERE, address.address(), address.port());
  }
  m_connection_manager->close_listen_endpoint(address);
  ERS_DEBUG(
    1,
    std::format("Requested to close listen endpoint for {}:{}", address.address(), address.port()));
}

netio3::NetioStatus netio3::libfabric::BackendLibfabric::send_data(
  const EndPointAddress& address,
  const std::span<std::uint8_t> data,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  return m_connection_manager->apply_to_send_socket_zero_copy(
    address, [&data, &header_data, &key](libfabric::SendSocketZeroCopy& socket) {
      return socket.send_data(data, header_data, key);
    });
}

netio3::NetioStatus netio3::libfabric::BackendLibfabric::send_data(
  const EndPointAddress& address,
  const std::span<const iovec> iov,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  return m_connection_manager->apply_to_send_socket_zero_copy(
    address, [&iov, &header_data, &key](libfabric::SendSocketZeroCopy& socket) {
      return socket.send_data(iov, header_data, key);
    });
}

netio3::NetioStatus netio3::libfabric::BackendLibfabric::send_data_copy(
  const EndPointAddress& /* address */,
  const std::span<const std::uint8_t> /* data */,
  const std::span<const std::uint8_t> /* header_data */,
  const std::uint64_t /* key */)
{
  throw NotSupported(ERS_HERE, "send_data_copy is not supported in libfabric backend");
}

netio3::NetioStatus netio3::libfabric::BackendLibfabric::send_data_copy(
  const EndPointAddress& /* address */,
  const std::span<const iovec> /* iov */,
  const std::span<const std::uint8_t> /* header_data */,
  const std::uint64_t /* key */)
{
  throw NotSupported(ERS_HERE, "send_data_copy is not supported in libfabric backend");
}

netio3::NetworkBuffer* netio3::libfabric::BackendLibfabric::get_buffer(
  const EndPointAddress& address)
{
  return m_connection_manager->apply_to_send_socket_buffered(
    address, [](libfabric::SendSocketBuffered& socket) { return socket.get_buffer(); });
}

netio3::NetioStatus netio3::libfabric::BackendLibfabric::send_buffer(const EndPointAddress& address,
                                                                     NetworkBuffer* buffer)
{
  return m_connection_manager->apply_to_send_socket_buffered(
    address, [buffer](libfabric::SendSocketBuffered& endpoint) {
      auto* actual_buffer = dynamic_cast<libfabric::Buffer*>(buffer);
      if (actual_buffer == nullptr) {
        ers::error(InvalidBuffer(ERS_HERE, "libfabric::Buffer"));
        return NetioStatus::FAILED;
      }
      return endpoint.send_buffer(actual_buffer, buffer->pos());
    });
}

std::size_t netio3::libfabric::BackendLibfabric::get_num_available_buffers(
  const EndPointAddress& address)
{
  return m_connection_manager->get_num_available_buffers(address);
}

bool netio3::libfabric::BackendLibfabric::check_ip_address(const EndPointAddress& address)
{
  if (address.address().empty()) {
    return false;
  }
  sockaddr_in sa{};
  return inet_pton(AF_INET, address.address().c_str(), &(sa.sin_addr)) != 0;
}