From 31729ee63df0fbaf34787ab9e5a53f7180d0ec8c Mon Sep 17 00:00:00 2001 From: Gabriel Massadas <5445926+G4brym@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:26:04 +0000 Subject: [PATCH] Update local AI fetcher to forward method and pathname to upstream (#7315) * Update local AI fetcher to forward method and pathname to upstream * Add unit test to cover changes * Lint files * Rename x-forward-for to x-forward header name --- .changeset/stupid-moons-juggle.md | 5 ++ .../wrangler/src/__tests__/ai.local.test.ts | 69 +++++++++++++++++++ packages/wrangler/src/ai/fetcher.ts | 10 +-- 3 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 .changeset/stupid-moons-juggle.md create mode 100644 packages/wrangler/src/__tests__/ai.local.test.ts diff --git a/.changeset/stupid-moons-juggle.md b/.changeset/stupid-moons-juggle.md new file mode 100644 index 000000000000..529e9c6886d0 --- /dev/null +++ b/.changeset/stupid-moons-juggle.md @@ -0,0 +1,5 @@ +--- +"wrangler": minor +--- + +Update local AI fetcher to forward method and url to upstream diff --git a/packages/wrangler/src/__tests__/ai.local.test.ts b/packages/wrangler/src/__tests__/ai.local.test.ts new file mode 100644 index 000000000000..7451fbe507b7 --- /dev/null +++ b/packages/wrangler/src/__tests__/ai.local.test.ts @@ -0,0 +1,69 @@ +import { Request } from "miniflare"; +import { HttpResponse } from "msw"; +import { AIFetcher } from "../ai/fetcher"; +import * as internal from "../cfetch/internal"; +import * as user from "../user"; +import type { RequestInit } from "undici"; + +describe("ai", () => { + describe("fetcher", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("local", () => { + it("should send x-forwarded header", async () => { + vi.spyOn(user, "getAccountId").mockImplementation(async () => "123"); + vi.spyOn(internal, "performApiFetch").mockImplementation( + async (resource: string, init: RequestInit = {}) => { + const headers = new Headers(init.headers); + return HttpResponse.json({ + xForwarded: headers.get("X-Forwarded"), + method: init.method, + }); + } + ); + + const url = "http://internal.ai/ai/test/path?version=123"; + const resp = await AIFetcher( + new Request(url, { + method: "PATCH", + headers: { + "x-example": "test", + }, + }) + ); + + expect(await resp.json()).toEqual({ + xForwarded: url, + method: "PATCH", + }); + }); + + it("account id should be set", async () => { + vi.spyOn(user, "getAccountId").mockImplementation(async () => "123"); + vi.spyOn(internal, "performApiFetch").mockImplementation( + async (resource: string) => { + return HttpResponse.json({ + resource: resource, + }); + } + ); + + const url = "http://internal.ai/ai/test/path?version=123"; + const resp = await AIFetcher( + new Request(url, { + method: "PATCH", + headers: { + "x-example": "test", + }, + }) + ); + + expect(await resp.json()).toEqual({ + resource: "/accounts/123/ai/run/proxy", + }); + }); + }); + }); +}); diff --git a/packages/wrangler/src/ai/fetcher.ts b/packages/wrangler/src/ai/fetcher.ts index 6660bb048ade..9d8d8e32168d 100644 --- a/packages/wrangler/src/ai/fetcher.ts +++ b/packages/wrangler/src/ai/fetcher.ts @@ -16,12 +16,14 @@ export default function (env) { export async function AIFetcher(request: Request): Promise { const accountId = await getAccountId(); - request.headers.delete("Host"); - request.headers.delete("Content-Length"); + const reqHeaders = new Headers(request.headers); + reqHeaders.delete("Host"); + reqHeaders.delete("Content-Length"); + reqHeaders.set("X-Forwarded", request.url); const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, { - method: "POST", - headers: Object.fromEntries(request.headers.entries()), + method: request.method, + headers: Object.fromEntries(reqHeaders.entries()), body: request.body, duplex: "half", });