Skip to content

Commit

Permalink
Handle late-arriving m.room_key.withheld messages (#4310)
Browse files Browse the repository at this point in the history
* Restructure eventsPendingKey to remove sender key

For withheld notices, we don't necessarily receive the sender key, so we'll
jhave to do without it.

* Re-decrypt events when we receive a withheld notice

* Extend test to cover late-arriving withheld notices

* update unit tests
  • Loading branch information
richvdh authored Jul 29, 2024
1 parent d32f398 commit dc1cccf
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
8 changes: 5 additions & 3 deletions spec/integ/crypto/crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2343,13 +2343,12 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
])(
"Decryption fails with withheld error if a withheld notice with code '%s' is received",
(withheldCode, expectedMessage, expectedErrorCode) => {
// TODO: test arrival after the event too.
it.each(["before"])("%s the event", async (when) => {
it.each(["before", "after"])("%s the event", async (when) => {
expectAliceKeyQuery({ device_keys: { "@alice:localhost": {} }, failures: {} });
await startClientAndAwaitFirstSync();

// A promise which resolves, with the MatrixEvent which wraps the event, once the decryption fails.
const awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
let awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);

// Send Alice an encrypted room event which looks like it was encrypted with a megolm session
async function sendEncryptedEvent() {
Expand Down Expand Up @@ -2393,6 +2392,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
await sendEncryptedEvent();
} else {
await sendEncryptedEvent();
// Make sure that the first attempt to decrypt has happened before the withheld arrives
await awaitDecryption;
awaitDecryption = emitPromise(aliceClient, MatrixEventEvent.Decrypted);
await sendWithheldMessage();
}

Expand Down
1 change: 1 addition & 0 deletions spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ describe("initRustCrypto", () => {
deleteSecretsFromInbox: jest.fn(),
registerReceiveSecretCallback: jest.fn(),
registerDevicesUpdatedCallback: jest.fn(),
registerRoomKeysWithheldCallback: jest.fn(),
outgoingRequests: jest.fn(),
isBackupEnabled: jest.fn().mockResolvedValue(false),
verifyBackup: jest.fn().mockResolvedValue({ trusted: jest.fn().mockReturnValue(false) }),
Expand Down
3 changes: 3 additions & 0 deletions src/rust-crypto/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ async function initOlmMachine(
await olmMachine.registerRoomKeyUpdatedCallback((sessions: RustSdkCryptoJs.RoomKeyInfo[]) =>
rustCrypto.onRoomKeysUpdated(sessions),
);
await olmMachine.registerRoomKeysWithheldCallback((withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]) =>
rustCrypto.onRoomKeysWithheld(withheld),
);
await olmMachine.registerUserIdentityUpdatedCallback((userId: RustSdkCryptoJs.UserId) =>
rustCrypto.onUserIdentityUpdated(userId),
);
Expand Down
75 changes: 51 additions & 24 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,7 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
this.logger.debug(
`Got update for session ${key.senderKey.toBase64()}|${key.sessionId} in ${key.roomId.toString()}`,
);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(key.roomId.toString(), key.sessionId);
if (pendingList.length === 0) return;

this.logger.debug(
Expand All @@ -1507,6 +1507,37 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
}
}

/**
* Callback for `OlmMachine.registerRoomKeyWithheldCallback`.
*
* Called by the rust sdk whenever we are told that a key has been withheld. We see if we had any events that
* failed to decrypt for the given session, and update their status if so.
*
* @param withheld - Details of the withheld sessions.
*/
public async onRoomKeysWithheld(withheld: RustSdkCryptoJs.RoomKeyWithheldInfo[]): Promise<void> {
for (const session of withheld) {
this.logger.debug(`Got withheld message for session ${session.sessionId} in ${session.roomId.toString()}`);
const pendingList = this.eventDecryptor.getEventsPendingRoomKey(
session.roomId.toString(),
session.sessionId,
);
if (pendingList.length === 0) return;

// The easiest way to update the status of the event is to have another go at decrypting it.
this.logger.debug(
"Retrying decryption on events:",
pendingList.map((e) => `${e.getId()}`),
);

for (const ev of pendingList) {
ev.attemptDecryption(this, { isRetry: true }).catch((_e) => {
// It's somewhat expected that we still can't decrypt here.
});
}
}
}

/**
* Callback for `OlmMachine.registerUserIdentityUpdatedCallback`
*
Expand Down Expand Up @@ -1683,7 +1714,7 @@ class EventDecryptor {
/**
* Events which we couldn't decrypt due to unknown sessions / indexes.
*
* Map from senderKey to sessionId to Set of MatrixEvents
* Map from roomId to sessionId to Set of MatrixEvents
*/
private eventsPendingKey = new MapWithDefault<string, MapWithDefault<string, Set<MatrixEvent>>>(
() => new MapWithDefault<string, Set<MatrixEvent>>(() => new Set()),
Expand Down Expand Up @@ -1843,54 +1874,50 @@ class EventDecryptor {
* Look for events which are waiting for a given megolm session
*
* Returns a list of events which were encrypted by `session` and could not be decrypted
*
* @param session -
*/
public getEventsPendingRoomKey(session: RustSdkCryptoJs.RoomKeyInfo): MatrixEvent[] {
const senderPendingEvents = this.eventsPendingKey.get(session.senderKey.toBase64());
if (!senderPendingEvents) return [];
public getEventsPendingRoomKey(roomId: string, sessionId: string): MatrixEvent[] {
const roomPendingEvents = this.eventsPendingKey.get(roomId);
if (!roomPendingEvents) return [];

const sessionPendingEvents = senderPendingEvents.get(session.sessionId);
const sessionPendingEvents = roomPendingEvents.get(sessionId);
if (!sessionPendingEvents) return [];

const roomId = session.roomId.toString();
return [...sessionPendingEvents].filter((ev) => ev.getRoomId() === roomId);
return [...sessionPendingEvents];
}

/**
* Add an event to the list of those awaiting their session keys.
*/
private addEventToPendingList(event: MatrixEvent): void {
const content = event.getWireContent();
const senderKey = content.sender_key;
const sessionId = content.session_id;
const roomId = event.getRoomId();
// We shouldn't have events without a room id here.
if (!roomId) return;

const senderPendingEvents = this.eventsPendingKey.getOrCreate(senderKey);
const sessionPendingEvents = senderPendingEvents.getOrCreate(sessionId);
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
const sessionPendingEvents = roomPendingEvents.getOrCreate(event.getWireContent().session_id);
sessionPendingEvents.add(event);
}

/**
* Remove an event from the list of those awaiting their session keys.
*/
private removeEventFromPendingList(event: MatrixEvent): void {
const content = event.getWireContent();
const senderKey = content.sender_key;
const sessionId = content.session_id;
const roomId = event.getRoomId();
if (!roomId) return;

const senderPendingEvents = this.eventsPendingKey.get(senderKey);
if (!senderPendingEvents) return;
const roomPendingEvents = this.eventsPendingKey.getOrCreate(roomId);
if (!roomPendingEvents) return;

const sessionPendingEvents = senderPendingEvents.get(sessionId);
const sessionPendingEvents = roomPendingEvents.get(event.getWireContent().session_id);
if (!sessionPendingEvents) return;

sessionPendingEvents.delete(event);

// also clean up the higher-level maps if they are now empty
if (sessionPendingEvents.size === 0) {
senderPendingEvents.delete(sessionId);
if (senderPendingEvents.size === 0) {
this.eventsPendingKey.delete(senderKey);
roomPendingEvents.delete(event.getWireContent().session_id);
if (roomPendingEvents.size === 0) {
this.eventsPendingKey.delete(roomId);
}
}
}
Expand Down

0 comments on commit dc1cccf

Please sign in to comment.