diff --git a/assistant/openapi.yaml b/assistant/openapi.yaml index 7412c80eb73..44dbf3aad74 100644 --- a/assistant/openapi.yaml +++ b/assistant/openapi.yaml @@ -5893,6 +5893,10 @@ paths: required: - accepted additionalProperties: false + "400": + description: x-vellum-client-id header is missing for a targeted host bash request. + "403": + description: Submitting client does not match the targeted client for this request. requestBody: required: true content: diff --git a/assistant/src/__tests__/host-bash-routes.test.ts b/assistant/src/__tests__/host-bash-routes.test.ts new file mode 100644 index 00000000000..4927876443e --- /dev/null +++ b/assistant/src/__tests__/host-bash-routes.test.ts @@ -0,0 +1,291 @@ +/** + * Unit tests for the /v1/host-bash-result route handler. + * + * Covers the client-identity validation introduced by the targeted-host-proxy + * plan: when a pending interaction has a `targetClientId`, the submitting + * client must supply a matching `x-vellum-client-id` header or be rejected + * with 400 (missing) or 403 (mismatch). + */ +import { afterAll, beforeEach, describe, expect, mock, test } from "bun:test"; + +// ── Module mocks ───────────────────────────────────────────────────── + +mock.module("../config/env.js", () => ({ + isHttpAuthDisabled: () => true, + hasUngatedHttpAuthDisabled: () => false, +})); + +import type { PendingInteraction } from "../runtime/pending-interactions.js"; + +// Stored pending interactions keyed by requestId. +const pendingStore = new Map(); +const resolvedIds: string[] = []; + +mock.module("../runtime/pending-interactions.js", () => ({ + get: (requestId: string) => pendingStore.get(requestId), + resolve: (requestId: string) => { + const entry = pendingStore.get(requestId); + if (entry) { + pendingStore.delete(requestId); + resolvedIds.push(requestId); + } + return entry; + }, +})); + +interface ResolveCall { + requestId: string; + result: { stdout: string; stderr: string; exitCode: number | null; timedOut: boolean }; +} + +const resolveSpy: ResolveCall[] = []; + +mock.module("../daemon/host-bash-proxy.js", () => ({ + HostBashProxy: { + get instance() { + return { + resolve( + requestId: string, + result: { stdout: string; stderr: string; exitCode: number | null; timedOut: boolean }, + ) { + resolveSpy.push({ requestId, result }); + }, + }; + }, + }, +})); + +// ── Real imports (after mocks) ─────────────────────────────────────── + +import { + BadRequestError, + ConflictError, + ForbiddenError, + NotFoundError, +} from "../runtime/routes/errors.js"; +import { ROUTES } from "../runtime/routes/host-bash-routes.js"; + +afterAll(() => { + mock.restore(); +}); + +const handleHostBashResult = ROUTES.find( + (r) => r.endpoint === "host-bash-result", +)!.handler; + +// ── Helpers ────────────────────────────────────────────────────────── + +function registerPending( + requestId: string, + overrides: Partial = {}, +): void { + pendingStore.set(requestId, { + conversationId: "conv-1", + kind: "host_bash", + ...overrides, + }); +} + +function bashBody(requestId: string): Record { + return { + requestId, + stdout: "hello\n", + stderr: "", + exitCode: 0, + timedOut: false, + }; +} + +// ── Tests ──────────────────────────────────────────────────────────── + +describe("handleHostBashResult", () => { + beforeEach(() => { + pendingStore.clear(); + resolvedIds.length = 0; + resolveSpy.length = 0; + }); + + // ── Happy paths ──────────────────────────────────────────────────── + + describe("untargeted request (no targetClientId)", () => { + test("accepts when header is present", async () => { + const requestId = "req-untargeted-with-header"; + registerPending(requestId); + + const result = await handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": "client-abc" }, + }); + + expect(result).toEqual({ accepted: true }); + expect(resolveSpy).toHaveLength(1); + expect(resolvedIds).toContain(requestId); + }); + + test("accepts when header is absent", async () => { + const requestId = "req-untargeted-no-header"; + registerPending(requestId); + + const result = await handleHostBashResult({ + body: bashBody(requestId), + }); + + expect(result).toEqual({ accepted: true }); + expect(resolveSpy).toHaveLength(1); + expect(resolvedIds).toContain(requestId); + }); + }); + + describe("targeted request (targetClientId set)", () => { + test("accepts when x-vellum-client-id matches targetClientId", async () => { + const requestId = "req-targeted-match"; + registerPending(requestId, { targetClientId: "client-abc" }); + + const result = await handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": "client-abc" }, + }); + + expect(result).toEqual({ accepted: true }); + expect(resolveSpy).toHaveLength(1); + expect(resolveSpy[0].requestId).toBe(requestId); + expect(resolvedIds).toContain(requestId); + }); + + test("trims whitespace from x-vellum-client-id before comparing", async () => { + const requestId = "req-targeted-trim"; + registerPending(requestId, { targetClientId: "client-abc" }); + + const result = await handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": " client-abc " }, + }); + + expect(result).toEqual({ accepted: true }); + }); + }); + + // ── Error: missing header on targeted request ────────────────────── + + describe("targeted request — missing x-vellum-client-id header", () => { + test("throws BadRequestError (400) when header is absent", () => { + const requestId = "req-targeted-no-header"; + registerPending(requestId, { targetClientId: "client-abc" }); + + expect(() => + handleHostBashResult({ body: bashBody(requestId) }), + ).toThrow(BadRequestError); + }); + + test("throws BadRequestError (400) when header is empty string", () => { + const requestId = "req-targeted-empty-header"; + registerPending(requestId, { targetClientId: "client-abc" }); + + expect(() => + handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": " " }, + }), + ).toThrow(BadRequestError); + }); + + test("interaction is NOT resolved on 400 (still pending)", () => { + const requestId = "req-targeted-no-header-stays"; + registerPending(requestId, { targetClientId: "client-abc" }); + + try { + handleHostBashResult({ body: bashBody(requestId) }); + } catch { + // expected + } + + expect(resolvedIds).not.toContain(requestId); + expect(pendingStore.has(requestId)).toBe(true); + }); + }); + + // ── Error: wrong client ──────────────────────────────────────────── + + describe("targeted request — mismatched x-vellum-client-id", () => { + test("throws ForbiddenError (403) when client ID does not match", () => { + const requestId = "req-targeted-mismatch"; + registerPending(requestId, { targetClientId: "client-abc" }); + + expect(() => + handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": "client-xyz" }, + }), + ).toThrow(ForbiddenError); + }); + + test("ForbiddenError message names both the submitting and expected client", () => { + const requestId = "req-targeted-mismatch-msg"; + registerPending(requestId, { targetClientId: "client-abc" }); + + let caught: unknown; + try { + handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": "client-xyz" }, + }); + } catch (e) { + caught = e; + } + + expect(caught).toBeInstanceOf(ForbiddenError); + const msg = (caught as ForbiddenError).message; + expect(msg).toContain("client-xyz"); + expect(msg).toContain("client-abc"); + }); + + test("interaction is NOT resolved on 403 (still pending)", () => { + const requestId = "req-targeted-mismatch-stays"; + registerPending(requestId, { targetClientId: "client-abc" }); + + try { + handleHostBashResult({ + body: bashBody(requestId), + headers: { "x-vellum-client-id": "client-xyz" }, + }); + } catch { + // expected + } + + expect(resolvedIds).not.toContain(requestId); + expect(pendingStore.has(requestId)).toBe(true); + }); + }); + + // ── Other existing validations (regression) ──────────────────────── + + test("throws BadRequestError when body is missing", () => { + expect(() => handleHostBashResult({})).toThrow(BadRequestError); + }); + + test("throws BadRequestError when requestId is missing", () => { + expect(() => + handleHostBashResult({ body: { stdout: "x" } }), + ).toThrow(BadRequestError); + }); + + test("throws NotFoundError for unknown requestId", () => { + expect(() => + handleHostBashResult({ + body: bashBody("unknown-req-id"), + }), + ).toThrow(NotFoundError); + }); + + test("throws ConflictError when pending interaction is not host_bash kind", () => { + const requestId = "req-wrong-kind"; + pendingStore.set(requestId, { + conversationId: "conv-1", + kind: "confirmation", + }); + + expect(() => + handleHostBashResult({ body: bashBody(requestId) }), + ).toThrow(ConflictError); + }); +}); diff --git a/assistant/src/runtime/routes/host-bash-routes.ts b/assistant/src/runtime/routes/host-bash-routes.ts index 19a4fd48a41..f1eab46844b 100644 --- a/assistant/src/runtime/routes/host-bash-routes.ts +++ b/assistant/src/runtime/routes/host-bash-routes.ts @@ -11,6 +11,7 @@ import * as pendingInteractions from "../pending-interactions.js"; import { BadRequestError, ConflictError, + ForbiddenError, NotFoundError, } from "./errors.js"; import type { RouteDefinition, RouteHandlerArgs } from "./types.js"; @@ -19,7 +20,7 @@ import type { RouteDefinition, RouteHandlerArgs } from "./types.js"; // POST /v1/host-bash-result // --------------------------------------------------------------------------- -function handleHostBashResult({ body }: RouteHandlerArgs) { +function handleHostBashResult({ body, headers }: RouteHandlerArgs) { if (!body || typeof body !== "object") { throw new BadRequestError("Request body is required"); } @@ -36,6 +37,8 @@ function handleHostBashResult({ body }: RouteHandlerArgs) { throw new BadRequestError("requestId is required"); } + const submittingClientId = headers?.["x-vellum-client-id"]?.trim() || undefined; + const peeked = pendingInteractions.get(requestId); if (!peeked) { throw new NotFoundError( @@ -49,6 +52,20 @@ function handleHostBashResult({ body }: RouteHandlerArgs) { ); } + const { targetClientId } = peeked; + if (targetClientId) { + if (!submittingClientId) { + throw new BadRequestError( + "x-vellum-client-id header is required for targeted host bash requests", + ); + } + if (submittingClientId !== targetClientId) { + throw new ForbiddenError( + `Client "${submittingClientId}" is not the target for this request (expected "${targetClientId}"). The targeted client must submit the result.`, + ); + } + } + pendingInteractions.resolve(requestId); HostBashProxy.instance.resolve(requestId, { @@ -84,6 +101,16 @@ export const ROUTES: RouteDefinition[] = [ responseBody: z.object({ accepted: z.boolean(), }), + additionalResponses: { + "400": { + description: + "x-vellum-client-id header is missing for a targeted host bash request.", + }, + "403": { + description: + "Submitting client does not match the targeted client for this request.", + }, + }, handler: handleHostBashResult, }, ]; diff --git a/clients/shared/Network/GatewayHTTPClient.swift b/clients/shared/Network/GatewayHTTPClient.swift index a0ac9503128..7cd4dfa9337 100644 --- a/clients/shared/Network/GatewayHTTPClient.swift +++ b/clients/shared/Network/GatewayHTTPClient.swift @@ -106,15 +106,22 @@ public enum GatewayHTTPClient { /// - body: Optional HTTP body data. /// - params: Optional query parameters. Keys and values are percent-encoded /// using a restricted character set that escapes `&`, `=`, `+`, and `#`. + /// - contentType: Optional Content-Type override. Defaults to `application/json`. + /// - extraHeaders: Optional additional headers to include in the request. /// - timeout: Request timeout in seconds. Defaults to 30. /// - Returns: A `Response` with the raw data and HTTP status code. /// - Throws: `ClientError` if the request cannot be constructed, or network errors from `URLSession`. - public static func post(path: String, body: Data? = nil, params: [String: String]? = nil, contentType: String? = nil, timeout: TimeInterval = 30, unprefixed: Bool = false) async throws -> Response { + public static func post(path: String, body: Data? = nil, params: [String: String]? = nil, contentType: String? = nil, extraHeaders: [String: String]? = nil, timeout: TimeInterval = 30, unprefixed: Bool = false) async throws -> Response { return try await executeWithRetry(path: path, params: params, method: "POST", timeout: timeout, unprefixed: unprefixed) { request in request.httpBody = body if let contentType { request.setValue(contentType, forHTTPHeaderField: "Content-Type") } + if let extraHeaders { + for (k, v) in extraHeaders { + request.setValue(v, forHTTPHeaderField: k) + } + } } } diff --git a/clients/shared/Network/HostProxyClient.swift b/clients/shared/Network/HostProxyClient.swift index 1593324cd2f..9348e94510a 100644 --- a/clients/shared/Network/HostProxyClient.swift +++ b/clients/shared/Network/HostProxyClient.swift @@ -24,6 +24,7 @@ public struct HostProxyClient: HostProxyClientProtocol { let response = try await GatewayHTTPClient.post( path: "host-bash-result", body: body, + extraHeaders: ["X-Vellum-Client-Id": DeviceIdStore.getOrCreate()], timeout: 30 ) guard response.isSuccess else {