Skip to content

Commit

Permalink
Merge pull request #30 from polyfact/feat/add-models
Browse files Browse the repository at this point in the history
✨ Add model option
  • Loading branch information
kevin-btc authored Aug 16, 2023
2 parents 0615b26 + 64d680e commit e841063
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/repl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Polyfact from "../lib/index";

async function main() {
const { Chat } = Polyfact.exec();
const chat = new Chat({ autoMemory: true, provider: "openai" });
const chat = new Chat({ autoMemory: true, provider: "openai", model: "gpt-3.5-turbo" });

const rl = readline.createInterface({ input, output });

Expand Down
8 changes: 6 additions & 2 deletions lib/chats/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export async function createChat(

type ChatOptions = {
provider?: "openai" | "cohere" | "llama";
model?: string;
systemPrompt?: string;
autoMemory?: boolean;
};
Expand All @@ -56,6 +57,8 @@ export class Chat {

provider: "openai" | "cohere" | "llama";

model?: string;

clientOptions: Promise<ClientOptions>;

autoMemory?: Promise<Memory>;
Expand All @@ -64,6 +67,7 @@ export class Chat {
this.clientOptions = defaultOptions(clientOptions);
this.chatId = createChat(options.systemPrompt, this.clientOptions);
this.provider = options.provider || "openai";
this.model = options.model;
if (options.autoMemory) {
this.autoMemory = this.clientOptions.then((co) => new Memory(co));
}
Expand All @@ -81,7 +85,7 @@ export class Chat {

const result = await generateWithTokenUsage(
message,
{ provider: this.provider, ...options, chatId },
{ provider: this.provider, model: this.model, ...options, chatId },
this.clientOptions,
);

Expand Down Expand Up @@ -154,7 +158,7 @@ export class Chat {

const result = generateStream(
message,
{ provider: this.provider, ...options, chatId },
{ provider: this.provider, model: this.model, ...options, chatId },
await this.clientOptions,
);

Expand Down
5 changes: 5 additions & 0 deletions lib/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const GenerationAPIResponse = t.intersection([Required, PartialResultType]);

export type GenerationOptions = {
provider?: "openai" | "cohere" | "llama";
model?: string;
chatId?: string;
memory?: Memory;
memoryId?: string;
Expand Down Expand Up @@ -63,11 +64,13 @@ export async function generateWithTokenUsage(
// eslint-disable-next-line camelcase
chat_id?: string;
provider: GenerationOptions["provider"];
model?: string;
stop: GenerationOptions["stop"];
infos: boolean;
} = {
task,
provider: options?.provider || "openai",
model: options.model,
memory_id: (await options?.memory?.memoryId) || options?.memoryId || "",
chat_id: options?.chatId || "",
stop: options?.stop || [],
Expand Down Expand Up @@ -140,11 +143,13 @@ function stream(
// eslint-disable-next-line camelcase
chat_id?: string;
provider: GenerationOptions["provider"];
model?: string;
stop: GenerationOptions["stop"];
infos?: boolean;
} = {
task,
provider: options?.provider || "openai",
model: options?.model,
memory_id: (await options?.memory?.memoryId) || options?.memoryId || "",
chat_id: options?.chatId || "",
stop: options?.stop || [],
Expand Down

0 comments on commit e841063

Please sign in to comment.