diff --git a/examples/repl.ts b/examples/repl.ts index 129678e..ef106eb 100644 --- a/examples/repl.ts +++ b/examples/repl.ts @@ -4,7 +4,7 @@ import Polyfact from "../lib/index"; async function main() { const { Chat } = Polyfact.exec(); - const chat = new Chat({ autoMemory: true, provider: "openai" }); + const chat = new Chat({ autoMemory: true, provider: "openai", model: "gpt-3.5-turbo" }); const rl = readline.createInterface({ input, output }); diff --git a/lib/chats/index.ts b/lib/chats/index.ts index ab53cea..7881f3c 100644 --- a/lib/chats/index.ts +++ b/lib/chats/index.ts @@ -47,6 +47,7 @@ export async function createChat( type ChatOptions = { provider?: "openai" | "cohere" | "llama"; + model?: string; systemPrompt?: string; autoMemory?: boolean; }; @@ -56,6 +57,8 @@ export class Chat { provider: "openai" | "cohere" | "llama"; + model?: string; + clientOptions: Promise; autoMemory?: Promise; @@ -64,6 +67,7 @@ export class Chat { this.clientOptions = defaultOptions(clientOptions); this.chatId = createChat(options.systemPrompt, this.clientOptions); this.provider = options.provider || "openai"; + this.model = options.model; if (options.autoMemory) { this.autoMemory = this.clientOptions.then((co) => new Memory(co)); } @@ -81,7 +85,7 @@ export class Chat { const result = await generateWithTokenUsage( message, - { provider: this.provider, ...options, chatId }, + { provider: this.provider, model: this.model, ...options, chatId }, this.clientOptions, ); @@ -154,7 +158,7 @@ export class Chat { const result = generateStream( message, - { provider: this.provider, ...options, chatId }, + { provider: this.provider, model: this.model, ...options, chatId }, await this.clientOptions, ); diff --git a/lib/generate.ts b/lib/generate.ts index a6d6456..f6261cd 100644 --- a/lib/generate.ts +++ b/lib/generate.ts @@ -22,6 +22,7 @@ const GenerationAPIResponse = t.intersection([Required, PartialResultType]); export type GenerationOptions = { provider?: "openai" | "cohere" | "llama"; + model?: string; chatId?: string; memory?: Memory; memoryId?: string; @@ -63,11 +64,13 @@ export async function generateWithTokenUsage( // eslint-disable-next-line camelcase chat_id?: string; provider: GenerationOptions["provider"]; + model?: string; stop: GenerationOptions["stop"]; infos: boolean; } = { task, provider: options?.provider || "openai", + model: options.model, memory_id: (await options?.memory?.memoryId) || options?.memoryId || "", chat_id: options?.chatId || "", stop: options?.stop || [], @@ -140,11 +143,13 @@ function stream( // eslint-disable-next-line camelcase chat_id?: string; provider: GenerationOptions["provider"]; + model?: string; stop: GenerationOptions["stop"]; infos?: boolean; } = { task, provider: options?.provider || "openai", + model: options?.model, memory_id: (await options?.memory?.memoryId) || options?.memoryId || "", chat_id: options?.chatId || "", stop: options?.stop || [],