diff --git a/package.json b/package.json index c7265e3..7bfa98e 100644 --- a/package.json +++ b/package.json @@ -26,6 +26,7 @@ "typescript": "^5.2.2" }, "dependencies": { - "zod": "^3.22.2" + "zod": "^3.22.2", + "zod-to-json-schema": "^3.21.4" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1f7318d..4106849 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,6 +8,9 @@ dependencies: zod: specifier: ^3.22.2 version: 3.22.2 + zod-to-json-schema: + specifier: ^3.21.4 + version: 3.21.4(zod@3.22.2) devDependencies: '@types/jest': @@ -2610,6 +2613,14 @@ packages: engines: {node: '>=10'} dev: true + /zod-to-json-schema@3.21.4(zod@3.22.2): + resolution: {integrity: sha512-fjUZh4nQ1s6HMccgIeE0VP4QG/YRGPmyjO9sAh890aQKPEk3nqbfUXhMFaC+Dr5KvYBm8BCyvfpZf2jY9aGSsw==} + peerDependencies: + zod: ^3.21.4 + dependencies: + zod: 3.22.2 + dev: false + /zod@3.22.2: resolution: {integrity: sha512-wvWkphh5WQsJbVk1tbx1l1Ly4yg+XecD+Mq280uBGt9wa5BKSWf4Mhp6GmrkPixhMxmabYY7RbzlwVP32pbGCg==} dev: false diff --git a/src/Chat.ts b/src/Chat.ts index 732a80c..1bd0ad9 100644 --- a/src/Chat.ts +++ b/src/Chat.ts @@ -1,8 +1,11 @@ import OpenAI from "openai"; import { PromptBuilder } from "./PromptBuilder"; import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types"; +import { ToolBuilder } from "./ToolBuilder"; +import { Tool, ToolType } from './Tool' export class Chat< + const ToolNames extends string, TMessages extends | [] | [ @@ -14,8 +17,15 @@ export class Chat< constructor( public messages: TMessages, public args: TSuppliedInputArgs, + public tools = {} as Record>, + public mustUseTool: boolean = false ) {} + toJSONSchema() { + const tools = Object.values(this.tools) as Tool[]; + return tools.reduce((acc, t) => ({ ...acc, ...t.toJSONSchema()}), {}) + } + toArray() { return this.messages.map((m) => ({ role: m.role, diff --git a/src/ChatBuilder.ts b/src/ChatBuilder.ts index 29e8cc2..1b0ff2f 100644 --- a/src/ChatBuilder.ts +++ b/src/ChatBuilder.ts @@ -8,6 +8,7 @@ import { TypeToZodShape, ReplaceChatArgs, } from "./types"; +import { ToolBuilder } from "./ToolBuilder"; export class ChatBuilder< TMessages extends @@ -90,7 +91,7 @@ export class ChatBuilder< build( args: TSuppliedInputArgs, ) { - return new Chat( + return new Chat<"", TMessages, TSuppliedInputArgs>( this.messages as any, args, ).toArray(); diff --git a/src/Tool.ts b/src/Tool.ts new file mode 100644 index 0000000..833199c --- /dev/null +++ b/src/Tool.ts @@ -0,0 +1,50 @@ +import { z, ZodType } from 'zod' +import { zodToJsonSchema } from 'zod-to-json-schema' + +export const ToolType = z.enum(["query", "mutation"]) +export type ToolType = z.infer + + +export class Tool< + TName extends string, + TType extends "query" | "mutation", + const TExpectedInput extends { [key: string]: string }, + TExpectedOutput +> { + constructor( + public name: TName, + public description: string, + public type: TType, + public use: (input: TExpectedInput) => TExpectedOutput, + public input?: ZodType, + public output?: ZodType, + ) {} + + toJSONSchema() { + if (!this.input) { + throw new Error('Tool has no input schema. Please use ToolBuilder.addZodInputValidation to set.') + } + const schema = zodToJsonSchema(this.input) as any; + delete schema.$schema; + if (!schema.additionalProperties) delete schema.additionalProperties; + return { + name: this.name, + description: this.description, + parameters: schema, + }; + } + + validateInput(args: unknown): args is TExpectedInput { + if (!this.input) { + throw new Error('Tool has no input schema. Please use ToolBuilder.addZodInputValidation to set.') + } + return this.input.safeParse(args).success + } + + validateOutput(args: unknown): args is TExpectedOutput { + if (!this.output) { + throw new Error('Tool has no output schema. Please use ToolBuilder.addZodInputValidation to set.') + } + return this.output.safeParse(args).success + } +} \ No newline at end of file diff --git a/src/ToolBuilder.ts b/src/ToolBuilder.ts new file mode 100644 index 0000000..c52a92d --- /dev/null +++ b/src/ToolBuilder.ts @@ -0,0 +1,94 @@ +import { z, AnyZodObject, infer as _infer, ZodType } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { TypeToZodShape } from "./types"; +import { Tool } from "./Tool"; + +export class ToolBuilder< + TName extends string, + TType extends "query" | "mutation", + const TExpectedInput extends Record, + TExpectedOutput +> { + constructor( + public name: TName, + public description: string = "", + public type: TType = "query" as TType, + public implementation?: (input: TExpectedInput) => TExpectedOutput + ) {} + + addZodInputValidation( + shape: TypeToZodShape + ): ToolBuilder { + const zodValidator = z.object(shape as any); + return new (class extends ToolBuilder< + TName, + TType, + TShape, + TExpectedOutput + > { + validate(args: unknown): args is TShape { + return zodValidator.safeParse(args).success; + } + + query(queryFunction: (input: TExpectedInput) => TShape) { + // zodValidator.parse(args); + return new Tool(this.name, "query", queryFunction); + } + + mutation(mutationFunction: (input: TExpectedInput) => TShape) { + return new Tool(this.name, "mutation", mutationFunction); + } + })(this.name, this.description, this.type, this.implementation); + } + + addZodOutputValidation(shape: ZodType) { + const zodValidator = z.object(shape as any); + return new (class extends ToolBuilder< + TName, + TType, + TExpectedInput, + TShape + > { + validateOutput(output: unknown): output is TShape { + return zodValidator.safeParse(output).success; + } + + query(queryFunction: (input: TExpectedInput) => TShape) { + return new Tool(this.name, this.description, "query", queryFunction); + } + + mutation(mutationFunction: (input: TExpectedInput) => TShape) { + return new Tool(this.name, this.description, "mutation", mutationFunction); + } + })(this.name, this.description, this.type, this.implementation as any); + } + + query(queryFunction: (input: any) => any) { + return new Tool(this.name, this.description, "query", queryFunction); + } + + mutation(mutationFunction: (input: any) => any) { + return new Tool(this.name, this.description, "mutation", mutationFunction); + } + + toJSONSchema() { + // // const fns: any[] = []; + // // const { params, ...rest } = this.implementation[key]; + const schema = zodToJsonSchema( + z.object({ + name: z.string(), + }) + ); + delete schema.$schema; + // if (!schema.additionalProperties) delete schema.additionalProperties; + // // fns.push(); + return { + name: this.name, + parameters: schema, + }; + } + + build(input: TShape) { + return new Tool(this.name, this.description, this.type, this.implementation!); + } +} diff --git a/src/__tests__/Chat.test.ts b/src/__tests__/Chat.test.ts index 1778805..2219b20 100644 --- a/src/__tests__/Chat.test.ts +++ b/src/__tests__/Chat.test.ts @@ -2,10 +2,13 @@ import { strict as assert } from "node:assert"; import { Chat } from "../Chat"; import { system, user, assistant } from "../ChatHelpers"; import { Equal, Expect } from "./types.test"; +import { ToolBuilder } from "../ToolBuilder"; +import { Tool } from '../Tool' +import { z } from "zod"; describe("Chat", () => { it("should allow empty array", () => { - const chat = new Chat([], {}).toArray(); + const chat = new Chat([], {}, {}).toArray(); type test = Expect>; assert.deepEqual(chat, []); }); @@ -14,7 +17,7 @@ describe("Chat", () => { const chat = new Chat( [user("Tell me a {{jokeType}} joke")], // @ts-expect-error - {}, + {} ).toArray(); type test = Expect< Equal @@ -31,6 +34,9 @@ describe("Chat", () => { assert.deepEqual(chat, [usrMsg]); }); + const usrMsg = user("Tell me a funny joke"); + const astMsg = assistant("foo joke?"); + const sysMsg = system("joke? bar"); it("should allow chat of all diffent types", () => { const chat = new Chat( [ @@ -42,11 +48,8 @@ describe("Chat", () => { jokeType1: "funny", var2: "foo", var3: "bar", - }, + } ).toArray(); - const usrMsg = user("Tell me a funny joke"); - const astMsg = assistant("foo joke?"); - const sysMsg = system("joke? bar"); type test = Expect< Equal >; @@ -54,16 +57,100 @@ describe("Chat", () => { }); it("should allow chat of all diffent types with no args", () => { - const chat = new Chat( - [user(`Tell me a joke`), assistant(`joke?`), system(`joke?`)], - {}, - ).toArray(); - const usrMsg = user("Tell me a joke"); - const astMsg = assistant("joke?"); - const sysMsg = system("joke?"); + const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray(); type test = Expect< Equal >; assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]); }); + + it("should allow me to pass in tools", () => { + const google = new ToolBuilder("google") + .addZodInputValidation({ query: z.string() }) + .addZodOutputValidation(z.object({ results: z.array(z.string()) })) + .query(({ query }) => { + return { + results: ["foo", "bar"], + }; + }); + const wikipedia = new ToolBuilder("wikipedia") + .addZodInputValidation({ page: z.string() }) + .addZodOutputValidation(z.object({ results: z.array(z.string()) })) + .query(({ page }) => { + return { + results: ["foo", "bar"], + }; + }); + + const sendEmail = new ToolBuilder("sendEmail") + .addZodInputValidation({ + to: z.string(), + subject: z.string(), + body: z.string(), + }) + .addZodOutputValidation(z.object({ success: z.boolean() })) + .mutation(({ to, subject, body }) => { + return { + success: true, + }; + }); + const tools = { + google, + wikipedia, + sendEmail, + }; + const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools); + + type tests = [ + Expect< + Equal< + typeof chat, + Chat< + keyof typeof tools, + [typeof usrMsg, typeof astMsg, typeof sysMsg], + {} + > + > + >, + Expect< + Equal< + typeof tools, + { + google: Tool< + "google", + "query", + { + query: string; + }, + { + results: string[]; + } + >; + wikipedia: Tool< + "wikipedia", + "query", + { + page: string; + }, + { + results: string[]; + } + >; + sendEmail: Tool< + "sendEmail", + "mutation", + { + to: string; + subject: string; + body: string; + }, + { + success: boolean; + } + >; + } + > + > + ]; + }); }); diff --git a/src/__tests__/ToolBuilder.test.ts b/src/__tests__/ToolBuilder.test.ts new file mode 100644 index 0000000..70df216 --- /dev/null +++ b/src/__tests__/ToolBuilder.test.ts @@ -0,0 +1,107 @@ +import { ToolBuilder } from "../ToolBuilder"; +import { Equal, Expect } from "./types.test"; +import { z } from "zod"; + +describe("ToolBuilder", () => { + describe("no args", () => { + const noop = () => ""; + it("ToolBuilder should allow me to create a query tool with any args", () => { + const noopQuery = new ToolBuilder("noop").query(noop); + type tests = [ + Expect< + Equal< + typeof noopQuery, + ToolBuilder<"noop", "query", Record, unknown> + > + > + ]; + }); + + it("ToolBuilder should allow me to create a mutation tool with any args", () => { + const noopMutation = new ToolBuilder("noop").mutation(noop); + type tests = [ + Expect< + Equal< + typeof noopMutation, + ToolBuilder<"noop", "mutation", Record, unknown> + > + > + ]; + }); + }); + +// describe("specific args using typescript", () => { +// it("Output as object mapping", () => { +// const toolName = "fooToBar"; +// type Input = { foo: string }; +// type Output = { bar: string }; +// const fooToBar = new ToolBuilder(toolName) +// .addInputValidation() +// .addOutputValidation() +// .query(((args) => ({ bar: "test" })) satisfies (obj: Input) => Output); + +// type tests = [ +// Expect< +// Equal< +// typeof fooToBar, +// ToolBuilder<"fooToBar", "query", Input, Output> +// > +// > +// ]; +// }); + +// it("Input should throw type error for non object types", () => { +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// // @ts-expect-error +// new ToolBuilder("toolName").addInputValidation(); +// }); + +// it("Output should accept any type", () => { +// new ToolBuilder("toolName").addOutputValidation(); +// new ToolBuilder("toolName").addOutputValidation(); +// new ToolBuilder("toolName").addOutputValidation(); +// new ToolBuilder("toolName").addOutputValidation(); +// new ToolBuilder("toolName").addOutputValidation(); +// new ToolBuilder("toolName").addOutputValidation(); +// }); + +// it("should throw type error if query function does not satisfy input/output types", () => { +// new ToolBuilder("toolName") +// .addInputValidation<{ foo: string }>() +// .addOutputValidation<{ bar: string }>() +// // @ts-expect-error +// .query((args) => { +// type test = Expect>; +// return { asdf: "test" }; +// }); +// }); +// }); + + describe("addZodInputValidation", () => { + it("Output as object mapping", () => { + const toolName = "fooToBar"; + const fooToBar = new ToolBuilder(toolName) + .addZodInputValidation({ foo: z.string() }) + .addZodOutputValidation(z.object({ bar: z.string() })) + .query((args) => ({ bar: "test" })); + + type tests = [ + Expect< + Equal< + typeof fooToBar, + Tool<"fooToBar", "query", { foo: string }, { bar: string }> + > + > + ]; + }); + }); +});