.. _program_listing_file_BackendLibfabric_ReceiveSocket.cpp: Program Listing for File ReceiveSocket.cpp ========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``BackendLibfabric/ReceiveSocket.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "ReceiveSocket.hpp" #include #include #include #include #include "BaseSocket.hpp" #include "DomainManager.hpp" #include "Helpers.hpp" #include "Issues.hpp" #include "netio3-backend/Issues.hpp" netio3::libfabric::ReceiveSocket::ReceiveSocket(ListenSocket& lsocket, fid_fabric* fabric, DomainContext& domain, FiInfoUniquePtr&& info, BaseEventLoop* event_loop) : BaseSocket{create_endpoint(lsocket.get_address(), fabric, domain.get_domain(), std::move(info))}, m_conn_params{prepare_connection_parameters(lsocket.get_connection_parameters())}, m_address{peer_address(get_endpoint().ep.get())}, m_retry_post_signal{event_loop->create_signal( [this](int) { retry_post_buffers(); }, false)} { init_buffers(domain); ERS_DEBUG(1, std::format("connection accepted. Lsocket address: {}:{}, rsocket EQ {}", lsocket.get_address().address(), lsocket.get_address().port(), get_eq_ev_ctx().fd)); ERS_DEBUG(2, std::format("Allocating memory regions number of buffers: {}", m_buffers.size())); post_buffers(); m_retry_post_buffers.reserve(m_conn_params.num_buf); const int ret = fi_accept(get_endpoint().ep.get(), nullptr, 0); if (ret != 0) { fi_reject(lsocket.get_pep(), info->handle, nullptr, 0); for (const auto& buf : m_buffers) { close_buffer(buf); } throw FailedOpenReceiveEndpoint( ERS_HERE, m_address.address(), m_address.port(), std::format("Listen socket, connection rejected, error {} - {}", ret, fi_strerror(-ret))); } ERS_DEBUG(2, std::format("Connection done, fd: {}", get_eq_ev_ctx().fd)); } netio3::libfabric::ReceiveSocket::~ReceiveSocket() { ERS_DEBUG(2, "Entered"); for (const auto& buf : m_buffers) { close_buffer(buf); } } netio3::libfabric::Endpoint netio3::libfabric::ReceiveSocket::create_endpoint( const EndPointAddress& address, fid_fabric* fabric, fid_domain* domain, FiInfoUniquePtr&& info) { ZoneScoped; auto ep = Endpoint{}; ep.fi = std::move(info); try { open_endpoint(address, ep, fabric, domain, ep.fi.get()); open_cq(address, ep, domain); enable_endpoint(address, ep); } catch (const FailedOpenEndpoint& e) { throw FailedOpenReceiveEndpoint(ERS_HERE, address.address(), address.port(), e.message()); } return ep; } void netio3::libfabric::ReceiveSocket::open_cq(const EndPointAddress& address, Endpoint& ep, fid_domain* domain) { ZoneScoped; auto cq_attr = prepare_cq_attr(); fid_cq* rcq = nullptr; if (const auto ret = fi_cq_open(domain, &cq_attr, &rcq, nullptr)) { throw FailedOpenReceiveEndpoint( ERS_HERE, address.address(), address.port(), std::format("Failed to open Completion Queue for receive socket, error {} - {}", ret, fi_strerror(-ret))); } ep.rcq = FiCloseUniquePtr( rcq, FiCloseDeleter(address, "Failed to close receive socket Completion Queue")); if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.rcq->fid, FI_RECV)) { throw FailedOpenReceiveEndpoint( ERS_HERE, address.address(), address.port(), std::format("Failed to bind Completion Queue for receive socket, error {} - {}", ret, fi_strerror(-ret))); } cq_attr.format = FI_CQ_FORMAT_UNSPEC; cq_attr.wait_obj = FI_WAIT_NONE; // FI_TRANSMIT CQ fid_cq* cq = nullptr; if (const auto ret = fi_cq_open(domain, &cq_attr, &cq, nullptr)) { throw FailedOpenReceiveEndpoint( ERS_HERE, address.address(), address.port(), std::format( "Failed to open Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret))); } ep.cq = FiCloseUniquePtr( cq, FiCloseDeleter(address, "Failed to close FI_RECV receive socket Completion Queue")); if (const auto ret = fi_ep_bind(ep.ep.get(), &ep.cq->fid, FI_TRANSMIT)) { throw FailedOpenReceiveEndpoint( ERS_HERE, address.address(), address.port(), std::format( "Failed to bind Completion Queue for send socket, error {} - {}", ret, fi_strerror(-ret))); } } void netio3::libfabric::ReceiveSocket::init_buffers(DomainContext& domain) { m_buffers.reserve(m_conn_params.num_buf); for (uint64_t i = 0; i < m_conn_params.num_buf; i++) { // Key does not matter for receive buffers m_buffers.emplace_back(domain, m_conn_params.buf_size, 0); try { register_buffer(m_buffers.back(), domain, FI_RECV); } catch (const LibFabricBufferError& e) { throw FailedOpenReceiveEndpoint(ERS_HERE, m_address.address(), m_address.port(), e.message()); } } } void netio3::libfabric::ReceiveSocket::post_buffers() { ZoneScoped; for (auto& buf : m_buffers) { post_buffer(&buf); } } void netio3::libfabric::ReceiveSocket::post_buffer(Buffer* buf) { ZoneScoped; const auto ret = do_post_buffer(buf); if (ret == -FI_EAGAIN) { m_retry_post_buffers.push_back(buf); m_retry_post_signal.fire(); return; } if (ret != 0) { ers::error(LibFabricBufferError( ERS_HERE, std::format("Failed to post a buffer to receive inbound messages, error {} - {}", ret, fi_strerror(-ret)))); } } void netio3::libfabric::ReceiveSocket::retry_post_buffers() { ZoneScoped; auto still_needs_retry = std::vector{}; still_needs_retry.reserve(m_retry_post_buffers.size()); for (auto* buf : m_retry_post_buffers) { const auto ret = do_post_buffer(buf); if (ret == -FI_EAGAIN) { still_needs_retry.push_back(buf); } else if (ret != 0) [[unlikely]] { ers::error(LibFabricBufferError( ERS_HERE, std::format("Failed to post a buffer to receive inbound messages, error {} - {}", ret, fi_strerror(-ret)))); } } if (not still_needs_retry.empty()) { m_retry_post_buffers.swap(still_needs_retry); m_retry_post_signal.fire(); } else { m_retry_post_buffers.clear(); } } int netio3::libfabric::ReceiveSocket::do_post_buffer(Buffer* buf) const { ZoneScoped; ERS_DEBUG(2, "Entered"); iovec iov{buf->data().data(), buf->size()}; void* desc = fi_mr_desc(buf->mr); fi_msg msg{}; msg.msg_iov = &iov; /* scatter-gather array */ msg.desc = &desc; msg.iov_count = 1; msg.addr = 0; msg.context = buf; msg.data = 0; std::uint64_t flags = FI_REMOTE_CQ_DATA; // FI_MULTI_RECV; return fi_recvmsg(get_endpoint().ep.get(), &msg, flags); } netio3::ConnectionParametersRecv netio3::libfabric::ReceiveSocket::prepare_connection_parameters( const ConnectionParametersRecv& requested) const { auto params = requested; const auto& info = get_endpoint().fi; if (params.num_buf > info->rx_attr->size) { ers::warning(TooManyBuffersRequested(params.num_buf, info->rx_attr->size)); params.num_buf = info->rx_attr->size; } return params; }