diff --git a/src/transport/SecureMessageCodec.cpp b/src/transport/SecureMessageCodec.cpp index 4664985b1a3960..3fb6775311fd3d 100644 --- a/src/transport/SecureMessageCodec.cpp +++ b/src/transport/SecureMessageCodec.cpp @@ -38,34 +38,24 @@ using System::PacketBufferHandle; namespace SecureMessageCodec { -CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle & msgBuf, MessageCounter & counter) +CHIP_ERROR Encrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, + System::PacketBufferHandle & msgBuf) { VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); VerifyOrReturnError(!msgBuf->HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); VerifyOrReturnError(msgBuf->TotalLength() <= kMaxAppMessageLen, CHIP_ERROR_MESSAGE_TOO_LONG); - uint32_t messageCounter = counter.Value(); - static_assert(std::is_sameTotalLength()), uint16_t>::value, "Addition to generate payloadLength might overflow"); - packetHeader - .SetMessageCounter(messageCounter) // - .SetSessionId(state->GetPeerSessionId()); - - // TODO set Session Type (Unicast or Group) - // packetHeader.SetSessionType(Header::SessionType::kUnicastSession); - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(msgBuf)); uint8_t * data = msgBuf->Start(); uint16_t totalLen = msgBuf->TotalLength(); CHIP_TRACE_MESSAGE(payloadHeader, packetHeader, data, totalLen); - MessageAuthenticationCode mac; - ReturnErrorOnFailure(state->EncryptBeforeSend(data, totalLen, data, packetHeader, mac)); + ReturnErrorOnFailure(session->EncryptBeforeSend(data, totalLen, data, packetHeader, mac)); uint16_t taglen = 0; ReturnErrorOnFailure(mac.Encode(packetHeader, &data[totalLen], msgBuf->AvailableDataLength(), &taglen)); @@ -73,11 +63,10 @@ CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHead VerifyOrReturnError(CanCastTo(totalLen + taglen), CHIP_ERROR_INTERNAL); msgBuf->SetDataLength(static_cast(totalLen + taglen)); - ReturnErrorOnFailure(counter.Advance()); return CHIP_NO_ERROR; } -CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, const PacketHeader & packetHeader, +CHIP_ERROR Decrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, const PacketHeader & packetHeader, System::PacketBufferHandle & msg) { ReturnErrorCodeIf(msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); @@ -107,7 +96,7 @@ CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHead msg->SetDataLength(len); uint8_t * plainText = msg->Start(); - ReturnErrorOnFailure(state->DecryptOnReceive(data, len, plainText, packetHeader, mac)); + ReturnErrorOnFailure(session->DecryptOnReceive(data, len, plainText, packetHeader, mac)); ReturnErrorOnFailure(payloadHeader.DecodeAndConsume(msg)); return CHIP_NO_ERROR; diff --git a/src/transport/SecureMessageCodec.h b/src/transport/SecureMessageCodec.h index 455466f374f9ec..2fa0700bb9a101 100644 --- a/src/transport/SecureMessageCodec.h +++ b/src/transport/SecureMessageCodec.h @@ -36,38 +36,38 @@ namespace SecureMessageCodec { /** * @brief * Attach payload header to the message and encrypt the message buffer using - * key from the connection state. + * key from the secure session. * - * @param state The connection state with peer node + * @param session The secure session context with the peer node * @param payloadHeader Reference to the payload header that should be inserted in * the message * @param packetHeader Reference to the packet header that contains unencrypted * portion of the message header * @param msgBuf The message buffer that contains the unencrypted message. If - * the operation is successuful, this buffer will contain the - * encrypted message. - * @param counter The local counter object to be used - * @ return CHIP_ERROR The result of the encode operation + * the operation is successful, this buffer will be mutated to contain + * the encrypted message. + * @return A CHIP_ERROR value consistent with the result of the encryption operation */ -CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader, - System::PacketBufferHandle & msgBuf, MessageCounter & counter); +CHIP_ERROR Encrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, + System::PacketBufferHandle & msgBuf); /** * @brief - * Decrypt the message, perform message integrity check, and decode the payload header. + * Decrypt the message, perform message integrity check, and decode the payload header, + * consuming the header from the packet in doing so. * - * @param state The connection state with peer node - * @param payloadHeader Reference to the payload header that should be inserted in - * the message + * @param session The secure session context with the peer node + * @param payloadHeader Reference to the payload header that will be recovered from the message * @param packetHeader Reference to the packet header that contains unencrypted * portion of the message header * @param msgBuf The message buffer that contains the encrypted message. If - * the operation is successuful, this buffer will contain the - * unencrypted message. - * @ return CHIP_ERROR The result of the decode operation + * the operation is successful, this buffer will be mutated to contain + * the decrypted message. + * @return A CHIP_ERROR value consistent with the result of the decryption operation */ -CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, const PacketHeader & packetHeader, +CHIP_ERROR Decrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, const PacketHeader & packetHeader, System::PacketBufferHandle & msgBuf); + } // namespace SecureMessageCodec } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 29841f0cde4c2e..25a195ff3fca9e 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -150,8 +150,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P { return CHIP_ERROR_NOT_CONNECTED; } + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session); - ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter)); + uint32_t messageCounter = counter.Value(); + packetHeader + .SetMessageCounter(messageCounter) // + .SetSessionId(session->GetPeerSessionId()) // + .SetSessionType(Header::SessionType::kUnicastSession); + + ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message)); + ReturnErrorOnFailure(counter.Advance()); #if CHIP_PROGRESS_LOGGING destination = session->GetPeerNodeId(); @@ -420,11 +428,11 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr } const SessionHandle & session = optionalSession.Value(); + Transport::UnauthenticatedSession * unsecuredSession = session->AsUnauthenticatedSession(); SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No; // Verify message counter - CHIP_ERROR err = - session->AsUnauthenticatedSession()->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter()); + CHIP_ERROR err = unsecuredSession->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter()); if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes; @@ -432,7 +440,7 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr } VerifyOrDie(err == CHIP_NO_ERROR); - session->AsUnauthenticatedSession()->MarkActive(); + unsecuredSession->MarkActive(); PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); @@ -445,11 +453,11 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr packetHeader.GetMessageCounter(), ChipLogValueExchangeIdFromReceivedHeader(payloadHeader)); } - session->AsUnauthenticatedSession()->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); + unsecuredSession->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); if (mCB != nullptr) { - mCB->OnMessageReceived(packetHeader, payloadHeader, optionalSession.Value(), peerAddress, isDuplicate, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg)); } } @@ -599,10 +607,11 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade { Optional session = CreateGroupSession(packetHeader.GetDestinationGroupId().Value()); VerifyOrReturn(session.HasValue(), ChipLogError(Inet, "Error when creating group session handle.")); + Transport::GroupSession * groupSession = session.Value()->AsGroupSession(); mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg)); - RemoveGroupSession(session.Value()->AsGroupSession()); + RemoveGroupSession(groupSession); } }