.. _program_listing_file_BackendLibfabric_ConnectionManager.cpp: Program Listing for File ConnectionManager.cpp ============================================== |exhale_lsh| :ref:`Return to documentation for file ` (``BackendLibfabric/ConnectionManager.cpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #include "ConnectionManager.hpp" #include #include #include #include #include #include #include #include #include #include "Issues.hpp" #include "ReceiveSocket.hpp" #include "SendSocketZeroCopy.hpp" #include "netio3-backend/EventLoop/BaseEventLoop.hpp" #include "netio3-backend/Netio3Backend.hpp" #include "netio3-backend/Issues.hpp" netio3::libfabric::ConnectionManager::ConnectionManager(BaseEventLoop* event_loop, NetworkConfig config) : m_config{std::move(config)}, m_event_loop{event_loop}, m_ssockets_buffered{config.thread_safety == ThreadSafetyModel::SAFE}, m_ssockets_zero_copy{config.thread_safety == ThreadSafetyModel::SAFE} {} netio3::libfabric::ConnectionManager::~ConnectionManager() { if (m_event_loop->get_thread_id() == std::this_thread::get_id()) { unregister_all(); } else if (m_event_loop->is_running()) { std::mutex mutex; std::condition_variable cv; std::atomic_bool unregistered{false}; const auto signal = m_event_loop->create_signal( [this, &unregistered, &cv](int) { unregister_all(); unregistered = true; cv.notify_one(); }, false); signal.fire(); std::unique_lock lock(mutex); while (m_event_loop->is_running() and !unregistered.load()) { constexpr static auto timeout = std::chrono::seconds(1); if (cv.wait_for(lock, timeout, [&unregistered] { return unregistered.load(); })) { break; } } } } std::shared_ptr netio3::libfabric::ConnectionManager::create( BaseEventLoop* event_loop, NetworkConfig config) { auto instance = std::make_shared(event_loop, config); instance->init(); return instance; } void netio3::libfabric::ConnectionManager::init() { m_close_signal = m_event_loop->create_signal( [weak_this = weak_from_this()](int) { if (auto shared_this = weak_this.lock()) { shared_this->handle_close_requests(); } }, false); } netio3::EndPointAddress netio3::libfabric::ConnectionManager::open_listen_endpoint( const EndPointAddress& address, ConnectionParametersRecv conn_params) { if (m_lsockets.contains(address)) { throw ListenEndpointAlreadyExists(ERS_HERE, address.address(), address.port()); } try { check_and_init(address, FI_RECV); } catch (const LibfabricDomainError& e) { throw FailedOpenListenEndpoint(ERS_HERE, address.address(), address.port(), e.message()); } auto lsocket_to_insert = ListenSocket(address, conn_params, m_config.mode, m_domain_manager->get_fabric()); const auto pair = m_lsockets.try_emplace(lsocket_to_insert.get_address(), std::move(lsocket_to_insert)); auto& lsocket = pair.first->second; const auto eq_ev_ctx = EventContext{ lsocket.get_endpoint().eqfd, [this, &lsocket](int) { on_listen_socket_cm_event(lsocket); }}; lsocket.set_eq_ev_ctx(eq_ev_ctx); m_event_loop->register_fd(eq_ev_ctx); return lsocket.get_address(); } void netio3::libfabric::ConnectionManager::open_send_endpoint_buffered( const EndPointAddress& address, ConnectionParameters conn_params) { if (m_ssockets_buffered.contains(address)) { throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port()); } try { check_and_init(address, FI_SEND); } catch (const LibfabricDomainError& e) { throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message()); } m_ssockets_buffered.try_emplace(address, address, conn_params, m_config.mode, m_domain_manager->get_fabric(), m_domain_manager->get_send_domain()); const auto ctx = m_ssockets_buffered.apply(address, [this, &address](SendSocketBuffered& ssocket) { const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) { on_send_socket_cm_event(address); }}; ssocket.set_eq_ev_ctx(eq_ev_ctx); return eq_ev_ctx; }); m_event_loop->register_fd(ctx); } void netio3::libfabric::ConnectionManager::open_send_endpoint_zero_copy( const EndPointAddress& address, ConnectionParameters conn_params) { if (m_ssockets_zero_copy.contains(address)) { throw SendEndpointAlreadyExists(ERS_HERE, address.address(), address.port()); } try { check_and_init(address, FI_SEND); } catch (const LibfabricDomainError& e) { throw FailedOpenSendEndpoint(ERS_HERE, address.address(), address.port(), e.message()); } m_ssockets_zero_copy.try_emplace(address, address, conn_params, m_config.mode, m_domain_manager->get_fabric(), m_domain_manager->get_send_domain()); const auto ctx = m_ssockets_zero_copy.apply(address, [this, &address](SendSocketZeroCopy& ssocket) { const auto eq_ev_ctx = EventContext{ssocket.get_endpoint().eqfd, [this, address](int) { on_send_socket_cm_event(address); }}; ssocket.set_eq_ev_ctx(eq_ev_ctx); return eq_ev_ctx; }); m_event_loop->register_fd(ctx); } void netio3::libfabric::ConnectionManager::close_listen_endpoint(const EndPointAddress& address) { if (not m_lsockets.contains(address)) { throw UnknownListenEndpoint(ERS_HERE, address.address(), address.port()); } m_close_queue.push({address, CloseQueueItem::Type::listen}); m_close_signal->fire(); } void netio3::libfabric::ConnectionManager::close_send_endpoint(const EndPointAddress& address) { if (not m_ssockets_buffered.contains(address) and not m_ssockets_zero_copy.contains(address)) { throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port()); } m_close_queue.push({address, CloseQueueItem::Type::send}); m_close_signal->fire(); } std::size_t netio3::libfabric::ConnectionManager::get_num_available_buffers( const EndPointAddress& address) { if (m_ssockets_buffered.contains(address)) { return m_ssockets_buffered.apply( address, [](SendSocketBuffered& ssocket) { return ssocket.get_num_available_buffers(); }); } if (m_ssockets_zero_copy.contains(address)) { return m_ssockets_zero_copy.apply( address, [](SendSocketZeroCopy& ssocket) { return ssocket.get_num_available_buffers(); }); } throw UnknownSendEndpoint(ERS_HERE, address.address(), address.port()); } void netio3::libfabric::ConnectionManager::check_and_init(const EndPointAddress& address, const std::uint64_t flags) { if (m_domain_manager == nullptr) { m_domain_manager = std::make_unique(m_config.mode, address, flags); } if (m_cq_reactor == nullptr) { m_cq_reactor = std::make_unique(m_domain_manager->get_fabric(), m_config.on_data_cb); } } netio3::libfabric::ConnectionManager::CmEvent netio3::libfabric::ConnectionManager::read_cm_event( fid_eq* eq) { ZoneScoped; std::uint32_t event{}; fi_eq_cm_entry entry{}; fi_info* info{nullptr}; fi_eq_err_entry err_entry{}; const auto rd = fi_eq_sread(eq, &event, &entry, sizeof(entry), 0, 0); if (rd < 0) { if (rd == -FI_EAGAIN) { return {rd, err_entry, FiInfoUniquePtr{info}}; } if (rd == -FI_EAVAIL) { const auto r = fi_eq_readerr(eq, &err_entry, 0); if (r < 0) { ers::error(LibFabricCmError( ERS_HERE, std::format("Failed to retrieve details on Event Queue error {}", r))); } ers::error(LibFabricCmError( ERS_HERE, std::format("Event Queue error: {} (code: {}), provider specific: {} (code: {})", fi_strerror(err_entry.err), err_entry.err, fi_eq_strerror(eq, err_entry.prov_errno, err_entry.err_data, nullptr, 0), err_entry.prov_errno))); return {rd, err_entry, FiInfoUniquePtr{info}}; } } if (rd != sizeof(entry)) { ers::error(LibFabricCmError(ERS_HERE, std::format("Failed to read from Event Queue: {}", rd))); } return {event, err_entry, FiInfoUniquePtr{entry.info}}; } void netio3::libfabric::ConnectionManager::on_listen_socket_cm_event(ListenSocket& lsocket) { ZoneScoped; ERS_DEBUG(1, "listen socket: connection event"); auto event = read_cm_event(lsocket.get_endpoint().eq.get()); switch (event.event) { case FI_CONNREQ: { ERS_DEBUG(2, ": FI_CONNREQ"); handle_connection_request(lsocket, std::move(event.info)); } break; case FI_CONNECTED: throw LibFabricCmError(ERS_HERE, "FI_CONNECTED received on listen socket"); case FI_SHUTDOWN: throw LibFabricCmError(ERS_HERE, "FI_SHUTDOWN received on listen socket"); case -FI_EAGAIN: { auto* fp = &lsocket.get_endpoint().eq->fid; fi_trywait(m_domain_manager->get_fabric(), &fp, 1); break; } case -FI_EAVAIL: ers::error(LibFabricCmError( ERS_HERE, std::format("Unhandled error in listen socket EQ code: {} provider specific code: {}", event.err_entry.err, event.err_entry.prov_errno))); break; default: ers::warning(LibFabricCmError( ERS_HERE, std::format("Unexpected event {} in listen socket Event Queue", event.event))); break; } } void netio3::libfabric::ConnectionManager::on_recv_socket_cm_event(ReceiveSocket& rsocket) { ZoneScoped; ERS_DEBUG(2, "Entered"); const auto event = read_cm_event(rsocket.get_endpoint().eq.get()); switch (event.event) { case FI_CONNECTED: { ERS_DEBUG(2, "FI_CONNECTED"); if (const auto ret = fi_control(&rsocket.get_endpoint().rcq->fid, FI_GETWAIT, &rsocket.get_endpoint().cqfd)) { ers::error(LibFabricError( ERS_HERE, std::format( "Failed to retrieve recv socket Completion Queue wait object: error {} cqfd: {}", ret, rsocket.get_endpoint().cqfd))); } rsocket.set_cq_ev_ctx({rsocket.get_endpoint().cqfd, [this, &rsocket](int) { m_cq_reactor->on_recv_socket_cq_event(rsocket); }}); ERS_DEBUG(1, std::format("recv_socket: EQ fd {} connected, CQ fd {}", rsocket.get_endpoint().eqfd, rsocket.get_endpoint().cqfd)); m_event_loop->register_fd(rsocket.get_cq_ev_ctx()); ERS_DEBUG(1, std::format("Adding recv CQ polled fid {} {}", rsocket.get_endpoint().cqfd, static_cast(&rsocket.get_endpoint().cq->fid))); ERS_DEBUG(1, std::format("recv_socket: EQ fd {} CQ fd {} connected", rsocket.get_endpoint().eqfd, rsocket.get_endpoint().cqfd)); const auto address = peer_address(rsocket.get_endpoint().ep.get()); ERS_INFO(std::format("Opened receive connection to {}:{}", address.address(), address.port())); if (m_config.on_connection_established_cb != nullptr) { m_config.on_connection_established_cb(address); } break; } case FI_SHUTDOWN: { ERS_DEBUG(2, "FI_SHUTDOWN"); const auto address = peer_address(rsocket.get_endpoint().ep.get()); ERS_INFO(std::format("Closed receive connection to {}:{} because received FI_SHUTDOWN", address.address(), address.port())); if (m_config.on_connection_closed_cb != nullptr) { m_config.on_connection_closed_cb(address, {}); } std::lock_guard lock(m_rsocket_mutex); const auto entry = std::ranges::find_if(m_rsockets, [&rsocket](const auto& pair) { return pair.second.get_endpoint().eqfd == rsocket.get_endpoint().eqfd; }); if (entry == std::cend(m_rsockets)) { ers::error(LibFabricError(ERS_HERE, "Failed to find ReceiveSocket in ConnectionManager")); } else { unregister_endpoint(rsocket.get_cq_cm_fds()); m_rsockets.erase(entry); } break; } case FI_MR_COMPLETE: case FI_AV_COMPLETE: case FI_JOIN_COMPLETE: // Not implemented break; case -FI_EAGAIN: { auto* fp = &rsocket.get_endpoint().eq->fid; fi_trywait(m_domain_manager->get_fabric(), &fp, 1); break; } case -FI_EAVAIL: ers::error(LibFabricCmError( ERS_HERE, std::format("Unhandled error in recv socket EQ code: {} provider specific code: {}", event.err_entry.err, event.err_entry.prov_errno))); break; default: throw LibFabricCmError( ERS_HERE, std::format("Unexpected event {} in recv socket Event Queue", event.event)); } } template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event(const EndPointAddress& address) { ZoneScoped; ERS_DEBUG(1, "send socket: connection event"); auto& map = [this]() -> ThreadSafeMap& { if constexpr (std::is_same_v) { return m_ssockets_zero_copy; } else if (std::is_same_v) { return m_ssockets_buffered; } else { throw LibFabricLogicError(ERS_HERE, "Unknown socket type"); } }(); const auto event = map.apply( address, [](SocketType& ssocket) { return read_cm_event(ssocket.get_endpoint().eq.get()); }); switch (event.event) { case FI_SHUTDOWN: { ERS_INFO(std::format("Closed send connection to {}:{} because received FI_SHUTDOWN", address.address(), address.port())); const auto fds = map.apply(address, [](SocketType& ssocket) { return ssocket.get_cq_cm_fds(); }); unregister_endpoint(fds); const auto pending_sends = map.apply(address, [](SocketType& ssocket) { return ssocket.get_pending_sends(); }); map.erase(address); if (m_config.on_connection_closed_cb != nullptr) { m_config.on_connection_closed_cb(address, pending_sends); } return; } case FI_CONNECTED: { const auto ctx = map.apply(address, [this, &map, &address](SocketType& ssocket) { if (const auto ret = fi_control(&ssocket.get_endpoint().cq->fid, FI_GETWAIT, &ssocket.get_endpoint().cqfd)) { ers::error(LibFabricError( ERS_HERE, std::format( "Failed to retrieve wait object for send socket Completion Queue, error {} - {}", ret, fi_strerror(-ret)))); } ssocket.set_cq_ev_ctx({ssocket.get_endpoint().cqfd, [this, &map, address](int) { try { const auto keys = map.apply(address, [this](SocketType& ssocket_cb) { return m_cq_reactor->on_send_socket_cq_event(ssocket_cb); }); if (m_config.on_send_completed_cb != nullptr) { for (const auto key : keys) { m_config.on_send_completed_cb(address, key); } } } catch (const std::out_of_range& e) { return; } }}); ERS_DEBUG(1, std::format("send_socket: EQ fd {} connected, CQ fd {}", ssocket.get_endpoint().eqfd, ssocket.get_endpoint().cqfd)); return ssocket.get_cq_ev_ctx(); }); ERS_INFO(std::format("Opened send connection to {}:{}", address.address(), address.port())); m_event_loop->register_fd(ctx); if (m_config.on_connection_established_cb != nullptr) { m_config.on_connection_established_cb(address); } return; } case FI_MR_COMPLETE: case FI_AV_COMPLETE: case FI_JOIN_COMPLETE: // Not implemented break; case -FI_EAVAIL: switch (event.err_entry.err) { case FI_ECONNREFUSED: ERS_DEBUG(1, "Connection refused (FI_ECONNREFUSED). Deallocating send_socket resources"); m_event_loop->remove_fd( map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eqfd; })); map.erase(address); if (m_config.on_connection_refused_cb != nullptr) { m_config.on_connection_refused_cb(address); } break; case FI_ETIMEDOUT: ers::warning( LibFabricCmError(ERS_HERE, "fi_verbs_process_send_socket_cm_event: FI_ETIMEDOUT")); // if (socket->eqfd < 0 ){ // log_info("Ignoring FI_SHUTDOWN on send_socket, invalid eqfd (socket already closed)"); // break; // } // // Need to take care of receive socket as well // if (socket->recv_socket != NULL){ // if (socket->recv_socket->eqfd < 0 ){ // log_info("Ignoring FI_ETIMEDOUT on recv_socket, invalid eqfd (socket already // closed)"); return; // } // log_info("Shutting down receive socket on FI_ETIMEDOUT"); // handle_recv_socket_shutdown(socket->recv_socket); // if(socket->recv_socket->lsocket->cb_connection_closed) { // socket->recv_socket->lsocket->cb_connection_closed(socket->recv_socket); // } // } // if(socket->cqfd < 0){ //cq not initalized yet // handle_send_socket_shutdown_on_connetion_refused(socket); // } else { // if(socket->cb_internal_connection_closed){ // socket->cb_internal_connection_closed(socket); // } // if(socket->cb_connection_closed) { // socket->cb_connection_closed(socket); // } // } break; default: ers::error(LibFabricError( ERS_HERE, std::format( "Unhandled error in the Event Queue: {} (code: {}), provider specific: {} (code: {})", fi_strerror(event.err_entry.err), event.err_entry.err, fi_eq_strerror( map.apply(address, [](SocketType& ssocket) { return ssocket.get_endpoint().eq.get(); }), event.err_entry.prov_errno, event.err_entry.err_data, nullptr, 0), event.err_entry.prov_errno))); } return; case -FI_EAGAIN: { auto* fp = map.apply(address, [](SocketType& ssocket) { return &ssocket.get_endpoint().eq->fid; }); fi_trywait(m_domain_manager->get_fabric(), &fp, 1); break; } default: throw LibFabricCmError( ERS_HERE, std::format("Unexpected event {} in send socket Event Queue", event.event)); } } void netio3::libfabric::ConnectionManager::handle_connection_request(ListenSocket& lsocket, FiInfoUniquePtr&& info) { ZoneScoped; // need to spawn new endpoint ERS_DEBUG(1, "Received connection request"); std::lock_guard lock(m_rsocket_mutex); const auto pair = m_rsockets.emplace( std::piecewise_construct, std::forward_as_tuple(lsocket.get_address()), std::forward_as_tuple( lsocket, m_domain_manager->get_fabric(), m_domain_manager->get_listen_domain(), std::move(info), m_event_loop)); auto& rsocket = pair->second; ERS_DEBUG(1, std::format("Created and connected endpoint. Eqfd: {}", rsocket.get_endpoint().eqfd)); const auto eq_ev_ctx = EventContext{rsocket.get_endpoint().eqfd, [this, &rsocket](int) { on_recv_socket_cm_event(rsocket); }}; rsocket.set_eq_ev_ctx(eq_ev_ctx); m_event_loop->register_fd(eq_ev_ctx); } void netio3::libfabric::ConnectionManager::unregister_endpoint(const CqCmFds& fds) { m_event_loop->remove_fd(fds.cm_fd); m_event_loop->remove_fd(fds.cq_fd); } void netio3::libfabric::ConnectionManager::handle_close_requests() { ZoneScoped; CloseQueueItem item; while (m_close_queue.try_pop(item)) { switch (item.type) { case CloseQueueItem::Type::listen: do_close_listen_endpoint(item.address); break; case CloseQueueItem::Type::send: do_close_send_endpoint(item.address); break; } } } void netio3::libfabric::ConnectionManager::do_close_listen_endpoint(const EndPointAddress& address) { ZoneScoped; if (not m_lsockets.contains(address)) { return; } // Listen socket has no CQ m_event_loop->remove_fd(m_lsockets.at(address).get_cq_cm_fds().cm_fd); std::lock_guard lock(m_rsocket_mutex); const auto [begin, end] = m_rsockets.equal_range(address); for (auto it = begin; it != end; ++it) { unregister_endpoint(it->second.get_cq_cm_fds()); } m_lsockets.erase(address); m_rsockets.erase(address); } void netio3::libfabric::ConnectionManager::do_close_send_endpoint(const EndPointAddress& address) { ZoneScoped; std::vector pending_sends{}; if (m_ssockets_buffered.contains(address)) { const auto fds = m_ssockets_buffered.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); }); unregister_endpoint(fds); pending_sends = m_ssockets_buffered.apply(address, [](auto& socket) { return socket.get_pending_sends(); }); m_ssockets_buffered.erase(address); } else if (m_ssockets_zero_copy.contains(address)) { const auto fds = m_ssockets_zero_copy.apply(address, [](auto& socket) { return socket.get_cq_cm_fds(); }); unregister_endpoint(fds); pending_sends = m_ssockets_zero_copy.apply(address, [](SendSocketZeroCopy& socket) { return socket.get_pending_sends(); }); m_ssockets_zero_copy.erase(address); } if (m_config.on_connection_closed_cb != nullptr) { m_config.on_connection_closed_cb(address, pending_sends); } } void netio3::libfabric::ConnectionManager::unregister_all() { ZoneScoped; for (const auto& [ep, rsocket] : m_rsockets) { unregister_endpoint(rsocket.get_cq_cm_fds()); } for (const auto& [ep, lsocket] : m_lsockets) { unregister_endpoint(lsocket.get_cq_cm_fds()); } m_ssockets_buffered.apply_all([this] (const auto& ssocket) { unregister_endpoint(ssocket.get_cq_cm_fds()); }); m_ssockets_zero_copy.apply_all([this] (const auto& ssocket) { unregister_endpoint(ssocket.get_cq_cm_fds()); }); } template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event< netio3::libfabric::SendSocketBuffered>(const EndPointAddress&); template void netio3::libfabric::ConnectionManager::on_send_socket_cm_event< netio3::libfabric::SendSocketZeroCopy>(const EndPointAddress&);