Program Listing for File BackendAsyncmsg.cpp

Return to documentation for file (BackendAsyncmsg/BackendAsyncmsg.cpp)

#include "BackendAsyncmsg.hpp"

#include <tracy/Tracy.hpp>

#include "Issues.hpp"
#include "SendMessageBuffered.hpp"
#include "SendMessageUnbuffered.hpp"
#include "SendMessageUnbufferedCopy.hpp"
#include "netio3-backend/EventLoop/AsioEventLoop.hpp"
#include "netio3-backend/Issues.hpp"

netio3::asyncmsg::BackendAsyncmsg::BackendAsyncmsg(const NetworkConfig& config,
                                                   std::shared_ptr<BaseEventLoop> evloop) :
  NetworkBackend(config, std::move(evloop)),
  m_ioService{initIoService()},
  m_eventSignal{m_evloop->create_signal([this](int /*fd*/) { handleQueuedEvent(); }, true)},
  m_useAsioEventLoop{dynamic_cast<AsioEventLoop*>(m_evloop.get()) != nullptr}
{
  start();
}

netio3::asyncmsg::BackendAsyncmsg::~BackendAsyncmsg()
{
  ERS_DEBUG(1, "Stopping TCP backend");
  for (const auto& [endpoint, server] : m_serversReceive) {
    server->shutdown();
  }
  stop();
}

void netio3::asyncmsg::BackendAsyncmsg::open_send_endpoint(
  const EndPointAddress& address,
  const ConnectionParameters& connection_params)
{
  ZoneScopedC(0x70dbdb);
  std::lock_guard lock(m_mutex);
  const auto endpoint = getEndpoint(address.address(), address.port());
  {
    auto ac = SenderMap::const_accessor{};
    if (m_sessionsSend.find(ac, endpoint)) {
      throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
    }
  }
  const auto session = std::make_shared<Session>(
    m_ioService.get(),
    m_config,
    m_evloop.get(),
    [this](const std::shared_ptr<felix::asyncmsg::Session>& session_cb) { closeSession(session_cb); },
    [this](const ConnectionEvent& event) { addEvent(event); },
    m_useAsioEventLoop,
    m_mode,
    connection_params);
  m_sessionsSend.insert({endpoint, session});
  session->asyncOpen(std::string{MYSELF}, endpoint);
  ERS_DEBUG(
    1,
    std::format("Requested to open send connection on {}:{}", address.address(), address.port()));
}

netio3::EndPointAddress netio3::asyncmsg::BackendAsyncmsg::open_listen_endpoint(
  const EndPointAddress& address,
  const ConnectionParametersRecv& /*connection_params*/)
{
  ZoneScopedC(0x70dbdb);
  std::lock_guard lock(m_mutex);
  const auto endpoint = getEndpoint(address.address(), address.port());
  if (m_serversReceive.contains(endpoint)) {
    throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port());
  }
  const auto server = std::make_shared<Server>(
    m_ioService.get(),
    m_config,
    [this](const EndPointAddress& address_cb) { onServerShutdown(address_cb); },
    [this](const ConnectionEvent& event) { addEvent(event); },
    m_evloop.get(),
    m_useAsioEventLoop,
    m_mode);
  try {
    server->listen(std::string{MYSELF}, endpoint);
  } catch (const boost::system::system_error& e) {
    throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.what(), e);
  }
  server->startAccept();
  m_serversReceive.emplace(server->localEndpoint(), server);
  ERS_INFO(std::format("Listening on {}:{}",
                       server->localEndpoint().address().to_string(),
                       server->localEndpoint().port()));

  return {server->localEndpoint().address().to_string(), server->localEndpoint().port()};
}

void netio3::asyncmsg::BackendAsyncmsg::close_listen_endpoint(const EndPointAddress& address)
{
  ZoneScopedC(0x29a3a3);
  std::lock_guard lock(m_mutex);
  const auto endpoint = getEndpoint(address.address(), address.port());
  if (not m_serversReceive.contains(endpoint)) {
    throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port());
  }
  m_serversReceive.at(endpoint)->shutdown();
  ERS_INFO(std::format("Stop listening on {}:{}", address.address(), address.port()));
}

void netio3::asyncmsg::BackendAsyncmsg::close_send_endpoint(const EndPointAddress& address)
{
  ZoneScopedC(0x29a3a3);
  std::lock_guard lock(m_mutex);
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  ac->second->asyncClose();
  ac.release();
  ERS_DEBUG(
    1,
    std::format("Requested to close send connection on {}:{}", address.address(), address.port()));
}

boost::asio::io_service& netio3::asyncmsg::BackendAsyncmsg::initIoService()
{
  auto* evloop = dynamic_cast<AsioEventLoop*>(m_evloop.get());
  if (evloop != nullptr) {
    return evloop->get_io_service();
  }
  return m_ioServiceResource;
}

void netio3::asyncmsg::BackendAsyncmsg::onServerShutdown(const EndPointAddress& address)
{
  std::lock_guard lock(m_mutex);
  m_serversReceive.erase(getEndpoint(address.address(), address.port()));
}

void netio3::asyncmsg::BackendAsyncmsg::addEvent(const ConnectionEvent& event)
{
  m_eventQueue.push(event);
  m_eventSignal.fire();
}

void netio3::asyncmsg::BackendAsyncmsg::handleQueuedEvent()
{
  ConnectionEvent event{};
  const auto dequeued = m_eventQueue.try_pop(event);
  if (not dequeued) {
    ers::error(TcpFailedDequeueMessage(ERS_HERE));
    return;
  }
  switch (event.type) {
  case ConnectionEvent::Type::OPENED:
    if (m_config.on_connection_established_cb) {
      m_config.on_connection_established_cb(event.address);
    }
    break;
  case ConnectionEvent::Type::REFUSED: {
    if (m_config.on_connection_refused_cb) {
      m_config.on_connection_refused_cb(event.address);
    }
    const auto endpoint = getEndpoint(event.address.address(), event.address.port());
    auto ac = SenderMap::const_accessor{};
    if (m_sessionsSend.find(ac, endpoint)) {
      const auto session = ac->second;
      ac.release();
      closeSession(session);
    }
    break;
  }
  case ConnectionEvent::Type::CLOSED: {
    const auto endpoint = getEndpoint(event.address.address(), event.address.port());
    auto ac = SenderMap::const_accessor{};
    if (m_sessionsSend.find(ac, endpoint)) {
      const auto session = ac->second;
      ac.release();
      if (m_config.on_connection_closed_cb) {
        m_config.on_connection_closed_cb(event.address, session->getPendingSends());
      }
      closeSession(session);
    } else if (m_config.on_connection_closed_cb) {  // receive sessions
        m_config.on_connection_closed_cb(event.address, {});
    }
    break;
  }
  }
}

void netio3::asyncmsg::BackendAsyncmsg::start()
{
  if (m_useAsioEventLoop) {
    return;
  }
  switch (m_mode) {
  case Mode::POLL:
    [[fallthrough]];
  case Mode::POLL_ONE:
    m_timer.start(TIMER_INTERVAL);
    break;
  case Mode::STANDALONE:
    [[fallthrough]];
  case Mode::DELEGATE:
    m_work = std::make_unique<boost::asio::io_service::work>(m_ioService);
    m_ioServiceThread = std::jthread([this]() { m_ioService.get().run(); });
    break;
  }
}

void netio3::asyncmsg::BackendAsyncmsg::stop()
{
  if (m_useAsioEventLoop) {
    return;
  }
  switch (m_mode) {
  case Mode::POLL:
    [[fallthrough]];
  case Mode::POLL_ONE:
    m_timer.stop();
    break;
  case Mode::STANDALONE:
    [[fallthrough]];
  case Mode::DELEGATE:
    m_work.reset();
    m_ioService.get().stop();
    m_ioServiceThread.join();
    break;
  }
}

void netio3::asyncmsg::BackendAsyncmsg::closeSession(
  std::shared_ptr<felix::asyncmsg::Session> session)
{
  std::lock_guard lock(m_mutex);
  m_sessionsSend.erase(session->cachedRemoteEndpoint());
}

void netio3::asyncmsg::BackendAsyncmsg::poll()
{
  ERS_DEBUG(2, "Polling");
  switch (m_mode) {
  case Mode::POLL:
    if (m_ioService.get().stopped()) {
      m_ioService.get().restart();
    }
    m_ioService.get().poll();
    break;
  case Mode::POLL_ONE:
    for (auto i = 0; i < MAX_POLLS; ++i) {
      const auto numPolled = m_ioService.get().poll_one();
      if (numPolled == 0) {
        break;
      }
    }
    break;
  default:
    ers::warning(TcpLogicError(ERS_HERE, "Polling is not supported in this mode"));
    break;
  }
}

boost::asio::ip::tcp::endpoint netio3::asyncmsg::BackendAsyncmsg::getEndpoint(
  std::string_view address,
  unsigned short port)
{
  boost::system::error_code ec;
  const auto ip = boost::asio::ip::make_address(address, ec);
  if (ec) {
    throw InvalidEndpointAddress(ERS_HERE, std::string{address}, port);
  }
  return boost::asio::ip::tcp::endpoint{ip, port};
}

netio3::NetioStatus netio3::asyncmsg::BackendAsyncmsg::send_data(
  const EndPointAddress& address,
  const std::span<std::uint8_t> data,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  auto message = std::make_unique<SendMessageUnbuffered>(data, header_data, key);
  ac->second->asyncSend(std::move(message));
  return NetioStatus::OK;
}

netio3::NetioStatus netio3::asyncmsg::BackendAsyncmsg::send_data(
  const EndPointAddress& address,
  const std::span<const iovec> iov,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  auto message = std::make_unique<SendMessageUnbuffered>(iov, header_data, key);
  ac->second->asyncSend(std::move(message));
  return NetioStatus::OK;
}

netio3::NetioStatus netio3::asyncmsg::BackendAsyncmsg::send_data_copy(
  const EndPointAddress& address,
  const std::span<const std::uint8_t> data,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  auto message = std::make_unique<SendMessageUnbufferedCopy>(data, header_data, key);
  ac->second->asyncSend(std::move(message));
  return NetioStatus::OK;
}

netio3::NetioStatus netio3::asyncmsg::BackendAsyncmsg::send_data_copy(
  const EndPointAddress& address,
  const std::span<const iovec> iov,
  const std::span<const std::uint8_t> header_data,
  const std::uint64_t key)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  auto message = std::make_unique<SendMessageUnbufferedCopy>(iov, header_data, key);
  ac->second->asyncSend(std::move(message));
  return NetioStatus::OK;
}

netio3::NetworkBuffer* netio3::asyncmsg::BackendAsyncmsg::get_buffer(const EndPointAddress& address)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  return ac->second->getBuffer();
}

netio3::NetioStatus netio3::asyncmsg::BackendAsyncmsg::send_buffer(const EndPointAddress& address,
                                                                   NetworkBuffer* buffer)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  const auto* actualBuffer = dynamic_cast<Buffer*>(buffer);
  if (actualBuffer == nullptr) {
    ers::error(InvalidBuffer(ERS_HERE, "Buffer"));
    return NetioStatus::FAILED;
  }
  ac->second->asyncSend(std::make_unique<SendMessageBuffered>(actualBuffer));
  return NetioStatus::OK;
}

std::size_t netio3::asyncmsg::BackendAsyncmsg::get_num_available_buffers(
  const EndPointAddress& address)
{
  const auto endpoint = getEndpoint(address.address(), address.port());
  auto ac = SenderMap::const_accessor{};
  if (not m_sessionsSend.find(ac, endpoint)) {
    throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port());
  }
  return ac->second->getNumAvailableBuffers();
}