Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup in aisle CASESession #26339

Merged
merged 3 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/credentials/FabricTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ FabricTable::AddOrUpdateInner(FabricIndex fabricIndex, bool isAddition, Crypto::
}
else
{
// Initialization for Upating fabric: setting up a shadow fabricInfo
// Initialization for Updating fabric: setting up a shadow fabricInfo
const FabricInfo * existingFabric = FindFabricWithIndex(fabricIndex);
VerifyOrReturnError(existingFabric != nullptr, CHIP_ERROR_INTERNAL);

Expand Down
4 changes: 2 additions & 2 deletions src/credentials/FabricTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DLL_EXPORT FabricInfo

friend class FabricTable;

protected:
private:
struct InitParams
{
NodeId nodeId = kUndefinedNodeId;
Expand Down Expand Up @@ -1098,7 +1098,7 @@ class DLL_EXPORT FabricTable
*/
const FabricInfo * GetShadowPendingFabricEntry() const { return HasPendingFabricUpdate() ? &mPendingFabric : nullptr; }

// Returns true if we have a shadow entry pending for a fabruc update.
// Returns true if we have a shadow entry pending for a fabric update.
bool HasPendingFabricUpdate() const
{
return mPendingFabric.IsInitialized() &&
Expand Down
118 changes: 61 additions & 57 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,19 @@ class CASESession::WorkHelper
// The `status` value is the result of the work callback (called beforehand).
typedef CHIP_ERROR (CASESession::*AfterWorkCallback)(DATA & data, CHIP_ERROR status);

// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is not managed, see `Create` for that option.
WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback)
{}

public:
// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is managed by sharing between the caller (typically the session) and the helper itself (while work is scheduled).
static Platform::SharedPtr<WorkHelper> Create(CASESession & session, WorkCallback workCallback,
AfterWorkCallback afterWorkCallback)
{
auto ptr = Platform::MakeShared<WorkHelper>(session, workCallback, afterWorkCallback);
struct EnableShared : public WorkHelper
mlepage-google marked this conversation as resolved.
Show resolved Hide resolved
{
EnableShared(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
WorkHelper(session, workCallback, afterWorkCallback)
{}
};
auto ptr = Platform::MakeShared<EnableShared>(session, workCallback, afterWorkCallback);
if (ptr)
{
ptr->mWeakPtr = ptr; // used by `ScheduleWork`
Expand All @@ -173,10 +174,7 @@ class CASESession::WorkHelper
// No scheduling, no outstanding work, no shared lifetime management.
CHIP_ERROR DoWork()
{
if (!mSession || !mWorkCallback || !mAfterWorkCallback)
{
return CHIP_ERROR_INCORRECT_STATE;
}
VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE);
auto * helper = this;
bool cancel = false;
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
Expand All @@ -187,18 +185,17 @@ class CASESession::WorkHelper
return helper->mStatus;
}

// Schedule the work after configuring the data.
// Schedule the work for later execution.
// If lifetime is managed, the helper shares management while work is outstanding.
CHIP_ERROR ScheduleWork()
{
if (!mSession || !mWorkCallback || !mAfterWorkCallback)
{
return CHIP_ERROR_INCORRECT_STATE;
}
VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE);
// Hold strong ptr while work is outstanding
mStrongPtr = mWeakPtr.lock(); // set in `Create`
auto status = DeviceLayer::PlatformMgr().ScheduleBackgroundWork(WorkHandler, reinterpret_cast<intptr_t>(this));
if (status != CHIP_NO_ERROR)
{
// Release strong ptr since scheduling failed
mStrongPtr.reset();
}
return status;
Expand All @@ -207,32 +204,47 @@ class CASESession::WorkHelper
// Cancel the work, by clearing the associated session.
void CancelWork() { mSession.store(nullptr); }

bool IsCancelled() const { return mSession.load() == nullptr; }

private:
// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is not managed, see `Create` for that option.
WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback)
{}

// Handler for the work callback.
static void WorkHandler(intptr_t arg)
{
auto * helper = reinterpret_cast<WorkHelper *>(arg);
bool cancel = false;
VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`?
// Hold strong ptr while work is handled
auto strongPtr(std::move(helper->mStrongPtr));
VerifyOrReturn(!helper->IsCancelled());
bool cancel = false;
// Execute callback in background thread; data must be OK with this
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`?
VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`?
SuccessOrExit(DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast<intptr_t>(helper)));
return;
exit:
helper->mStrongPtr.reset();
VerifyOrReturn(!cancel && !helper->IsCancelled());
// Hold strong ptr while work is outstanding
helper->mStrongPtr.swap(strongPtr);
auto status = DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast<intptr_t>(helper));
if (status != CHIP_NO_ERROR)
{
// Release strong ptr since scheduling failed
helper->mStrongPtr.reset();
}
}

// Handler for the after work callback.
static void AfterWorkHandler(intptr_t arg)
{
// Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.)
auto * helper = reinterpret_cast<WorkHelper *>(arg);
// Hold strong ptr while work is handled
auto strongPtr(std::move(helper->mStrongPtr));
if (auto * session = helper->mSession.load())
{
// Execute callback in Matter thread; session should be OK with this
(session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus);
}
helper->mStrongPtr.reset();
}

private:
Expand Down Expand Up @@ -261,7 +273,7 @@ class CASESession::WorkHelper

struct CASESession::SendSigma3Data
{
std::atomic<FabricIndex> fabricIndex;
FabricIndex fabricIndex;

// Use one or the other
const FabricTable * fabricTable;
Expand Down Expand Up @@ -319,7 +331,6 @@ void CASESession::Clear()
// Cancel any outstanding work.
if (mSendSigma3Helper)
{
mSendSigma3Helper->mData.fabricIndex = kUndefinedFabricIndex;
mSendSigma3Helper->CancelWork();
mSendSigma3Helper.reset();
}
Expand Down Expand Up @@ -1359,40 +1370,37 @@ CHIP_ERROR CASESession::SendSigma3a()

CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel)
{
CHIP_ERROR err = CHIP_NO_ERROR;

// Generate a signature
if (data.keystore != nullptr)
{
// Recommended case: delegate to operational keystore
err = data.keystore->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len },
data.tbsData3Signature);
ReturnErrorOnFailure(data.keystore->SignWithOpKeypair(
data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature));
}
else
{
// Legacy case: delegate to fabric table fabric info
err = data.fabricTable->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len },
data.tbsData3Signature);
ReturnErrorOnFailure(data.fabricTable->SignWithOpKeypair(
data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature));
}
SuccessOrExit(err);

// Prepare Sigma3 TBE Data Blob
data.msg_r3_encrypted_len =
TLV::EstimateStructOverhead(data.nocCert.size(), data.icaCert.size(), data.tbsData3Signature.Length());

VerifyOrExit(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES),
err = CHIP_ERROR_NO_MEMORY);
VerifyOrReturnError(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES),
CHIP_ERROR_NO_MEMORY);

{
TLV::TLVWriter tlvWriter;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;

tlvWriter.Init(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len);
SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert));
if (!data.icaCert.empty())
{
SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert));
}

// We are now done with ICAC and NOC certs so we can release the memory.
Expand All @@ -1404,15 +1412,14 @@ CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel)
data.nocCert = MutableByteSpan{};
}

SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(),
static_cast<uint32_t>(data.tbsData3Signature.Length())));
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize());
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(),
static_cast<uint32_t>(data.tbsData3Signature.Length())));
ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize());
data.msg_r3_encrypted_len = static_cast<size_t>(tlvWriter.GetLengthWritten());
}

exit:
return err;
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status)
Expand Down Expand Up @@ -1650,17 +1657,15 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)

CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel)
{
CHIP_ERROR err = CHIP_NO_ERROR;

// Step 5/6
// Validate initiator identity located in msg->Start()
// Constructing responder identity
CompressedFabricId unused;
FabricId initiatorFabricId;
P256PublicKey initiatorPublicKey;
SuccessOrExit(err = FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext,
unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey));
VerifyOrExit(data.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext,
unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey));
VerifyOrReturnError(data.fabricId == initiatorFabricId, CHIP_ERROR_INVALID_CASE_PARAMETER);

// TODO - Validate message signature prior to validating the received operational credentials.
// The op cert check requires traversal of cert chain, that is a more expensive operation.
Expand All @@ -1672,16 +1677,15 @@ CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel)
{
P256PublicKeyHSM initiatorPublicKeyHSM;
memcpy(Uint8::to_uchar(initiatorPublicKeyHSM), initiatorPublicKey.Bytes(), initiatorPublicKey.Length());
SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
ReturnErrorOnFailure(initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
}
#else
SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
ReturnErrorOnFailure(
initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, data.tbsData3Signature));
#endif

exit:
return err;
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status)
Expand Down