Program Listing for File AddressVectorManager.cpp

Return to documentation for file (BackendLibfabric/AddressVectorManager.cpp)

#include "AddressVectorManager.hpp"

#include <cstring>
#include <format>

#include <rdma/fabric.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_endpoint.h>

#include "Issues.hpp"

netio3::libfabric::AddressVectorManager::AddressVectorManager(fid_domain* domain, NetworkMode mode) :
  m_av{create_av(domain)}, m_mode{mode}
{}

fi_addr_t netio3::libfabric::AddressVectorManager::add_address(const EndPointAddress& address)
{
  if (!m_av) {
    throw LibfabricAvError(ERS_HERE, "Address vector is not initialized");
  }

  if (m_address_map.contains(address)) {
    return m_address_map.at(address);
  }

  const auto info = FiInfoUniquePtr{get_fi_info(address, m_mode, 0)};
  const auto* address_formatted = info->dest_addr;
  constexpr static std::size_t count = 1;
  std::array<fi_addr_t, count> fi_addr{};
  std::array<int, count> errors{};
  const auto ret = fi_av_insert(m_av.get(), address_formatted, count, fi_addr.data(), FI_SYNC_ERR, errors.data());
  if (ret != count) {
    throw LibfabricAvError(ERS_HERE,
                            std::format("Failed to insert address {}:{} into AV, error {} - {}",
                                        address.address(),
                                        address.port(),
                                        errors.at(0),
                                        fi_strerror(-errors.at(0))));
  }
  m_address_map.try_emplace(address, fi_addr.at(0));

  return m_address_map.at(address);
}

void netio3::libfabric::AddressVectorManager::remove_address(const EndPointAddress& address)
{
  if (!m_av) {
    throw LibfabricAvError(ERS_HERE, "Address vector is not initialized");
  }

  const auto ret = fi_av_remove(m_av.get(), &m_address_map.at(address), 1, 0);
  if (ret != 1) {
    throw LibfabricAvError(ERS_HERE,
                           std::format("Failed to remove address {}:{} from AV, error {} - {}",
                                       address.address(),
                                       address.port(),
                                       ret,
                                       fi_strerror(-ret)));
  }

  m_address_map.erase(address);
}

fid_av* netio3::libfabric::AddressVectorManager::create_av(fid_domain* domain)
{
  if (domain == nullptr) {
    throw LibfabricAvError(ERS_HERE, "Domain is null, cannot create address vector");
  }

  fi_av_attr av_attr{};
  av_attr.type = FI_AV_TABLE;
  av_attr.count = entry_count;

  fid_av* av = nullptr;
  const int ret = fi_av_open(domain, &av_attr, &av, nullptr);

  if (ret != 0) {
    throw LibfabricAvError(
      ERS_HERE,
      std::format("Failed to create address vector, error {} - {}", ret, fi_strerror(-ret)));
  }

  return av;
}