From a8cfbac07730816dbc68cd797332ea81227892d5 Mon Sep 17 00:00:00 2001 From: Pradip De Date: Tue, 26 Mar 2024 20:59:15 -0700 Subject: [PATCH] TCP connection setup/management and CASESession association. Add TCPConnect()/TCPDisconnect() API for explicit connection setup. Currently, connecting to a peer is coupled with sending a message to the peer. This decouples the two and creates a clear API for connecting to a peer address. Goes along with the existing Disconnect() API. This would be essential during activation of retained sessions by solely connecting to the peer and associating with the retained session. Surface Connection completion and Closure callbacks and hook them through SessionManager(TransportMgr delegate) and CASESession. Mark SecureSession as defunct on connection closures. Modify ActiveConnectionState in TCPBase to hold state for each connection, so that it is able to handle the various control flow paths. Associate a session with a connection object. Associate the PeerAddress with the session early. Pass the PeerAddress in the Find APIs. This helps check against the correct TransportType when searching for a Sesssion in the SessionTable. Add a `large payload` flag in EstablishSession() and Session lookup functions to create/associate with the correct session and transport. Have default configurations for TCP in a separate TCPConfig.h. Refactor echo_requester.cpp and echo_responder.cpp to use the session associated with the connection. Handle Connection closure at ExchangeMgr and uplevel to corresponding ExchangeContext using the corresponding session handle. Add tests around connection establishment in TestTCP. --- examples/shell/shell_common/include/Globals.h | 2 + src/app/CASESessionManager.cpp | 65 ++- src/app/CASESessionManager.h | 35 +- src/app/OperationalSessionSetup.cpp | 28 +- src/app/OperationalSessionSetup.h | 17 +- src/app/server/Server.cpp | 9 + src/app/server/Server.h | 10 + src/app/tests/TestCommissionManager.cpp | 4 + .../CHIPDeviceControllerFactory.cpp | 6 + .../CHIPDeviceControllerSystemState.h | 23 +- .../internal/GenericPlatformManagerImpl.ipp | 14 + src/messaging/ExchangeContext.cpp | 7 + src/messaging/ExchangeContext.h | 3 + src/messaging/ExchangeMgr.cpp | 16 + src/messaging/ExchangeMgr.h | 7 + src/messaging/tests/echo/common.cpp | 5 - src/messaging/tests/echo/echo_requester.cpp | 146 +++++- src/messaging/tests/echo/echo_responder.cpp | 2 + src/protocols/secure_channel/CASESession.cpp | 130 +++++- src/protocols/secure_channel/CASESession.h | 17 + .../secure_channel/PairingSession.cpp | 13 + .../UserDirectedCommissioning.h | 6 +- .../UserDirectedCommissioningClient.cpp | 4 +- .../UserDirectedCommissioningServer.cpp | 4 +- src/transport/BUILD.gn | 1 + src/transport/Session.h | 21 + src/transport/SessionConnectionDelegate.h | 46 ++ src/transport/SessionDelegate.h | 4 + src/transport/SessionManager.cpp | 289 ++++++++++-- src/transport/SessionManager.h | 81 +++- src/transport/TransportMgr.h | 23 +- src/transport/TransportMgrBase.cpp | 79 +++- src/transport/TransportMgrBase.h | 23 +- src/transport/UnauthenticatedSessionTable.h | 42 +- src/transport/raw/ActiveTCPConnectionState.h | 125 ++++++ src/transport/raw/BUILD.gn | 12 +- src/transport/raw/Base.h | 63 ++- src/transport/raw/TCP.cpp | 424 ++++++++++++------ src/transport/raw/TCP.h | 155 ++++--- src/transport/raw/TCPConfig.h | 127 ++++++ src/transport/raw/Tuple.h | 79 +++- src/transport/raw/tests/BUILD.gn | 7 +- src/transport/raw/tests/TestTCP.cpp | 253 ++++++++++- src/transport/raw/tests/TestUDP.cpp | 3 +- 44 files changed, 2062 insertions(+), 368 deletions(-) create mode 100644 src/transport/SessionConnectionDelegate.h create mode 100644 src/transport/raw/ActiveTCPConnectionState.h create mode 100644 src/transport/raw/TCPConfig.h diff --git a/examples/shell/shell_common/include/Globals.h b/examples/shell/shell_common/include/Globals.h index d91620ca0ad171..69162be92f5204 100644 --- a/examples/shell/shell_common/include/Globals.h +++ b/examples/shell/shell_common/include/Globals.h @@ -24,7 +24,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #if INET_CONFIG_ENABLE_TCP_ENDPOINT diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index c15ddcabe01c04..162ae7021a9f9f 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -30,62 +30,55 @@ CHIP_ERROR CASESessionManager::Init(chip::System::Layer * systemLayer, const CAS } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onFailure + Callback::Callback * onFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, onFailure, nullptr + FindOrEstablishSessionHelper(peerId, onConnection, onFailure, nullptr, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - std::nullptr_t + std::nullptr_t, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { - FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr + FindOrEstablishSessionHelper(peerId, onConnection, nullptr, nullptr, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - attemptCount, onRetry + attemptCount, onRetry, #endif - ); + transportPayloadCapability); } void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif -) + TransportPayloadCapability transportPayloadCapability) { ChipLogDetail(CASESessionManager, "FindOrEstablishSession: PeerId = [%d:" ChipLogFormatX64 "]", peerId.GetFabricIndex(), ChipLogValueX64(peerId.GetNodeId())); @@ -124,12 +117,12 @@ void CASESessionManager::FindOrEstablishSessionHelper(const ScopedNodeId & peerI if (onFailure != nullptr) { - session->Connect(onConnection, onFailure); + session->Connect(onConnection, onFailure, transportPayloadCapability); } if (onSetupFailure != nullptr) { - session->Connect(onConnection, onSetupFailure); + session->Connect(onConnection, onSetupFailure, transportPayloadCapability); } } @@ -143,10 +136,11 @@ void CASESessionManager::ReleaseAllSessions() mConfig.sessionSetupPool->ReleaseAllSessionSetup(); } -CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr) +CHIP_ERROR CASESessionManager::GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability) { ReturnErrorOnFailure(mConfig.sessionInitParams.Validate()); - auto optionalSessionHandle = FindExistingSession(peerId); + auto optionalSessionHandle = FindExistingSession(peerId, transportPayloadCapability); ReturnErrorCodeIf(!optionalSessionHandle.HasValue(), CHIP_ERROR_NOT_CONNECTED); addr = optionalSessionHandle.Value()->AsSecureSession()->GetPeerAddress(); return CHIP_NO_ERROR; @@ -182,10 +176,11 @@ OperationalSessionSetup * CASESessionManager::FindExistingSessionSetup(const Sco return mConfig.sessionSetupPool->FindSessionSetup(peerId, forAddressUpdate); } -Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId) const +Optional CASESessionManager::FindExistingSession(const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability) const { - return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode(peerId, - MakeOptional(Transport::SecureSession::Type::kCASE)); + return mConfig.sessionInitParams.sessionManager->FindSecureSessionForNode( + peerId, MakeOptional(Transport::SecureSession::Type::kCASE), transportPayloadCapability); } void CASESessionManager::ReleaseSession(OperationalSessionSetup * session) diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index e78478852b640e..38b39108b43b7e 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace chip { @@ -78,12 +79,11 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * setup is not successful. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onFailure + Callback::Callback * onFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -106,14 +106,14 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onSetupFailure A callback to be called upon an extended device connection failure. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, - Callback::Callback * onSetupFailure + Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /** * Find an existing session for the given node ID or trigger a new session request. @@ -134,13 +134,13 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * @param onConnection A callback to be called upon successful connection establishment. * @param attemptCount The number of retry attempts if session setup fails (default is 1). * @param onRetry A callback to be called on a retry attempt (enabled by a config flag). + * @param transportPayloadCapability An indicator of what payload types the session needs to be able to transport. */ - void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, std::nullptr_t + void FindOrEstablishSession(const ScopedNodeId & peerId, Callback::Callback * onConnection, std::nullptr_t, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - , - uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr + uint8_t attemptCount = 1, Callback::Callback * onRetry = nullptr, #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - ); + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void ReleaseSessionsForFabric(FabricIndex fabricIndex); @@ -154,7 +154,8 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess * an ongoing session with the peer node. If the session doesn't exist, the API will return * `CHIP_ERROR_NOT_CONNECTED` error. */ - CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr); + CHIP_ERROR GetPeerAddress(const ScopedNodeId & peerId, Transport::PeerAddress & addr, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); //////////// OperationalSessionReleaseDelegate Implementation /////////////// void ReleaseSession(OperationalSessionSetup * device) override; @@ -165,15 +166,17 @@ class CASESessionManager : public OperationalSessionReleaseDelegate, public Sess private: OperationalSessionSetup * FindExistingSessionSetup(const ScopedNodeId & peerId, bool forAddressUpdate = false) const; - Optional FindExistingSession(const ScopedNodeId & peerId) const; + Optional FindExistingSession( + const ScopedNodeId & peerId, + const TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload) const; void FindOrEstablishSessionHelper(const ScopedNodeId & peerId, Callback::Callback * onConnection, Callback::Callback * onFailure, Callback::Callback * onSetupFailure, #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES - uint8_t attemptCount, Callback::Callback * onRetry + uint8_t attemptCount, Callback::Callback * onRetry, #endif - ); + TransportPayloadCapability transportPayloadCapability); CASESessionManagerConfig mConfig; }; diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 157de953116ef5..9d91b7b573a59f 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -76,8 +76,8 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() mState == State::WaitingForRetry, false); - auto sessionHandle = - mInitParams.sessionManager->FindSecureSessionForNode(mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE)); + auto sessionHandle = mInitParams.sessionManager->FindSecureSessionForNode( + mPeerId, MakeOptional(Transport::SecureSession::Type::kCASE), mTransportPayloadCapability); if (!sessionHandle.HasValue()) return false; @@ -93,11 +93,13 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() void OperationalSessionSetup::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { CHIP_ERROR err = CHIP_NO_ERROR; bool isConnected = false; + mTransportPayloadCapability = transportPayloadCapability; // // Always enqueue our user provided callbacks into our callback list. // If anything goes wrong below, we'll trigger failures (including any queued from @@ -180,15 +182,17 @@ void OperationalSessionSetup::Connect(Callback::Callback * on } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, onFailure, nullptr); + Connect(onConnection, onFailure, nullptr, transportPayloadCapability); } void OperationalSessionSetup::Connect(Callback::Callback * onConnection, - Callback::Callback * onSetupFailure) + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability) { - Connect(onConnection, nullptr, onSetupFailure); + Connect(onConnection, nullptr, onSetupFailure, transportPayloadCapability); } void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config) @@ -288,6 +292,16 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // TODO: Combine LargePayload flag with DNS-SD advertisements from peer. + // Issue #32348. + if (mTransportPayloadCapability == TransportPayloadCapability::kLargePayload) + { + // Set the transport type for carrying large payloads + mDeviceAddress.SetTransportType(chip::Transport::Type::kTcp); + } +#endif + mCASEClient = mClientPool->Allocate(); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 1de8c305353314..5955dbab0713bd 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -210,8 +210,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * `onFailure` may be called before the Connect call returns, for error * cases that are detected synchronously (e.g. inability to start an address * lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); /* * This function can be called to establish a secure session with the device. @@ -227,8 +231,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * * `onSetupFailure` may be called before the Connect call returns, for error cases that are detected synchronously * (e.g. inability to start an address lookup). + * + * `transportPayloadCapability` is set to kLargePayload when the session needs to be established + * over a transport that allows large payloads to be transferred, e.g., TCP. */ - void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure); + void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); bool IsForAddressUpdate() const { return mPerformingAddressUpdate; } @@ -318,6 +326,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, System::Clock::Milliseconds16 mRequestedBusyDelay = System::Clock::kZero; #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES || CHIP_CONFIG_ENABLE_BUSY_HANDLING_FOR_OPERATIONAL_SESSION_SETUP + TransportPayloadCapability mTransportPayloadCapability = TransportPayloadCapability::kMRPPayload; + #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES // When we TryNextResult on the resolver, it will synchronously call back // into our OnNodeAddressResolved when it succeeds. We need to track @@ -351,7 +361,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, void CleanupCASEClient(); void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, - Callback::Callback * onSetupFailure); + Callback::Callback * onSetupFailure, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure, diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index cd9c56f9f021f7..1ad98d271262da 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -71,6 +71,9 @@ using chip::Transport::BleListenParameters; #endif using chip::Transport::PeerAddress; using chip::Transport::UdpListenParameters; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +using chip::Transport::TcpListenParameters; +#endif namespace { @@ -201,6 +204,12 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) #if CONFIG_NETWORK_LAYER_BLE , BleListenParameters(DeviceLayer::ConnectivityMgr().GetBleLayer()) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + TcpListenParameters(DeviceLayer::TCPEndPointManager()) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(mOperationalServicePort) #endif ); diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 8f6fcd5abecc66..d649e0fc923896 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -77,6 +77,12 @@ namespace chip { inline constexpr size_t kMaxBlePendingPackets = 1; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +inline constexpr size_t kMaxTcpActiveConnectionCount = CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS; + +inline constexpr size_t kMaxTcpPendingPackets = CHIP_CONFIG_MAX_TCP_PENDING_PACKETS; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // // NOTE: Please do not alter the order of template specialization here as the logic // in the Server impl depends on this. @@ -89,6 +95,10 @@ using ServerTransportMgr = chip::TransportMgr +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + chip::Transport::TCP #endif >; diff --git a/src/app/tests/TestCommissionManager.cpp b/src/app/tests/TestCommissionManager.cpp index 0076fd6a55718f..188f51aee6a609 100644 --- a/src/app/tests/TestCommissionManager.cpp +++ b/src/app/tests/TestCommissionManager.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -104,6 +105,9 @@ void InitializeChip(nlTestSuite * suite) static chip::SimpleTestEventTriggerDelegate sSimpleTestEventTriggerDelegate; initParams.testEventTriggerDelegate = &sSimpleTestEventTriggerDelegate; (void) initParams.InitializeStaticResourcesBeforeServerInit(); + // Set a randomized server port(slightly shifted from CHIP_PORT) for testing + initParams.operationalServicePort = static_cast(initParams.operationalServicePort + chip::Crypto::GetRandU16() % 20); + err = chip::Server::GetInstance().Init(initParams); NL_TEST_ASSERT(suite, err == CHIP_NO_ERROR); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 198d0c47c9f614..eb52e389a6b9d6 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -161,6 +161,12 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) #if CONFIG_NETWORK_LAYER_BLE , Transport::BleListenParameters(stateParams.bleLayer) +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TcpListenParameters(stateParams.tcpEndPointManager) + .SetAddressType(IPAddressType::kIPv6) + .SetListenPort(params.listenPort) #endif )); diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index 389bb557f6c0cc..1ea0593d594993 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -57,16 +57,27 @@ namespace chip { inline constexpr size_t kMaxDeviceTransportBlePendingPackets = 1; -using DeviceTransportMgr = TransportMgr /* BLE */ + , + Transport::BLE /* BLE */ +#endif +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + Transport::TCP #endif - >; + >; namespace Controller { diff --git a/src/include/platform/internal/GenericPlatformManagerImpl.ipp b/src/include/platform/internal/GenericPlatformManagerImpl.ipp index d6f29ed515e034..64878f5928cd9b 100644 --- a/src/include/platform/internal/GenericPlatformManagerImpl.ipp +++ b/src/include/platform/internal/GenericPlatformManagerImpl.ipp @@ -89,6 +89,16 @@ CHIP_ERROR GenericPlatformManagerImpl::_InitChipStack() } SuccessOrExit(err); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Initialize the CHIP TCP layer. + err = TCPEndPointManager()->Init(SystemLayer()); + if (err != CHIP_NO_ERROR) + { + ChipLogError(DeviceLayer, "TCP initialization failed: %" CHIP_ERROR_FORMAT, err.Format()); + } + SuccessOrExit(err); +#endif + // TODO Perform dynamic configuration of the core CHIP objects based on stored settings. // Initialize the CHIP BLE manager. @@ -132,6 +142,10 @@ void GenericPlatformManagerImpl::_Shutdown() ChipLogError(DeviceLayer, "Inet Layer shutdown"); UDPEndPointManager()->Shutdown(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + TCPEndPointManager()->Shutdown(); +#endif + #if CHIP_DEVICE_CONFIG_ENABLE_CHIPOBLE ChipLogError(DeviceLayer, "BLE shutdown"); BLEMgr().Shutdown(); diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index f36a274be1da2a..bca044d91f230d 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -670,5 +670,12 @@ void ExchangeContext::ExchangeSessionHolder::GrabExpiredSession(const SessionHan GrabUnchecked(session); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeContext::OnSessionConnectionClosed(CHIP_ERROR conErr) +{ + // TODO: Handle connection closure at the ExchangeContext level. +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index fc20a5aace5273..47cf0ddbef2783 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -86,6 +86,9 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } void OnSessionReleased() override; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnSessionConnectionClosed(CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Send a CHIP message on this exchange. * diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 3971864bb7e06e..a184723726781e 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -77,6 +77,9 @@ CHIP_ERROR ExchangeManager::Init(SessionManager * sessionManager) sessionManager->SetMessageDelegate(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + sessionManager->SetConnectionDelegate(this); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT mReliableMessageMgr.Init(sessionManager->SystemLayer()); mState = State::kState_Initialized; @@ -413,5 +416,18 @@ void ExchangeManager::CloseAllContextsForDelegate(const ExchangeDelegate * deleg }); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void ExchangeManager::OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) +{ + mContextPool.ForEachActiveObject([&](auto * ec) { + if (ec->HasSessionHandle() && ec->GetSessionHandle() == session) + { + ec->OnSessionConnectionClosed(conErr); + } + return Loop::Continue; + }); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace Messaging } // namespace chip diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 48c6f8df673fb6..b6e416cdbf74e7 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -49,6 +49,10 @@ static constexpr int16_t kAnyMessageType = -1; * handling the registration/unregistration of unsolicited message handlers. */ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + , + public SessionConnectionDelegate +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT { friend class ExchangeContext; @@ -242,6 +246,9 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override; void SendStandaloneAckIfNeeded(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session, MessageFlags msgFlags, System::PacketBufferHandle && msgBuf); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace Messaging diff --git a/src/messaging/tests/echo/common.cpp b/src/messaging/tests/echo/common.cpp index 80befc92d27ae6..10eeeb67738024 100644 --- a/src/messaging/tests/echo/common.cpp +++ b/src/messaging/tests/echo/common.cpp @@ -51,10 +51,6 @@ void InitializeChip() err = chip::DeviceLayer::PlatformMgr().InitChipStack(); SuccessOrExit(err); - // Initialize TCP. - err = chip::DeviceLayer::TCPEndPointManager()->Init(chip::DeviceLayer::SystemLayer()); - SuccessOrExit(err); - exit: if (err != CHIP_NO_ERROR) { @@ -68,6 +64,5 @@ void ShutdownChip() gMessageCounterManager.Shutdown(); gExchangeManager.Shutdown(); gSessionManager.Shutdown(); - (void) chip::DeviceLayer::TCPEndPointManager()->Shutdown(); chip::DeviceLayer::PlatformMgr().Shutdown(); } diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index 1bff8e70f55d8d..b06edb359b989f 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include #include @@ -49,6 +51,9 @@ namespace { // Max value for the number of EchoRequests sent. constexpr size_t kMaxEchoCount = 3; +// Max value for the number of tcp connect attempts. +constexpr size_t kMaxTCPConnectAttempts = 3; + // The CHIP Echo interval time. constexpr chip::System::Clock::Timeout gEchoInterval = chip::System::Clock::Seconds16(1); @@ -62,9 +67,21 @@ chip::TransportMgrAsSecureSession()->SetTCPConnection(conn); + } + + printf("Successfully established secure session with peer at %s\n", peerAddrBuf); } return err; } +void CloseConnection() +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + gSessionManager.TCPDisconnect(peerAddr); + + peerAddr.ToString(peerAddrBuf); + printf("Connection closed to peer at %s\n", peerAddrBuf); + + gClientConEstablished = false; + gClientConInProgress = false; +} + +void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + chip::DeviceLayer::PlatformMgr().StopEventLoopTask(); + + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + err = EstablishSecureSession(conn); + if (err != CHIP_NO_ERROR) + { + printf("Secure session FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + return; + } + + gClientConEstablished = true; + gClientConInProgress = false; +} + +void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + CloseConnection(); +} + +void EstablishTCPConnection() +{ + CHIP_ERROR err = CHIP_NO_ERROR; + // Previous connection attempt underway. + if (gClientConInProgress) + { + return; + } + + gClientConEstablished = false; + + chip::Transport::PeerAddress peerAddr = chip::Transport::PeerAddress::TCP(gDestAddr, CHIP_PORT); + + // Connect to the peer + err = gSessionManager.TCPConnect(peerAddr, &gAppTCPConnCbCtxt, &gActiveTCPConnState); + if (err != CHIP_NO_ERROR) + { + printf("Connection FAILED with err: %s\n", chip::ErrorStr(err)); + + gLastEchoTime = chip::System::SystemClock().GetMonotonicTimestamp(); + CloseConnection(); + gTCPConnAttemptCount++; + return; + } + + gClientConInProgress = true; +} + void HandleEchoResponseReceived(chip::Messaging::ExchangeContext * ec, chip::System::PacketBufferHandle && payload) { chip::System::Clock::Timestamp respTime = chip::System::SystemClock().GetMonotonicTimestamp(); @@ -236,6 +342,10 @@ int main(int argc, char * argv[]) err = gSessionManager.Init(&chip::DeviceLayer::SystemLayer(), &gTCPManager, &gMessageCounterManager, &gStorage, &gFabricTable, gSessionKeystore); SuccessOrExit(err); + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + gAppTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; } else { @@ -255,9 +365,29 @@ int main(int argc, char * argv[]) err = gMessageCounterManager.Init(&gExchangeManager); SuccessOrExit(err); - // Start the CHIP connection to the CHIP echo responder. - err = EstablishSecureSession(); - SuccessOrExit(err); + if (gUseTCP) + { + + while (!gClientConEstablished) + { + // For TCP transport, attempt to establish the connection to the CHIP echo responder. + // On Connection completion, call EstablishSecureSession(conn); + EstablishTCPConnection(); + + chip::DeviceLayer::PlatformMgr().RunEventLoop(); + + if (gTCPConnAttemptCount > kMaxTCPConnectAttempts) + { + ExitNow(); + } + } + } + else + { + // Start the CHIP session to the CHIP echo responder. + err = EstablishSecureSession(nullptr); + SuccessOrExit(err); + } err = gEchoClient.Init(&gExchangeManager, gSession.Get().Value()); SuccessOrExit(err); @@ -274,14 +404,14 @@ int main(int argc, char * argv[]) if (gUseTCP) { - gTCPManager.Disconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); + gTCPManager.TCPDisconnect(chip::Transport::PeerAddress::TCP(gDestAddr)); } gTCPManager.Close(); Shutdown(); exit: - if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount)) + if ((err != CHIP_NO_ERROR) || (gEchoRespCount != kMaxEchoCount) || (gTCPConnAttemptCount > kMaxTCPConnectAttempts)) { printf("ChipEchoClient failed: %s\n", chip::ErrorStr(err)); exit(EXIT_FAILURE); diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index 140aab00b0ff23..0cd2366a1f6c2b 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -35,7 +35,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include namespace { diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 745bd3898fa865..ccc2bc2188d176 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -419,6 +419,21 @@ void CASESession::Clear() mPeerNodeId = kUndefinedNodeId; mFabricsTable = nullptr; mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Clear the context object. + mTCPConnCbCtxt.appContext = nullptr; + mTCPConnCbCtxt.connCompleteCb = nullptr; + mTCPConnCbCtxt.connClosedCb = nullptr; + mTCPConnCbCtxt.connReceivedCb = nullptr; + + if (mPeerConnState && mPeerConnState->mConnectionState != Transport::TCPState::kConnected) + { + // Abort the connection if the CASESession is being destroyed and the + // connection is in the middle of being set up. + mSessionManager->TCPDisconnect(mPeerConnState, /* shouldAbort = */ true); + mPeerConnState = nullptr; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT } void CASESession::InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex) @@ -446,7 +461,9 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi ReturnErrorOnFailure(mCommissioningHash.Begin()); - mDelegate = delegate; + mDelegate = delegate; + mSessionManager = &sessionManager; + ReturnErrorOnFailure(AllocateSecureSession(sessionManager, sessionEvictionHint)); mValidContext.Reset(); @@ -454,6 +471,11 @@ CHIP_ERROR CASESession::Init(SessionManager & sessionManager, Credentials::Certi mValidContext.mRequiredKeyPurposes.Set(KeyPurposeFlags::kServerAuth); mValidContext.mValidityPolicy = policy; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mTCPConnCbCtxt.appContext = this; + mTCPConnCbCtxt.connCompleteCb = HandleConnectionAttemptComplete; + mTCPConnCbCtxt.connClosedCb = HandleConnectionClosed; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT return CHIP_NO_ERROR; } @@ -516,12 +538,18 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric // This is to make sure the exchange will get closed if Init() returned an error. mExchangeCtxt.Emplace(*exchangeCtxt); + Transport::PeerAddress peerAddress = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); + // From here onwards, let's go to exit on error, as some state might have already // been initialized SuccessOrExit(err); SuccessOrExit(err = fabricTable->AddFabricDelegate(this)); + // Set the PeerAddress in the secure session up front to indicate the + // Transport Type of the session that is being set up. + mSecureSessionHolder->AsSecureSession()->SetPeerAddress(peerAddress); + mFabricsTable = fabricTable; mFabricIndex = fabricInfo->GetFabricIndex(); mSessionResumptionStorage = sessionResumptionStorage; @@ -534,8 +562,18 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric ChipLogProgress(SecureChannel, "Initiating session on local FabricIndex %u from 0x" ChipLogFormatX64 " -> 0x" ChipLogFormatX64, static_cast(mFabricIndex), ChipLogValueX64(mLocalNodeId), ChipLogValueX64(mPeerNodeId)); - err = SendSigma1(); - SuccessOrExit(err); + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + err = sessionManager.TCPConnect(peerAddress, &mTCPConnCbCtxt, &mPeerConnState); + SuccessOrExit(err); +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } + else + { + err = SendSigma1(); + SuccessOrExit(err); + } exit: if (err != CHIP_NO_ERROR) @@ -648,6 +686,78 @@ CHIP_ERROR CASESession::RecoverInitiatorIpk() return CHIP_NO_ERROR; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void CASESession::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR err) +{ + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + // Exit and disconnect if connection setup encountered an error. + SuccessOrExit(err); + + ChipLogDetail(SecureChannel, "TCP Connection established with %s before session establishment", peerAddrBuf); + + // Associate the connection with the current unauthenticated session for the + // CASE exchange. + caseSession->mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetTCPConnection(conn); + + // Associate the connection with the current secure session that is being + // set up. + caseSession->mSecureSessionHolder.Get().Value()->AsSecureSession()->SetTCPConnection(conn); + + // Send Sigma1 after connection is established for sessions over TCP + err = caseSession->SendSigma1(); + SuccessOrExit(err); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Connection establishment failed with peer at %s: %" CHIP_ERROR_FORMAT, peerAddrBuf, + err.Format()); + + // Close the underlying connection and ensure that the CASESession is + // not holding on to a stale ActiveTCPConnectionState. We call + // TCPDisconnect() here explicitly in order to abort the connection + // even after it establishes successfully, but SendSigma1() fails for + // some reason. + caseSession->mSessionManager->TCPDisconnect(conn, /* shouldAbort = */ true); + caseSession->mPeerConnState = nullptr; + + caseSession->Clear(); + } +} + +void CASESession::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + // conn->mAppState should not be NULL. SessionManager has already checked + // before calling this callback. + VerifyOrDie(conn->mAppState != nullptr); + + CASESession * caseSession = reinterpret_cast(conn->mAppState->appContext); + VerifyOrReturn(caseSession != nullptr); + + // Drop our pointer to the now-invalid connection state. + // + // Since the connection is closed, message sends over the ExchangeContext + // will just fail and be handled like normal send errors. + // + // Additionally, SessionManager notifies (via ExchangeMgr) all ExchangeContexts on the + // connection closures for the attached sessions and the ExchangeContexts + // can close proactively if that's appropriate. + caseSession->mPeerConnState = nullptr; + ChipLogDetail(SecureChannel, "TCP Connection for this session has closed"); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR CASESession::SendSigma1() { MATTER_TRACE_SCOPE("SendSigma1", "CASESession"); @@ -2143,10 +2253,16 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST #if CHIP_CONFIG_SLOW_CRYPTO - if (msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || - msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) - { + if ((msgType == Protocols::SecureChannel::MsgType::CASE_Sigma1 || msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2 || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume || + msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3) && + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress().GetTransportType() != + Transport::Type::kTcp) + { + // TODO: Rename FlushAcks() to something more semantically correct and + // call unconditionally for TCP or MRP from here. Inside, the + // PeerAddress type could be consulted to selectively flush MRP Acks + // when transport is not TCP. Issue #33183 SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks()); } #endif // CHIP_CONFIG_SLOW_CRYPTO diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index b7c6b429b950ad..045d1982dd723c 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -286,6 +286,22 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void InvalidateIfPendingEstablishmentOnFabric(FabricIndex fabricIndex); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + static void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + static void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + + // Context to pass down when connecting to peer + Transport::AppTCPConnectionCallbackCtxt mTCPConnCbCtxt; + // Pointer to the underlying TCP connection state. Returned by the + // TCPConnect() method (on the connection Initiator side) when an + // ActiveTCPConnectionState object is allocated. This connection + // context is used on the CASE Initiator side to facilitate the + // invocation of the callbacks when the connection is established/closed. + // + // This pointer must be nulled out when the connection is closed. + Transport::ActiveTCPConnectionState * mPeerConnState = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + #if CONFIG_BUILD_FOR_HOST_UNIT_TEST void SetStopSigmaHandshakeAt(Optional state) { mStopHandshakeAtState = state; } #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST @@ -301,6 +317,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, uint8_t mIPK[kIPKSize]; SessionResumptionStorage * mSessionResumptionStorage = nullptr; + SessionManager * mSessionManager = nullptr; FabricTable * mFabricsTable = nullptr; FabricIndex mFabricIndex = kUndefinedFabricIndex; diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 1f7874bdf115dc..ae4ca272858a78 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace chip { @@ -58,6 +59,18 @@ void PairingSession::Finish() { Transport::PeerAddress address = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (address.GetTransportType() == Transport::Type::kTcp) + { + // Fetch the connection for the unauthenticated session used to set up + // the secure session. + Transport::ActiveTCPConnectionState * conn = + mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetTCPConnection(); + + // Associate the connection with the secure session being activated. + mSecureSessionHolder->AsSecureSession()->SetTCPConnection(conn); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT // Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that. DiscardExchange(); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h index 86fd4d51ed4e48..7a496e5a2f0a33 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioning.h @@ -539,7 +539,8 @@ class DLL_EXPORT UserDirectedCommissioningClient : public TransportMgrDelegate } private: - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; CommissionerDeclarationHandler * mCommissionerDeclarationHandler = nullptr; }; @@ -652,7 +653,8 @@ class DLL_EXPORT UserDirectedCommissioningServer : public TransportMgrDelegate void HandleNewUDC(const Transport::PeerAddress & source, IdentificationDeclaration & id); void HandleUDCCancel(IdentificationDeclaration & id); void HandleUDCCommissionerPasscodeReady(IdentificationDeclaration & id); - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; UDCClients mUdcClients; // < Active UDC clients diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp index 9fc43634d7ec7d..6d7c315ffb5308 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp @@ -24,6 +24,7 @@ */ #include "UserDirectedCommissioning.h" +#include #ifdef __ZEPHYR__ #include @@ -235,7 +236,8 @@ CHIP_ERROR CommissionerDeclaration::ReadPayload(uint8_t * udcPayload, size_t pay return CHIP_NO_ERROR; } -void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningClient::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp index 2efd8d0a33de28..a06bffcec62b3a 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp @@ -26,6 +26,7 @@ #include "UserDirectedCommissioning.h" #include #include +#include #include @@ -33,7 +34,8 @@ namespace chip { namespace Protocols { namespace UserDirectedCommissioning { -void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg) +void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { char addrBuffer[chip::Transport::PeerAddress::kMaxToStringSize]; source.ToString(addrBuffer); diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index fa556d40afd5b0..c4649c9b92d27c 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -37,6 +37,7 @@ static_library("transport") { "SecureSessionTable.h", "Session.cpp", "Session.h", + "SessionConnectionDelegate.h", "SessionDelegate.h", "SessionHolder.cpp", "SessionHolder.h", diff --git a/src/transport/Session.h b/src/transport/Session.h index 0b6048d5c077c0..d9840ec3ee33f1 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -28,6 +28,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { @@ -225,6 +228,15 @@ class Session bool IsUnauthenticatedSession() const { return GetSessionType() == SessionType::kUnauthenticated; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // This API is used to associate the connection with the session when the + // latter is about to be marked active. It is also used to reset the + // connection to a nullptr when the connection is lost and the session + // is marked as Defunct. + ActiveTCPConnectionState * GetTCPConnection() const { return mTCPConnection; } + void SetTCPConnection(ActiveTCPConnectionState * conn) { mTCPConnection = conn; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void DispatchSessionEvent(SessionDelegate::Event event) { // Holders might remove themselves when notified. @@ -264,6 +276,15 @@ class Session private: FabricIndex mFabricIndex = kUndefinedFabricIndex; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // The underlying TCP connection object over which the session is + // established. + // The lifetime of this member connection pointer is, essentially, the same + // as that of the underlying connection with the peer. + // It would remain as a nullptr for all sessions that are not set up over + // a TCP connection. + ActiveTCPConnectionState * mTCPConnection = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; // diff --git a/src/transport/SessionConnectionDelegate.h b/src/transport/SessionConnectionDelegate.h new file mode 100644 index 00000000000000..4557d0ae107e68 --- /dev/null +++ b/src/transport/SessionConnectionDelegate.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { + +/** + * @brief + * This class defines a delegate that will be called by the SessionManager on + * specific connection-related (e.g. for TCP) events. If the user of SessionManager + * is interested in receiving these callbacks, they can specialize this class and + * handle each trigger in their implementation of this class. + */ +class DLL_EXPORT SessionConnectionDelegate +{ +public: + virtual ~SessionConnectionDelegate() {} + + /** + * @brief + * Called when the underlying connection for the session is closed. + * + * @param session The handle to the secure session + * @param conErr The connection error code + */ + virtual void OnTCPConnectionClosed(const SessionHandle & session, CHIP_ERROR conErr) = 0; +}; + +} // namespace chip diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index b9e0a7b8b38b0c..503aaa2b0c4f5c 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -66,6 +66,10 @@ class DLL_EXPORT SessionDelegate * SessionManager to allocate a new session. If they desire to do so, it MUST be done asynchronously. */ virtual void OnSessionHang() {} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void OnSessionConnectionClosed(CHIP_ERROR conErr) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 5adfde18b47ff4..d49e66fb6f5116 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -147,6 +147,11 @@ CHIP_ERROR SessionManager::Init(System::Layer * systemLayer, TransportMgrBase * mTransportMgr->SetSessionManager(this); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + mConnCompleteCb = nullptr; + mConnClosedCb = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + return CHIP_NO_ERROR; } @@ -602,7 +607,8 @@ CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionH return CHIP_NO_ERROR; } -void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { PacketHeader partialPacketHeader; @@ -621,20 +627,151 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: } else { - SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + SecureUnicastMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); } } else { - UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg)); + UnauthenticatedMessageDispatch(partialPacketHeader, peerAddress, std::move(msg), ctxt); + } +} + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void SessionManager::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + + VerifyOrReturn(conn != nullptr); + conn->mPeerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Received TCP connection request from %s.", peerAddrBuf); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt != nullptr && appTCPConnCbCtxt->connReceivedCb != nullptr) + { + appTCPConnCbCtxt->connReceivedCb(conn); + } +} + +void SessionManager::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + if (appTCPConnCbCtxt == nullptr) + { + TCPDisconnect(conn, /* shouldAbort = */ true); + return; + } + + if (appTCPConnCbCtxt->connCompleteCb != nullptr) + { + appTCPConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + + ChipLogProgress(Inet, "TCP Connection established with peer %s, but no registered handler. Disconnecting.", peerAddrBuf); + + // Close the connection + TCPDisconnect(conn, /* shouldAbort = */ true); + } +} + +void SessionManager::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + VerifyOrReturn(conn != nullptr); + + // Mark the corresponding secure sessions as defunct + mSecureSessions.ForEachSession([&](auto session) { + if (session->IsActiveSession() && session->GetTCPConnection() == conn) + { + SessionHandle handle(*session); + // Notify the SessionConnection delegate of the connection + // closure. + if (mConnDelegate != nullptr) + { + mConnDelegate->OnTCPConnectionClosed(handle, conErr); + } + + // Dis-associate the connection from session by setting it to a + // nullptr. + session->SetTCPConnection(nullptr); + // Mark session as defunct + session->MarkAsDefunct(); + } + + return Loop::Continue; + }); + + // TODO: A mechanism to mark an unauthenticated session as unusable when + // the underlying connection is broken. Issue #32323 + + Transport::AppTCPConnectionCallbackCtxt * appTCPConnCbCtxt = conn->mAppState; + VerifyOrReturn(appTCPConnCbCtxt != nullptr); + + if (appTCPConnCbCtxt->connClosedCb != nullptr) + { + appTCPConnCbCtxt->connClosedCb(conn, conErr); + } +} + +CHIP_ERROR SessionManager::TCPConnect(const PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) +{ + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + if (mTransportMgr != nullptr) + { + ChipLogProgress(Inet, "Connecting over TCP with peer at %s.", peerAddrBuf); + return mTransportMgr->TCPConnect(peerAddress, appState, peerConnState); + } + + ChipLogError(Inet, "The transport manager is not initialized. Unable to connect to peer at %s.", peerAddrBuf); + + return CHIP_ERROR_INCORRECT_STATE; +} + +CHIP_ERROR SessionManager::TCPDisconnect(const PeerAddress & peerAddress) +{ + if (mTransportMgr != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(peerAddress); } + + return CHIP_NO_ERROR; } +void SessionManager::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + if (mTransportMgr != nullptr && conn != nullptr) + { + char peerAddrBuf[chip::Transport::PeerAddress::kMaxToStringSize]; + conn->mPeerAddr.ToString(peerAddrBuf); + ChipLogProgress(Inet, "Disconnecting TCP connection from peer at %s.", peerAddrBuf); + mTransportMgr->TCPDisconnect(conn, shouldAbort); + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Unauthenticated Message Dispatch", "SessionManager"); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Drop unsecured messages with privacy enabled. if (partialPacketHeader.HasPrivacyFlag()) { @@ -660,7 +797,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial if (source.HasValue()) { // Assume peer is the initiator, we are the responder. - optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig()); + optionalSession = mUnauthenticatedSessions.FindOrAllocateResponder(source.Value(), GetDefaultMRPConfig(), peerAddress); if (!optionalSession.HasValue()) { ChipLogError(Inet, "UnauthenticatedSession exhausted"); @@ -670,7 +807,7 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial else { // Assume peer is the responder, we are the initiator. - optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value()); + optionalSession = mUnauthenticatedSessions.FindInitiator(destination.Value(), peerAddress); if (!optionalSession.HasValue()) { ChipLogProgress(Inet, "Received unknown unsecure packet for initiator 0x" ChipLogFormatX64, @@ -685,6 +822,25 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial CorrectPeerAddressInterfaceID(mutablePeerAddress); unsecuredSession->SetPeerAddress(mutablePeerAddress); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the unauthenticated session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = unsecuredSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + unsecuredSession->SetTCPConnection(ctxt->conn); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT unsecuredSession->MarkActiveRx(); @@ -723,13 +879,55 @@ void SessionManager::UnauthenticatedMessageDispatch(const PacketHeader & partial } void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, - const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) + const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { MATTER_TRACE_SCOPE("Secure Unicast Message Dispatch", "SessionManager"); CHIP_ERROR err = CHIP_NO_ERROR; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (peerAddress.GetTransportType() == Transport::Type::kTcp && ctxt->conn == nullptr) + { + ChipLogError(Inet, "Connection object is missing for received message."); + return; + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + Optional session = mSecureSessions.FindSecureSessionByLocalKey(partialPacketHeader.GetSessionId()); + if (!session.HasValue()) + { + ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", partialPacketHeader.GetSessionId()); + return; + } + + Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); + Transport::PeerAddress mutablePeerAddress = peerAddress; + CorrectPeerAddressInterfaceID(mutablePeerAddress); + if (secureSession->GetPeerAddress() != mutablePeerAddress) + { + secureSession->SetPeerAddress(mutablePeerAddress); + } + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Associate the secure session with the connection, if not done already. + if (peerAddress.GetTransportType() == Transport::Type::kTcp) + { + Transport::ActiveTCPConnectionState * sessionConn = secureSession->GetTCPConnection(); + if (sessionConn == nullptr) + { + secureSession->SetTCPConnection(ctxt->conn); + } + else + { + if (sessionConn != ctxt->conn) + { + ChipLogError(Inet, "Data received over wrong connection %p. Dropping it!", ctxt->conn); + return; + } + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT PayloadHeader payloadHeader; @@ -751,14 +949,6 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPa return; } - if (!session.HasValue()) - { - ChipLogError(Inet, "Data received on an unknown session (LSID=%d). Dropping it!", packetHeader.GetSessionId()); - return; - } - - Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); - // We need to allow through messages even on sessions that are pending // evictions, because for some cases (UpdateNOC, RemoveFabric, etc) there // can be a single exchange alive on the session waiting for a MRP ack, and @@ -816,13 +1006,6 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & partialPa secureSession->GetSessionMessageCounter().GetPeerMessageCounter().CommitEncryptedUnicast(packetHeader.GetMessageCounter()); } - Transport::PeerAddress mutablePeerAddress = peerAddress; - CorrectPeerAddressInterfaceID(mutablePeerAddress); - if (secureSession->GetPeerAddress() != mutablePeerAddress) - { - secureSession->SetPeerAddress(mutablePeerAddress); - } - if (mCB != nullptr) { MATTER_LOG_MESSAGE_RECEIVED(chip::Tracing::IncomingMessageType::kSecureUnicast, &payloadHeader, &packetHeader, @@ -1057,27 +1240,69 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack } Optional SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type) + const Optional & type, + TransportPayloadCapability transportPayloadCapability) { - SecureSession * found = nullptr; - - mSecureSessions.ForEachSession([&peerNodeId, &type, &found](auto session) { + SecureSession * mrpSession = nullptr; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + SecureSession * tcpSession = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + mSecureSessions.ForEachSession([&peerNodeId, &type, &mrpSession, +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + &tcpSession, +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + &transportPayloadCapability](auto session) { if (session->IsActiveSession() && session->GetPeer() == peerNodeId && (!type.HasValue() || type.Value() == session->GetSecureSessionType())) { - // - // Select the active session with the most recent activity to return back to the caller. - // - if ((found == nullptr) || (found->GetLastActivityTime() < session->GetLastActivityTime())) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if ((transportPayloadCapability == TransportPayloadCapability::kMRPOrTCPCompatiblePayload || + transportPayloadCapability == TransportPayloadCapability::kLargePayload) && + session->GetTCPConnection() != nullptr) { - found = session; + // Set up a TCP transport based session as standby + if ((tcpSession == nullptr) || (tcpSession->GetLastActivityTime() < session->GetLastActivityTime())) + { + tcpSession = session; + } + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + if ((mrpSession == nullptr) || (mrpSession->GetLastActivityTime() < session->GetLastActivityTime())) + { + mrpSession = session; } } return Loop::Continue; }); - return found != nullptr ? MakeOptional(*found) : Optional::Missing(); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + if (transportPayloadCapability == TransportPayloadCapability::kLargePayload) + { + return tcpSession != nullptr ? MakeOptional(*tcpSession) : Optional::Missing(); + } + + if (transportPayloadCapability == TransportPayloadCapability::kMRPOrTCPCompatiblePayload) + { + // If MRP-based session is available, use it. + if (mrpSession != nullptr) + { + return MakeOptional(*mrpSession); + } + + // Otherwise, look for a tcp-based session + if (tcpSession != nullptr) + { + return MakeOptional(*tcpSession); + } + + return Optional::Missing(); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + + return mrpSession != nullptr ? MakeOptional(*mrpSession) : Optional::Missing(); } /** diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 5f2e6f7603cad0..b7a1630b3d2851 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -52,8 +52,30 @@ #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + namespace chip { +/* + * This enum indicates whether a session needs to be established over a + * suitable transport that meets certain payload size requirements for + * transmitted messages. + * + */ +enum class TransportPayloadCapability : uint8_t +{ + kMRPPayload, // Transport requires the maximum payload size to fit within a single + // IPv6 packet(1280 bytes). + kLargePayload, // Transport needs to handle payloads larger than the single IPv6 + // packet, as supported by MRP. The transport of choice, in this + // case, is TCP. + kMRPOrTCPCompatiblePayload // This option provides the ability to use MRP + // as the preferred transport, but use a large + // payload transport if that is already + // available. +}; /** * @brief * Tracks ownership of a encrypted packet buffer. @@ -151,6 +173,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl /// ExchangeManager) void SetMessageDelegate(SessionMessageDelegate * cb) { mCB = cb; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + void SetConnectionDelegate(SessionConnectionDelegate * cb) { mConnDelegate = cb; } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + // Test-only: create a session on the fly. CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabricIndex, @@ -413,8 +439,34 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) override; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + CHIP_ERROR TCPDisconnect(const Transport::PeerAddress & peerAddress); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + // Functors for callbacks into higher layers + using OnTCPConnectionReceivedCallback = void (*)(Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config) @@ -436,8 +488,9 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl // is returned. Otherwise, an Optional with no value set is returned. // // - Optional FindSecureSessionForNode(ScopedNodeId peerNodeId, - const Optional & type = NullOptional); + Optional + FindSecureSessionForNode(ScopedNodeId peerNodeId, const Optional & type = NullOptional, + TransportPayloadCapability transportPayloadCapability = TransportPayloadCapability::kMRPPayload); using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle); CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback); @@ -477,8 +530,22 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl State mState; // < Initialization state of the object chip::Transport::GroupOutgoingCounters mGroupClientCounter; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; + + // Hold the TCPConnection callback context for the receiver application in the SessionManager. + // On receipt of a connection from a peer, the SessionManager + Transport::AppTCPConnectionCallbackCtxt * mServerTCPConnCbCtxt = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionMessageDelegate * mCB = nullptr; +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + SessionConnectionDelegate * mConnDelegate = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + TransportMgrBase * mTransportMgr = nullptr; Transport::MessageCounterManagerInterface * mMessageCounterManager = nullptr; @@ -491,9 +558,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * If the message decrypts successfully, this will be filled with a fully decoded PacketHeader. * @param[in] peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void SecureUnicastMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); /** * @brief Parse, decrypt, validate, and dispatch a secure group message. @@ -511,9 +580,11 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate, public FabricTabl * @param partialPacketHeader The partial PacketHeader of the message after processing with DecodeFixed. * @param peerAddress The PeerAddress of the message as provided by the receiving Transport Endpoint. * @param msg The full message buffer, including header fields. + * @param ctxt The pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ void UnauthenticatedMessageDispatch(const PacketHeader & partialPacketHeader, const Transport::PeerAddress & peerAddress, - System::PacketBufferHandle && msg); + System::PacketBufferHandle && msg, Transport::MessageTransportContext * ctxt = nullptr); void OnReceiveError(CHIP_ERROR error, const Transport::PeerAddress & source); diff --git a/src/transport/TransportMgr.h b/src/transport/TransportMgr.h index 494db39de964fd..6e5f03ab4a44a6 100644 --- a/src/transport/TransportMgr.h +++ b/src/transport/TransportMgr.h @@ -34,6 +34,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -49,8 +52,26 @@ class TransportMgrDelegate * * @param source the source address of the package * @param msgBuf the buffer containing a full CHIP message (except for the optional length field). + * @param ctxt the pointer to additional context on the underlying transport. For TCP, it is a pointer + * to the underlying connection object. */ - virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) = 0; + virtual void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * @brief + * Handle connection attempt completion. + * + * @param conn the connection object + * @param conErr the connection error on the attempt, or CHIP_NO_ERROR. + */ + virtual void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + + virtual void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; template diff --git a/src/transport/TransportMgrBase.cpp b/src/transport/TransportMgrBase.cpp index bde54a96c195b6..98997102eb92ca 100644 --- a/src/transport/TransportMgrBase.cpp +++ b/src/transport/TransportMgrBase.cpp @@ -28,11 +28,24 @@ CHIP_ERROR TransportMgrBase::SendMessage(const Transport::PeerAddress & address, return mTransport->SendMessage(address, std::move(msgBuf)); } -void TransportMgrBase::Disconnect(const Transport::PeerAddress & address) +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +CHIP_ERROR TransportMgrBase::TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) { - mTransport->Disconnect(address); + return mTransport->TCPConnect(address, appState, peerConnState); } +void TransportMgrBase::TCPDisconnect(const Transport::PeerAddress & address) +{ + mTransport->TCPDisconnect(address); +} + +void TransportMgrBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) +{ + mTransport->TCPDisconnect(conn, shouldAbort); +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) { if (mTransport != nullptr) @@ -41,6 +54,7 @@ CHIP_ERROR TransportMgrBase::Init(Transport::Base * transport) } mTransport = transport; mTransport->SetDelegate(this); + ChipLogDetail(Inet, "TransportMgr initialized"); return CHIP_NO_ERROR; } @@ -56,7 +70,8 @@ CHIP_ERROR TransportMgrBase::MulticastGroupJoinLeave(const Transport::PeerAddres return mTransport->MulticastGroupJoinLeave(address, join); } -void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) +void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt) { // This is the first point all incoming messages funnel through. Ensure // that our message receipts are all synchronized correctly. @@ -73,7 +88,7 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer if (mSessionManager != nullptr) { - mSessionManager->OnMessageReceived(peerAddress, std::move(msg)); + mSessionManager->OnMessageReceived(peerAddress, std::move(msg), ctxt); } else { @@ -83,4 +98,60 @@ void TransportMgrBase::HandleMessageReceived(const Transport::PeerAddress & peer } } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +void TransportMgrBase::HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionReceived(conn); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} + +void TransportMgrBase::HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionAttemptComplete(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + + // Close connection here since no upper layer is interested in the + // connection. + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} + +void TransportMgrBase::HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) +{ + if (mSessionManager != nullptr) + { + mSessionManager->HandleConnectionClosed(conn, conErr); + } + else + { + Transport::TCPBase * tcp = reinterpret_cast(conn->mEndPoint->mAppState); + if (tcp) + { + tcp->TCPDisconnect(conn, /* shouldAbort = */ true); + } + } +} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT + } // namespace chip diff --git a/src/transport/TransportMgrBase.h b/src/transport/TransportMgrBase.h index e4942ca6ecd90b..2b0f33bfb9e3a7 100644 --- a/src/transport/TransportMgrBase.h +++ b/src/transport/TransportMgrBase.h @@ -21,6 +21,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { @@ -39,13 +42,29 @@ class TransportMgrBase : public Transport::RawTransportDelegate void Close(); - void Disconnect(const Transport::PeerAddress & address); +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const Transport::PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState); + + void TCPDisconnect(const Transport::PeerAddress & address); + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0); + + void HandleConnectionReceived(Transport::ActiveTCPConnectionState * conn) override; + + void HandleConnectionAttemptComplete(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; + + void HandleConnectionClosed(Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void SetSessionManager(TransportMgrDelegate * sessionManager) { mSessionManager = sessionManager; } + TransportMgrDelegate * GetSessionManager() { return mSessionManager; }; + CHIP_ERROR MulticastGroupJoinLeave(const Transport::PeerAddress & address, bool join); - void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) override; + void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + Transport::MessageTransportContext * ctxt = nullptr) override; private: TransportMgrDelegate * mSessionManager = nullptr; diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 1c3ff7ed55586a..913058229e1a6b 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -45,9 +45,10 @@ class UnauthenticatedSession : public Session, public ReferenceCounted & sessionTable) : - UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, config) + UnauthenticatedSession(sessionRole, ephemeralInitiatorNodeID, peerAddress, config) #if CHIP_SYSTEM_CONFIG_POOL_USE_HEAP , mSessionTable(sessionTable) @@ -224,13 +225,16 @@ class UnauthenticatedSessionTable * @return the session found or allocated, or Optional::Missing if not found and allocation failed. */ CHECK_RETURN_VALUE - Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config) + Optional FindOrAllocateResponder(NodeId ephemeralInitiatorNodeID, const ReliableMessageProtocolConfig & config, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) return MakeOptional(*result); - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kResponder, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { return MakeOptional(*result); @@ -239,9 +243,11 @@ class UnauthenticatedSessionTable return Optional::Missing(); } - CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID) + CHECK_RETURN_VALUE Optional FindInitiator(NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { - UnauthenticatedSession * result = FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID); + UnauthenticatedSession * result = + FindEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress); if (result != nullptr) { return MakeOptional(*result); @@ -254,7 +260,8 @@ class UnauthenticatedSessionTable const ReliableMessageProtocolConfig & config) { UnauthenticatedSession * result = nullptr; - CHIP_ERROR err = AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, config, result); + CHIP_ERROR err = + AllocEntry(UnauthenticatedSession::SessionRole::kInitiator, ephemeralInitiatorNodeID, peerAddress, config, result); if (err == CHIP_NO_ERROR) { result->SetPeerAddress(peerAddress); @@ -276,9 +283,10 @@ class UnauthenticatedSessionTable */ CHECK_RETURN_VALUE CHIP_ERROR AllocEntry(UnauthenticatedSession::SessionRole sessionRole, NodeId ephemeralInitiatorNodeID, - const ReliableMessageProtocolConfig & config, UnauthenticatedSession *& entry) + const PeerAddress & peerAddress, const ReliableMessageProtocolConfig & config, + UnauthenticatedSession *& entry) { - auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + auto entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse != nullptr) { entry = entryToUse; @@ -294,7 +302,7 @@ class UnauthenticatedSessionTable // Drop the least recent entry to allow for a new alloc. mEntries.ReleaseObject(entryToUse); - entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, config, *this); + entryToUse = mEntries.CreateObject(sessionRole, ephemeralInitiatorNodeID, peerAddress, config, *this); if (entryToUse == nullptr) { @@ -308,11 +316,13 @@ class UnauthenticatedSessionTable } CHECK_RETURN_VALUE UnauthenticatedSession * FindEntry(UnauthenticatedSession::SessionRole sessionRole, - NodeId ephemeralInitiatorNodeID) + NodeId ephemeralInitiatorNodeID, + const Transport::PeerAddress & peerAddress) { UnauthenticatedSession * result = nullptr; mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { - if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID) + if (entry->GetSessionRole() == sessionRole && entry->GetEphemeralInitiatorNodeID() == ephemeralInitiatorNodeID && + entry->GetPeerAddress().GetTransportType() == peerAddress.GetTransportType()) { result = entry; return Loop::Break; diff --git a/src/transport/raw/ActiveTCPConnectionState.h b/src/transport/raw/ActiveTCPConnectionState.h new file mode 100644 index 00000000000000..0f53d4479e9300 --- /dev/null +++ b/src/transport/raw/ActiveTCPConnectionState.h @@ -0,0 +1,125 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP Active Connection object that maintains TCP connections. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +/** + * The State of the TCP connection + */ +enum class TCPState +{ + kNotReady = 0, /**< State before initialization. */ + kInitialized = 1, /**< State after class is listening and ready. */ + kConnecting = 3, /**< Connection with peer has been initiated. */ + kConnected = 4, /**< Connected with peer and ready for Send/Receive. */ + kClosed = 5, /**< Connection is closed. */ +}; + +struct AppTCPConnectionCallbackCtxt; +/** + * State for each active TCP connection + */ +struct ActiveTCPConnectionState +{ + + void Init(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddr) + { + mEndPoint = endPoint; + mPeerAddr = peerAddr; + mReceived = nullptr; + mAppState = nullptr; + } + + void Free() + { + mEndPoint->Free(); + mPeerAddr = PeerAddress::Uninitialized(); + mEndPoint = nullptr; + mReceived = nullptr; + mAppState = nullptr; + } + + bool InUse() const { return mEndPoint != nullptr; } + + bool IsConnected() const { return (mEndPoint != nullptr && mConnectionState == TCPState::kConnected); } + + bool IsConnecting() const { return (mEndPoint != nullptr && mConnectionState == TCPState::kConnecting); } + + // Associated endpoint. + Inet::TCPEndPoint * mEndPoint; + + // Peer Node Address + PeerAddress mPeerAddr; + + // Buffers received but not yet consumed. + System::PacketBufferHandle mReceived; + + // Current state of the connection + TCPState mConnectionState; + + // A pointer to an application-specific state object. It should + // represent an object that is at a layer above the SessionManager. The + // SessionManager would accept this object at the time of connecting to + // the peer, and percolate it down to the TransportManager that then, + // should store this state in the corresponding connection object that + // is created. + // At various connection events, this state is passed back to the + // corresponding application. + AppTCPConnectionCallbackCtxt * mAppState = nullptr; + + // KeepAlive interval in seconds + uint16_t mTCPKeepAliveIntervalSecs = CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS; + uint16_t mTCPMaxNumKeepAliveProbes = CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES; +}; + +// Functors for callbacks into higher layers +using OnTCPConnectionReceivedCallback = void (*)(ActiveTCPConnectionState * conn); + +using OnTCPConnectionCompleteCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +using OnTCPConnectionClosedCallback = void (*)(ActiveTCPConnectionState * conn, CHIP_ERROR conErr); + +/* + * Application callback state that is passed down at connection establishment + * stage. + * */ +struct AppTCPConnectionCallbackCtxt +{ + void * appContext = nullptr; // A pointer to an application context object. + OnTCPConnectionReceivedCallback connReceivedCb = nullptr; + OnTCPConnectionCompleteCallback connCompleteCb = nullptr; + OnTCPConnectionClosedCallback connClosedCb = nullptr; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/raw/BUILD.gn b/src/transport/raw/BUILD.gn index 736b16cdb08477..3e8d3d4761dca3 100644 --- a/src/transport/raw/BUILD.gn +++ b/src/transport/raw/BUILD.gn @@ -14,6 +14,7 @@ import("//build_overrides/chip.gni") import("${chip_root}/src/ble/ble.gni") +import("${chip_root}/src/inet/inet.gni") static_library("raw") { output_name = "libRawTransport" @@ -23,13 +24,20 @@ static_library("raw") { "MessageHeader.cpp", "MessageHeader.h", "PeerAddress.h", - "TCP.cpp", - "TCP.h", "Tuple.h", "UDP.cpp", "UDP.h", ] + if (chip_inet_config_enable_tcp_endpoint) { + sources += [ + "ActiveTCPConnectionState.h", + "TCP.cpp", + "TCP.h", + "TCPConfig.h", + ] + } + if (chip_config_network_layer_ble) { sources += [ "BLE.cpp", diff --git a/src/transport/raw/Base.h b/src/transport/raw/Base.h index 66c01a5fcc3d5d..920f932b82be41 100644 --- a/src/transport/raw/Base.h +++ b/src/transport/raw/Base.h @@ -24,20 +24,38 @@ #pragma once #include +#include #include #include #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT +#include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT namespace chip { namespace Transport { +struct MessageTransportContext +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * conn = nullptr; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT +}; + class RawTransportDelegate { public: virtual ~RawTransportDelegate() {} - virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) = 0; + virtual void HandleMessageReceived(const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg, + MessageTransportContext * ctxt = nullptr) = 0; + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + virtual void HandleConnectionReceived(ActiveTCPConnectionState * conn){}; + virtual void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; + virtual void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr){}; +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT }; /** @@ -77,10 +95,26 @@ class Base */ virtual bool CanListenMulticast() { return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Connect to the specified peer. + */ + virtual CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_NO_ERROR; + } + /** * Handle disconnection from the specified peer if currently connected to it. */ - virtual void Disconnect(const PeerAddress & address) {} + virtual void TCPDisconnect(const PeerAddress & address) {} + + /** + * Disconnect on the active connection that is passed in. + */ + virtual void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Enable Listening for multicast messages ( IPV6 UDP only) @@ -97,12 +131,31 @@ class Base * Method used by subclasses to notify that a packet has been received after * any associated headers have been decoded. */ - void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer) + void HandleMessageReceived(const PeerAddress & source, System::PacketBufferHandle && buffer, + MessageTransportContext * ctxt = nullptr) + { + mDelegate->HandleMessageReceived(source, std::move(buffer), ctxt); + } + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + // Handle an incoming connection request from a peer. + void HandleConnectionReceived(ActiveTCPConnectionState * conn) { mDelegate->HandleConnectionReceived(conn); } + + // Callback during connection establishment to notify of success or any + // error. + void HandleConnectionAttemptComplete(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) + { + mDelegate->HandleConnectionAttemptComplete(conn, conErr); + } + + // Callback to notify the higher layer of an unexpected connection closure. + void HandleConnectionClosed(ActiveTCPConnectionState * conn, CHIP_ERROR conErr) { - mDelegate->HandleMessageReceived(source, std::move(buffer)); + mDelegate->HandleConnectionClosed(conn, conErr); } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT - RawTransportDelegate * mDelegate; + RawTransportDelegate * mDelegate = nullptr; }; } // namespace Transport diff --git a/src/transport/raw/TCP.cpp b/src/transport/raw/TCP.cpp index a1b1df1fe45a2d..e33590a63a6fdf 100644 --- a/src/transport/raw/TCP.cpp +++ b/src/transport/raw/TCP.cpp @@ -67,8 +67,7 @@ void TCPBase::CloseActiveConnections() { if (mActiveConnections[i].InUse()) { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } @@ -77,7 +76,7 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) { CHIP_ERROR err = CHIP_NO_ERROR; - VerifyOrExit(mState == State::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(mState == TCPState::kNotReady, err = CHIP_ERROR_INCORRECT_STATE); #if INET_CONFIG_ENABLE_TCP_ENDPOINT err = params.GetEndPointManager()->NewEndPoint(&mListenSocket); @@ -90,23 +89,21 @@ CHIP_ERROR TCPBase::Init(TcpListenParameters & params) params.GetInterfaceId().IsPresent()); SuccessOrExit(err); + mListenSocket->mAppState = reinterpret_cast(this); + mListenSocket->OnConnectionReceived = HandleIncomingConnection; + mListenSocket->OnAcceptError = HandleAcceptError; + + mEndpointType = params.GetAddressType(); + err = mListenSocket->Listen(kListenBacklogSize); SuccessOrExit(err); - mListenSocket->mAppState = reinterpret_cast(this); - mListenSocket->OnDataReceived = OnTcpReceive; - mListenSocket->OnConnectComplete = OnConnectionComplete; - mListenSocket->OnConnectionClosed = OnConnectionClosed; - mListenSocket->OnConnectionReceived = OnConnectionReceived; - mListenSocket->OnAcceptError = OnAcceptError; - mEndpointType = params.GetAddressType(); - - mState = State::kInitialized; + mState = TCPState::kInitialized; exit: if (err != CHIP_NO_ERROR) { - ChipLogError(Inet, "Failed to initialize TCP transport: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to initialize TCP transport: %" CHIP_ERROR_FORMAT, err.Format()); if (mListenSocket) { mListenSocket->Free(); @@ -124,10 +121,24 @@ void TCPBase::Close() mListenSocket->Free(); mListenSocket = nullptr; } - mState = State::kNotReady; + mState = TCPState::kNotReady; +} + +ActiveTCPConnectionState * TCPBase::AllocateConnection() +{ + for (size_t i = 0; i < mActiveConnectionsSize; i++) + { + if (!mActiveConnections[i].InUse()) + { + return &mActiveConnections[i]; + } + } + + return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) +// Find an ActiveTCPConnectionState corresponding to a peer address +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const PeerAddress & address) { if (address.GetTransportType() != Type::kTcp) { @@ -136,7 +147,7 @@ TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (!mActiveConnections[i].InUse()) + if (!mActiveConnections[i].IsConnected()) { continue; } @@ -153,8 +164,26 @@ TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const PeerAddress return nullptr; } -TCPBase::ActiveConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) +// Find the ActiveTCPConnectionState for a given TCPEndPoint +ActiveTCPConnectionState * TCPBase::FindActiveConnection(const Inet::TCPEndPoint * endPoint) { + for (size_t i = 0; i < mActiveConnectionsSize; i++) + { + if (mActiveConnections[i].mEndPoint == endPoint && mActiveConnections[i].IsConnected()) + { + return &mActiveConnections[i]; + } + } + return nullptr; +} + +ActiveTCPConnectionState * TCPBase::FindInUseConnection(const Inet::TCPEndPoint * endPoint) +{ + if (endPoint == nullptr) + { + return nullptr; + } + for (size_t i = 0; i < mActiveConnectionsSize; i++) { if (mActiveConnections[i].mEndPoint == endPoint) @@ -172,7 +201,7 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // - actual data VerifyOrReturnError(address.GetTransportType() == Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(mState == State::kInitialized, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(kPacketSizeBytes + msgBuf->DataLength() <= std::numeric_limits::max(), CHIP_ERROR_INVALID_ARGUMENT); @@ -186,7 +215,7 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: // Reuse existing connection if one exists, otherwise a new one // will be established - ActiveConnectionState * connection = FindActiveConnection(address); + ActiveTCPConnectionState * connection = FindActiveConnection(address); if (connection != nullptr) { @@ -196,8 +225,46 @@ CHIP_ERROR TCPBase::SendMessage(const Transport::PeerAddress & address, System:: return SendAfterConnect(address, std::move(msgBuf)); } +CHIP_ERROR TCPBase::StartConnect(const PeerAddress & addr, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) +{ +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::TCPEndPoint * endPoint = nullptr; + *outPeerConnState = nullptr; + ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); + + auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; + std::unique_ptr endPointHolder(endPoint, EndPointDeletor); + + endPoint->mAppState = reinterpret_cast(this); + endPoint->OnConnectComplete = HandleTCPEndPointConnectComplete; + endPoint->SetConnectTimeout(mConnectTimeout); + + activeConnection = AllocateConnection(); + VerifyOrReturnError(activeConnection != nullptr, CHIP_ERROR_NO_MEMORY); + activeConnection->Init(endPoint, addr); + activeConnection->mAppState = appState; + activeConnection->mConnectionState = TCPState::kConnecting; + // Set the return value of the peer connection state to the allocated + // connection. + *outPeerConnState = activeConnection; + + ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + + mUsedEndPointCount++; + + endPointHolder.release(); + + return CHIP_NO_ERROR; +#else + return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; +#endif +} + CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBufferHandle && msg) { +#if INET_CONFIG_ENABLE_TCP_ENDPOINT // This will initiate a connection to the specified peer bool alreadyConnecting = false; @@ -224,28 +291,13 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf // Ensures sufficient active connections size exist VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); -#if INET_CONFIG_ENABLE_TCP_ENDPOINT - Inet::TCPEndPoint * endPoint = nullptr; - ReturnErrorOnFailure(mListenSocket->GetEndPointManager().NewEndPoint(&endPoint)); - auto EndPointDeletor = [](Inet::TCPEndPoint * e) { e->Free(); }; - std::unique_ptr endPointHolder(endPoint, EndPointDeletor); - - endPoint->mAppState = reinterpret_cast(this); - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; - - ReturnErrorOnFailure(endPoint->Connect(addr.GetIPAddress(), addr.GetPort(), addr.GetInterface())); + Transport::ActiveTCPConnectionState * peerConnState = nullptr; + ReturnErrorOnFailure(StartConnect(addr, nullptr, &peerConnState)); // enqueue the packet once the connection succeeds VerifyOrReturnError(mPendingPackets.CreateObject(addr, std::move(msg)) != nullptr, CHIP_ERROR_NO_MEMORY); mUsedEndPointCount++; - endPointHolder.release(); - return CHIP_NO_ERROR; #else return CHIP_ERROR_UNSUPPORTED_CHIP_FEATURE; @@ -255,7 +307,7 @@ CHIP_ERROR TCPBase::SendAfterConnect(const PeerAddress & addr, System::PacketBuf CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const PeerAddress & peerAddress, System::PacketBufferHandle && buffer) { - ActiveConnectionState * state = FindActiveConnection(endPoint); + ActiveTCPConnectionState * state = FindActiveConnection(endPoint); VerifyOrReturnError(state != nullptr, CHIP_ERROR_INTERNAL); state->mReceived.AddToEnd(std::move(buffer)); @@ -275,6 +327,7 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe uint16_t messageSize = LittleEndian::Get16(messageSizeBuf); if (messageSize >= kMaxMessageSize) { + // This message is too long for upper layers. return CHIP_ERROR_MESSAGE_TOO_LONG; } @@ -291,12 +344,15 @@ CHIP_ERROR TCPBase::ProcessReceivedBuffer(Inet::TCPEndPoint * endPoint, const Pe return CHIP_NO_ERROR; } -CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize) +CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize) { // We enter with `state->mReceived` containing at least one full message, perhaps in a chain. // `state->mReceived->Start()` currently points to the message data. // On exit, `state->mReceived` will have had `messageSize` bytes consumed, no matter what. System::PacketBufferHandle message; + MessageTransportContext msgContext; + msgContext.conn = state; + if (state->mReceived->DataLength() == messageSize) { // In this case, the head packet buffer contains exactly the message. @@ -321,23 +377,53 @@ CHIP_ERROR TCPBase::ProcessSingleMessage(const PeerAddress & peerAddress, Active message->SetDataLength(messageSize); } - HandleMessageReceived(peerAddress, std::move(message)); + HandleMessageReceived(peerAddress, std::move(message), &msgContext); return CHIP_NO_ERROR; } -void TCPBase::ReleaseActiveConnection(Inet::TCPEndPoint * endPoint) +void TCPBase::CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback) { - for (size_t i = 0; i < mActiveConnectionsSize; i++) + TCPState prevState; + + if (connection == nullptr) { - if (mActiveConnections[i].mEndPoint == endPoint) + return; + } + + if (connection->mConnectionState != TCPState::kClosed && connection->mEndPoint) + { + if (err == CHIP_NO_ERROR) { - mActiveConnections[i].Free(); - mUsedEndPointCount--; + connection->mEndPoint->Close(); } + else + { + connection->mEndPoint->Abort(); + } + + prevState = connection->mConnectionState; + connection->mConnectionState = TCPState::kClosed; + + if (suppressCallback == SuppressCallback::No) + { + if (prevState == TCPState::kConnecting) + { + // Call upper layer connection attempt complete handler + HandleConnectionAttemptComplete(connection, err); + } + else + { + // Call upper layer connection closed handler + HandleConnectionClosed(connection, err); + } + } + + connection->Free(); + mUsedEndPointCount--; } } -CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) +CHIP_ERROR TCPBase::HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer) { Inet::IPAddress ipAddress; uint16_t port; @@ -353,13 +439,13 @@ CHIP_ERROR TCPBase::OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBuf if (err != CHIP_NO_ERROR) { // Connection could need to be closed at this point - ChipLogError(Inet, "Failed to accept received TCP message: %s", ErrorStr(err)); + ChipLogError(Inet, "Failed to accept received TCP message: %" CHIP_ERROR_FORMAT, err.Format()); return CHIP_ERROR_UNEXPECTED_EVENT; } return CHIP_NO_ERROR; } -void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inetErr) +void TCPBase::HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR conErr) { CHIP_ERROR err = CHIP_NO_ERROR; bool foundPendingPacket = false; @@ -367,157 +453,229 @@ void TCPBase::OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR inet Inet::IPAddress ipAddress; uint16_t port; Inet::InterfaceId interfaceId; + ActiveTCPConnectionState * activeConnection = nullptr; endPoint->GetPeerInfo(&ipAddress, &port); endPoint->GetInterfaceId(&interfaceId); + char addrStr[Transport::PeerAddress::kMaxToStringSize]; PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); + addr.ToString(addrStr); - // Send any pending packets - tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { - if (pending->mPeerAddress == addr) + if (conErr == CHIP_NO_ERROR) + { + // Set the Data received handler when connection completes + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; + + activeConnection = tcp->FindInUseConnection(endPoint); + VerifyOrDie(activeConnection != nullptr); + + // Set to Connected state + activeConnection->mConnectionState = TCPState::kConnected; + + // Disable TCP Nagle buffering by setting TCP_NODELAY socket option to true. + // This is to expedite transmission of payload data and not rely on the + // network stack's configuration of collating enough data in the TCP + // window to begin transmission. + err = endPoint->EnableNoDelay(); + if (err != CHIP_NO_ERROR) { - foundPendingPacket = true; - System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); - tcp->mPendingPackets.ReleaseObject(pending); + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); + return; + } - if ((inetErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + // Send any pending packets that are queued for this connection + tcp->mPendingPackets.ForEachActiveObject([&](PendingPacket * pending) { + if (pending->mPeerAddress == addr) { - err = endPoint->Send(std::move(buffer)); + foundPendingPacket = true; + System::PacketBufferHandle buffer = std::move(pending->mPacketBuffer); + tcp->mPendingPackets.ReleaseObject(pending); + + if ((conErr == CHIP_NO_ERROR) && (err == CHIP_NO_ERROR)) + { + err = endPoint->Send(std::move(buffer)); + } } - } - return Loop::Continue; - }); + return Loop::Continue; + }); - if (err == CHIP_NO_ERROR) - { - err = inetErr; - } + // Set the TCPKeepalive configurations on the established connection + endPoint->EnableKeepAlive(activeConnection->mTCPKeepAliveIntervalSecs, activeConnection->mTCPMaxNumKeepAliveProbes); - if (!foundPendingPacket && (err == CHIP_NO_ERROR)) - { - // Force a close: new connections are only expected when a - // new buffer is being sent. - ChipLogError(Inet, "Connection accepted without pending buffers"); - err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; - } + ChipLogProgress(Inet, "Connection established successfully with %s.", addrStr); - // cleanup packets or mark as free - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, "Connection complete encountered an error: %s", ErrorStr(err)); - endPoint->Free(); - tcp->mUsedEndPointCount--; + // Let higher layer/delegate know that connection is successfully + // established + tcp->HandleConnectionAttemptComplete(activeConnection, CHIP_NO_ERROR); } else { - bool connectionStored = false; - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - connectionStored = true; - break; - } - } - - // since we track end points counts, we always expect to store the - // connection. - if (!connectionStored) - { - endPoint->Free(); - ChipLogError(Inet, "Internal logic error: insufficient space to store active connection"); - } + ChipLogError(Inet, "Connection establishment with %s encountered an error: %" CHIP_ERROR_FORMAT, addrStr, err.Format()); + endPoint->Free(); + tcp->mUsedEndPointCount--; } } -void TCPBase::OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + TCPBase * tcp = reinterpret_cast(endPoint->mAppState); + ActiveTCPConnectionState * activeConnection = tcp->FindInUseConnection(endPoint); - ChipLogProgress(Inet, "Connection closed."); + if (activeConnection == nullptr) + { + endPoint->Free(); + return; + } - ChipLogProgress(Inet, "Freeing closed connection."); - tcp->ReleaseActiveConnection(endPoint); + if (err == CHIP_NO_ERROR && activeConnection->IsConnected()) + { + err = CHIP_ERROR_CONNECTION_CLOSED_UNEXPECTEDLY; + } + + tcp->CloseConnectionInternal(activeConnection, err, SuppressCallback::No); } -void TCPBase::OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort) +// Handler for incoming connection requests from peer nodes +void TCPBase::HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort) { - TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + TCPBase * tcp = reinterpret_cast(listenEndPoint->mAppState); + ActiveTCPConnectionState * activeConnection = nullptr; + Inet::InterfaceId interfaceId; + Inet::IPAddress ipAddress; + uint16_t port; + + endPoint->GetPeerInfo(&ipAddress, &port); + endPoint->GetInterfaceId(&interfaceId); + PeerAddress addr = PeerAddress::TCP(ipAddress, port, interfaceId); if (tcp->mUsedEndPointCount < tcp->mActiveConnectionsSize) { - // have space to use one more (even if considering pending connections) - for (size_t i = 0; i < tcp->mActiveConnectionsSize; i++) - { - if (!tcp->mActiveConnections[i].InUse()) - { - tcp->mActiveConnections[i].Init(endPoint); - tcp->mUsedEndPointCount++; - break; - } - } + activeConnection = tcp->AllocateConnection(); + + endPoint->mAppState = listenEndPoint->mAppState; + endPoint->OnDataReceived = HandleTCPEndPointDataReceived; + endPoint->OnDataSent = nullptr; + endPoint->OnConnectionClosed = HandleTCPEndPointConnectionClosed; + + // By default, disable TCP Nagle buffering by setting TCP_NODELAY socket option to true + endPoint->EnableNoDelay(); - endPoint->mAppState = listenEndPoint->mAppState; - endPoint->OnDataReceived = OnTcpReceive; - endPoint->OnConnectComplete = OnConnectionComplete; - endPoint->OnConnectionClosed = OnConnectionClosed; - endPoint->OnConnectionReceived = OnConnectionReceived; - endPoint->OnAcceptError = OnAcceptError; - endPoint->OnPeerClose = OnPeerClosed; + // Update state for the active connection + activeConnection->Init(endPoint, addr); + tcp->mUsedEndPointCount++; + activeConnection->mConnectionState = TCPState::kConnected; + + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + peerAddress.ToString(addrStr); + ChipLogProgress(Inet, "Incoming connection established with peer at %s.", addrStr); + + // Call the upper layer handler for incoming connection received. + tcp->HandleConnectionReceived(activeConnection); } else { - ChipLogError(Inet, "Insufficient connection space to accept new connections"); + ChipLogError(Inet, "Insufficient connection space to accept new connections."); endPoint->Free(); + listenEndPoint->OnAcceptError(endPoint, CHIP_ERROR_TOO_MANY_CONNECTIONS); } } -void TCPBase::OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +void TCPBase::HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err) +{ + endPoint->Free(); + ChipLogError(Inet, "Accept error: %" CHIP_ERROR_FORMAT, err.Format()); +} + +CHIP_ERROR TCPBase::TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) { - ChipLogError(Inet, "Accept error: %s", ErrorStr(err)); + VerifyOrReturnError(mState == TCPState::kInitialized, CHIP_ERROR_INCORRECT_STATE); + + // Verify that PeerAddress AddressType is TCP + VerifyOrReturnError(address.GetTransportType() == Transport::Type::kTcp, CHIP_ERROR_INVALID_ARGUMENT); + + VerifyOrReturnError(mUsedEndPointCount < mActiveConnectionsSize, CHIP_ERROR_NO_MEMORY); + + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + address.ToString(addrStr); + ChipLogProgress(Inet, "Connecting to peer %s.", addrStr); + + ReturnErrorOnFailure(StartConnect(address, appState, outPeerConnState)); + + return CHIP_NO_ERROR; } -void TCPBase::Disconnect(const PeerAddress & address) +void TCPBase::TCPDisconnect(const PeerAddress & address) { + CHIP_ERROR err = CHIP_NO_ERROR; // Closes an existing connection for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (mActiveConnections[i].InUse()) + if (mActiveConnections[i].IsConnected()) { Inet::IPAddress ipAddress; uint16_t port; Inet::InterfaceId interfaceId; - mActiveConnections[i].mEndPoint->GetPeerInfo(&ipAddress, &port); - mActiveConnections[i].mEndPoint->GetInterfaceId(&interfaceId); - if (address == PeerAddress::TCP(ipAddress, port, interfaceId)) + err = mActiveConnections[i].mEndPoint->GetPeerInfo(&ipAddress, &port); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "TCPDisconnect: GetPeerInfo error: %" CHIP_ERROR_FORMAT, err.Format()); + return; + } + + err = mActiveConnections[i].mEndPoint->GetInterfaceId(&interfaceId); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "TCPDisconnect: GetInterfaceId error: %" CHIP_ERROR_FORMAT, err.Format()); + return; + } + // if (address == PeerAddress::TCP(ipAddress, port, interfaceId)) + if (ipAddress == address.GetIPAddress() && port == address.GetPort()) { + char addrStr[Transport::PeerAddress::kMaxToStringSize]; + address.ToString(addrStr); + ChipLogProgress(Inet, "Disconnecting with peer %s.", addrStr); + // NOTE: this leaves the socket in TIME_WAIT. // Calling Abort() would clean it since SO_LINGER would be set to 0, // however this seems not to be useful. - mActiveConnections[i].Free(); - mUsedEndPointCount--; + CloseConnectionInternal(&mActiveConnections[i], CHIP_NO_ERROR, SuppressCallback::Yes); } } } } -void TCPBase::OnPeerClosed(Inet::TCPEndPoint * endPoint) +void TCPBase::TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort) { - TCPBase * tcp = reinterpret_cast(endPoint->mAppState); - ChipLogProgress(Inet, "Freeing connection: connection closed by peer"); + if (conn == nullptr) + { + ChipLogError(Inet, "Failed to Disconnect. Passed in Connection is null."); + return; + } - tcp->ReleaseActiveConnection(endPoint); + // This call should be able to disconnect the connection either when it is + // already established, or when it is being set up. + if ((conn->IsConnected() && shouldAbort) || conn->IsConnecting()) + { + CloseConnectionInternal(conn, CHIP_ERROR_CONNECTION_ABORTED, SuppressCallback::Yes); + } + + if (conn->IsConnected() && !shouldAbort) + { + CloseConnectionInternal(conn, CHIP_NO_ERROR, SuppressCallback::Yes); + } } bool TCPBase::HasActiveConnections() const { for (size_t i = 0; i < mActiveConnectionsSize; i++) { - if (mActiveConnections[i].InUse()) + if (mActiveConnections[i].IsConnected()) { return true; } diff --git a/src/transport/raw/TCP.h b/src/transport/raw/TCP.h index d9f78be1771b0f..bb4671215b96c8 100644 --- a/src/transport/raw/TCP.h +++ b/src/transport/raw/TCP.h @@ -34,7 +34,9 @@ #include #include #include +#include #include +#include namespace chip { namespace Transport { @@ -96,45 +98,23 @@ struct PendingPacket /** Implements a transport using TCP. */ class DLL_EXPORT TCPBase : public Base { - /** - * The State of the TCP connection - */ - enum class State - { - kNotReady = 0, /**< State before initialization. */ - kInitialized = 1, /**< State after class is listening and ready. */ - }; protected: - /** - * State for each active connection - */ - struct ActiveConnectionState + enum class ShouldAbort : uint8_t { - void Init(Inet::TCPEndPoint * endPoint) - { - mEndPoint = endPoint; - mReceived = nullptr; - } - - void Free() - { - mEndPoint->Free(); - mEndPoint = nullptr; - mReceived = nullptr; - } - bool InUse() const { return mEndPoint != nullptr; } - - // Associated endpoint. - Inet::TCPEndPoint * mEndPoint; + Yes, + No + }; - // Buffers received but not yet consumed. - System::PacketBufferHandle mReceived; + enum class SuppressCallback : uint8_t + { + Yes, + No }; public: using PendingPacketPoolType = PoolInterface; - TCPBase(ActiveConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : + TCPBase(ActiveTCPConnectionState * activeConnectionsBuffer, size_t bufferSize, PendingPacketPoolType & packetBuffers) : mActiveConnections(activeConnectionsBuffer), mActiveConnectionsSize(bufferSize), mPendingPackets(packetBuffers) { // activeConnectionsBuffer must be initialized by the caller. @@ -153,6 +133,13 @@ class DLL_EXPORT TCPBase : public Base */ CHIP_ERROR Init(TcpListenParameters & params); + /** + * Set the timeout (in milliseconds) for the node to wait for the TCP + * connection attempt to complete. + * + */ + void SetConnectTimeout(const uint32_t connTimeoutMsecs) { mConnectTimeout = connTimeoutMsecs; } + /** * Close the open endpoint without destroying the object */ @@ -160,14 +147,46 @@ class DLL_EXPORT TCPBase : public Base CHIP_ERROR SendMessage(const PeerAddress & address, System::PacketBufferHandle && msgBuf) override; - void Disconnect(const PeerAddress & address) override; + /* + * Connect to the given peerAddress over TCP. + * + * @param address The address of the peer. + * + * @param appState Context passed in by the application to be sent back + * via the connection attempt complete callback when + * connection attempt with peer completes. + * + * @param outPeerConnState Pointer to pointer to the active TCP connection state. This is + * an output parameter that is allocated by the + * transport layer and held by the caller object. + * This allows the caller object to abort the + * connection attempt if the caller object dies + * before the attempt completes. + * + */ + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState) override; + + void TCPDisconnect(const PeerAddress & address) override; + + // Close an active connection (corresponding to the passed + // ActiveTCPConnectionState object) + // and release from the pool. + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = false) override; bool CanSendToPeer(const PeerAddress & address) override { - return (mState == State::kInitialized) && (address.GetTransportType() == Type::kTcp) && + return (mState == TCPState::kInitialized) && (address.GetTransportType() == Type::kTcp) && (address.GetIPAddress().Type() == mEndpointType); } + const Optional GetConnectionPeerAddress(const Inet::TCPEndPoint * con) + { + ActiveTCPConnectionState * activeConState = FindActiveConnection(con); + + return activeConState != nullptr ? MakeOptional(activeConState->mPeerAddr) : Optional::Missing(); + } + /** * Helper method to determine if IO processing is still required for a TCP transport * before everything is cleaned up (socket closing is async, so after calling 'Close' on @@ -183,12 +202,22 @@ class DLL_EXPORT TCPBase : public Base private: friend class TCPTest; + /** + * Allocate an unused connection from the pool + * + */ + ActiveTCPConnectionState * AllocateConnection(); /** * Find an active connection to the given peer or return nullptr if * no active connection exists. */ - ActiveConnectionState * FindActiveConnection(const PeerAddress & addr); - ActiveConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); + ActiveTCPConnectionState * FindActiveConnection(const PeerAddress & addr); + ActiveTCPConnectionState * FindActiveConnection(const Inet::TCPEndPoint * endPoint); + + /** + * Find an allocated connection that matches the corresponding TCPEndPoint. + */ + ActiveTCPConnectionState * FindInUseConnection(const Inet::TCPEndPoint * endPoint); /** * Sends the specified message once a connection has been established. @@ -223,46 +252,63 @@ class DLL_EXPORT TCPBase : public Base * is no other data). * @param[in] messageSize Size of the single message. */ - CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveConnectionState * state, uint16_t messageSize); + CHIP_ERROR ProcessSingleMessage(const PeerAddress & peerAddress, ActiveTCPConnectionState * state, uint16_t messageSize); - // Release an active connection (corresponding to the passed TCPEndPoint) - // from the pool. - void ReleaseActiveConnection(Inet::TCPEndPoint * endPoint); + /** + * Initiate a connection to the given peer. On connection completion, + * HandleTCPConnectComplete callback would be called. + * + */ + CHIP_ERROR StartConnect(const PeerAddress & addr, AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** outPeerConnState); + + /** + * Gracefully Close or Abort a given connection. + * + */ + void CloseConnectionInternal(ActiveTCPConnectionState * connection, CHIP_ERROR err, SuppressCallback suppressCallback); + + // Close the listening socket endpoint + void CloseListeningSocket(); // Callback handler for TCPEndPoint. TCP message receive handler. // @see TCPEndpoint::OnDataReceivedFunct - static CHIP_ERROR OnTcpReceive(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); + static CHIP_ERROR HandleTCPEndPointDataReceived(Inet::TCPEndPoint * endPoint, System::PacketBufferHandle && buffer); // Callback handler for TCPEndPoint. Called when a connection has been completed. // @see TCPEndpoint::OnConnectCompleteFunct - static void OnConnectionComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleTCPEndPointConnectComplete(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection has been closed. // @see TCPEndpoint::OnConnectionClosedFunct - static void OnConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); - - // Callback handler for TCPEndPoint. Callend when a peer closes the connection. - // @see TCPEndpoint::OnPeerCloseFunct - static void OnPeerClosed(Inet::TCPEndPoint * endPoint); + static void HandleTCPEndPointConnectionClosed(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); // Callback handler for TCPEndPoint. Called when a connection is received on the listening port. // @see TCPEndpoint::OnConnectionReceivedFunct - static void OnConnectionReceived(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, - const Inet::IPAddress & peerAddress, uint16_t peerPort); + static void HandleIncomingConnection(Inet::TCPEndPoint * listenEndPoint, Inet::TCPEndPoint * endPoint, + const Inet::IPAddress & peerAddress, uint16_t peerPort); - // Called on accept error + // Callback handler for handling accept error // @see TCPEndpoint::OnAcceptErrorFunct - static void OnAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); + static void HandleAcceptError(Inet::TCPEndPoint * endPoint, CHIP_ERROR err); Inet::TCPEndPoint * mListenSocket = nullptr; ///< TCP socket used by the transport Inet::IPAddressType mEndpointType = Inet::IPAddressType::kUnknown; ///< Socket listening type - State mState = State::kNotReady; ///< State of the TCP transport + TCPState mState = TCPState::kNotReady; ///< State of the TCP transport + + // The configured timeout for the connection attempt to the peer, before + // giving up. + uint32_t mConnectTimeout = CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS; + + // The max payload size of data over a TCP connection that is transmissible + // at a time. + uint32_t mMaxTCPPayloadSize = CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES; // Number of active and 'pending connection' endpoints size_t mUsedEndPointCount = 0; // Currently active connections - ActiveConnectionState * mActiveConnections; + ActiveTCPConnectionState * mActiveConnections; const size_t mActiveConnectionsSize; // Data to be sent when connections succeed @@ -277,14 +323,15 @@ class TCP : public TCPBase { for (size_t i = 0; i < kActiveConnectionsSize; ++i) { - mConnectionsBuffer[i].Init(nullptr); + mConnectionsBuffer[i].Init(nullptr, PeerAddress::Uninitialized()); } } + ~TCP() override { mPendingPackets.ReleaseAll(); } private: friend class TCPTest; - TCPBase::ActiveConnectionState mConnectionsBuffer[kActiveConnectionsSize]; + ActiveTCPConnectionState mConnectionsBuffer[kActiveConnectionsSize]; PoolImpl mPendingPackets; }; diff --git a/src/transport/raw/TCPConfig.h b/src/transport/raw/TCPConfig.h new file mode 100644 index 00000000000000..d54a9466b4d294 --- /dev/null +++ b/src/transport/raw/TCPConfig.h @@ -0,0 +1,127 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines default compile-time configuration constants + * for CHIP. + * + * Package integrators that wish to override these values should + * either use preprocessor definitions or create a project- + * specific chipProjectConfig.h header and then assert + * HAVE_CHIPPROJECTCONFIG_H via the package configuration tool + * via --with-chip-project-includes=DIR where DIR is the + * directory that contains the header. + * + * + */ + +#pragma once + +#include + +namespace chip { + +/** + * @def CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS + * + * @brief Maximum Number of TCP connections a device can simultaneously have + */ +#ifndef CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS +#define CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS 4 +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS < 1 +#error "If TCP is enabled, the device needs to support at least 1 TCP connection" +#endif + +#if INET_CONFIG_ENABLE_TCP_ENDPOINT && CHIP_CONFIG_MAX_ACTIVE_TCP_CONNECTIONS > INET_CONFIG_NUM_TCP_ENDPOINTS +#error "If TCP is enabled, the maximum number of connections cannot exceed the number of tcp endpoints" +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PENDING_PACKETS + * + * @brief Maximum Number of outstanding pending packets in the queue before a TCP connection + * needs to be established + */ +#ifndef CHIP_CONFIG_MAX_TCP_PENDING_PACKETS +#define CHIP_CONFIG_MAX_TCP_PENDING_PACKETS 4 +#endif + +/** + * @def CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES + * + * @brief Maximum payload size of a message over a TCP connection + */ +#ifndef CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES +#define CHIP_CONFIG_MAX_TCP_PAYLOAD_SIZE_BYTES 1000000 +#endif + +/** + * @def CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + * + * @brief + * This defines the default timeout for the TCP connect + * attempt to either succeed or notify the caller of an + * error. + * + */ +#ifndef CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS +#define CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS (10000) +#endif // CHIP_CONFIG_TCP_CONNECT_TIMEOUT_MSECS + +/** + * @def CHIP_CONFIG_KEEPALIVE_INTERVAL_SECS + * + * @brief + * This defines the default interval (in seconds) between + * keepalive probes for a TCP connection. + * This value also controls the time between last data + * packet sent and the transmission of the first keepalive + * probe. + * + */ +#ifndef CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS +#define CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS (25) +#endif // CHIP_CONFIG_TCP_KEEPALIVE_INTERVAL_SECS + +/** + * @def CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + * + * @brief + * This defines the default value for the maximum number of + * keepalive probes for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES +#define CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES (5) +#endif // CHIP_CONFIG_MAX_TCP_KEEPALIVE_PROBES + +/** + * @def CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + * + * @brief + * This defines the default value for the maximum timeout + * of unacknowledged data for a TCP connection. + * + */ +#ifndef CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS +#define CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS (30) +#endif // CHIP_CONFIG_MAX_UNACKED_DATA_TIMEOUT_SECS + +} // namespace chip diff --git a/src/transport/raw/Tuple.h b/src/transport/raw/Tuple.h index 743b9b9b0e09aa..e3e52a171abcda 100644 --- a/src/transport/raw/Tuple.h +++ b/src/transport/raw/Tuple.h @@ -93,7 +93,20 @@ class Tuple : public Base bool CanSendToPeer(const PeerAddress & address) override { return CanSendToPeerImpl<0>(address); } - void Disconnect(const PeerAddress & address) override { return DisconnectImpl<0>(address); } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + CHIP_ERROR TCPConnect(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) override + { + return TCPConnectImpl<0>(address, appState, peerConnState); + } + + void TCPDisconnect(const PeerAddress & address) override { return TCPDisconnectImpl<0>(address); } + + void TCPDisconnect(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) override + { + return TCPDisconnectImpl<0>(conn, shouldAbort); + } +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT void Close() override { return CloseImpl<0>(); } @@ -138,26 +151,78 @@ class Tuple : public Base return false; } +#if INET_CONFIG_ENABLE_TCP_ENDPOINT + /** + * Recursive TCPConnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param address what address to connect to. + */ + template ::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + Base * base = &std::get(mTransports); + if (base->CanSendToPeer(address)) + { + return base->TCPConnect(address, appState, peerConnState); + } + return TCPConnectImpl(address, appState, peerConnState); + } + + /** + * TCPConnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + CHIP_ERROR TCPConnectImpl(const PeerAddress & address, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return CHIP_ERROR_NO_MESSAGE_HANDLER; + } + /** * Recursive disconnect implementation iterating through transport members. * * @tparam N the index of the underlying transport to send disconnect to * - * @param address what address to check. + * @param address what address to disconnect from. + */ + template ::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + { + std::get(mTransports).TCPDisconnect(address); + TCPDisconnectImpl(address); + } + + /** + * TCPDisconnectImpl template for out of range N. + */ + template = sizeof...(TransportTypes))>::type * = nullptr> + void TCPDisconnectImpl(const PeerAddress & address) + {} + + /** + * Recursive disconnect implementation iterating through transport members. + * + * @tparam N the index of the underlying transport to send disconnect to + * + * @param conn pointer to the connection to the peer. */ template ::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) { - std::get(mTransports).Disconnect(address); - DisconnectImpl(address); + std::get(mTransports).TCPDisconnect(conn, shouldAbort); + TCPDisconnectImpl(conn, shouldAbort); } /** - * DisconnectImpl template for out of range N. + * TCPDisconnectImpl template for out of range N. */ template = sizeof...(TransportTypes))>::type * = nullptr> - void DisconnectImpl(const PeerAddress & address) + void TCPDisconnectImpl(Transport::ActiveTCPConnectionState * conn, bool shouldAbort = 0) {} +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT /** * Recursive disconnect implementation iterating through transport members. diff --git a/src/transport/raw/tests/BUILD.gn b/src/transport/raw/tests/BUILD.gn index c655586c5a0e35..973626af7f340b 100644 --- a/src/transport/raw/tests/BUILD.gn +++ b/src/transport/raw/tests/BUILD.gn @@ -14,7 +14,9 @@ import("//build_overrides/build.gni") import("//build_overrides/chip.gni") +import("//build_overrides/nlunit_test.gni") import("//build_overrides/pigweed.gni") +import("${chip_root}/src/inet/inet.gni") import("${chip_root}/build/chip/chip_test_suite.gni") static_library("helpers") { @@ -40,10 +42,13 @@ chip_test_suite("tests") { test_sources = [ "TestMessageHeader.cpp", "TestPeerAddress.cpp", - "TestTCP.cpp", "TestUDP.cpp", ] + if (chip_inet_config_enable_tcp_endpoint) { + test_sources += [ "TestTCP.cpp" ] + } + public_deps = [ ":helpers", "${chip_root}/src/inet/tests:helpers", diff --git a/src/transport/raw/tests/TestTCP.cpp b/src/transport/raw/tests/TestTCP.cpp index 93414e01e3d653..6a60fb330c70f7 100644 --- a/src/transport/raw/tests/TestTCP.cpp +++ b/src/transport/raw/tests/TestTCP.cpp @@ -23,6 +23,7 @@ #include "NetworkTestHelpers.h" +#include #include #include #include @@ -30,7 +31,9 @@ #include #include #include +#if INET_CONFIG_ENABLE_TCP_ENDPOINT #include +#endif // INET_CONFIG_ENABLE_TCP_ENDPOINT #include @@ -47,6 +50,9 @@ namespace { constexpr size_t kMaxTcpActiveConnectionCount = 4; constexpr size_t kMaxTcpPendingPackets = 4; constexpr uint16_t kPacketSizeBytes = static_cast(sizeof(uint16_t)); +uint16_t gChipTCPPort = static_cast(CHIP_PORT + chip::Crypto::GetRandU16() % 100); +chip::Transport::AppTCPConnectionCallbackCtxt gAppTCPConnCbCtxt; +chip::Transport::ActiveTCPConnectionState * gActiveTCPConnState = nullptr; using TCPImpl = Transport::TCP; @@ -71,7 +77,8 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mCallback = callback; mCallbackData = callback_data; } - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader; @@ -82,12 +89,54 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate EXPECT_EQ(mCallback(msgBuf->Start(), msgBuf->DataLength(), mReceiveHandlerCallCount, mCallbackData), 0); } + ChipLogProgress(Inet, "Message Receive Handler called"); + mReceiveHandlerCallCount++; } + void HandleConnectionAttemptComplete(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + mHandleConnectionCompleteCalled = true; + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + if (appConnCbCtxt->connCompleteCb != nullptr) + { + appConnCbCtxt->connCompleteCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection established. App callback missing."); + } + } + + void HandleConnectionClosed(chip::Transport::ActiveTCPConnectionState * conn, CHIP_ERROR conErr) override + { + chip::Transport::AppTCPConnectionCallbackCtxt * appConnCbCtxt = nullptr; + VerifyOrReturn(conn != nullptr); + + mHandleConnectionCloseCalled = true; + appConnCbCtxt = conn->mAppState; + VerifyOrReturn(appConnCbCtxt != nullptr); + + if (appConnCbCtxt->connClosedCb != nullptr) + { + appConnCbCtxt->connClosedCb(conn, conErr); + } + else + { + ChipLogProgress(Inet, "Connection Closed. App callback missing."); + } + } + void InitializeMessageTest(TCPImpl & tcp, const IPAddress & addr) { - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()).SetAddressType(addr.Type())); + CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); // retry a few times in case the port is somehow in use. // this is a WORKAROUND for flaky testing if we run tests very fast after each other. @@ -106,7 +155,9 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate { ChipLogProgress(NotSpecified, "RETRYING tcp initialization"); chip::test_utils::SleepMillis(100); - err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()).SetAddressType(addr.Type())); + err = tcp.Init(Transport::TcpListenParameters(mContext->GetTCPEndPointManager()) + .SetAddressType(addr.Type()) + .SetListenPort(gChipTCPPort)); } EXPECT_EQ(err, CHIP_NO_ERROR); @@ -114,7 +165,14 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate mTransportMgrBase.SetSessionManager(this); mTransportMgrBase.Init(&tcp); - mReceiveHandlerCallCount = 0; + mReceiveHandlerCallCount = 0; + mHandleConnectionCompleteCalled = false; + mHandleConnectionCloseCalled = false; + + gAppTCPConnCbCtxt.appContext = nullptr; + gAppTCPConnCbCtxt.connReceivedCb = nullptr; + gAppTCPConnCbCtxt.connCompleteCb = nullptr; + gAppTCPConnCbCtxt.connClosedCb = nullptr; } void SingleMessageTest(TCPImpl & tcp, const IPAddress & addr) @@ -132,7 +190,7 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate EXPECT_EQ(err, CHIP_NO_ERROR); // Should be able to send a message to itself by just calling send. - err = tcp.SendMessage(Transport::PeerAddress::TCP(addr), std::move(buffer)); + err = tcp.SendMessage(Transport::PeerAddress::TCP(addr, gChipTCPPort), std::move(buffer)); EXPECT_EQ(err, CHIP_NO_ERROR); mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mReceiveHandlerCallCount != 0; }); @@ -141,39 +199,114 @@ class MockTransportMgrDelegate : public chip::TransportMgrDelegate SetCallback(nullptr); } - void FinalizeMessageTest(TCPImpl & tcp, const IPAddress & addr) + void ConnectTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), true); + } + + void HandleConnectCompleteCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + EXPECT_EQ(mHandleConnectionCompleteCalled, true); + } + + void HandleConnectCloseCbCalledTest(TCPImpl & tcp, const IPAddress & addr) + { + // Connect and wait for seeing active connection and connection complete + // handler being called. + CHIP_ERROR err = tcp.TCPConnect(Transport::PeerAddress::TCP(addr, gChipTCPPort), &gAppTCPConnCbCtxt, &gActiveTCPConnState); + EXPECT_EQ(err, CHIP_NO_ERROR); + + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [this]() { return mHandleConnectionCompleteCalled; }); + EXPECT_EQ(mHandleConnectionCompleteCalled, true); + + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(mHandleConnectionCloseCalled, true); + } + + void DisconnectTest(TCPImpl & tcp, chip::Transport::ActiveTCPConnectionState * conn) { // Disconnect and wait for seeing peer close - tcp.Disconnect(Transport::PeerAddress::TCP(addr)); + tcp.TCPDisconnect(conn, true); mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), false); + } + + void DisconnectTest(TCPImpl & tcp, const IPAddress & addr) + { + // Disconnect and wait for seeing peer close + tcp.TCPDisconnect(Transport::PeerAddress::TCP(addr, gChipTCPPort)); + mContext->DriveIOUntil(chip::System::Clock::Seconds16(5), [&tcp]() { return !tcp.HasActiveConnections(); }); + EXPECT_EQ(tcp.HasActiveConnections(), false); + } + + CHIP_ERROR TCPConnect(const Transport::PeerAddress & peerAddress, Transport::AppTCPConnectionCallbackCtxt * appState, + Transport::ActiveTCPConnectionState ** peerConnState) + { + return mTransportMgrBase.TCPConnect(peerAddress, appState, peerConnState); + } + + using OnTCPConnectionReceivedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn); + + using OnTCPConnectionCompleteCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + using OnTCPConnectionClosedCallback = void (*)(void * context, chip::Transport::ActiveTCPConnectionState * conn, + CHIP_ERROR conErr); + + void SetConnectionCallbacks(OnTCPConnectionCompleteCallback connCompleteCb, OnTCPConnectionClosedCallback connClosedCb, + OnTCPConnectionReceivedCallback connReceivedCb) + { + mConnCompleteCb = connCompleteCb; + mConnClosedCb = connClosedCb; + mConnReceivedCb = connReceivedCb; } int mReceiveHandlerCallCount = 0; + bool mHandleConnectionCompleteCalled = false; + + bool mHandleConnectionCloseCalled = false; + private: TestContext * mContext; MessageReceivedCallback mCallback; void * mCallbackData; TransportMgrBase mTransportMgrBase; + OnTCPConnectionCompleteCallback mConnCompleteCb = nullptr; + OnTCPConnectionClosedCallback mConnClosedCb = nullptr; + OnTCPConnectionReceivedCallback mConnReceivedCb = nullptr; }; -/////////////////////////// Init test - class TestTCP : public ::testing::Test, public chip::Test::IOContext { protected: void SetUp() { ASSERT_EQ(Init(), CHIP_NO_ERROR); } void TearDown() { Shutdown(); } + /////////////////////////// Init test void CheckSimpleInitTest(Inet::IPAddressType type) { TCPImpl tcp; - CHIP_ERROR err = tcp.Init(Transport::TcpListenParameters(GetTCPEndPointManager()).SetAddressType(type)); + CHIP_ERROR err = + tcp.Init(Transport::TcpListenParameters(GetTCPEndPointManager()).SetAddressType(type).SetListenPort(gChipTCPPort)); EXPECT_EQ(err, CHIP_NO_ERROR); } + /////////////////////////// Messaging test void CheckMessageTest(const IPAddress & addr) { TCPImpl tcp; @@ -181,7 +314,48 @@ class TestTCP : public ::testing::Test, public chip::Test::IOContext MockTransportMgrDelegate gMockTransportMgrDelegate(this); gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void ConnectToSelfTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void ConnectSendMessageThenCloseTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.ConnectTest(tcp, addr); + gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void HandleConnCompleteTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCompleteCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); + } + + void HandleConnCloseTest(const IPAddress & addr) + { + TCPImpl tcp; + + MockTransportMgrDelegate gMockTransportMgrDelegate(this); + gMockTransportMgrDelegate.InitializeMessageTest(tcp, addr); + gMockTransportMgrDelegate.HandleConnectCloseCbCalledTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } }; @@ -211,6 +385,57 @@ TEST_F(TestTCP, CheckMessageTest6) CheckMessageTest(addr); } +#if INET_CONFIG_ENABLE_IPV4 +TEST_F(TestTCP, ConnectToSelfTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectToSelfTest(addr); +} + +TEST_F(TestTCP, ConnectSendMessageThenCloseTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + ConnectSendMessageThenCloseTest(addr); +} + +TEST_F(TestTCP, HandleConnCompleteCalledTest4) +{ + IPAddress addr; + IPAddress::FromString("127.0.0.1", addr); + HandleConnCompleteTest(addr); +} +#endif // INET_CONFIG_ENABLE_IPV4 + +TEST_F(TestTCP, ConnectToSelfTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectToSelfTest(addr); +} + +TEST_F(TestTCP, ConnectSendMessageThenCloseTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + ConnectSendMessageThenCloseTest(addr); +} + +TEST_F(TestTCP, HandleConnCompleteCalledTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCompleteTest(addr); +} + +TEST_F(TestTCP, HandleConnCloseCalledTest6) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + HandleConnCloseTest(addr); +} + // Generates a packet buffer or a chain of packet buffers for a single message. struct TestData { @@ -381,8 +606,8 @@ class TCPTest // (The current TCPEndPoint implementation is not effectively mockable.) gMockTransportMgrDelegate.SingleMessageTest(tcp, addr); - Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr); - TCPBase::ActiveConnectionState * state = tcp.FindActiveConnection(lPeerAddress); + Transport::PeerAddress lPeerAddress = Transport::PeerAddress::TCP(addr, gChipTCPPort); + chip::Transport::ActiveTCPConnectionState * state = tcp.FindActiveConnection(lPeerAddress); ASSERT_NE(state, nullptr); Inet::TCPEndPoint * lEndPoint = state->mEndPoint; ASSERT_NE(lEndPoint, nullptr); @@ -433,7 +658,7 @@ class TCPTest EXPECT_EQ(err, CHIP_ERROR_MESSAGE_TOO_LONG); EXPECT_EQ(gMockTransportMgrDelegate.mReceiveHandlerCallCount, 0); - gMockTransportMgrDelegate.FinalizeMessageTest(tcp, addr); + gMockTransportMgrDelegate.DisconnectTest(tcp, addr); } }; } // namespace Transport diff --git a/src/transport/raw/tests/TestUDP.cpp b/src/transport/raw/tests/TestUDP.cpp index 70077a6fa297e9..6a22a821fdfc94 100644 --- a/src/transport/raw/tests/TestUDP.cpp +++ b/src/transport/raw/tests/TestUDP.cpp @@ -52,7 +52,8 @@ class MockTransportMgrDelegate : public TransportMgrDelegate MockTransportMgrDelegate() {} ~MockTransportMgrDelegate() override {} - void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override + void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf, + Transport::MessageTransportContext * transCtxt = nullptr) override { PacketHeader packetHeader;