Program Listing for File SendSocketZeroCopy.cpp

Return to documentation for file (BackendLibfabric/SendSocketZeroCopy.cpp)

#include "SendSocketZeroCopy.hpp"

#include <algorithm>
#include <utility>

#include <tracy/Tracy.hpp>

#include <rdma/fi_domain.h>

#include "Issues.hpp"

netio3::libfabric::SendSocketZeroCopy::SendSocketZeroCopy(
  EndPointAddress address,
  const ConnectionParameters& connection_params,
  NetworkMode mode,
  fid_fabric* fabric,
  DomainContext& domain) :
  SendSocket{std::move(address), mode, fabric, domain.get_domain()},
  m_conn_parameters{connection_params},
  m_mr{domain, m_conn_parameters.buf_size, m_conn_parameters.mr_start},
  m_header_buffer{domain}
{
  init_buffers(domain);
  init();
}

netio3::libfabric::SendSocketZeroCopy::~SendSocketZeroCopy()
{
  ERS_DEBUG(2, "Entered");
  close_buffer(m_mr);
  close_buffer(m_header_buffer);
  ERS_DEBUG(2, "Finished");
}

netio3::NetioStatus netio3::libfabric::SendSocketZeroCopy::send_data(
  const std::span<const iovec> iov,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  ZoneScoped;
  ERS_DEBUG(2, std::format("Send zero copy data with key {}", key));
  const auto header = m_header_buffer.get_header(header_data, key);
  if (header.status != NetioStatus::OK) {
    return header.status;
  }
  std::vector<iovec> data{};
  data.reserve(iov.size() + 1);
  data.push_back(header.data);
  data.insert(data.end(), iov.begin(), iov.end());
  std::vector<fid_mr*> mrs{};
  mrs.reserve(data.size());
  mrs.push_back(m_header_buffer.mr);
  std::ranges::generate_n(std::back_inserter(mrs), iov.size(), [&] { return m_mr.mr; });
  auto status = SendSocket::send_data(std::span{data}, mrs, header.bufnum);
  if (status != NetioStatus::OK) {
    // Send operation failed, no reason to communicate the buffer number
    std::ignore = m_header_buffer.return_header(header.bufnum);
  }
  return status;
}

netio3::NetioStatus netio3::libfabric::SendSocketZeroCopy::send_data(
  const std::span<std::uint8_t> data,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  ZoneScoped;
  return send_data(
    std::vector{iovec{.iov_base = data.data(), .iov_len = data.size()}}, header_data, key);
}

void netio3::libfabric::SendSocketZeroCopy::init_buffers(DomainContext& domain)
{
  ZoneScoped;
  try {
    ERS_DEBUG(1, std::format("Registering MR of size {}", m_conn_parameters.buf_size));
    register_buffer(m_mr, domain, FI_SEND);

    ERS_DEBUG(1,
              std::format("Registering Header buffer of size {}",
                          ZERO_COPY_NUM_HEADER_SLOTS * ZERO_COPY_SIZE_HEADER));
    register_buffer(m_header_buffer, domain, FI_SEND);
  } catch (const LibFabricBufferError& e) {
    throw FailedOpenSendEndpoint(
      ERS_HERE, get_address().address(), get_address().port(), e.message());
  }
}

std::uint64_t netio3::libfabric::SendSocketZeroCopy::release_buffer(const std::uint64_t bufnum)
{
  ERS_DEBUG(2, std::format("Releasing buffer {}", bufnum));
  return m_header_buffer.return_header(bufnum);
}

std::size_t netio3::libfabric::SendSocketZeroCopy::get_num_available_buffers()
{
  ZoneScoped;
  return m_header_buffer.get_num_available_buffers();
}

std::vector<std::uint64_t> netio3::libfabric::SendSocketZeroCopy::get_pending_sends()
{
  ZoneScoped;
  return m_header_buffer.get_pending_sends();
}