Skip to content

Commit

Permalink
Update local AI fetcher to forward method and pathname to upstream (#…
Browse files Browse the repository at this point in the history
…7315)

* Update local AI fetcher to forward method and pathname to upstream

* Add unit test to cover changes

* Lint files

* Rename x-forward-for to x-forward header name
  • Loading branch information
G4brym authored Nov 22, 2024
1 parent c650cc9 commit 31729ee
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .changeset/stupid-moons-juggle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wrangler": minor
---

Update local AI fetcher to forward method and url to upstream
69 changes: 69 additions & 0 deletions packages/wrangler/src/__tests__/ai.local.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { Request } from "miniflare";
import { HttpResponse } from "msw";
import { AIFetcher } from "../ai/fetcher";
import * as internal from "../cfetch/internal";
import * as user from "../user";
import type { RequestInit } from "undici";

describe("ai", () => {
describe("fetcher", () => {
afterEach(() => {
vi.restoreAllMocks();
});

describe("local", () => {
it("should send x-forwarded header", async () => {
vi.spyOn(user, "getAccountId").mockImplementation(async () => "123");
vi.spyOn(internal, "performApiFetch").mockImplementation(
async (resource: string, init: RequestInit = {}) => {
const headers = new Headers(init.headers);
return HttpResponse.json({
xForwarded: headers.get("X-Forwarded"),
method: init.method,
});
}
);

const url = "http://internal.ai/ai/test/path?version=123";
const resp = await AIFetcher(
new Request(url, {
method: "PATCH",
headers: {
"x-example": "test",
},
})
);

expect(await resp.json()).toEqual({
xForwarded: url,
method: "PATCH",
});
});

it("account id should be set", async () => {
vi.spyOn(user, "getAccountId").mockImplementation(async () => "123");
vi.spyOn(internal, "performApiFetch").mockImplementation(
async (resource: string) => {
return HttpResponse.json({
resource: resource,
});
}
);

const url = "http://internal.ai/ai/test/path?version=123";
const resp = await AIFetcher(
new Request(url, {
method: "PATCH",
headers: {
"x-example": "test",
},
})
);

expect(await resp.json()).toEqual({
resource: "/accounts/123/ai/run/proxy",
});
});
});
});
});
10 changes: 6 additions & 4 deletions packages/wrangler/src/ai/fetcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ export default function (env) {
export async function AIFetcher(request: Request): Promise<Response> {
const accountId = await getAccountId();

request.headers.delete("Host");
request.headers.delete("Content-Length");
const reqHeaders = new Headers(request.headers);
reqHeaders.delete("Host");
reqHeaders.delete("Content-Length");
reqHeaders.set("X-Forwarded", request.url);

const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, {
method: "POST",
headers: Object.fromEntries(request.headers.entries()),
method: request.method,
headers: Object.fromEntries(reqHeaders.entries()),
body: request.body,
duplex: "half",
});
Expand Down

0 comments on commit 31729ee

Please sign in to comment.