Skip to content

Commit

Permalink
Add Ai Gateway methods to AI Binding
Browse files Browse the repository at this point in the history
  • Loading branch information
G4brym authored and fhanau committed Nov 27, 2024
1 parent 564c8a2 commit 7750cb6
Show file tree
Hide file tree
Showing 10 changed files with 505 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/cloudflare/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ export {
AiOptions,
InferenceUpstreamError,
Ai,
AiGateway,
AiGatewayInternalError,
AiGatewayLogNotFound,
} from 'cloudflare-internal:ai-api';
135 changes: 134 additions & 1 deletion src/cloudflare/internal/ai-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export class Ai {
};

const res = await this.fetcher.fetch(
'http://workers-binding.ai/run?version=3',
'https://workers-binding.ai/run?version=3',
fetchOptions
);

Expand Down Expand Up @@ -150,6 +150,139 @@ export class Ai {
return new InferenceUpstreamError(content);
}
}

public gateway(gatewayId: string): AiGateway {
return new AiGateway(this.fetcher, gatewayId);
}
}

//
// Ai Gateway
//

export type AiGatewayPatchLog = {
score?: number | null;
feedback?: -1 | 1 | '-1' | '1' | null;
metadata?: Record<string, number | string | boolean | null | bigint> | null;
};

export type AiGatewayLog = {
id: string;
provider: string;
model: string;
model_type?: string;
path: string;
duration: number;
request_type?: string;
request_content_type?: string;
status_code: number;
response_content_type?: string;
success: boolean;
cached: boolean;
tokens_in?: number;
tokens_out?: number;
metadata?: Record<string, number | string | boolean | null | bigint>;
step?: number;
cost?: number;
custom_cost?: boolean;
request_size: number;
request_head?: string;
request_head_complete: boolean;
response_size: number;
response_head?: string;
response_head_complete: boolean;
created_at: Date;
};

export class AiGatewayInternalError extends Error {
public constructor(message: string) {
super(message);
this.name = 'AiGatewayInternalError';
}
}

export class AiGatewayLogNotFound extends Error {
public constructor(message: string) {
super(message);
this.name = 'AiGatewayLogNotFound';
}
}

export class AiGateway {
private readonly fetcher: Fetcher;
private readonly gatewayId: string;

public constructor(fetcher: Fetcher, gatewayId: string) {
this.fetcher = fetcher;
this.gatewayId = gatewayId;
}

public async getLog(logId: string): Promise<AiGatewayLog> {
const res = await this.fetcher.fetch(
`https://workers-binding.ai/ai-gateway/gateways/${this.gatewayId}/logs/${logId}`,
{
method: 'GET',
}
);

switch (res.status) {
case 200: {
const data = (await res.json()) as { result: AiGatewayLog };

return {
...data.result,
created_at: new Date(data.result.created_at),
};
}
case 404: {
const data = (await res.json()) as { errors: { message: string }[] };

throw new AiGatewayLogNotFound(
data.errors[0]?.message || 'Log Not Found'
);
}
default: {
const data = (await res.json()) as { errors: { message: string }[] };

throw new AiGatewayInternalError(
data.errors[0]?.message || 'Internal Error'
);
}
}
}

public async patchLog(logId: string, data: AiGatewayPatchLog): Promise<void> {
const res = await this.fetcher.fetch(
`https://workers-binding.ai/ai-gateway/gateways/${this.gatewayId}/logs/${logId}`,
{
method: 'PATCH',
body: JSON.stringify(data),
headers: {
'content-type': 'application/json',
},
}
);

switch (res.status) {
case 200: {
return;
}
case 404: {
const data = (await res.json()) as { errors: { message: string }[] };

throw new AiGatewayLogNotFound(
data.errors[0]?.message || 'Log Not Found'
);
}
default: {
const data = (await res.json()) as { errors: { message: string }[] };

throw new AiGatewayInternalError(
data.errors[0]?.message || 'Internal Error'
);
}
}
}
}

export default function makeBinding(env: { fetcher: Fetcher }): Ai {
Expand Down
18 changes: 18 additions & 0 deletions src/cloudflare/internal/test/aig/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("//:build/wd_test.bzl", "wd_test")
load("//src/workerd/server/tests/python:py_wd_test.bzl", "py_wd_test")

wd_test(
src = "aig-api-test.wd-test",
args = ["--experimental"],
data = glob(["*.js"]),
)

py_wd_test(
size = "large",
src = "python-aig-api-test.wd-test",
args = ["--experimental"],
data = glob([
"*.js",
"*.py",
]),
)
111 changes: 111 additions & 0 deletions src/cloudflare/internal/test/aig/aig-api-test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2024 Cloudflare, Inc.
// Licensed under the Apache 2.0 license found in the LICENSE file or at:
// https://opensource.org/licenses/Apache-2.0

import * as assert from 'node:assert';

export const tests = {
async test(_, env) {
{
// Test gateway get log
const resp = await env.ai.gateway('my-gateway').getLog('my-log-123');
assert.deepEqual(resp, {
cached: false,
cost: 0,
created_at: new Date('2019-08-24T14:15:22Z'),
custom_cost: true,
duration: 0,
id: 'string',
metadata: 'string',
model: 'string',
model_type: 'string',
path: 'string',
provider: 'string',
request_content_type: 'string',
request_head: 'string',
request_head_complete: true,
request_size: 0,
request_type: 'string',
response_content_type: 'string',
response_head: 'string',
response_head_complete: true,
response_size: 0,
status_code: 0,
step: 0,
success: true,
tokens_in: 0,
tokens_out: 0,
});
}

{
// Test get log error responses
try {
await env.ai.gateway('my-gateway').getLog('404');
} catch (e) {
assert.deepEqual(
{
name: e.name,
message: e.message,
},
{
name: 'AiGatewayLogNotFound',
message: 'Not Found',
}
);
}
}

{
try {
await env.ai.gateway('my-gateway').getLog('500');
} catch (e) {
assert.deepEqual(
{
name: e.name,
message: e.message,
},
{
name: 'AiGatewayInternalError',
message: 'Internal Error',
}
);
}
}

{
// Test patch log error responses
try {
await env.ai.gateway('my-gateway').patchLog('404', { feedback: -1 });
} catch (e) {
assert.deepEqual(
{
name: e.name,
message: e.message,
},
{
name: 'AiGatewayLogNotFound',
message: 'Not Found',
}
);
}
}

{
try {
await env.ai.gateway('my-gateway').patchLog('500', { feedback: -1 });
} catch (e) {
assert.deepEqual(
{
name: e.name,
message: e.message,
},
{
name: 'AiGatewayInternalError',
message: 'Internal Error',
}
);
}
}
},
};
9 changes: 9 additions & 0 deletions src/cloudflare/internal/test/aig/aig-api-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2024 Cloudflare, Inc.
# Licensed under the Apache 2.0 license found in the LICENSE file or at:
# https://opensource.org/licenses/Apache-2.0


async def test(context, env):
resp = await env.ai.gateway("my-gateway").getLog("my-log-123")
assert resp.cached is False
assert resp.model == "string"
36 changes: 36 additions & 0 deletions src/cloudflare/internal/test/aig/aig-api-test.wd-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Workerd = import "/workerd/workerd.capnp";

const unitTests :Workerd.Config = (
services = [
( name = "aig-api-test",
worker = (
modules = [
(name = "worker", esModule = embed "aig-api-test.js")
],
compatibilityDate = "2023-01-15",
compatibilityFlags = ["nodejs_compat"],
bindings = [
(
name = "ai",
wrapped = (
moduleName = "cloudflare-internal:ai-api",
innerBindings = [(
name = "fetcher",
service = "aig-mock"
)],
)
)
],
)
),
( name = "aig-mock",
worker = (
compatibilityDate = "2023-01-15",
compatibilityFlags = ["experimental", "nodejs_compat"],
modules = [
(name = "worker", esModule = embed "aig-mock.js")
],
)
)
]
);
Loading

0 comments on commit 7750cb6

Please sign in to comment.