diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 0164e5c08247d6..5d3aca2620cf91 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -217,7 +217,7 @@ void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASECli device->DequeueConnectionSuccessCallbacks(/* executeCallback */ false); device->DequeueConnectionFailureCallbacks(error, /* executeCallback */ true); - device->DeferCloseCASESession(); + device->CloseCASESession(); } void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client) @@ -238,7 +238,7 @@ void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * cl device->DequeueConnectionFailureCallbacks(CHIP_NO_ERROR, /* executeCallback */ false); device->DequeueConnectionSuccessCallbacks(/* executeCallback */ true); - device->DeferCloseCASESession(); + device->CloseCASESession(); } } @@ -276,22 +276,15 @@ void OperationalDeviceProxy::Clear() mInitParams = DeviceProxyInitParams(); } -void OperationalDeviceProxy::CloseCASESessionTask(System::Layer * layer, void * context) +void OperationalDeviceProxy::CloseCASESession() { - OperationalDeviceProxy * device = static_cast(context); - if (device->mCASEClient) + if (mCASEClient) { - device->mInitParams.clientPool->Release(device->mCASEClient); - device->mCASEClient = nullptr; + mInitParams.clientPool->Release(mCASEClient); + mCASEClient = nullptr; } } -void OperationalDeviceProxy::DeferCloseCASESession() -{ - // Defer the release for the pending Ack to be sent - mSystemLayer->ScheduleWork(CloseCASESessionTask, this); -} - void OperationalDeviceProxy::OnSessionReleased() { mState = State::Initialized; diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 91dd127706ebce..292c58c07728ea 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -235,7 +235,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele static void CloseCASESessionTask(System::Layer * layer, void * context); - void DeferCloseCASESession(); + void CloseCASESession(); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index ff9b25b9ca2750..a069263f371b50 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -122,6 +122,19 @@ void CASESession::CloseExchange() } } +void CASESession::DiscardExchange() +{ + if (mExchangeCtxt != nullptr) + { + // Make sure the exchange doesn't try to notify us when it closes, + // since we might be dead by then. + mExchangeCtxt->SetDelegate(nullptr); + // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // exchange will handle that. + mExchangeCtxt = nullptr; + } +} + CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession) { const NodeId peerNodeId = GetPeerNodeId(); @@ -252,11 +265,12 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match")); ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Current state was %" PRIu8, mState); - mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) @@ -683,10 +697,12 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate session establishment is successful + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -1117,10 +1133,12 @@ CHIP_ERROR CASESession::HandleSigma3(System::PacketBufferHandle && msg) mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate session establishment is successful + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -1301,16 +1319,18 @@ void CASESession::OnSuccessStatusReport() ChipLogProgress(SecureChannel, "Success status report received. Session was established"); mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; - - // Call delegate to indicate pairing completion - mDelegate->OnSessionEstablished(); + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); mState = kInitialized; // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction // on running out of session contexts. + + // Call delegate to indicate pairing completion. + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablished(); } CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) @@ -1522,10 +1542,11 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea // Call delegate to indicate session establishment failure. if (err != CHIP_NO_ERROR) { - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablishmentError(err); } return err; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index b266158c011f4d..a780c298a6dd9a 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -220,6 +220,12 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin void CloseExchange(); + /** + * Clear our reference to our exchange context pointer so that it can close + * itself at some later time. + */ + void DiscardExchange(); + // TODO: Remove this and replace with system method to retrieve current time CHIP_ERROR SetEffectiveTime(void); diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 4056ee1a44da31..3f1d29cadc146e 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -116,6 +116,19 @@ void PASESession::CloseExchange() } } +void PASESession::DiscardExchange() +{ + if (mExchangeCtxt != nullptr) + { + // Make sure the exchange doesn't try to notify us when it closes, + // since we might be dead by then. + mExchangeCtxt->SetDelegate(nullptr); + // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // exchange will handle that. + mExchangeCtxt = nullptr; + } +} + CHIP_ERROR PASESession::Serialize(PASESessionSerialized & output) { PASESessionSerializable serializable; @@ -349,11 +362,12 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec) ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %" PRIu8, to_underlying(mNextExpectedMsg)); - mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) @@ -829,10 +843,12 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) mPairingComplete = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate pairing completion + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -848,10 +864,12 @@ void PASESession::OnSuccessStatusReport() { mPairingComplete = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate pairing completion + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); } @@ -942,11 +960,12 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl // Call delegate to indicate pairing failure if (err != CHIP_NO_ERROR) { - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); ChipLogError(SecureChannel, "Failed during PASE session setup. %s", ErrorStr(err)); + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablishmentError(err); } return err; diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 4bb43440b0cc37..f3ff1ae62f0df0 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -263,6 +263,12 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin void CloseExchange(); + /** + * Clear our reference to our exchange context pointer so that it can close + * itself at some later time. + */ + void DiscardExchange(); + SessionEstablishmentDelegate * mDelegate = nullptr; Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError;