From a042db4d78f6f4dfe3caf0e3dc38aee009922f77 Mon Sep 17 00:00:00 2001 From: Guilherme Gazzo Date: Wed, 14 Jan 2026 11:03:30 -0300 Subject: [PATCH] refactor: improve type safety in invite and state services --- .../src/services/invite.service.ts | 9 +---- .../src/services/state.service.ts | 7 ++-- .../room/src/authorizartion-rules/errors.ts | 4 +- .../room/src/authorizartion-rules/rules.ts | 10 ++--- packages/room/src/manager/event-wrapper.ts | 37 ++++++++++++++++--- 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/packages/federation-sdk/src/services/invite.service.ts b/packages/federation-sdk/src/services/invite.service.ts index 9a2a5aff2..dc3e1af15 100644 --- a/packages/federation-sdk/src/services/invite.service.ts +++ b/packages/federation-sdk/src/services/invite.service.ts @@ -89,20 +89,13 @@ export class InviteService { // SPEC: Invites a remote user to a room. Once the event has been signed by both the inviting homeserver and the invited homeserver, it can be sent to all of the servers in the room by the inviting homeserver. - const invitedServer = extractDomainFromId(inviteEvent.stateKey ?? ''); - if (!invitedServer) { - throw new Error( - `invalid state_key ${inviteEvent.stateKey}, no server_name part`, - ); - } - await this.federationValidationService.validateOutboundInvite( userId, roomId, ); // if user invited belongs to our server - if (invitedServer === this.configService.serverName) { + if (inviteEvent.stateKeyDomain === this.configService.serverName) { await stateService.handlePdu(inviteEvent); // let all servers know of this state change diff --git a/packages/federation-sdk/src/services/state.service.ts b/packages/federation-sdk/src/services/state.service.ts index c1361e524..7aa6e0d72 100644 --- a/packages/federation-sdk/src/services/state.service.ts +++ b/packages/federation-sdk/src/services/state.service.ts @@ -294,7 +294,7 @@ export class StateService { return instance; } - private async addAuthEvents(event: PersistentEventBase) { + private async addAuthEvents(event: PersistentEventBase) { const state = await this.getLatestRoomState(event.roomId); const eventsNeeded = event.getAuthEventStateKeys(); @@ -307,7 +307,7 @@ export class StateService { } } - async addPrevEvents(event: PersistentEventBase) { + async addPrevEvents(event: PersistentEventBase) { const roomVersion = await this.getRoomVersion(event.roomId); if (!roomVersion) { throw new Error('Room version not found while filling prev events'); @@ -330,7 +330,7 @@ export class StateService { event.addPrevEvents(events); } - public async signEvent(event: T) { + public async signEvent>(event: T) { if (process.env.NODE_ENV === 'test') return event; const signingKey = await this.configService.getSigningKey(); @@ -511,7 +511,6 @@ export class StateService { previousStateId, ); await this.addToRoomGraph(event, previousStateId); - await this.eventService.notify(event); } diff --git a/packages/room/src/authorizartion-rules/errors.ts b/packages/room/src/authorizartion-rules/errors.ts index e6e57b2ef..d3d91cae8 100644 --- a/packages/room/src/authorizartion-rules/errors.ts +++ b/packages/room/src/authorizartion-rules/errors.ts @@ -23,9 +23,9 @@ class StateResolverAuthorizationError extends Error { reason, rejectedBy, }: { - rejectedEvent: PersistentEventBase; + rejectedEvent: PersistentEventBase; reason: string; - rejectedBy?: PersistentEventBase; + rejectedBy?: PersistentEventBase; }, ) { // build the message diff --git a/packages/room/src/authorizartion-rules/rules.ts b/packages/room/src/authorizartion-rules/rules.ts index 16d84f3bd..50a249ca0 100644 --- a/packages/room/src/authorizartion-rules/rules.ts +++ b/packages/room/src/authorizartion-rules/rules.ts @@ -174,11 +174,11 @@ async function isMembershipChangeAllowed( if (previousEvents.length === 1) { const [event] = previousEvents; - if ( - event.isCreateEvent() && - event.getContent().creator === membershipEventToCheck.stateKey - ) { - return; + if (event.isCreateEvent()) { + const content = event.getContent() as { creator?: string }; + if (content.creator === membershipEventToCheck.stateKey) { + return; + } } } diff --git a/packages/room/src/manager/event-wrapper.ts b/packages/room/src/manager/event-wrapper.ts index d24c76542..054e89b2d 100644 --- a/packages/room/src/manager/event-wrapper.ts +++ b/packages/room/src/manager/event-wrapper.ts @@ -42,16 +42,22 @@ export type PduWithHashesAndSignaturesOptional = Prettify< MakeOptional >; +export type PduTypeWithoutStateKey = { + [K in PduType]: PduForType extends { state_key: unknown } ? never : K; +}[PduType]; + export const REDACT_ALLOW_ALL_KEYS: unique symbol = Symbol.for('all'); -export interface State extends Map { +export interface State + extends Omit, 'get'> { + get(key: StateMapKey): PersistentEventBase | undefined; get( key: T, ): T extends `${infer I}:${string}` ? I extends PduType ? PersistentEventBase | undefined - : never - : never; + : PersistentEventBase | undefined + : PersistentEventBase | undefined; } // convinient wrapper to manage schema differences when working with same algorithms across different versions @@ -67,7 +73,7 @@ export abstract class PersistentEventBase< private signatures: Signature = {}; - protected rawEvent: PduWithHashesAndSignaturesOptional; + protected rawEvent: PduWithHashesAndSignaturesOptional>; private authEventsIds: Set = new Set(); private prevEventsIds: Set = new Set(); @@ -134,8 +140,27 @@ export abstract class PersistentEventBase< return residentServer; } - get stateKey() { - return 'state_key' in this.rawEvent ? this.rawEvent.state_key : undefined; + get stateKey(): Type extends PduTypeWithoutStateKey + ? undefined + : PduForType extends { state_key: string } + ? PduForType['state_key'] + : undefined { + return ( + 'state_key' in this.rawEvent ? this.rawEvent.state_key : undefined + ) as Type extends PduTypeWithoutStateKey + ? undefined + : PduForType extends { state_key: string } + ? PduForType['state_key'] + : undefined; + } + + get stateKeyDomain(): Type extends PduTypeWithoutStateKey ? never : string { + if (this.stateKey === undefined) { + throw new Error('stateKey is undefined'); + } + return extractDomainFromId( + this.stateKey, + ) as Type extends PduTypeWithoutStateKey ? never : string; } get originServerTs() {