Skip to content

Commit

Permalink
Add ICD Client Info management support with persistent storage
Browse files Browse the repository at this point in the history
  • Loading branch information
yunhanw-google committed Oct 17, 2023
1 parent 4df48ef commit 9134568
Show file tree
Hide file tree
Showing 7 changed files with 620 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/app/icd/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
import("//build_overrides/chip.gni")
import("icd.gni")

# ICD Client sources and configurations
source_set("client") {
sources = [
"ICDClientInfoManagement.cpp",
"ICDClientInfoManagement.h",
"ICDClientInfoPersistentStorage.h"
]

deps = [ "${chip_root}/src/lib/core" ]
public_deps = [
"${chip_root}/src/app:app_config",
"${chip_root}/src/crypto",
]
}

# ICD Server sources and configurations
source_set("observer") {
sources = [ "ICDStateObserver.h" ]
Expand Down
368 changes: 368 additions & 0 deletions src/app/icd/ICDClientInfoManagement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
/*
* Copyright (c) 2023 Project CHIP Authors
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* @file
* This file defines a basic implementation of ICDClientPersistentStorage that
* persists clientInfo in a flat list in TLV.
*/

#include <app/icd/ICDClientInfoManagement.h>

#include <lib/support/Base64.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/SafeInt.h>
#include <lib/support/logging/CHIPLogging.h>
#include <lib/core/Global.h>

namespace chip {
namespace app {

Global<ICDClientInfoManagement> sICDClientInfoManagement;

ICDClientInfoManagement * ICDClientInfoManagement::GetInstance()
{
return &sICDClientInfoManagement.get();
}

ICDClientInfoManagement::ICDClientInfoIteratorImpl::ICDClientInfoIteratorImpl(
ICDClientInfoManagement & aStorage) :
mStorage(aStorage)
{
mNextIndex = 0;
}

size_t ICDClientInfoManagement::ICDClientInfoIteratorImpl::Count()
{
return static_cast<size_t>(mStorage.Count());
}

bool ICDClientInfoManagement::ICDClientInfoIteratorImpl::Next(ICDClientInfo & aOutput)
{
for (; mNextIndex < CHIP_IM_MAX_NUM_ICD_CLIENTS; mNextIndex++)
{
CHIP_ERROR err = mStorage.Load(mNextIndex, aOutput);
if (err == CHIP_NO_ERROR)
{
mNextIndex++;
return true;
}

if (err != CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND)
{
ChipLogError(DataManagement, "Failed to load ICDClient Info at index %u error %" CHIP_ERROR_FORMAT,
static_cast<unsigned>(mNextIndex), err.Format());
mStorage.Delete(mNextIndex);
}
}

return false;
}

void ICDClientInfoManagement::ICDClientInfoIteratorImpl::Release()
{
mStorage.mICDClientInfoIterators.ReleaseObject(this);
}

CHIP_ERROR ICDClientInfoManagement::Init(PersistentStorageDelegate * apStorage, Crypto::SymmetricKeystore * apSymmetricKeystore)
{
VerifyOrReturnError(apStorage != nullptr && apSymmetricKeystore != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
mpStorage = apStorage;
mpSymmetricKeystore = apSymmetricKeystore;

uint16_t countMax;
uint16_t len = sizeof(countMax);
CHIP_ERROR err =
mpStorage->SyncGetKeyValue(DefaultStorageKeyAllocator::ICDClientInfoMaxCount().KeyName(), &countMax, len);
// If there's a previous countMax and it's larger than CHIP_IM_MAX_NUM_ICD,
// clean up ICDClientInfo beyond the limit
if ((err == CHIP_NO_ERROR) && (countMax != CHIP_IM_MAX_NUM_ICD_CLIENTS))
{
for (uint16_t index = CHIP_IM_MAX_NUM_ICD_CLIENTS; index < countMax; index++)
{
Delete(index);
}
}

// Always save the current CHIP_IM_MAX_NUM_ICD_CLIENTS
uint16_t countMaxToSave = CHIP_IM_MAX_NUM_ICD_CLIENTS;
ReturnErrorOnFailure(mpStorage->SyncSetKeyValue(DefaultStorageKeyAllocator::ICDClientInfoMaxCount().KeyName(),
&countMaxToSave, sizeof(uint16_t)));

return CHIP_NO_ERROR;
}

ICDClientInfoManagement::ICDClientInfoIterator * ICDClientInfoManagement::IterateICDClientInfo()
{
return mICDClientInfoIterators.CreateObject(*this);
}

uint16_t ICDClientInfoManagement::Count()
{
uint16_t count = 0;
for (uint16_t index = 0; index < CHIP_IM_MAX_NUM_ICD_CLIENTS; index++)
{
if (mpStorage->SyncDoesKeyExist(DefaultStorageKeyAllocator::ICDClientInfo(index).KeyName()))
{
count++;
}
}

return count;
}

CHIP_ERROR ICDClientInfoManagement::Delete(uint16_t aIndex)
{
return mpStorage->SyncDeleteKeyValue(DefaultStorageKeyAllocator::ICDClientInfo(aIndex).KeyName());
}

CHIP_ERROR ICDClientInfoManagement::Load(uint16_t aIndex, ICDClientInfo & aICDClientInfo)
{
Platform::ScopedMemoryBuffer<uint8_t> backingBuffer;
backingBuffer.Calloc(MaxICDClientInfoSize());
ReturnErrorCodeIf(backingBuffer.Get() == nullptr, CHIP_ERROR_NO_MEMORY);

uint16_t len = static_cast<uint16_t>(MaxICDClientInfoSize());
ReturnErrorOnFailure(mpStorage->SyncGetKeyValue(DefaultStorageKeyAllocator::ICDClientInfo(aIndex).KeyName(),
backingBuffer.Get(), len));

TLV::ScopedBufferTLVReader reader(std::move(backingBuffer), len);

ReturnErrorOnFailure(reader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));

TLV::TLVType ICDClientInfoType;
ReturnErrorOnFailure(reader.EnterContainer(ICDClientInfoType));

// Peer Node ID
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kPeerNodeId)));
ReturnErrorOnFailure(reader.Get(aICDClientInfo.mPeerNodeId));

// Fabric index
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kFabricIndex)));
ReturnErrorOnFailure(reader.Get(aICDClientInfo.mFabricIndex));

// Start ICD Counter
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kStartICDCounter)));
ReturnErrorOnFailure(reader.Get(aICDClientInfo.mStartICDCounter));

// Offset
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kOffset)));
ReturnErrorOnFailure(reader.Get(aICDClientInfo.mOffset));

// MonitorSubject
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kMonitorSubject)));
ReturnErrorOnFailure(reader.Get(aICDClientInfo.mMonitoredSubject));

//shared key
ReturnErrorOnFailure(reader.Next(TLV::ContextTag(Tag::kSharedKey)));
ByteSpan buf(aICDClientInfo.mSharedKey.AsMutable<Crypto::Aes128KeyByteArray>());
ReturnErrorOnFailure(reader.Get(buf));
ReturnErrorOnFailure(SetKey(buf, aICDClientInfo.mSharedKey));

ReturnErrorOnFailure(reader.ExitContainer(ICDClientInfoType));

return CHIP_NO_ERROR;
}

CHIP_ERROR ICDClientInfoManagement::SetKey(const ByteSpan & aKeyData, Crypto::Aes128KeyHandle & aSharedKey)
{
VerifyOrReturnError(aKeyData.size() == sizeof(Crypto::Aes128KeyByteArray), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(mpSymmetricKeystore != nullptr, CHIP_ERROR_INTERNAL);

DeleteKey(aSharedKey);

Crypto::Aes128KeyByteArray keyMaterial;
memcpy(keyMaterial, aKeyData.data(), sizeof(Crypto::Aes128KeyByteArray));

ReturnErrorOnFailure(mpSymmetricKeystore->CreateKey(keyMaterial, aSharedKey));

return CHIP_NO_ERROR;
}

CHIP_ERROR ICDClientInfoManagement::Save(TLV::TLVWriter & aWriter, ICDClientInfo & aICDClientInfo)
{
TLV::TLVType ICDClientInfoContainerType;
ReturnErrorOnFailure(aWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, ICDClientInfoContainerType));
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kPeerNodeId), aICDClientInfo.mPeerNodeId));
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kFabricIndex), aICDClientInfo.mFabricIndex));
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kStartICDCounter), aICDClientInfo.mStartICDCounter));
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kOffset), aICDClientInfo.mOffset));
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kMonitorSubject), aICDClientInfo.mMonitoredSubject));
ByteSpan buf(aICDClientInfo.mSharedKey.As<Crypto::Aes128KeyByteArray>());
ReturnErrorOnFailure(aWriter.Put(TLV::ContextTag(Tag::kSharedKey), buf));
ReturnErrorOnFailure(aWriter.EndContainer(ICDClientInfoContainerType));
return CHIP_NO_ERROR;
}

CHIP_ERROR ICDClientInfoManagement::Save(ICDClientInfo & aICDClientInfo)
{
// Find empty index or duplicate if exists
uint16_t index;
uint16_t firstEmptyIndex = CHIP_IM_MAX_NUM_ICD_CLIENTS; // initialize to out of bounds as "not set"
for (index = 0; index < CHIP_IM_MAX_NUM_ICD_CLIENTS; index++)
{
ICDClientInfo currentICDClientInfo;
CHIP_ERROR err = Load(index, currentICDClientInfo);

// if empty and firstEmptyIndex isn't set yet, then mark empty spot
if ((firstEmptyIndex == CHIP_IM_MAX_NUM_ICD_CLIENTS) && (err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND))
{
firstEmptyIndex = index;
}

// delete duplicate
if (err == CHIP_NO_ERROR)
{
if ((aICDClientInfo.mPeerNodeId == currentICDClientInfo.mPeerNodeId) &&
(aICDClientInfo.mFabricIndex == currentICDClientInfo.mFabricIndex))
{
Delete(index);
// if duplicate is the first empty spot, then also set it
if (firstEmptyIndex == CHIP_IM_MAX_NUM_ICD_CLIENTS)
{
firstEmptyIndex = index;
}
}
}
}

// Fail if no empty space
if (firstEmptyIndex == CHIP_IM_MAX_NUM_ICD_CLIENTS)
{
return CHIP_ERROR_NO_MEMORY;
}

// Now construct ICD ClientInfo and save
Platform::ScopedMemoryBuffer<uint8_t> backingBuffer;
backingBuffer.Calloc(MaxICDClientInfoSize());
ReturnErrorCodeIf(backingBuffer.Get() == nullptr, CHIP_ERROR_NO_MEMORY);

TLV::ScopedBufferTLVWriter writer(std::move(backingBuffer), MaxICDClientInfoSize());

ReturnErrorOnFailure(Save(writer, aICDClientInfo));

const auto len = writer.GetLengthWritten();
VerifyOrReturnError(CanCastTo<uint16_t>(len), CHIP_ERROR_BUFFER_TOO_SMALL);

writer.Finalize(backingBuffer);

ReturnErrorOnFailure(
mpStorage->SyncSetKeyValue(DefaultStorageKeyAllocator::ICDClientInfo(firstEmptyIndex).KeyName(),
backingBuffer.Get(), static_cast<uint16_t>(len)));

return CHIP_NO_ERROR;
}

CHIP_ERROR ICDClientInfoManagement::Delete(NodeId aPeerNodeId, FabricIndex aFabricIndex)
{
bool found = false;
CHIP_ERROR lastDeleteErr = CHIP_NO_ERROR;

uint16_t remainingCount = 0;
for (uint16_t index = 0; index < CHIP_IM_MAX_NUM_ICD_CLIENTS; index++)
{
ICDClientInfo ICDClientInfo;
CHIP_ERROR err = Load(index, ICDClientInfo);

// delete match
if (err == CHIP_NO_ERROR)
{
if ((aPeerNodeId == ICDClientInfo.mPeerNodeId) && (aFabricIndex == ICDClientInfo.mFabricIndex))
{
found = true;
CHIP_ERROR deleteErr = Delete(index);
if (deleteErr != CHIP_NO_ERROR)
{
lastDeleteErr = deleteErr;
}
}
else
{
remainingCount++;
}
}
}

// if there are no persisted ICD client Info, the MaxCount can also be deleted
if (remainingCount == 0)
{
DeleteMaxCount();
}

if (lastDeleteErr != CHIP_NO_ERROR)
{
return lastDeleteErr;
}

return found ? CHIP_NO_ERROR : CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND;
}

CHIP_ERROR ICDClientInfoManagement::DeleteKey(Crypto::Aes128KeyHandle & aKey)
{
VerifyOrReturnError(mpSymmetricKeystore != nullptr, CHIP_ERROR_INTERNAL);
mpSymmetricKeystore->DestroyKey(aKey);
return CHIP_NO_ERROR;
}

CHIP_ERROR ICDClientInfoManagement::DeleteMaxCount()
{
return mpStorage->SyncDeleteKeyValue(DefaultStorageKeyAllocator::ICDClientInfoMaxCount().KeyName());
}

CHIP_ERROR ICDClientInfoManagement::DeleteAll(FabricIndex fabricIndex)
{
CHIP_ERROR deleteErr = CHIP_NO_ERROR;

uint16_t count = 0;
for (uint16_t index = 0; index < CHIP_IM_MAX_NUM_ICD_CLIENTS; index++)
{
ICDClientInfo clientInfo;
CHIP_ERROR err = Load(index, clientInfo);

if (err == CHIP_NO_ERROR)
{
if (fabricIndex == clientInfo.mFabricIndex)
{
err = Delete(index);
if ((err != CHIP_NO_ERROR) && (err != CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND))
{
deleteErr = err;
}
}
else
{
count++;
}
}
}

// if there are no persisted ICD ClientInfo, the MaxCount can also be deleted
if (count == 0)
{
CHIP_ERROR err = DeleteMaxCount();

if ((err != CHIP_NO_ERROR) && (err != CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND))
{
deleteErr = err;
}
}

return deleteErr;
}

} // namespace app
} // namespace chip
Loading

0 comments on commit 9134568

Please sign in to comment.