Skip to content

Commit

Permalink
Implement SessionHolder auto shifting (#18107)
Browse files Browse the repository at this point in the history
* Implement SessionHolder auto shifting

* Resolve comments from Jerry

* Resolve comments

* Apply suggestions from code review

Co-authored-by: Boris Zbarsky <[email protected]>

* Restyle

Co-authored-by: Boris Zbarsky <[email protected]>
  • Loading branch information
2 people authored and pull[bot] committed Feb 9, 2024
1 parent a1eefed commit 532791f
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 37 deletions.
8 changes: 6 additions & 2 deletions src/lib/core/CASEAuthTag.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include <array>

#include <lib/core/CHIPConfig.h>
#include <lib/core/CHIPEncoding.h>
#include <lib/core/NodeId.h>
Expand All @@ -35,11 +37,11 @@ static constexpr size_t kMaxSubjectCATAttributeCount = CHIP_CONFIG_CERT_MAX_RDN_

struct CATValues
{
CASEAuthTag values[kMaxSubjectCATAttributeCount] = { kUndefinedCAT };
std::array<CASEAuthTag, kMaxSubjectCATAttributeCount> values = { kUndefinedCAT };

/* @brief Returns size of the CAT values array.
*/
static constexpr size_t size() { return ArraySize(values); }
static constexpr size_t size() { return std::tuple_size<decltype(values)>::value; }

/* @brief Returns true if subject input checks against one of the CATs in the values array.
*/
Expand All @@ -58,6 +60,8 @@ struct CATValues
return false;
}

bool operator==(const CATValues & that) const { return values == that.values; }

static constexpr size_t kSerializedLength = kMaxSubjectCATAttributeCount * sizeof(CASEAuthTag);
typedef uint8_t Serialized[kSerializedLength];

Expand Down
47 changes: 47 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@ void SecureSessionDeleter::Release(SecureSession * entry)
entry->mTable.ReleaseSession(entry);
}

void SecureSession::Activate(const ScopedNodeId & localNode, const ScopedNodeId & peerNode, CATValues peerCATs,
uint16_t peerSessionId, const ReliableMessageProtocolConfig & config)
{
VerifyOrDie(mState == State::kEstablishing);
VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex());

// PASE sessions must always start unassociated with a Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex)));
// CASE sessions must always start "associated" a given Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex)));
// CASE sessions can only be activated against operational node IDs!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) &&
(!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId()))));

mPeerNodeId = peerNode.GetNodeId();
mLocalNodeId = localNode.GetNodeId();
mPeerCATs = peerCATs;
mPeerSessionId = peerSessionId;
mMRPConfig = config;
SetFabricIndex(peerNode.GetFabricIndex());

Retain(); // This ref is released inside MarkForEviction
MoveToState(State::kActive);

if (mSecureSessionType == Type::kCASE)
mTable.NewerSessionAvailable(this);

ChipLogDetail(Inet, "SecureSession[%p]: Activated - Type:%d LSID:%d", this, to_underlying(mSecureSessionType), mLocalSessionId);
}

const char * SecureSession::StateToString(State state) const
{
switch (state)
Expand Down Expand Up @@ -200,5 +230,22 @@ void SecureSession::Release()
ReferenceCounted<SecureSession, SecureSessionDeleter, 0, uint16_t>::Release();
}

void SecureSession::NewerSessionAvailable(const SessionHandle & session)
{
// Shift to the new session, checks are performed by the the caller SecureSessionTable::NewerSessionAvailable.
IntrusiveList<SessionHolder>::Iterator iter = mHolders.begin();
while (iter != mHolders.end())
{
// The iterator can be invalid once the session holder is migrated to another session. So we store its next value before
// notifying the holder.
IntrusiveList<SessionHolder>::Iterator next = iter;
++next;

iter->ShiftToSession(session);

iter = next;
}
}

} // namespace Transport
} // namespace chip
34 changes: 8 additions & 26 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,8 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec
* discovered during session establishment.
*/
void Activate(const ScopedNodeId & localNode, const ScopedNodeId & peerNode, CATValues peerCATs, uint16_t peerSessionId,
const ReliableMessageProtocolConfig & config)
{
VerifyOrDie(mState == State::kEstablishing);
VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex());

// PASE sessions must always start unassociated with a Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex)));
// CASE sessions must always start "associated" a given Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex)));
// CASE sessions can only be activated against operational node IDs!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) &&
(!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId()))));

mPeerNodeId = peerNode.GetNodeId();
mLocalNodeId = localNode.GetNodeId();
mPeerCATs = peerCATs;
mPeerSessionId = peerSessionId;
mMRPConfig = config;
SetFabricIndex(peerNode.GetFabricIndex());

Retain(); // This ref is released inside MarkForEviction
MoveToState(State::kActive);
ChipLogDetail(Inet, "SecureSession[%p]: Activated - Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
mLocalSessionId);
}
const ReliableMessageProtocolConfig & config);

~SecureSession() override
{
ChipLogDetail(Inet, "SecureSession[%p]: Released - Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
Expand Down Expand Up @@ -213,7 +190,7 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec
NodeId GetPeerNodeId() const { return mPeerNodeId; }
NodeId GetLocalNodeId() const { return mLocalNodeId; }

CATValues GetPeerCATs() const { return mPeerCATs; }
const CATValues & GetPeerCATs() const { return mPeerCATs; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }

Expand Down Expand Up @@ -262,6 +239,11 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec

SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }

// This should be a private API, only meant to be called by SecureSessionTable
// Session holders to this session may shift to the target session regarding SessionDelegate::GetNewSessionHandlingPolicy.
// It requires that the target sessoin is also a CASE session, having the same peer and CATs as this session.
void NewerSessionAvailable(const SessionHandle & session);

private:
enum class State : uint8_t
{
Expand Down
26 changes: 26 additions & 0 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ class SecureSessionTable
CHECK_RETURN_VALUE
Optional<SessionHandle> FindSecureSessionByLocalKey(uint16_t localSessionId);

// Select SessionHolders which are pointing to a session with the same peer as the given session. Shift them to the given
// session.
// This is an internal API, using raw pointer to a session is allowed here.
void NewerSessionAvailable(SecureSession * session)
{
VerifyOrDie(session->GetSecureSessionType() == SecureSession::Type::kCASE);
mEntries.ForEachActiveObject([&](SecureSession * oldSession) {
if (session == oldSession)
return Loop::Continue;

SessionHandle ref(*oldSession);

// This will give all SessionHolders pointing to oldSession a chance to switch to the provided session
//
// See documentation for SessionDelegate::GetNewSessionHandlingPolicy about how session auto-shifting works, and how
// to disable it for a specific SessionHolder in a specific scenario.
if (oldSession->GetSecureSessionType() == SecureSession::Type::kCASE && oldSession->GetPeer() == session->GetPeer() &&
oldSession->GetPeerCATs() == session->GetPeerCATs())
{
oldSession->NewerSessionAvailable(SessionHandle(*session));
}

return Loop::Continue;
});
}

private:
friend class TestSecureSessionTable;

Expand Down
5 changes: 3 additions & 2 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ class Session
SessionHandle session(*this);
while (!mHolders.Empty())
{
mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list
mHolders.begin()->SessionReleased(); // SessionReleased must remove the item from the linked list
}
}

void SetFabricIndex(FabricIndex index) { mFabricIndex = index; }

private:
IntrusiveList<SessionHolder> mHolders;

private:
FabricIndex mFabricIndex = kUndefinedFabricIndex;
};

Expand Down
10 changes: 9 additions & 1 deletion src/transport/SessionDelegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ class DLL_EXPORT SessionDelegate
* Called when a new secure session to the same peer is established, over the delegate of SessionHolderWithDelegate object. It
* is suggested to shift to the newly created session.
*
* Our security model is built upon Exchanges and Sessions, but not SessionHolders, such that SessionHolders should be able to
* shift to a new session freely. If an application is holding a session which is not intended to be shifted, it can provide
* its shifting policy by overriding GetNewSessionHandlingPolicy in SessionDelegate. For example SessionHolders inside
* ExchangeContext and PairingSession are not eligible for auto-shifting.
*
* Note: the default implementation orders shifting to the new session, it should be fine for all users, unless the
* SessionHolder object is expected to be sticky to a specified session.
* SessionHolder object is expected to be sticky to a specified session.
*
* Note: the implementation MUST NOT modify the session pool or the state of session holders (eg, adding new session, removing
* old session) from inside this callback.
*/
virtual NewSessionHandlingPolicy GetNewSessionHandlingPolicy() { return NewSessionHandlingPolicy::kShiftToNewSession; }

Expand Down
22 changes: 16 additions & 6 deletions src/transport/SessionHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ namespace chip {
* released when the underlying session is released. One must verify it is available before use. The object can be
* created using SessionHandle.Grab()
*/
class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
class SessionHolder : public IntrusiveListNodeBase<>
{
public:
SessionHolder() {}
~SessionHolder() override;
virtual ~SessionHolder();

SessionHolder(const SessionHolder &);
SessionHolder(SessionHolder && that);
SessionHolder & operator=(const SessionHolder &);
SessionHolder & operator=(SessionHolder && that);

// Implement SessionDelegate
void OnSessionReleased() override { Release(); }
virtual void SessionReleased() { Release(); }
virtual void ShiftToSession(const SessionHandle & session)
{
Release();
Grab(session);
}

bool Contains(const SessionHandle & session) const
{
Expand All @@ -51,7 +55,7 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
bool Grab(const SessionHandle & session);
void Release();

operator bool() const { return mSession.HasValue(); }
explicit operator bool() const { return mSession.HasValue(); }
Optional<SessionHandle> Get() const
{
//
Expand Down Expand Up @@ -81,14 +85,20 @@ class SessionHolderWithDelegate : public SessionHolder
SessionHolderWithDelegate(const SessionHandle & handle, SessionDelegate & delegate) : mDelegate(delegate) { Grab(handle); }
operator bool() const { return SessionHolder::operator bool(); }

void OnSessionReleased() override
void SessionReleased() override
{
Release();

// Note, the session is already cleared during mDelegate.OnSessionReleased
mDelegate.OnSessionReleased();
}

void ShiftToSession(const SessionHandle & session) override
{
if (mDelegate.GetNewSessionHandlingPolicy() == SessionDelegate::NewSessionHandlingPolicy::kShiftToNewSession)
SessionHolder::ShiftToSession(session);
}

void DispatchSessionEvent(SessionDelegate::Event event) override { (mDelegate.*event)(); }

private:
Expand Down
21 changes: 21 additions & 0 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH
return CHIP_NO_ERROR;
}

CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId,
uint16_t peerSessionId, NodeId localNodeId, NodeId peerNodeId,
FabricIndex fabric, const Transport::PeerAddress & peerAddress,
CryptoContext::SessionRole role, const CATValues & cats)
{
Optional<SessionHandle> session =
mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kCASE, localSessionId, localNodeId,
peerNodeId, cats, peerSessionId, fabric, GetLocalMRPConfig());
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);
SecureSession * secureSession = session.Value()->AsSecureSession();
secureSession->SetPeerAddress(peerAddress);

size_t secretLen = strlen(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE);
ByteSpan secret(reinterpret_cast<const uint8_t *>(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen);
ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret(
secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role));
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue);
sessionHolder.Grab(session.Value());
return CHIP_NO_ERROR;
}

void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg)
{
CHIP_TRACE_PREPARED_MESSAGE_RECEIVED(&peerAddress, &msg);
Expand Down
5 changes: 5 additions & 0 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId,
uint16_t peerSessionId, FabricIndex fabricIndex,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role);
CHIP_ERROR InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, uint16_t peerSessionId,
NodeId localNodeId, NodeId peerNodeId, FabricIndex fabric,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role,
const CATValues & cats = CATValues{});

/**
* @brief
Expand Down Expand Up @@ -210,6 +214,7 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
void FabricRemoved(FabricIndex fabricIndex);

TransportMgrBase * GetTransportManager() const { return mTransportMgr; }
Transport::SecureSessionTable & GetSecureSessions() { return mSecureSessions; }

/**
* @brief
Expand Down
Loading

0 comments on commit 532791f

Please sign in to comment.