Skip to content

Commit

Permalink
Ensure that we send MRP acks to incoming messages as needed. (#19398)
Browse files Browse the repository at this point in the history
There were two ways we could fail to send an ack to an incoming reliable message:

1) If we found no matching handler, and hence created an ephemeral exchange to
handle the message, but the message was unencrypted.  In this case our ephemeral
exchange would return true for IsEncryptionRequired(), because it would default
to an ApplicationExchangeDispatch, and we would never call into
ExchangeContext::HandleMessage.

2) If ExchangeMessageDispatch::MessagePermitted returned false for the message.
In particular, for an ApplicationExchangeDispatch, this would happen for all the
handshake messages except StatusReport.

The fix for issue 1 is to ensure we always call into HandleMEssage if we manage
to allocate an exchange.  If there is an encryption mismatch, which only matters
when the exchange is non-ephemeral, we close the exchange first to prevent event
delivery to the delegate.

The fix for issue 2 is to move the MRP processing out of ExchangeMessageDispatch
and into ExchangeContext, and to move the MessagePermitted check so the only
thing it prevents is delivery of the message to the delegate, not any other
processing by the exchange.

Fixes #10515
  • Loading branch information
bzbarsky-apple authored and pull[bot] committed Feb 22, 2024
1 parent 12de763 commit 1426828
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 62 deletions.
10 changes: 3 additions & 7 deletions src/messaging/ApplicationExchangeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
namespace chip {
namespace Messaging {

bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type)
{
// TODO: Change this check to only include the protocol and message types that are allowed
switch (protocol)
if (protocol == Protocols::SecureChannel::Id)
{
case Protocols::SecureChannel::Id.GetProtocolId():
switch (type)
{
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PBKDFParamRequest):
Expand All @@ -49,11 +48,8 @@ bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t ty
default:
break;
}
break;

default:
break;
}

return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ApplicationExchangeDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch
~ApplicationExchangeDispatch() override {}

protected:
bool MessagePermitted(uint16_t protocol, uint8_t type) override;
bool MessagePermitted(Protocols::Id protocol, uint8_t type) override;
};

} // namespace Messaging
Expand Down
20 changes: 18 additions & 2 deletions src/messaging/ExchangeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,21 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload
MessageHandled();
});

ReturnErrorOnFailure(mDispatch.OnMessageReceived(messageCounter, payloadHeader, msgFlags, GetReliableMessageContext()));
if (mDispatch.IsReliableTransmissionAllowed() && !IsGroupExchangeContext())
{
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() &&
payloadHeader.GetAckMessageCounter().HasValue())
{
HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value());
}

if (payloadHeader.NeedsAck())
{
// An acknowledgment needs to be sent back to the peer for this message on this exchange,

HandleNeedsAck(messageCounter, msgFlags);
}
}

if (IsAckPending() && !mDelegate)
{
Expand Down Expand Up @@ -487,7 +501,9 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload
// is implicitly that response.
SetResponseExpected(false);

if (mDelegate != nullptr)
// Don't send messages on to our delegate if our dispatch does not allow
// those messages.
if (mDelegate != nullptr && mDispatch.MessagePermitted(payloadHeader.GetProtocolID(), payloadHeader.GetMessageType()))
{
return mDelegate->OnMessageReceived(this, payloadHeader, std::move(msgBuf));
}
Expand Down
27 changes: 1 addition & 26 deletions src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager,
bool isReliableTransmission, Protocols::Id protocol, uint8_t type,
System::PacketBufferHandle && message)
{
ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT);
ReturnErrorCodeIf(!MessagePermitted(protocol, type), CHIP_ERROR_INVALID_ARGUMENT);

PayloadHeader payloadHeader;
payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator);
Expand Down Expand Up @@ -113,30 +113,5 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager,
return CHIP_NO_ERROR;
}

CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader,
MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext)
{
ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()),
CHIP_ERROR_INVALID_ARGUMENT);

if (IsReliableTransmissionAllowed() && !reliableMessageContext->GetExchangeContext()->IsGroupExchangeContext())
{
if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() &&
payloadHeader.GetAckMessageCounter().HasValue())
{
reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value());
}

if (payloadHeader.NeedsAck())
{
// An acknowledgment needs to be sent back to the peer for this message on this exchange,

ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags));
}
}

return CHIP_NO_ERROR;
}

} // namespace Messaging
} // namespace chip
6 changes: 2 additions & 4 deletions src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#pragma once

#include <messaging/Flags.h>
#include <protocols/Protocols.h>
#include <transport/SessionManager.h>

namespace chip {
Expand All @@ -42,11 +43,8 @@ class ExchangeMessageDispatch
CHIP_ERROR SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, bool isInitiator,
ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol,
uint8_t type, System::PacketBufferHandle && message);
CHIP_ERROR OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, MessageFlags msgFlags,
ReliableMessageContext * reliableMessageContext);

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;
virtual bool MessagePermitted(Protocols::Id protocol, uint8_t type) = 0;

// TODO: remove IsReliableTransmissionAllowed, this function should be provided over session.
virtual bool IsReliableTransmissionAllowed() const { return true; }
Expand Down
18 changes: 15 additions & 3 deletions src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,23 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec),
ec->GetDelegate());

if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
// Make sure the exchange stays alive through the code below even if we
// close it before calling HandleMessage.
ExchangeHandle ref(*ec);

// Ignore encryption-required mismatches for emphemeral exchanges,
// because those never have delegates anyway.
if (matchingUMH != nullptr && ec->IsEncryptionRequired() != packetHeader.IsEncrypted())
{
ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE));
// We want to still to do MRP processing for this message, but we do
// not want to deliver it to the application. Just close the
// exchange (which will notify the delegate, null it out, etc), then
// go ahead and call HandleMessage() on it to do the MRP
// processing.null out the delegate on the exchange, pretend to
// matchingUMH that exchange creation failed, so it cleans up the
// delegate, then tell the exchagne to handle the message.
ChipLogProgress(ExchangeManager, "OnMessageReceived encryption mismatch");
ec->Close();
return;
}

CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, msgFlags, std::move(msgBuf));
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationEx
public:
bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; }

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }
bool MessagePermitted(Protocols::Id protocol, uint8_t type) override { return true; }

bool IsEncryptionRequired() const override { return mRequireEncryption; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ namespace chip {

using namespace Messaging;

bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t type)
bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8_t type)
{
switch (protocol)
if (protocol == Protocols::SecureChannel::Id)
{
case Protocols::SecureChannel::Id.GetProtocolId():
switch (type)
{
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::StandaloneAck):
Expand All @@ -52,11 +51,8 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u
default:
break;
}
break;

default:
break;
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi
~SessionEstablishmentExchangeDispatch() override {}

protected:
bool MessagePermitted(uint16_t protocol, uint8_t type) override;
bool MessagePermitted(Protocols::Id, uint8_t type) override;
bool IsEncryptionRequired() const override { return false; }
};

Expand Down
9 changes: 4 additions & 5 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,11 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
ctx.DrainAndServiceIO();

auto & loopback = ctx.GetLoopback();
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1);
// There should have been two message sent: Sigma1 and an ack.
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2);

// Clear pending packet in CRMP
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = context->GetReliableMessageContext();
rm->ClearRetransTable(rc);
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);

loopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;

Expand Down
9 changes: 4 additions & 5 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,11 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
&delegate) == CHIP_NO_ERROR);
ctx.DrainAndServiceIO();

NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 1);
// There should have been two messages sent: PBKDFParamRequest and an ack.
NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == 2);

// Clear pending packet in CRMP
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = context->GetReliableMessageContext();
rm->ClearRetransTable(rc);
ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0);

loopback.Reset();
loopback.mSentMessageCount = 0;
Expand Down

0 comments on commit 1426828

Please sign in to comment.