diff --git a/packages/federation-sdk/src/container.ts b/packages/federation-sdk/src/container.ts index b4e7d04fc..33d0947eb 100644 --- a/packages/federation-sdk/src/container.ts +++ b/packages/federation-sdk/src/container.ts @@ -91,10 +91,10 @@ export async function createFederationContainer( container.registerSingleton(SignatureVerificationService); container.registerSingleton(FederationService); container.registerSingleton(StateService); - container.registerSingleton(EventAuthorizationService); - container.registerSingleton('EventFetcherService', EventFetcherService); + container.registerSingleton(EventFetcherService); container.registerSingleton(EventStateService); container.registerSingleton(EventService); + container.registerSingleton(EventAuthorizationService); container.registerSingleton(EventEmitterService); container.registerSingleton(InviteService); container.registerSingleton(MediaService); diff --git a/packages/federation-sdk/src/index.spec.ts b/packages/federation-sdk/src/index.spec.ts new file mode 100644 index 000000000..201fc19c0 --- /dev/null +++ b/packages/federation-sdk/src/index.spec.ts @@ -0,0 +1,248 @@ +/* +Note on testing framework: +- These tests are authored to run under the repository's existing test runner (Jest or Vitest). +- If running under Vitest, vi is available; under Jest, jest is available. +- We create a small shim so the same test file works in both without adding dependencies. +*/ + +const testApi = (() => { + const g: any = globalThis as any; + if (g.vi) return { spy: g.vi.spyOn.bind(g.vi), mock: g.vi.fn.bind(g.vi), reset: g.vi.resetAllMocks?.bind(g.vi) ?? (() => {}), clear: g.vi.clearAllMocks?.bind(g.vi) ?? (() => {}) }; + if (g.jest) return { spy: g.jest.spyOn.bind(g.jest), mock: g.jest.fn.bind(g.jest), reset: g.jest.resetAllMocks?.bind(g.jest) ?? (() => {}), clear: g.jest.clearAllMocks?.bind(g.jest) ?? (() => {}) }; + throw new Error("No supported test framework detected: expected vi or jest on globalThis."); +})(); + +import * as Tsyringe from 'tsyringe'; + +// Import subject under test +import { + getAllServices, + // Services we'll assert are resolved + ConfigService, + EduService, + EventAuthorizationService, + EventService, + FederationRequestService, + InviteService, + MediaService, + MessageService, + ProfilesService, + RoomService, + SendJoinService, + ServerService, + StateService, + WellKnownService, + // Selected runtime exports to sanity check re-exports + FederationModule, + FederationRequestService as FederationRequestServiceExport, + FederationService, + SignatureVerificationService, + WellKnownService as WellKnownServiceExport, + DatabaseConnectionService, + EduService as EduServiceExport, + ServerService as ServerServiceExport, + EventAuthorizationService as EventAuthorizationServiceExport, + EventStateService, + MissingEventService, + ProfilesService as ProfilesServiceExport, + EventFetcherService, + InviteService as InviteServiceExport, + MediaService as MediaServiceExport, + MessageService as MessageServiceExport, + EventService as EventServiceExport, + RoomService as RoomServiceExport, + StateService as StateServiceExport, + StagingAreaService, + SendJoinService as SendJoinServiceExport, + EventEmitterService, + MissingEventListener, + // Queues and utils + BaseQueue, + getErrorMessage, + USERNAME_REGEX, + ROOM_ID_REGEX, + LockManagerService, + EventRepository, + RoomRepository, + ServerRepository, + KeyRepository, + StateRepository, + StagingAreaListener, + createFederationContainer, + DependencyContainer +} from './index'; + +describe('packages/federation-sdk/src/index.ts public API', () => { + beforeEach(() => { + testApi.reset?.(); + testApi.clear?.(); + }); + + it('should expose key runtime exports', () => { + // Modules / classes (existence checks) + expect(FederationModule).toBeDefined(); + expect(FederationService).toBeDefined(); + expect(SignatureVerificationService).toBeDefined(); + expect(DatabaseConnectionService).toBeDefined(); + expect(EventStateService).toBeDefined(); + expect(MissingEventService).toBeDefined(); + expect(EventFetcherService).toBeDefined(); + expect(StagingAreaService).toBeDefined(); + expect(EventEmitterService).toBeDefined(); + expect(MissingEventListener).toBeDefined(); + expect(StagingAreaListener).toBeDefined(); + // Re-export sanity (aliases point to same runtime) + expect(WellKnownServiceExport).toBe(WellKnownService); + expect(EduServiceExport).toBe(EduService); + expect(ServerServiceExport).toBe(ServerService); + expect(EventAuthorizationServiceExport).toBe(EventAuthorizationService); + expect(ProfilesServiceExport).toBe(ProfilesService); + expect(InviteServiceExport).toBe(InviteService); + expect(MediaServiceExport).toBe(MediaService); + expect(MessageServiceExport).toBe(MessageService); + expect(EventServiceExport).toBe(EventService); + expect(RoomServiceExport).toBe(RoomService); + expect(StateServiceExport).toBe(StateService); + expect(SendJoinServiceExport).toBe(SendJoinService); + // Utils and constants + expect(typeof getErrorMessage).toBe('function'); + expect(USERNAME_REGEX).toBeInstanceOf(RegExp); + expect(ROOM_ID_REGEX).toBeInstanceOf(RegExp); + // Queues / Base types + expect(BaseQueue).toBeDefined(); + // Container helpers + expect(createFederationContainer).toBeDefined(); + expect(DependencyContainer).toBeDefined(); + // Repositories + expect(EventRepository).toBeDefined(); + expect(RoomRepository).toBeDefined(); + expect(ServerRepository).toBeDefined(); + expect(KeyRepository).toBeDefined(); + expect(StateRepository).toBeDefined(); + // Additional runtime export check + expect(FederationRequestServiceExport).toBe(FederationRequestService); + }); + + describe('getAllServices()', () => { + function makeMockInstances() { + // Unique objects to ensure identity mapping + return { + room: { name: 'room' }, + message: { name: 'message' }, + media: { name: 'media' }, + event: { name: 'event' }, + invite: { name: 'invite' }, + wellKnown: { name: 'wellKnown' }, + profile: { name: 'profile' }, + state: { name: 'state' }, + sendJoin: { name: 'sendJoin' }, + server: { name: 'server' }, + config: { name: 'config' }, + edu: { name: 'edu' }, + request: { name: 'request' }, + federationAuth: { name: 'federationAuth' }, + } as const; + } + + function arrangeContainerResolveMock(instances: ReturnType) { + // Spy on container.resolve and route by token + const spy = testApi.spy(Tsyringe, 'container', 'get'); // Not available; fallback approach below + // In environments where spying on "container.resolve" directly is easier: + const resolveSpy = testApi.spy(Tsyringe.container as any, 'resolve'); + (Tsyringe.container.resolve as unknown as jest.Mock | ((...args:any[])=>any)).mockImplementation((cls: any) => { + switch (cls) { + case RoomService: return instances.room; + case MessageService: return instances.message; + case MediaService: return instances.media; + case EventService: return instances.event; + case InviteService: return instances.invite; + case WellKnownService: return instances.wellKnown; + case ProfilesService: return instances.profile; + case StateService: return instances.state; + case SendJoinService: return instances.sendJoin; + case ServerService: return instances.server; + case ConfigService: return instances.config; + case EduService: return instances.edu; + case FederationRequestService: return instances.request; + case EventAuthorizationService: return instances.federationAuth; + default: + throw new Error('Unexpected token passed to container.resolve'); + } + }); + return resolveSpy; + } + + it('returns a mapping of all services resolved from tsyringe container', () => { + const instances = makeMockInstances(); + const resolveSpy = arrangeContainerResolveMock(instances); + + const result = getAllServices(); + + // ensure resolve called for each service token exactly once + const expectedTokens = [ + RoomService, + MessageService, + MediaService, + EventService, + InviteService, + WellKnownService, + ProfilesService, + StateService, + SendJoinService, + ServerService, + ConfigService, + EduService, + FederationRequestService, + EventAuthorizationService, + ]; + for (const token of expectedTokens) { + expect(resolveSpy).toHaveBeenCalledWith(token); + } + expect(resolveSpy).toHaveBeenCalledTimes(expectedTokens.length); + + // result shape and identity + expect(result).toEqual(instances); + // identity checks + expect(result.room).toBe(instances.room); + expect(result.federationAuth).toBe(instances.federationAuth); + }); + + it('propagates errors thrown by container.resolve', () => { + const error = new Error('boom'); + const resolveSpy = testApi.spy(Tsyringe.container as any, 'resolve'); + (Tsyringe.container.resolve as unknown as jest.Mock | ((...args:any[])=>any)).mockImplementation((cls: any) => { + if (cls === RoomService) throw error; + return {}; + }); + + expect(() => getAllServices()).toThrow(error); + expect(resolveSpy).toHaveBeenCalledWith(RoomService); + }); + + it('resolves fresh instances on each call (no shared object reuse by function wrapper)', () => { + const first = { name: 'first' }; + const second = { name: 'second' }; + const resolveSpy = testApi.spy(Tsyringe.container as any, 'resolve'); + let call = 0; + (Tsyringe.container.resolve as unknown as jest.Mock | ((...args:any[])=>any)).mockImplementation((cls: any) => { + // Return different objects for room per call to ensure we call container each time + if (cls === RoomService) { + call += 1; + return call === 1 ? first : second; + } + return {}; + }); + + const a = getAllServices(); + const b = getAllServices(); + expect(a.room).toBe(first); + expect(b.room).toBe(second); + expect(resolveSpy).toHaveBeenCalledWith(RoomService); + expect(resolveSpy).toHaveBeenCalledTimes(2 + 2 * 13); // 14 tokens per call; soft check below to be resilient + + // Soft assertion: exactly 14 calls per invocation + const callsPerInvocation = 14; + expect((resolveSpy as any).mock.calls.length % callsPerInvocation).toBe(0); + }); + }); +}); \ No newline at end of file diff --git a/packages/federation-sdk/src/index.ts b/packages/federation-sdk/src/index.ts index e42f82f5a..5f3ac1ce6 100644 --- a/packages/federation-sdk/src/index.ts +++ b/packages/federation-sdk/src/index.ts @@ -2,6 +2,7 @@ import type { Membership } from '@hs/core'; import { container } from 'tsyringe'; import { ConfigService } from './services/config.service'; import { EduService } from './services/edu.service'; +import { EventAuthorizationService } from './services/event-authorization.service'; import { EventService } from './services/event.service'; import { FederationRequestService } from './services/federation-request.service'; import { InviteService } from './services/invite.service'; @@ -45,6 +46,7 @@ export { EventFetcherService } from './services/event-fetcher.service'; export type { FetchedEvents } from './services/event-fetcher.service'; export { InviteService } from './services/invite.service'; export type { ProcessInviteEvent } from './services/invite.service'; +export { MediaService } from './services/media.service'; export { MessageService } from './services/message.service'; export { EventService } from './services/event.service'; export { RoomService } from './services/room.service'; @@ -53,7 +55,6 @@ export { StagingAreaService } from './services/staging-area.service'; export { SendJoinService } from './services/send-join.service'; export { EventEmitterService } from './services/event-emitter.service'; export { MissingEventListener } from './listeners/missing-event.listener'; -export { MediaService } from './services/media.service'; // Repository interfaces and implementations @@ -96,6 +97,7 @@ export { StateRepository } from './repositories/state.repository'; export interface HomeserverServices { room: RoomService; message: MessageService; + media: MediaService; event: EventService; invite: InviteService; wellKnown: WellKnownService; @@ -105,8 +107,8 @@ export interface HomeserverServices { server: ServerService; config: ConfigService; edu: EduService; - media: MediaService; request: FederationRequestService; + federationAuth: EventAuthorizationService; } export type HomeserverEventSignatures = { @@ -220,6 +222,7 @@ export function getAllServices(): HomeserverServices { return { room: container.resolve(RoomService), message: container.resolve(MessageService), + media: container.resolve(MediaService), event: container.resolve(EventService), invite: container.resolve(InviteService), wellKnown: container.resolve(WellKnownService), @@ -229,8 +232,8 @@ export function getAllServices(): HomeserverServices { server: container.resolve(ServerService), config: container.resolve(ConfigService), edu: container.resolve(EduService), - media: container.resolve(MediaService), request: container.resolve(FederationRequestService), + federationAuth: container.resolve(EventAuthorizationService), }; } diff --git a/packages/federation-sdk/src/services/event-authorization.service.ts b/packages/federation-sdk/src/services/event-authorization.service.ts index 8f2098c50..79f636446 100644 --- a/packages/federation-sdk/src/services/event-authorization.service.ts +++ b/packages/federation-sdk/src/services/event-authorization.service.ts @@ -1,12 +1,27 @@ -import { createLogger, generateId } from '@hs/core'; -import type { EventBase } from '@hs/core'; -import { Pdu } from '@hs/room'; +import { + createLogger, + extractSignaturesFromHeader, + generateId, + validateAuthorizationHeader, +} from '@hs/core'; +import type { Pdu } from '@hs/room'; import { singleton } from 'tsyringe'; +import { ConfigService } from './config.service'; +import { EventService } from './event.service'; +import { SignatureVerificationService } from './signature-verification.service'; +import { StateService } from './state.service'; @singleton() export class EventAuthorizationService { private readonly logger = createLogger('EventAuthorizationService'); + constructor( + private readonly stateService: StateService, + private readonly eventService: EventService, + private readonly signatureVerificationService: SignatureVerificationService, + private readonly configService: ConfigService, + ) {} + async authorizeEvent(event: Pdu, authEvents: Pdu[]): Promise { this.logger.debug( `Authorizing event ${generateId(event)} of type ${event.type}`, @@ -93,4 +108,264 @@ export class EventAuthorizationService { // TODO: Check sender has permission to change join rules return true; } + + private async verifyRequestSignature( + method: string, + uri: string, + authorizationHeader: string, + body?: Record, + ): Promise { + if (!authorizationHeader?.startsWith('X-Matrix')) { + this.logger.debug('Missing or invalid X-Matrix authorization header'); + return; + } + + try { + const { origin, destination, key, signature } = + extractSignaturesFromHeader(authorizationHeader); + + if (!origin || !key || !signature) { + this.logger.warn('Missing required fields in X-Matrix header'); + return; + } + + if (destination && destination !== this.configService.serverName) { + this.logger.warn( + `Request destination ${destination} does not match server name ${this.configService.serverName}`, + ); + return; + } + + const [algorithm] = key.split(':'); + if (algorithm !== 'ed25519') { + this.logger.warn(`Unsupported key algorithm: ${algorithm}`); + return; + } + + const publicKey = + await this.signatureVerificationService.getOrFetchPublicKey( + origin, + key, + ); + if (!publicKey) { + this.logger.warn(`Could not fetch public key for ${origin}:${key}`); + return; + } + + const isValid = await validateAuthorizationHeader( + origin, + publicKey.verify_keys[key].key, + destination || this.configService.serverName, + method, + uri, + signature, + body, + ); + if (!isValid) { + this.logger.warn(`Invalid signature from ${origin}`); + return; + } + + return origin; + } catch (error) { + this.logger.error( + { error, method, uri, authorizationHeader, body }, + 'Error verifying request signature', + ); + return; + } + } + + private async canAccessEvent( + eventId: string, + serverName: string, + ): Promise { + try { + const event = await this.eventService.getEventById(eventId); + if (!event) { + this.logger.debug(`Event ${eventId} not found`); + return false; + } + + const roomId = event.event.room_id; + + const isServerAllowed = await this.checkServerAcl(roomId, serverName); + if (!isServerAllowed) { + this.logger.warn( + `Server ${serverName} is denied by room ACL for room ${roomId}`, + ); + return false; + } + + const serversInRoom = await this.stateService.getServersInRoom(roomId); + if (serversInRoom.includes(serverName)) { + this.logger.debug(`Server ${serverName} is in room, allowing access`); + return true; + } + + const roomState = await this.stateService.getFullRoomState(roomId); + const historyVisibility = this.getHistoryVisibility(roomState); + if (historyVisibility === 'world_readable') { + this.logger.debug( + `Event ${eventId} is world_readable, allowing ${serverName}`, + ); + return true; + } + + this.logger.debug( + `Server ${serverName} not authorized: not in room and event not world_readable`, + ); + return false; + } catch (error) { + this.logger.error( + { error, eventId, serverName }, + 'Error checking event access', + ); + return false; + } + } + + private getHistoryVisibility(roomState: Map): string { + for (const [, stateEvent] of roomState) { + if (stateEvent.type === 'm.room.history_visibility') { + const content = stateEvent.getContent() as { + history_visibility?: string; + }; + return content?.history_visibility || 'joined'; + } + } + return 'joined'; // default per Matrix spec + } + + async canAccessEventFromAuthorizationHeader( + eventId: string, + authorizationHeader: string, + method: string, + uri: string, + body?: Record, + ): Promise< + | { authorized: true } + | { + authorized: false; + errorCode: 'M_UNAUTHORIZED' | 'M_FORBIDDEN' | 'M_UNKNOWN'; + } + > { + try { + const signatureResult = await this.verifyRequestSignature( + method, + uri, + authorizationHeader, + body, // keep body due to canonical json validation + ); + if (!signatureResult) { + return { + authorized: false, + errorCode: 'M_UNAUTHORIZED', + }; + } + + const authorized = await this.canAccessEvent(eventId, signatureResult); + if (!authorized) { + return { + authorized: false, + errorCode: 'M_FORBIDDEN', + }; + } + + return { + authorized: true, + }; + } catch (error) { + this.logger.error( + { error, eventId, authorizationHeader, method, uri, body }, + 'Error checking event access', + ); + return { + authorized: false, + errorCode: 'M_UNKNOWN', + }; + } + } + + // as per Matrix spec: https://spec.matrix.org/v1.15/client-server-api/#mroomserver_acl + private async checkServerAcl( + roomId: string, + serverName: string, + ): Promise { + const [serverAcl] = await this.stateService.getStateEventsByType( + roomId, + 'm.room.server_acl', + ); + if (!serverAcl) { + return true; + } + + const serverAclContent = serverAcl.getContent() as { + allow?: string[]; + deny?: string[]; + allow_ip_literals?: boolean; + }; + const { + allow = [], + deny = [], + allow_ip_literals = true, + } = serverAclContent; + + const isIpLiteral = + /^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(:\d+)?$/.test(serverName) || + /^\[.*\](:\d+)?$/.test(serverName); // IPv6 + if (isIpLiteral && !allow_ip_literals) { + this.logger.debug(`Server ${serverName} denied: IP literals not allowed`); + return false; + } + + for (const pattern of deny) { + if (this.matchesServerPattern(serverName, pattern)) { + this.logger.debug( + `Server ${serverName} matches deny pattern: ${pattern}`, + ); + return false; + } + } + + // if allow list is empty, deny all servers (as per Matrix spec) + // empty allow list means no servers are allowed + if (allow.length === 0) { + this.logger.debug(`Server ${serverName} denied: allow list is empty`); + return false; + } + + for (const pattern of allow) { + if (this.matchesServerPattern(serverName, pattern)) { + this.logger.debug( + `Server ${serverName} matches allow pattern: ${pattern}`, + ); + return true; + } + } + + this.logger.debug(`Server ${serverName} not in allow list`); + return false; + } + + private matchesServerPattern(serverName: string, pattern: string): boolean { + if (serverName === pattern) { + return true; + } + + let regexPattern = pattern + .replace(/[.+^${}()|[\]\\]/g, '\\$&') + .replace(/\*/g, '.*') + .replace(/\?/g, '.'); + + regexPattern = `^${regexPattern}$`; + + try { + const regex = new RegExp(regexPattern); + return regex.test(serverName); + } catch (error) { + this.logger.warn(`Invalid ACL pattern: ${pattern}`, error); + return false; + } + } } diff --git a/packages/federation-sdk/src/services/federation-request.service.spec.ts b/packages/federation-sdk/src/services/federation-request.service.spec.ts index 7ee34557f..13ffb6511 100644 --- a/packages/federation-sdk/src/services/federation-request.service.spec.ts +++ b/packages/federation-sdk/src/services/federation-request.service.spec.ts @@ -332,3 +332,292 @@ describe('FederationRequestService', async () => { }); }); }); + +describe('FederationRequestService – additional edge cases', async () => { + let service: FederationRequestService; + let configService: ConfigService; + + const mockServerName = 'example.com'; + const mockSigningKey = 'aGVsbG93b3JsZA=='; + const mockSigningKeyId = 'ed25519:1'; + + const mockKeyPair = { + publicKey: new Uint8Array([1, 2, 3]), + secretKey: new Uint8Array([4, 5, 6]), + }; + + const discoveryHttps443 = [ + 'https://target.example.com:443' as const, + { Host: 'target.example.com' }, + ]; + + const discoveryHttp80 = [ + 'http://target.example.com:80' as const, + { Host: 'target.example.com' }, + ]; + + const discoveryCustomPort = [ + 'https://target.example.com:8448' as const, + { Host: 'target.example.com' }, + ]; + + const { getHomeserverFinalAddress } = await import('../server-discovery/discovery'); + const { fetch: originalFetch } = await import('@hs/core'); + + await mock.module('../server-discovery/discovery', () => ({ + getHomeserverFinalAddress: () => discoveryHttps443, + })); + + await mock.module('@hs/core', () => ({ + fetch: async (_url: string, _options?: RequestInit) => { + return { + ok: true, + status: 200, + json: async () => ({ result: 'success' }), + text: async () => '{"result":"success"}', + } as Response; + }, + })); + + afterAll(() => { + mock.restore(); + mock.module('../server-discovery/discovery', () => ({ + getHomeserverFinalAddress, + })); + mock.module('@hs/core', () => ({ + fetch: originalFetch, + })); + }); + + beforeEach(() => { + // crypto and core helpers + spyOn(nacl.sign.keyPair, 'fromSecretKey').mockReturnValue(mockKeyPair); + spyOn(nacl.sign, 'detached').and.returnValue(new Uint8Array([7, 8, 9])); + spyOn(core, 'extractURIfromURL').mockReturnValue('/test/path?query=value'); + spyOn(core, 'authorizationHeaders').mockResolvedValue( + 'X-Matrix origin="example.com",destination="target.example.com",key="ed25519:1",sig="xyz123"' + ); + spyOn(core, 'signJson').mockResolvedValue({ + content: 'test', + signatures: { 'example.com': { 'ed25519:1': 'abcdef' } }, + }); + spyOn(core, 'computeAndMergeHash').mockImplementation((obj: any) => obj); + + // config + configService = { + serverName: mockServerName, + getSigningKeyBase64: async () => mockSigningKey, + getSigningKeyId: async () => mockSigningKeyId, + } as ConfigService; + + // SUT + service = new FederationRequestService(configService); + }); + + afterEach(() => { + mock.restore(); + }); + + it('builds URL with discovered https:443 and appends provided query string when empty body', async () => { + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/_matrix/federation/v1/version', + queryString: 'q=1', + }); + + expect(fetchSpy).toHaveBeenCalledWith( + new URL('https://target.example.com/_matrix/federation/v1/version?q=1'), + expect.any(Object), + ); + }); + + it('respects discovery result with http:80 (no implicit https) when constructing URL', async () => { + mock.module('../server-discovery/discovery', () => ({ + getHomeserverFinalAddress: () => discoveryHttp80, + })); + + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/health', + }); + + expect(fetchSpy).toHaveBeenCalledWith( + new URL('http://target.example.com/health'), + expect.any(Object), + ); + }); + + it('respects non-standard port (8448) from discovery when constructing URL', async () => { + mock.module('../server-discovery/discovery', () => ({ + getHomeserverFinalAddress: () => discoveryCustomPort, + })); + + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/_matrix/key/v2/server', + }); + + expect(fetchSpy).toHaveBeenCalledWith( + new URL('https://target.example.com:8448/_matrix/key/v2/server'), + expect.any(Object), + ); + }); + + it('omits body for GET even if body is accidentally provided (defensive behavior)', async () => { + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/test/path', + // accidental body + body: { shouldNot: 'be sent' } as any, + }); + + const [, opts] = (fetchSpy.mock.calls[0] as any) || []; + expect(opts.method).toBe('GET'); + expect(opts.body ?? undefined).toBeUndefined(); + }); + + it('stringifies body for POST after signing and hashing', async () => { + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'POST', + domain: 'target.example.com', + uri: '/tx', + body: { hello: 'world' }, + }); + + const [, opts] = (fetchSpy.mock.calls[0] as any) || []; + expect(opts.method).toBe('POST'); + // Body should be the signed JSON not the raw body + expect(typeof opts.body).toBe('string'); + const parsed = JSON.parse(opts.body); + expect(parsed.signatures).toBeDefined(); + expect(core.signJson).toHaveBeenCalled(); + expect(core.computeAndMergeHash).toHaveBeenCalled(); + }); + + it('propagates authorization headers and Host correctly', async () => { + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'PUT', + domain: 'target.example.com', + uri: '/resource/1', + body: { a: 1 }, + }); + + const [, opts] = (fetchSpy.mock.calls[0] as any) || []; + expect(opts.headers.Authorization).toContain('X-Matrix'); + expect(opts.headers.Host).toBe('target.example.com'); + expect(opts.headers['Content-Type']).toBe('application/json'); + }); + + it('throws when required config values are missing (no key id)', async () => { + configService = { + serverName: mockServerName, + getSigningKeyBase64: async () => mockSigningKey, + getSigningKeyId: async () => undefined as any, + } as unknown as ConfigService; + + service = new FederationRequestService(configService); + + await expect( + service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/test', + }), + ).rejects.toThrow(); + }); + + it('handles non-JSON error bodies by including raw text in the thrown message', async () => { + // Simulate failing fetch from @hs/core + mock.module('@hs/core', () => ({ + fetch: async (_url: string, _options?: RequestInit) => { + return { + ok: false, + status: 502, + text: async () => 'Bad Gateway', + } as Response; + }, + })); + + try { + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/fail', + }); + throw new Error('Expected to throw'); + } catch (err) { + if (err instanceof Error) { + expect(err.message).toContain('Federation request failed: 502 Bad Gateway'); + } else { + throw err; + } + } + }); + + it('GET/POST/PUT convenience methods join query params deterministically', async () => { + const spy = spyOn(service as any, 'makeSignedRequest').mockResolvedValue({ ok: 1 }); + await service.get('target.example.com', '/items', { a: '1', b: '2' }); + await service.post('target.example.com', '/items', { x: 1 }, { b: '2', a: '1' }); + await service.put('target.example.com', '/items/1', { y: 2 }); + + // Validate parameters passed down + expect(spy.mock.calls[0][0]).toEqual({ + method: 'GET', + domain: 'target.example.com', + uri: '/items', + queryString: 'a=1&b=2', + }); + expect(spy.mock.calls[1][0]).toEqual({ + method: 'POST', + domain: 'target.example.com', + uri: '/items', + body: { x: 1 }, + queryString: 'a=1&b=2', + }); + expect(spy.mock.calls[2][0]).toEqual({ + method: 'PUT', + domain: 'target.example.com', + uri: '/items/1', + body: { y: 2 }, + queryString: '', + }); + }); + + it('bubbles up low-level network error intact (message preserved)', async () => { + mock.module('@hs/core', () => ({ + fetch: async () => { + throw new Error('socket hang up'); + }, + })); + + await expect( + service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '/timeout', + }), + ).rejects.toThrow('socket hang up'); + }); + + it('normalizes uri joining to avoid double slashes', async () => { + const fetchSpy = spyOn(core, 'fetch'); + await service.makeSignedRequest({ + method: 'GET', + domain: 'target.example.com', + uri: '//double//slashes/', // intentionally odd + }); + + const [url] = (fetchSpy.mock.calls[0] as any) || []; + expect(String(url)).toMatch(/^https?:\/\/target\.example\.com(:\d+)?\/double\/slashes\/?$/); + }); +}); diff --git a/packages/federation-sdk/src/services/federation-request.service.ts b/packages/federation-sdk/src/services/federation-request.service.ts index 6efd3e887..c9e2e9b97 100644 --- a/packages/federation-sdk/src/services/federation-request.service.ts +++ b/packages/federation-sdk/src/services/federation-request.service.ts @@ -180,4 +180,61 @@ export class FederationRequestService { ): Promise { return this.request('POST', targetServer, endpoint, body, queryParams); } + + async prepareSignedRequest( + targetServer: string, + endpoint: string, + method: string, + body?: Record, + ): Promise<{ url: URL; headers: Record }> { + const serverName = this.configService.serverName; + const signingKeyBase64 = await this.configService.getSigningKeyBase64(); + const signingKeyId = await this.configService.getSigningKeyId(); + const privateKeyBytes = Buffer.from(signingKeyBase64, 'base64'); + const keyPair = nacl.sign.keyPair.fromSecretKey(privateKeyBytes); + + const signingKey: SigningKey = { + algorithm: EncryptionValidAlgorithm.ed25519, + version: signingKeyId.split(':')[1] || '1', + privateKey: keyPair.secretKey, + publicKey: keyPair.publicKey, + sign: async (data: Uint8Array) => + nacl.sign.detached(data, keyPair.secretKey), + }; + + const [address, discoveryHeaders] = await getHomeserverFinalAddress( + targetServer, + this.logger, + ); + + const url = new URL(`${address}${endpoint}`); + + let signedBody: Record | undefined; + if (body) { + signedBody = await signJson( + body.hashes ? body : computeAndMergeHash({ ...body, signatures: {} }), + signingKey, + serverName, + ); + } + + const auth = await authorizationHeaders( + serverName, + signingKey, + targetServer, + method, + extractURIfromURL(url), + signedBody, + ); + + return { + url, + headers: { + Authorization: auth, + 'User-Agent': 'Rocket.Chat Federation', + 'Content-Type': 'application/json', + ...discoveryHeaders, + }, + }; + } } diff --git a/packages/federation-sdk/src/services/media.service.spec.ts b/packages/federation-sdk/src/services/media.service.spec.ts new file mode 100644 index 000000000..b262ac989 --- /dev/null +++ b/packages/federation-sdk/src/services/media.service.spec.ts @@ -0,0 +1,346 @@ +/** + * Unit tests for MediaService + * Testing library/framework: Jest with ts-jest (TypeScript) + * + * These tests focus on the MediaService behavior: + * - Authenticated download first, with detailed fallbacks + * - Legacy v3/r0 fallbacks and logging paths + * - Multipart parsing, boundary handling, trimming + * - Low-level httpsRequest success/error paths + */ + +import { MediaService } from './media.service'; +import * as https from 'node:https'; +import { EventEmitter } from 'events'; + +// Mock the logger factory used by the service so we can assert logs. +const logger = { + info: jest.fn(), + debug: jest.fn(), + error: jest.fn(), +}; +jest.mock('@hs/core', () => ({ + createLogger: jest.fn(() => logger), +})); + +// Helpers +const zeroHeaders: Record = {}; +const asAny = (v: unknown) => v as T; + +describe('MediaService', () => { + beforeEach(() => { + jest.clearAllMocks(); + jest.restoreAllMocks(); + }); + + const makeService = (prepareSignedRequestImpl?: any) => { + const federationRequest = { + prepareSignedRequest: + prepareSignedRequestImpl ?? + jest.fn().mockResolvedValue({ + url: new URL('https://remote.example/_matrix/federation/v1/media/download/mX'), + headers: { Authorization: 'sig' }, + }), + }; + return { svc: new MediaService(asAny(federationRequest)), federationRequest }; + }; + + describe('downloadFromRemoteServer (happy path and fallbacks)', () => { + it('downloads via authenticated endpoint when it returns 2xx and logs info', async () => { + const { svc, federationRequest } = makeService( + jest.fn().mockResolvedValue({ + url: new URL('https://remote.example/_matrix/federation/v1/media/download/m1'), + headers: { Authorization: 'sig' }, + }), + ); + + const httpsSpy = jest.spyOn(asAny(svc), 'httpsRequest').mockResolvedValue({ + statusCode: 200, + headers: { 'content-type': 'image/png' }, + body: Buffer.from('AUTH_DATA'), + }); + + const data = await svc.downloadFromRemoteServer('remote.example', 'm1'); + + expect(data).toEqual(Buffer.from('AUTH_DATA')); + expect(federationRequest.prepareSignedRequest).toHaveBeenCalledWith( + 'remote.example', + '/_matrix/federation/v1/media/download/m1', + 'GET', + ); + expect(logger.info).toHaveBeenCalledWith( + 'Downloaded media m1 from remote.example via authenticated endpoint', + ); + expect(httpsSpy).toHaveBeenCalledTimes(1); + }); + + it('falls back to legacy endpoints when authenticated is non-2xx; succeeds on r0 and logs info', async () => { + const { svc } = makeService( + jest.fn().mockResolvedValue({ + url: new URL('https://remote.example/_matrix/federation/v1/media/download/m2'), + headers: {}, + }), + ); + + const httpsSpy = jest + .spyOn(asAny(svc), 'httpsRequest') + // Auth attempt -> non-2xx + .mockResolvedValueOnce({ + statusCode: 401, + headers: zeroHeaders, + body: Buffer.alloc(0), + }) + // Legacy v3 -> 404 + .mockResolvedValueOnce({ + statusCode: 404, + headers: zeroHeaders, + body: Buffer.alloc(0), + }) + // Legacy r0 -> 200 + .mockResolvedValueOnce({ + statusCode: 200, + headers: { 'content-type': 'image/jpeg' }, + body: Buffer.from('LEGACY_OK'), + }); + + const data = await svc.downloadFromRemoteServer('s.example', 'm2'); + expect(data).toEqual(Buffer.from('LEGACY_OK')); + expect(logger.info).toHaveBeenCalledWith( + 'Downloaded media m2 from s.example via legacy endpoint', + ); + expect(httpsSpy).toHaveBeenCalledTimes(3); + }); + + it('logs debug when authenticated path throws; then throws if all attempts fail', async () => { + const { svc } = makeService( + jest.fn().mockResolvedValue({ + url: new URL('https://s.example/_matrix/federation/v1/media/download/m3'), + headers: {}, + }), + ); + + const httpsSpy = jest + .spyOn(asAny(svc), 'httpsRequest') + // Auth attempt throws -> triggers "Authenticated download failed" debug log + .mockRejectedValueOnce(new Error('auth failure')) + // Legacy v3 -> 500 + .mockResolvedValueOnce({ + statusCode: 500, + headers: zeroHeaders, + body: Buffer.alloc(0), + }) + // Legacy r0 -> 502 + .mockResolvedValueOnce({ + statusCode: 502, + headers: zeroHeaders, + body: Buffer.alloc(0), + }); + + await expect(svc.downloadFromRemoteServer('s.example', 'm3')).rejects.toThrow( + 'Failed to download media m3 from s.example', + ); + + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Authenticated download failed:'), + ); + expect(httpsSpy).toHaveBeenCalledTimes(3); + }); + + it('logs debug when a legacy endpoint request throws and continues to next', async () => { + const { svc } = makeService( + jest.fn().mockResolvedValue({ + url: new URL('https://s.example/_matrix/federation/v1/media/download/m4'), + headers: {}, + }), + ); + + const httpsSpy = jest + .spyOn(asAny(svc), 'httpsRequest') + // Auth returns non-2xx -> triggers fallback without debug log in auth + .mockResolvedValueOnce({ + statusCode: 404, + headers: zeroHeaders, + body: Buffer.alloc(0), + }) + // Legacy v3 throws -> should log "Legacy endpoint failed" + .mockRejectedValueOnce(new Error('v3 hard fail')) + // Legacy r0 non-2xx -> end of attempts -> throws overall + .mockResolvedValueOnce({ + statusCode: 400, + headers: zeroHeaders, + body: Buffer.alloc(0), + }); + + await expect(svc.downloadFromRemoteServer('s.example', 'm4')).rejects.toThrow( + 'Failed to download media m4 from s.example', + ); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Legacy endpoint failed:'), + ); + expect(httpsSpy).toHaveBeenCalledTimes(3); + }); + }); + + describe('extractMediaFromResponse (multipart and non-multipart)', () => { + it('returns body directly when content-type is not multipart', () => { + const { svc } = makeService(); + const response = { + statusCode: 200, + headers: { 'content-type': 'image/svg+xml' as const }, + body: Buffer.from('BODY'), + }; + const out = asAny(svc).extractMediaFromResponse(response); + expect(out).toEqual(Buffer.from('BODY')); + }); + + it('accepts content-type header as array form and returns body', () => { + const { svc } = makeService(); + const response = { + statusCode: 200, + headers: { 'content-type': ['image/jpeg'] as const }, + body: Buffer.from('IMG'), + }; + const out = asAny(svc).extractMediaFromResponse(response); + expect(out).toEqual(Buffer.from('IMG')); + }); + + it('throws when multipart content-type lacks boundary', () => { + const { svc } = makeService(); + const response = { + statusCode: 200, + headers: { 'content-type': 'multipart/mixed' as const }, + body: Buffer.from(''), + }; + expect(() => asAny(svc).extractMediaFromResponse(response)).toThrow( + 'No boundary in multipart response', + ); + }); + + it('extracts first non-JSON part and trims trailing CRLF', () => { + const { svc } = makeService(); + const boundary = 'abc123'; + const CRLF = '\r\n'; + + const partJsonHeaders = 'Content-Type: application/json'; + const partImgHeaders = 'Content-Type: image/png'; + + const parts: Buffer[] = []; + // --boundary + JSON part + parts.push(Buffer.from(`--${boundary}${CRLF}${partJsonHeaders}${CRLF}${CRLF}`)); + parts.push(Buffer.from('{"ok":true}')); + parts.push(Buffer.from(CRLF)); + // --boundary + image part (with trailing CRLFs to be trimmed) + parts.push(Buffer.from(`--${boundary}${CRLF}${partImgHeaders}${CRLF}${CRLF}`)); + parts.push(Buffer.from('IMAGE_BYTES')); + parts.push(Buffer.from('\r\n\r\n')); + // closing delimiter (not strictly required by the parser, but realistic) + parts.push(Buffer.from(`--${boundary}${CRLF}`)); + + const body = Buffer.concat(parts); + const response = { + statusCode: 200, + headers: { 'content-type': `multipart/mixed; boundary=${boundary}` }, + body, + }; + + const out = asAny(svc).extractMediaFromResponse(response); + expect(out).toEqual(Buffer.from('IMAGE_BYTES')); + }); + + it('throws when multipart contains no non-JSON media part', () => { + const { svc } = makeService(); + const boundary = 'no-media'; + const CRLF = '\r\n'; + + const parts: Buffer[] = []; + parts.push( + Buffer.from(`--${boundary}${CRLF}Content-Type: application/json${CRLF}${CRLF}`), + ); + parts.push(Buffer.from('{"only":"json"}')); + parts.push(Buffer.from(CRLF)); + parts.push(Buffer.from(`--${boundary}${CRLF}`)); + + const response = { + statusCode: 200, + headers: { 'content-type': `multipart/form-data; boundary=${boundary}` }, + body: Buffer.concat(parts), + }; + + expect(() => asAny(svc).extractMediaFromResponse(response)).toThrow( + 'No media content in multipart response', + ); + }); + }); + + describe('httpsRequest (low-level behavior)', () => { + it('resolves with accumulated body and response metadata on success', async () => { + const { svc } = makeService(); + + const requestSpy = jest + .spyOn(https, 'request') + .mockImplementation((options: any, callback: any) => { + const req = new EventEmitter() as any; + req.end = jest.fn(); + + const res = new EventEmitter() as any; + res.statusCode = 201; + res.headers = { 'content-type': 'text/plain' }; + + // Simulate async response + setImmediate(() => { + callback(res); + res.emit('data', Buffer.from('foo')); + res.emit('data', Buffer.from('bar')); + res.emit('end'); + }); + + return req; + }); + + const url = new URL('https://example.com/path?x=1'); + const out = await asAny(svc).httpsRequest(url, { + method: 'GET', + headers: { A: 'B' }, + }); + + expect(out).not.toBeNull(); + expect(out\!.statusCode).toBe(201); + expect(out\!.headers).toEqual({ 'content-type': 'text/plain' }); + expect(out\!.body.toString()).toBe('foobar'); + + expect(requestSpy).toHaveBeenCalledTimes(1); + const [calledOpts] = asAny(requestSpy.mock.calls[0]); + expect(calledOpts.hostname).toBe('example.com'); + expect(calledOpts.port).toBe(443); + expect(calledOpts.path).toBe('/path?x=1'); + expect(calledOpts.method).toBe('GET'); + expect(calledOpts.headers).toEqual({ A: 'B' }); + }); + + it('resolves to null and logs error when request emits an error', async () => { + const { svc } = makeService(); + + jest.spyOn(https, 'request').mockImplementation((options: any, callback: any) => { + const req = new EventEmitter() as any; + req.end = jest.fn(() => { + // On end, emit an error + setImmediate(() => { + req.emit('error', new Error('netdown')); + }); + }); + return req; + }); + + const url = new URL('https://example.com/'); + const out = await asAny(svc).httpsRequest(url, { + method: 'GET', + headers: {}, + }); + + expect(out).toBeNull(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('HTTPS request failed: netdown'), + ); + }); + }); +}); \ No newline at end of file diff --git a/packages/federation-sdk/src/services/media.service.ts b/packages/federation-sdk/src/services/media.service.ts index d74685b0a..e3c74aa2e 100644 --- a/packages/federation-sdk/src/services/media.service.ts +++ b/packages/federation-sdk/src/services/media.service.ts @@ -1,205 +1,188 @@ -import crypto from 'node:crypto'; +import https from 'node:https'; import { createLogger } from '@hs/core'; import { singleton } from 'tsyringe'; -import { ConfigService } from './config.service'; -import { EventEmitterService } from './event-emitter.service'; +import { FederationRequestService } from './federation-request.service'; @singleton() export class MediaService { private readonly logger = createLogger('MediaService'); - constructor( - private readonly configService: ConfigService, - - private readonly eventEmitterService: EventEmitterService, - ) {} - - generateMXCUri(mediaId?: string): string { - const serverName = this.configService.serverName; - const id = mediaId || crypto.randomBytes(16).toString('hex'); - return `mxc://${serverName}/${id}`; - } - - parseMXCUri(mxcUri: string): { serverName: string; mediaId: string } | null { - const match = mxcUri.match(/^mxc:\/\/([^/]+)\/(.+)$/); - if (!match) { - this.logger.error('Invalid MXC URI format', { mxcUri }); - return null; - } - return { - serverName: match[1], - mediaId: match[2], - }; - } - - extractUserFromToken(authHeader: string | null): { - userId: string; - isAuthenticated: boolean; - } { - if (!authHeader || !authHeader.startsWith('Bearer ')) { - return { userId: 'anonymous', isAuthenticated: false }; - } - - const token = authHeader.substring(7); - if (!token || token.length < 10) { - return { userId: 'anonymous', isAuthenticated: false }; - } + constructor(private readonly federationRequest: FederationRequestService) {} + async downloadFromRemoteServer( + serverName: string, + mediaId: string, + ): Promise { try { - const decoded = Buffer.from(token, 'base64').toString('utf-8'); - let userId: string; - - if (decoded.includes(':')) { - userId = `@${decoded}`; - } else { - userId = `@${decoded}:${this.configService.serverName}`; - } - - if (userId.match(/^@[^:]+:[^:]+$/)) { - return { userId, isAuthenticated: true }; + const buffer = await this.downloadWithAuth(serverName, mediaId); + if (buffer) { + this.logger.info( + `Downloaded media ${mediaId} from ${serverName} via authenticated endpoint`, + ); + return buffer; } - } catch {} + } catch (error: any) { + this.logger.debug(`Authenticated download failed: ${error.message}`); + } - return { userId: 'anonymous', isAuthenticated: false }; + return this.downloadLegacy(serverName, mediaId); } - async downloadFile( + private async downloadWithAuth( serverName: string, mediaId: string, - authHeader: string | null, - ): Promise { - const { userId, isAuthenticated } = this.extractUserFromToken(authHeader); - const ourServerName = this.configService.serverName; + ): Promise { + const endpoint = `/_matrix/federation/v1/media/download/${mediaId}`; - this.logger.info('Media download request', { + const { url, headers } = await this.federationRequest.prepareSignedRequest( serverName, - mediaId, - userId, - isAuthenticated, - }); - - if (serverName === ourServerName && !isAuthenticated) { - return { - errcode: 'M_MISSING_TOKEN', - error: 'Authentication required for local media access', - }; - } + endpoint, + 'GET', + ); - if (serverName === ourServerName) { - return { - errcode: 'M_UNRECOGNIZED', - error: 'Local file download not yet implemented', - }; + const response = await this.httpsRequest(url, { method: 'GET', headers }); + if (!response || response.statusCode < 200 || response.statusCode >= 300) { + return null; } - return this.proxyRemoteMedia(serverName, mediaId); + return this.extractMediaFromResponse(response); } - private async proxyRemoteMedia( - serverName: string, - mediaId: string, - ): Promise { - this.logger.info('Proxying to remote Matrix server', { - serverName, - mediaId, + private httpsRequest( + url: URL, + options: { method: string; headers: Record }, + ): Promise<{ + statusCode: number; + headers: Record; + body: Buffer; + } | null> { + return new Promise((resolve) => { + const req = https.request( + { + hostname: url.hostname, + port: url.port || 443, + path: url.pathname + url.search, + method: options.method, + headers: options.headers, + }, + (res) => { + const chunks: Buffer[] = []; + res.on('data', (chunk) => chunks.push(chunk)); + res.on('end', () => { + resolve({ + statusCode: res.statusCode || 500, + headers: res.headers as Record, + body: Buffer.concat(chunks), + }); + }); + }, + ); + + req.on('error', (error) => { + this.logger.error(`HTTPS request failed: ${error.message}`); + resolve(null); + }); + + req.end(); }); + } - try { - const remoteUrl = `https://${serverName}/_matrix/media/v3/download/${serverName}/${mediaId}`; + private extractMediaFromResponse(response: { + statusCode: number; + headers: Record; + body: Buffer; + }): Buffer { + const contentType = Array.isArray(response.headers['content-type']) + ? response.headers['content-type'][0] + : response.headers['content-type']; + + if (!contentType?.includes('multipart')) { + return response.body; + } - const response = await fetch(remoteUrl, { - method: 'GET', - headers: { - 'User-Agent': `RocketChat-Matrix-Bridge/${this.configService.version}`, - }, - signal: AbortSignal.timeout(30000), - }); + const boundary = contentType.match(/boundary=([^;]+)/)?.[1]; + if (!boundary) { + throw new Error('No boundary in multipart response'); + } - if (!response.ok) { - this.logger.warn('Remote media fetch failed', { - serverName, - mediaId, - status: response.status, - }); + return this.parseMultipart(response.body, boundary); + } - return { - errcode: 'M_NOT_FOUND', - error: 'Remote media not found', - }; + private parseMultipart(data: Buffer, boundary: string): Buffer { + const boundaryBuffer = Buffer.from(`--${boundary}`); + const headerEnd = Buffer.from('\r\n\r\n'); + + let start = 0; + while (start < data.length) { + const boundaryIndex = data.indexOf(boundaryBuffer, start); + if (boundaryIndex === -1) break; + + const partStart = boundaryIndex + boundaryBuffer.length; + const nextBoundary = data.indexOf(boundaryBuffer, partStart); + const partEnd = nextBoundary === -1 ? data.length : nextBoundary; + + const part = data.subarray(partStart, partEnd); + const headerEndIndex = part.indexOf(headerEnd); + + if (headerEndIndex !== -1) { + const headers = part.subarray(0, headerEndIndex).toString('utf-8'); + if ( + headers.includes('Content-Type:') && + !headers.includes('application/json') + ) { + let content = part.subarray(headerEndIndex + headerEnd.length); + while ( + content.length > 0 && + (content[content.length - 1] === 0x0a || + content[content.length - 1] === 0x0d) + ) { + content = content.subarray(0, -1); + } + return content; + } } - const contentType = - response.headers.get('content-type') || 'application/octet-stream'; - const contentDisposition = - response.headers.get('content-disposition') || - `attachment; filename="${mediaId}"`; - const arrayBuffer = await response.arrayBuffer(); - const buffer = Buffer.from(arrayBuffer); - - this.logger.info('Successfully proxied remote media', { - serverName, - mediaId, - contentType, - size: buffer.length, - }); - - return new Response(buffer, { - headers: { - 'content-type': contentType, - 'content-disposition': contentDisposition, - 'cache-control': 'public, max-age=31536000', - }, - }); - } catch (error) { - this.logger.error('Error proxying remote media:', error); - return { - errcode: 'M_UNKNOWN', - error: 'Failed to fetch remote media', - }; + start = partEnd; } + + throw new Error('No media content in multipart response'); } - async getThumbnail( + private async downloadLegacy( serverName: string, mediaId: string, - width = 96, - height = 96, - method: 'crop' | 'scale' = 'scale', - ): Promise<{ errcode: string; error: string }> { - this.logger.info('Thumbnail request', { - serverName, - mediaId, - width, - height, - method, - }); - - const mediaConfig = this.configService.getMediaConfig(); - if (!mediaConfig.enableThumbnails) { - return { - errcode: 'M_NOT_FOUND', - error: 'Thumbnails are disabled', - }; - } + ): Promise { + const endpoints = [ + `https://${serverName}/_matrix/media/v3/download/${serverName}/${mediaId}`, + `https://${serverName}/_matrix/media/r0/download/${serverName}/${mediaId}`, + ]; + + for (const endpoint of endpoints) { + try { + const url = new URL(endpoint); + const response = await this.httpsRequest(url, { + method: 'GET', + headers: { + 'User-Agent': 'Rocket.Chat Federation', + Accept: '*/*', + }, + }); - const ourServerName = this.configService.serverName; - if (serverName === ourServerName) { - return { - errcode: 'M_UNRECOGNIZED', - error: 'Thumbnail generation not yet implemented', - }; + if ( + response && + response.statusCode >= 200 && + response.statusCode < 300 + ) { + this.logger.info( + `Downloaded media ${mediaId} from ${serverName} via legacy endpoint`, + ); + return response.body; + } + } catch (error: any) { + this.logger.debug(`Legacy endpoint failed: ${error.message}`); + } } - return { - errcode: 'M_NOT_FOUND', - error: 'Media not found', - }; - } - - getMediaConfig(): { 'm.upload.size': number } { - const mediaConfig = this.configService.getMediaConfig(); - return { - 'm.upload.size': mediaConfig.maxFileSize, - }; + throw new Error(`Failed to download media ${mediaId} from ${serverName}`); } } diff --git a/packages/federation-sdk/src/services/signature-verification.service.ts b/packages/federation-sdk/src/services/signature-verification.service.ts index fdd7ffe17..406ba3840 100644 --- a/packages/federation-sdk/src/services/signature-verification.service.ts +++ b/packages/federation-sdk/src/services/signature-verification.service.ts @@ -94,7 +94,7 @@ export class SignatureVerificationService { /** * Get public key from cache or fetch it from the server */ - private async getOrFetchPublicKey( + async getOrFetchPublicKey( serverName: string, keyId: string, ): Promise { diff --git a/packages/federation-sdk/src/services/staging-area.service.spec.ts b/packages/federation-sdk/src/services/staging-area.service.spec.ts new file mode 100644 index 000000000..4365f4508 --- /dev/null +++ b/packages/federation-sdk/src/services/staging-area.service.spec.ts @@ -0,0 +1,491 @@ +/* + Test framework: Vitest or Jest (auto-compatible). + - If Vitest is present, use: import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; + - If Jest is present, these globals are typically available; vi.* calls are aliased in Vitest. +*/ + +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; + +vi.mock('@hs/core', () => { + return { + createLogger: () => ({ + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + }), + isRedactedEvent: (ev: any) => Boolean(ev?.content?.redacts), + }; +}); + +vi.mock('@hs/room', () => { + return { + PersistentEventFactory: { + createFromRawEvent: vi.fn(), + }, + }; +}); + +// Use type-only imports when available (not required for runtime) +import { PersistentEventFactory } from '@hs/room'; + +// Import subject under test +import { StagingAreaService } from './staging-area.service'; + +type AnyFn = (...args: any[]) => any; + +const makeBaseEvent = (overrides: Partial = {}) => ({ + eventId: '$evt1', + roomId: '\!room:server', + origin: 'server.test', + event: { + type: 'm.room.message', + sender: '@user:server', + origin_server_ts: 123456789, + content: { body: 'hi', msgtype: 'm.text' }, + auth_events: [], + prev_events: [], + ...overrides.event, + }, + ...overrides, +}); + +describe('StagingAreaService', () => { + let eventService: any; + let missingEventsService: any; + let stagingAreaQueue: any; + let eventAuthService: any; + let eventStateService: any; + let eventEmitterService: any; + let stateService: any; + + let svc: StagingAreaService; + + beforeEach(() => { + // Mocks for collaborators + eventService = { + checkIfEventsExists: vi.fn().mockResolvedValue({ missing: [] }), + getAuthEventIds: vi.fn().mockResolvedValue([{ event: { id: 'a' } }]), + }; + + missingEventsService = { + addEvent: vi.fn(), + }; + + stagingAreaQueue = { + enqueue: vi.fn(), + }; + + eventAuthService = { + authorizeEvent: vi.fn().mockResolvedValue(true), + }; + + eventStateService = { + // Not directly used in provided code, kept for completeness + }; + + eventEmitterService = { + emit: vi.fn(), + }; + + stateService = { + getRoomVersion: vi.fn().mockResolvedValue('9'), + persistStateEvent: vi.fn().mockResolvedValue(undefined), + persistTimelineEvent: vi.fn().mockResolvedValue(undefined), + getFullRoomStateBeforeEvent2: vi.fn().mockResolvedValue({ + powerLevels: { users: { '@owner:hs': 100, '@mod:hs': 50, '@user:hs': 0 } }, + creator: '@owner:hs', + }), + }; + + vi.useFakeTimers(); + vi.spyOn(global, 'setTimeout'); // observe retry scheduling + + svc = new StagingAreaService( + eventService, + missingEventsService, + stagingAreaQueue, + eventAuthService, + eventStateService, + eventEmitterService, + stateService, + ); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.clearAllMocks(); + }); + + it('addEventToQueue: tracks event and enqueues with pending_dependencies metadata', () => { + const evt = makeBaseEvent(); + // @ts-ignore internal map visibility - exercise public API only + // add event + // @ts-expect-no-error + (svc as any).addEventToQueue(evt); + + expect(stagingAreaQueue.enqueue).toHaveBeenCalledTimes(1); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ + eventId: evt.eventId, + metadata: { state: expect.stringMatching(/pending_dependencies/) }, + }), + ); + }); + + it('extractEventsFromIncomingPDU: returns concat of auth_events and prev_events', () => { + const evt = makeBaseEvent({ + event: { + type: 'm.room.message', + sender: '@u:hs', + origin_server_ts: 1, + content: {}, + auth_events: [['a1'], ['a2']], + prev_events: [['p1'], ['p2']], + }, + }); + + // @ts-ignore accessing private for test via bracket notation + const result = (svc as any).extractEventsFromIncomingPDU(evt); + expect(result).toEqual([['a1'], ['a2'], ['p1'], ['p2']]); + }); + + it('processEvent: newly seen event enters dependency stage', async () => { + const evt = makeBaseEvent(); + await (svc as any).processEvent(evt); + + // After dependency stage success -> should enqueue authorization state + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ + eventId: evt.eventId, + metadata: { state: 'pending_authorization' }, + }), + ); + }); + + it('dependency stage: when missing deps present, schedules retries with backoff and invokes MissingEventService', async () => { + const evt = makeBaseEvent({ + event: { + ...makeBaseEvent().event, + auth_events: [['dep1']], + prev_events: [['dep2']], + }, + }); + + eventService.checkIfEventsExists.mockResolvedValueOnce({ missing: ['dep1', 'dep2'] }); + + // Prime processing map by "adding" event + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_dependencies' } }); + + expect(missingEventsService.addEvent).toHaveBeenCalledTimes(2); + expect(setTimeout).toHaveBeenCalledTimes(1); + // Ensure re-enqueue on retry + vi.runOnlyPendingTimers(); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ + metadata: { state: 'pending_dependencies' }, + }), + ); + }); + + it('dependency stage: after 5 retries, marks event as REJECTED', async () => { + const evt = makeBaseEvent({ + event: { ...makeBaseEvent().event, auth_events: [['x']], prev_events: [] }, + }); + // Always missing + eventService.checkIfEventsExists.mockResolvedValue({ missing: ['x'] }); + + (svc as any).addEventToQueue(evt); + // Simulate 5 attempts + for (let i = 0; i < 5; i++) { + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_dependencies' } }); + vi.runOnlyPendingTimers(); + } + + // After 5th attempt, no more retries scheduled, not enqueued again for deps + expect(stagingAreaQueue.enqueue).not.toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_authorization' } }), + ); + }); + + it('authorization stage: success advances to state resolution; failure rejects', async () => { + const evt = makeBaseEvent(); + // Put event in map with expected state + (svc as any).addEventToQueue(evt); + + // Success path + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_authorization' } }); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_state_resolution' } }), + ); + + // Failure path + eventAuthService.authorizeEvent.mockResolvedValueOnce(false); + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_authorization' } }); + // should not enqueue next stage on failure + expect(stagingAreaQueue.enqueue).not.toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_state_resolution' } }), + ); + }); + + it('authorization stage: exceptions mark event as REJECTED', async () => { + const evt = makeBaseEvent(); + (svc as any).addEventToQueue(evt); + eventAuthService.authorizeEvent.mockRejectedValueOnce(new Error('boom')); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_authorization' } }); + + // No enqueue on error + expect(stagingAreaQueue.enqueue).not.toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_state_resolution' } }), + ); + }); + + it('state resolution: persists state events and advances; handles rejection flag', async () => { + const evt = makeBaseEvent({ + event: { ...makeBaseEvent().event, type: 'm.room.name' }, + }); + (svc as any).addEventToQueue(evt); + + (PersistentEventFactory.createFromRawEvent as AnyFn) + .mockReturnValueOnce({ + isState: () => true, + rejected: false, + rejectedReason: undefined, + }); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_state_resolution' } }); + + expect(stateService.persistStateEvent).toHaveBeenCalledTimes(1); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_persistence' } }), + ); + + // Rejected path + (PersistentEventFactory.createFromRawEvent as AnyFn) + .mockReturnValueOnce({ + isState: () => false, + rejected: true, + rejectedReason: 'invalid', + }); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_state_resolution' } }); + + // Should not enqueue persistence on rejection + expect(stagingAreaQueue.enqueue).not.toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_persistence' } }), + ); + }); + + it('state resolution: error when room version missing -> REJECTED', async () => { + const evt = makeBaseEvent(); + (svc as any).addEventToQueue(evt); + stateService.getRoomVersion.mockResolvedValueOnce(null); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_state_resolution' } }); + + expect(stagingAreaQueue.enqueue).not.toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_persistence' } }), + ); + }); + + it('persistence stage: advances straight to federation', async () => { + const evt = makeBaseEvent(); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_persistence' } }); + + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_federation' } }), + ); + }); + + it('federation stage: success and error both advance to notification', async () => { + const evt = makeBaseEvent(); + (svc as any).addEventToQueue(evt); + + // success (no throw) + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_federation' } }); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_notification' } }), + ); + + // error path: simulate throw by spying and throwing inside processFederationStage call path + // Not needed because implementation already catches and advances; we just invoke again + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_federation' } }); + expect(stagingAreaQueue.enqueue).toHaveBeenCalledWith( + expect.objectContaining({ metadata: { state: 'pending_notification' } }), + ); + }); + + it('notification stage: emits for m.room.message', async () => { + const evt = makeBaseEvent(); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.message', + expect.objectContaining({ event_id: evt.eventId, room_id: evt.roomId }), + ); + }); + + it('notification stage: emits for m.reaction', async () => { + const evt = makeBaseEvent({ + event: { + type: 'm.reaction', + sender: '@u:hs', + origin_server_ts: 2, + content: { 'm.relates_to': { rel_type: 'm.annotation', event_id: '$msg', key: '👍' } }, + }, + }); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.reaction', + expect.objectContaining({ + event_id: evt.eventId, + room_id: evt.roomId, + content: { 'm.relates_to': { rel_type: 'm.annotation', event_id: '$msg', key: '👍' } }, + }), + ); + }); + + it('notification stage: emits for redaction when isRedactedEvent returns true', async () => { + const evt = makeBaseEvent({ + event: { + type: 'm.room.redaction', + sender: '@u:hs', + origin_server_ts: 3, + content: { redacts: '$target', reason: 'spam' }, + }, + }); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.redaction', + expect.objectContaining({ + event_id: evt.eventId, + redacts: '$target', + content: { reason: 'spam' }, + }), + ); + }); + + it('notification stage: emits for m.room.member (membership)', async () => { + const evt = makeBaseEvent({ + event: { + type: 'm.room.member', + sender: '@u:hs', + state_key: '@target:hs', + origin_server_ts: 4, + content: { membership: 'join', displayname: 'User' }, + }, + }); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.membership', + expect.objectContaining({ + event_id: evt.eventId, + state_key: '@target:hs', + content: expect.objectContaining({ membership: 'join' }), + }), + ); + }); + + it('notification stage: emits for m.room.name and m.room.topic', async () => { + const nameEvt = makeBaseEvent({ + event: { type: 'm.room.name', sender: '@u:hs', origin_server_ts: 5, content: { name: 'Room' } }, + }); + (svc as any).addEventToQueue(nameEvt); + await (svc as any).processEvent({ ...nameEvt, metadata: { state: 'pending_notification' } }); + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.room.name', + expect.objectContaining({ room_id: nameEvt.roomId, name: 'Room' }), + ); + + const topicEvt = makeBaseEvent({ + event: { type: 'm.room.topic', sender: '@u:hs', origin_server_ts: 6, content: { topic: 'T' } }, + }); + (svc as any).addEventToQueue(topicEvt); + await (svc as any).processEvent({ ...topicEvt, metadata: { state: 'pending_notification' } }); + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.room.topic', + expect.objectContaining({ room_id: topicEvt.roomId, topic: 'T' }), + ); + }); + + it('notification stage: handles m.room.power_levels with changedUserPowers (delta and direct changes)', async () => { + const evt = makeBaseEvent({ + event: { + type: 'm.room.power_levels', + sender: '@u:hs', + state_key: '', + origin_server_ts: 7, + content: { users: { '@user:hs': 50, '@new:hs': 0 } }, // changedUserPowers + }, + }); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + // Expect role events emitted for changed users; roles derived from 100->owner, 50->moderator, else user + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.room.role', + expect.objectContaining({ user_id: '@user:hs', role: 'moderator' }), + ); + }); + + it('notification stage: m.room.power_levels with no changedUserPowers resets all except owner', async () => { + stateService.getFullRoomStateBeforeEvent2.mockResolvedValueOnce({ + powerLevels: { users: { '@owner:hs': 100, '@a:hs': 50, '@b:hs': 0 } }, + creator: '@owner:hs', + }); + + const evt = makeBaseEvent({ + event: { + type: 'm.room.power_levels', + sender: '@u:hs', + origin_server_ts: 8, + content: { }, // no users -> reset path + }, + }); + (svc as any).addEventToQueue(evt); + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + // Should emit resets for @a:hs and @b:hs to "user", but not for owner + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.room.role', + expect.objectContaining({ user_id: '@a:hs', role: 'user' }), + ); + expect(eventEmitterService.emit).toHaveBeenCalledWith( + 'homeserver.matrix.room.role', + expect.objectContaining({ user_id: '@b:hs', role: 'user' }), + ); + // Ensure no emit for owner reset + const emits = (eventEmitterService.emit as AnyFn).mock.calls.filter((c: any[]) => c[0] === 'homeserver.matrix.room.role'); + expect(emits.find(([_, payload]) => payload.user_id === '@owner:hs')).toBeUndefined(); + }); + + it('notification stage: unknown event type logs warning and still completes', async () => { + const evt = makeBaseEvent({ + event: { type: 'm.unknown', sender: '@u:hs', origin_server_ts: 9, content: {} }, + }); + (svc as any).addEventToQueue(evt); + + await (svc as any).processEvent({ ...evt, metadata: { state: 'pending_notification' } }); + + // No emitter call expected for unknown type + const calls = (eventEmitterService.emit as AnyFn).mock.calls; + expect(calls.length).toBeGreaterThanOrEqual(0); + }); +}); \ No newline at end of file diff --git a/packages/federation-sdk/src/services/staging-area.service.ts b/packages/federation-sdk/src/services/staging-area.service.ts index 8c1f41260..328cc1d34 100644 --- a/packages/federation-sdk/src/services/staging-area.service.ts +++ b/packages/federation-sdk/src/services/staging-area.service.ts @@ -356,6 +356,7 @@ export class StagingAreaService { sender: stagedEvent.event.sender, origin_server_ts: stagedEvent.event.origin_server_ts, content: { + ...stagedEvent.event.content, body: stagedEvent.event.content?.body as string, msgtype: stagedEvent.event.content?.msgtype as string, 'm.relates_to': stagedEvent.event.content?.['m.relates_to'] as { diff --git a/packages/federation-sdk/src/services/state.service.ts b/packages/federation-sdk/src/services/state.service.ts index 196cd604e..6a355d998 100644 --- a/packages/federation-sdk/src/services/state.service.ts +++ b/packages/federation-sdk/src/services/state.service.ts @@ -17,7 +17,7 @@ import { StateRepository } from '../repositories/state.repository'; import { createLogger } from '../utils/logger'; import { ConfigService } from './config.service'; -type State = Map; +export type State = Map; type StrippedRoomState = { content: PduContent; @@ -259,6 +259,17 @@ export class StateService { return new RoomState(state); } + async getStateEventsByType(roomId: string, type: PduType) { + const state = await this.getFullRoomState(roomId); + const events = []; + for (const [, event] of state) { + if (event.type === type) { + events.push(event); + } + } + return events; + } + public async getStrippedRoomState( roomId: string, ): Promise { @@ -792,7 +803,7 @@ export class StateService { async getServersInRoom(roomId: string) { return this.getMembersOfRoom(roomId).then((members) => - members.map((member) => member.split(':').pop()!), + members.map((member) => member.split(':').pop() ?? ''), ); } } diff --git a/packages/homeserver/src/__tests__/homeserver.module.test.ts b/packages/homeserver/src/__tests__/homeserver.module.test.ts new file mode 100644 index 000000000..4121864f8 --- /dev/null +++ b/packages/homeserver/src/__tests__/homeserver.module.test.ts @@ -0,0 +1,262 @@ +import path from 'node:path'; + +jest.mock('node:fs', () => ({ + existsSync: jest.fn(() => false), +})); + +// Mock dotenv to prevent reading real env +jest.mock('dotenv', () => ({ + config: jest.fn(), +})); + +// Capture plugin functions to assert order of registration +const swaggerPluginFn = jest.fn((app: any) => app); +jest.mock('@elysiajs/swagger', () => ({ + swagger: jest.fn(() => swaggerPluginFn), +})); + +// Provide a lightweight Elysia mock that records .use calls +const useCalls: any[] = []; +class ElysiaMock { + used: any[] = useCalls; + use(fn: any) { + this.used.push(fn); + return this; + } +} +jest.mock('elysia', () => ({ + __esModule: true, + default: ElysiaMock, +})); + +// Mock all internal plugin modules with identifiable functions +const makePlugin = (name: string) => { + const plugin = Object.defineProperty(jest.fn((app: any) => app), 'name', { value: name }); + return plugin; +}; + +const invitePlugin = makePlugin('invitePlugin'); +const profilesPlugin = makePlugin('profilesPlugin'); +const roomPlugin = makePlugin('roomPlugin'); +const sendJoinPlugin = makePlugin('sendJoinPlugin'); +const transactionsPlugin = makePlugin('transactionsPlugin'); +const versionsPlugin = makePlugin('versionsPlugin'); +const internalDirectMessagePlugin = makePlugin('internalDirectMessagePlugin'); +const internalInvitePlugin = makePlugin('internalInvitePlugin'); +const internalMessagePlugin = makePlugin('internalMessagePlugin'); +const pingPlugin = makePlugin('pingPlugin'); +const internalRoomPlugin = makePlugin('internalRoomPlugin'); +const serverKeyPlugin = makePlugin('serverKeyPlugin'); +const wellKnownPlugin = makePlugin('wellKnownPlugin'); + +jest.mock('../../src/controllers/federation/invite.controller', () => ({ invitePlugin })); +jest.mock('../../src/controllers/federation/profiles.controller', () => ({ profilesPlugin })); +jest.mock('../../src/controllers/federation/rooms.controller', () => ({ roomPlugin })); +jest.mock('../../src/controllers/federation/send-join.controller', () => ({ sendJoinPlugin })); +jest.mock('../../src/controllers/federation/transactions.controller', () => ({ transactionsPlugin })); +jest.mock('../../src/controllers/federation/versions.controller', () => ({ versionsPlugin })); +jest.mock('../../src/controllers/internal/direct-message.controller', () => ({ internalDirectMessagePlugin })); +jest.mock('../../src/controllers/internal/invite.controller', () => ({ internalInvitePlugin })); +jest.mock('../../src/controllers/internal/message.controller', () => ({ internalMessagePlugin })); +jest.mock('../../src/controllers/internal/ping.controller', () => ({ pingPlugin })); +jest.mock('../../src/controllers/internal/room.controller', () => ({ internalRoomPlugin })); +jest.mock('../../src/controllers/key/server.controller', () => ({ serverKeyPlugin })); +jest.mock('../../src/controllers/well-known/well-known.controller', () => ({ wellKnownPlugin })); + +// Mock federation SDK: capture config options constructed and container options passed in +const createFederationContainerMock = jest.fn(async (_opts: any, _config: any) => ({ mockContainer: true, _opts, _config })); +class MockConfigService { + public options: any; + constructor(opts: any) { + this.options = opts; + } +} +jest.mock('@hs/federation-sdk', () => ({ + ConfigService: MockConfigService, + createFederationContainer: createFederationContainerMock, +})); + +// Import after mocks are set up +import { setup, appPromise } from '../homeserver.module.spec'; +import * as fs from 'node:fs'; +import * as dotenv from 'dotenv'; +import { swagger } from '@elysiajs/swagger'; +import type { Emitter } from '@rocket.chat/emitter'; + +describe('homeserver.module setup', () => { + const ORIGINAL_ENV = process.env; + + beforeEach(() => { + jest.clearAllMocks(); + (useCalls as any[]).length = 0; + process.env = { ...ORIGINAL_ENV }; // shallow clone + delete process.env.SERVER_NAME; + delete process.env.SERVER_PORT; + delete process.env.MONGODB_URI; + delete process.env.DATABASE_NAME; + delete process.env.DATABASE_POOL_SIZE; + delete process.env.MATRIX_DOMAIN; + delete process.env.MATRIX_KEY_REFRESH_INTERVAL; + delete process.env.CONFIG_FOLDER; + delete process.env.SERVER_VERSION; + delete process.env.MEDIA_MAX_FILE_SIZE; + delete process.env.MEDIA_ALLOWED_MIME_TYPES; + delete process.env.MEDIA_ENABLE_THUMBNAILS; + delete process.env.MEDIA_UPLOAD_RATE_LIMIT; + delete process.env.MEDIA_DOWNLOAD_RATE_LIMIT; + (fs.existsSync as jest.Mock).mockReturnValue(false); + }); + + afterAll(() => { + process.env = ORIGINAL_ENV; + }); + + it('loads .env when present using dotenv.config with correct path', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(true); + const cwdEnvPath = path.resolve(process.cwd(), '.env'); + + const result = await setup(); + + expect(result).toHaveProperty('app'); + expect(result).toHaveProperty('container'); + expect(dotenv.config).toHaveBeenCalledWith({ path: cwdEnvPath }); + }); + + it('does not call dotenv.config when .env is absent', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(false); + + await setup(); + + expect(dotenv.config).not.toHaveBeenCalled(); + }); + + it('constructs ConfigService with default values when env not set', async () => { + await setup(); + + // Config instance is passed as the 2nd arg to createFederationContainer + expect(createFederationContainerMock).toHaveBeenCalledTimes(1); + const passedConfig = createFederationContainerMock.mock.calls[0][1] as MockConfigService; + expect(passedConfig).toBeInstanceOf(MockConfigService); + + // Validate critical defaults + expect(passedConfig.options.serverName).toBe('rc1'); + expect(passedConfig.options.port).toBe(8080); + expect(passedConfig.options.database).toEqual({ + uri: 'mongodb://localhost:27017/matrix', + name: 'matrix', + poolSize: 10, + }); + expect(passedConfig.options.matrixDomain).toBe('rc1'); + expect(passedConfig.options.keyRefreshInterval).toBe(60); + expect(passedConfig.options.signingKeyPath).toBe('./rc1.signing.key'); + expect(passedConfig.options.version).toBe('1.0'); + expect(passedConfig.options.media.maxFileSize).toBe(100 * 1024 * 1024); + expect(passedConfig.options.media.allowedMimeTypes).toEqual([ + 'image/jpeg', + 'image/png', + 'image/gif', + 'image/webp', + 'text/plain', + 'application/pdf', + 'video/mp4', + 'audio/mpeg', + 'audio/ogg', + ]); + // Notice: code uses (env === 'true') || true, which always evaluates to true + expect(passedConfig.options.media.enableThumbnails).toBe(true); + expect(passedConfig.options.media.rateLimits).toEqual({ + uploadPerMinute: 10, + downloadPerMinute: 60, + }); + }); + + it('respects environment variables for configuration', async () => { + process.env.SERVER_NAME = 'example'; + process.env.SERVER_PORT = '9090'; + process.env.MONGODB_URI = 'mongodb://db:27017/hs'; + process.env.DATABASE_NAME = 'hs'; + process.env.DATABASE_POOL_SIZE = '42'; + process.env.MATRIX_DOMAIN = 'example.org'; + process.env.MATRIX_KEY_REFRESH_INTERVAL = '120'; + process.env.CONFIG_FOLDER = '/etc/hs/sign.key'; + process.env.SERVER_VERSION = '2.1.5'; + process.env.MEDIA_MAX_FILE_SIZE = '5'; + process.env.MEDIA_ALLOWED_MIME_TYPES = 'image/jpeg,application/json'; + process.env.MEDIA_ENABLE_THUMBNAILS = 'false'; // code still ORs with true + process.env.MEDIA_UPLOAD_RATE_LIMIT = '7'; + process.env.MEDIA_DOWNLOAD_RATE_LIMIT = '99'; + + await setup(); + + const passedConfig = createFederationContainerMock.mock.calls[0][1] as MockConfigService; + + expect(passedConfig.options.serverName).toBe('example'); + expect(passedConfig.options.port).toBe(9090); + expect(passedConfig.options.database).toEqual({ + uri: 'mongodb://db:27017/hs', + name: 'hs', + poolSize: 42, + }); + expect(passedConfig.options.matrixDomain).toBe('example.org'); + expect(passedConfig.options.keyRefreshInterval).toBe(120); + expect(passedConfig.options.signingKeyPath).toBe('/etc/hs/sign.key'); + expect(passedConfig.options.version).toBe('2.1.5'); + // 5 MiB converted to bytes + expect(passedConfig.options.media.maxFileSize).toBe(5 * 1024 * 1024); + expect(passedConfig.options.media.allowedMimeTypes).toEqual(['image/jpeg', 'application/json']); + expect(passedConfig.options.media.enableThumbnails).toBe(true); // due to "=== 'true' || true" in source + expect(passedConfig.options.media.rateLimits).toEqual({ + uploadPerMinute: 7, + downloadPerMinute: 99, + }); + }); + + it('passes emitter via containerOptions to createFederationContainer', async () => { + const emitter: Partial> = { on: jest.fn(), emit: jest.fn() }; + await setup({ emitter: emitter as any }); + + expect(createFederationContainerMock).toHaveBeenCalledTimes(1); + const containerOpts = createFederationContainerMock.mock.calls[0][0]; + expect(containerOpts).toEqual({ emitter }); + }); + + it('registers swagger and internal/federation plugins in correct order', async () => { + await setup(); + + // swagger(...) called once and returns swaggerPluginFn + expect(swagger).toHaveBeenCalledTimes(1); + expect(useCalls[0]).toBe(swaggerPluginFn); + + const expectedOrder = [ + swaggerPluginFn, + invitePlugin, + profilesPlugin, + sendJoinPlugin, + transactionsPlugin, + versionsPlugin, + internalDirectMessagePlugin, + internalInvitePlugin, + internalMessagePlugin, + pingPlugin, + internalRoomPlugin, + serverKeyPlugin, + wellKnownPlugin, + roomPlugin, + ]; + + expect(useCalls).toEqual(expectedOrder); + }); + + it('returns app and container from setup()', async () => { + const { app, container } = await setup(); + expect(app).toBeInstanceOf(ElysiaMock as any); + expect(container).toEqual(expect.objectContaining({ mockContainer: true })); + }); + + it('appPromise resolves to the app instance', async () => { + const app = await appPromise; + // appPromise uses default setup() call under the hood + expect(app).toBeDefined(); + expect((app as any).used).toBeDefined(); + }); +}); \ No newline at end of file diff --git a/packages/homeserver/src/controllers/federation/transactions.controller.ts b/packages/homeserver/src/controllers/federation/transactions.controller.ts index 75367281d..bd2e08418 100644 --- a/packages/homeserver/src/controllers/federation/transactions.controller.ts +++ b/packages/homeserver/src/controllers/federation/transactions.controller.ts @@ -1,15 +1,26 @@ -import { EventService } from '@hs/federation-sdk'; -import { Elysia } from 'elysia'; +import { + ConfigService, + EventAuthorizationService, + EventService, +} from '@hs/federation-sdk'; +import { Context, Elysia } from 'elysia'; import { container } from 'tsyringe'; import { ErrorResponseDto, + GetEventErrorResponseDto, + GetEventParamsDto, + GetEventResponseDto, SendTransactionBodyDto, SendTransactionResponseDto, } from '../../dtos'; +import { canAccessEvent } from '../../middlewares/acl.middleware'; export const transactionsPlugin = (app: Elysia) => { const eventService = container.resolve(EventService); - return app.put( + const configService = container.resolve(ConfigService); + const eventAuthService = container.resolve(EventAuthorizationService); + + app.put( '/_matrix/federation/v1/send/:txnId', async ({ body }) => { await eventService.processIncomingTransaction(body as any); @@ -32,4 +43,42 @@ export const transactionsPlugin = (app: Elysia) => { }, }, ); + + app.get( + '/_matrix/federation/v1/event/:eventId', + async ({ params, set }) => { + const eventData = await eventService.getEventById(params.eventId); + if (!eventData) { + set.status = 404; + return { + errcode: 'M_NOT_FOUND', + error: 'Event not found', + }; + } + + return { + origin_server_ts: eventData.event.origin_server_ts, + origin: configService.serverName, + pdus: [{ ...eventData.event, origin: configService.serverName }], + }; + }, + { + beforeHandle: canAccessEvent(eventAuthService), + params: GetEventParamsDto, + response: { + 200: GetEventResponseDto, + 401: GetEventErrorResponseDto, + 403: GetEventErrorResponseDto, + 404: GetEventErrorResponseDto, + 500: GetEventErrorResponseDto, + }, + detail: { + tags: ['Federation'], + summary: 'Get event', + description: 'Get an event', + }, + }, + ); + + return app; }; diff --git a/packages/homeserver/src/controllers/media.controller.ts b/packages/homeserver/src/controllers/media.controller.ts deleted file mode 100644 index 70a8392b3..000000000 --- a/packages/homeserver/src/controllers/media.controller.ts +++ /dev/null @@ -1,181 +0,0 @@ -import { createLogger } from '@hs/core'; -import { ConfigService, MediaService } from '@hs/federation-sdk'; -import { Elysia, t } from 'elysia'; -import { container } from 'tsyringe'; - -const ErrorResponseSchema = t.Object({ - errcode: t.String(), - error: t.String(), -}); - -export const mediaPlugin = (app: Elysia) => { - const mediaService = container.resolve(MediaService); - const logger = createLogger('MediaController'); - - return app.group('/_matrix/media/v3', (app) => - app - .get( - '/download/:serverName/:mediaId', - async ({ - params, - request, - set, - }: { - params: { serverName: string; mediaId: string }; - request: Request; - set: { status?: number }; - }) => { - const { serverName, mediaId } = params; - try { - const authHeader = request.headers.get('authorization'); - const result = await mediaService.downloadFile( - serverName, - mediaId, - authHeader, - ); - - if ('errcode' in result) { - if (result.errcode === 'M_MISSING_TOKEN') { - set.status = 401; - } else if (result.errcode === 'M_NOT_FOUND') { - set.status = 404; - } else if (result.errcode === 'M_UNRECOGNIZED') { - set.status = 501; - } else { - set.status = 502; - } - return result; - } - - return result; - } catch (error) { - logger.error('Media download error:', error); - set.status = 500; - return { - errcode: 'M_UNKNOWN', - error: 'Internal server error', - }; - } - }, - { - params: t.Object({ - serverName: t.String(), - mediaId: t.String(), - }), - query: t.Object({ - allow_remote: t.Optional(t.Boolean()), - timeout_ms: t.Optional(t.Number()), - }), - response: { - 200: t.Any(), - 401: ErrorResponseSchema, - 404: ErrorResponseSchema, - 500: ErrorResponseSchema, - 501: ErrorResponseSchema, - 502: ErrorResponseSchema, - }, - detail: { - tags: ['Media'], - summary: 'Download media', - description: 'Download a file from the Matrix media repository', - }, - }, - ) - - .get( - '/thumbnail/:serverName/:mediaId', - async ({ - params, - query, - set, - }: { - params: { serverName: string; mediaId: string }; - query: { width?: number; height?: number; method?: string }; - set: { status?: number }; - }) => { - try { - const { serverName, mediaId } = params; - const { width = 96, height = 96, method = 'scale' } = query; - - const result = await mediaService.getThumbnail( - serverName, - mediaId, - width, - height, - method as 'crop' | 'scale', - ); - - if (result.errcode === 'M_NOT_FOUND') { - set.status = 404; - } else if (result.errcode === 'M_UNRECOGNIZED') { - set.status = 501; - } - - return result; - } catch (error) { - logger.error('Media thumbnail error:', error); - set.status = 500; - return { - errcode: 'M_UNKNOWN', - error: 'Internal server error', - }; - } - }, - { - params: t.Object({ - serverName: t.String(), - mediaId: t.String(), - }), - query: t.Object({ - width: t.Optional(t.Number({ minimum: 1, maximum: 800 })), - height: t.Optional(t.Number({ minimum: 1, maximum: 600 })), - method: t.Optional( - t.Union([t.Literal('crop'), t.Literal('scale')]), - ), - allow_remote: t.Optional(t.Boolean()), - timeout_ms: t.Optional(t.Number()), - }), - response: { - 200: t.Any(), - 404: ErrorResponseSchema, - 500: ErrorResponseSchema, - 501: ErrorResponseSchema, - }, - detail: { - tags: ['Media'], - summary: 'Get media thumbnail', - description: 'Get a thumbnail for a media file', - }, - }, - ) - - .get( - '/config', - async ({ set }) => { - try { - return mediaService.getMediaConfig(); - } catch (error) { - logger.error('Media config error:', error); - set.status = 500; - return { - errcode: 'M_UNKNOWN', - error: 'Internal server error', - }; - } - }, - { - response: { - 200: t.Object({ - 'm.upload.size': t.Number(), - }), - 500: ErrorResponseSchema, - }, - detail: { - tags: ['Media'], - summary: 'Get media configuration', - description: 'Get the media configuration for the homeserver', - }, - }, - ), - ); -}; diff --git a/packages/homeserver/src/dtos/federation/transactions.dto.ts b/packages/homeserver/src/dtos/federation/transactions.dto.ts index 58e2e7a71..54849fe3a 100644 --- a/packages/homeserver/src/dtos/federation/transactions.dto.ts +++ b/packages/homeserver/src/dtos/federation/transactions.dto.ts @@ -29,6 +29,25 @@ export const SendTransactionResponseDto = t.Object({ }), }); +export const GetEventParamsDto = t.Object({ + eventId: t.String({ description: 'Event ID' }), +}); + +export const GetEventResponseDto = t.Object({ + origin_server_ts: t.Number({ description: 'Origin server timestamp' }), + origin: t.String({ description: 'Origin server' }), + pdus: t.Array(EventBaseDto, { + description: 'An array containing a single PDU', + }), +}); + +export const GetEventErrorResponseDto = t.Object({ + errcode: t.String({ description: 'Error code' }), + error: t.String({ description: 'Error message' }), +}); + export type SendTransactionParams = Static; export type SendTransactionBody = Static; export type SendTransactionResponse = Static; +export type GetEventParams = Static; +export type GetEventResponse = Static; diff --git a/packages/homeserver/src/homeserver.module.ts b/packages/homeserver/src/homeserver.module.ts index aef891fa7..117e63185 100644 --- a/packages/homeserver/src/homeserver.module.ts +++ b/packages/homeserver/src/homeserver.module.ts @@ -25,7 +25,6 @@ import { internalMessagePlugin } from './controllers/internal/message.controller import { pingPlugin } from './controllers/internal/ping.controller'; import { internalRoomPlugin } from './controllers/internal/room.controller'; import { serverKeyPlugin } from './controllers/key/server.controller'; -import { mediaPlugin } from './controllers/media.controller'; import { wellKnownPlugin } from './controllers/well-known/well-known.controller'; export type { HomeserverEventSignatures }; @@ -118,8 +117,7 @@ export async function setup(options?: HomeserverSetupOptions) { .use(internalRoomPlugin) .use(serverKeyPlugin) .use(wellKnownPlugin) - .use(roomPlugin) - .use(mediaPlugin); + .use(roomPlugin); return { app, container }; } diff --git a/packages/homeserver/src/middlewares/acl.middleware.ts b/packages/homeserver/src/middlewares/acl.middleware.ts new file mode 100644 index 000000000..db85ac0f1 --- /dev/null +++ b/packages/homeserver/src/middlewares/acl.middleware.ts @@ -0,0 +1,51 @@ +import type { EventAuthorizationService } from '@hs/federation-sdk'; + +const errCodes = { + M_UNAUTHORIZED: { + errcode: 'M_UNAUTHORIZED', + error: 'Invalid or missing signature', + status: 401, + }, + M_FORBIDDEN: { + errcode: 'M_FORBIDDEN', + error: 'Access denied', + status: 403, + }, + M_UNKNOWN: { + errcode: 'M_UNKNOWN', + error: 'Internal server error while processing request', + status: 500, + }, +}; + +interface ACLContext { + params: { eventId: string }; + headers: Record; + request: Request; + set: Record; +} + +export const canAccessEvent = (federationAuth: EventAuthorizationService) => { + return async (context: ACLContext) => { + const { params, headers, request, set } = context; + const { eventId } = params; + const authorizationHeader = headers.authorization || ''; + const method = request.method; + const uri = new URL(request.url).pathname; + + const result = await federationAuth.canAccessEventFromAuthorizationHeader( + eventId, + authorizationHeader, + method, + uri, + ); + + if (!result.authorized) { + set.status = errCodes[result.errorCode].status; + return { + errcode: errCodes[result.errorCode].errcode, + error: errCodes[result.errorCode].error, + }; + } + }; +};