Skip to content

Commit

Permalink
Merge pull request #28 from polyfact/feat/auth
Browse files Browse the repository at this point in the history
✨ Add front-end auth & projects
  • Loading branch information
lowczarc authored Aug 11, 2023
2 parents 69d7f4c + ebe133b commit 8825bb1
Show file tree
Hide file tree
Showing 14 changed files with 646 additions and 125 deletions.
5 changes: 3 additions & 2 deletions examples/repl.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import * as readline from "node:readline/promises";
import { stdin as input, stdout as output } from "node:process";
import { Chat } from "../lib/index";
import Polyfact from "../lib/index";

async function main() {
const chat = new Chat({ autoMemory: true });
const { Chat } = Polyfact.exec();
const chat = new Chat({ autoMemory: true, provider: "openai" });

const rl = readline.createInterface({ input, output });

Expand Down
82 changes: 45 additions & 37 deletions lib/chats/index.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import axios, { AxiosError } from "axios";
import * as t from "polyfact-io-ts";
import { Readable, PassThrough } from "stream";
import { Readable, PassThrough } from "readable-stream";
import {
generateStream,
generateStreamWithInfos,
generateWithTokenUsage,
GenerationOptions,
GenerationResult,
} from "../generate";
import { ClientOptions, defaultOptions } from "../clientOpts";
import { InputClientOptions, ClientOptions, defaultOptions } from "../clientOpts";
import { Memory } from "../memory";
import { ApiError, ErrorData } from "../helpers/error";

Expand All @@ -21,10 +21,10 @@ const Message = t.type({
});
export async function createChat(
systemPrompt?: string,
options: Partial<ClientOptions> = {},
options: InputClientOptions = {},
): Promise<string> {
try {
const { token, endpoint } = defaultOptions(options);
const { token, endpoint } = await defaultOptions(options);

const response = await axios.post(
`${endpoint}/chats`,
Expand All @@ -45,28 +45,27 @@ export async function createChat(
}
}

type ChatOptions = {
provider?: "openai" | "cohere" | "llama";
systemPrompt?: string;
autoMemory?: boolean;
};

export class Chat {
chatId: Promise<string>;

provider: "openai" | "cohere";
provider: "openai" | "cohere" | "llama";

clientOptions: ClientOptions;
clientOptions: Promise<ClientOptions>;

autoMemory?: Memory;
autoMemory?: Promise<Memory>;

constructor(
options: {
provider?: "openai" | "cohere";
systemPrompt?: string;
autoMemory?: boolean;
} = {},
clientOptions: Partial<ClientOptions> = {},
) {
constructor(options: ChatOptions = {}, clientOptions: InputClientOptions = {}) {
this.clientOptions = defaultOptions(clientOptions);
this.chatId = createChat(options.systemPrompt, this.clientOptions);
this.provider = options.provider || "openai";
if (options.autoMemory) {
this.autoMemory = new Memory(this.clientOptions);
this.autoMemory = this.clientOptions.then((co) => new Memory(co));
}
}

Expand All @@ -77,18 +76,18 @@ export class Chat {
const chatId = await this.chatId;

if (this.autoMemory && !options.memory && !options.memoryId) {
options.memory = this.autoMemory;
options.memory = await this.autoMemory;
}

const result = await generateWithTokenUsage(
message,
{ ...options, chatId },
{ provider: this.provider, ...options, chatId },
this.clientOptions,
);

if (this.autoMemory) {
this.autoMemory.add(`Human: ${message}`);
this.autoMemory.add(`AI: ${result.result}`);
(await this.autoMemory).add(`Human: ${message}`);
(await this.autoMemory).add(`AI: ${result.result}`);
}

return result;
Expand All @@ -111,13 +110,13 @@ export class Chat {
const chatId = await this.chatId;

if (this.autoMemory && !options.memory && !options.memoryId) {
options.memory = this.autoMemory;
options.memory = await this.autoMemory;
}

const result = generateStreamWithInfos(
message,
{ ...options, chatId },
this.clientOptions,
await this.clientOptions,
);

result.on("infos", (data) => {
Expand All @@ -130,11 +129,13 @@ export class Chat {
});
result.on("end", () => {
resultStream.push(null);
if (this.autoMemory) {
const totalResult = Buffer.concat(bufs).toString("utf8");
this.autoMemory.add(`Human: ${message}`);
this.autoMemory.add(`AI: ${totalResult}`);
}
(async () => {
if (this.autoMemory) {
const totalResult = Buffer.concat(bufs).toString("utf8");
(await this.autoMemory).add(`Human: ${message}`);
(await this.autoMemory).add(`AI: ${totalResult}`);
}
})();
});
})();

Expand All @@ -148,10 +149,14 @@ export class Chat {
const chatId = await this.chatId;

if (this.autoMemory && !options.memory && !options.memoryId) {
options.memory = this.autoMemory;
options.memory = await this.autoMemory;
}

const result = generateStream(message, { ...options, chatId }, this.clientOptions);
const result = generateStream(
message,
{ provider: this.provider, ...options, chatId },
await this.clientOptions,
);

result.pipe(resultStream);

Expand All @@ -166,21 +171,21 @@ export class Chat {
});

if (this.autoMemory) {
this.autoMemory.add(`Human: ${message}`);
this.autoMemory.add(`AI: ${totalResult}`);
(await this.autoMemory).add(`Human: ${message}`);
(await this.autoMemory).add(`AI: ${totalResult}`);
}
})();

return resultStream;
return resultStream as unknown as Readable;
}

async getMessages(): Promise<t.TypeOf<typeof Message>[]> {
try {
const response = await axios.get(
`${this.clientOptions.endpoint}/chat/${await this.chatId}/history`,
`${(await this.clientOptions).endpoint}/chat/${await this.chatId}/history`,
{
headers: {
"X-Access-Token": this.clientOptions.token,
"X-Access-Token": (await this.clientOptions).token,
},
},
);
Expand All @@ -197,10 +202,13 @@ export class Chat {
}
}

export default function client(clientOptions: Partial<ClientOptions> = {}) {
export default function client(clientOptions: InputClientOptions = {}) {
return {
createChat: (systemPrompt?: string) => createChat(systemPrompt, clientOptions),
Chat: (options?: { provider?: "openai" | "cohere"; systemPrompt?: string }) =>
new Chat(options, clientOptions),
Chat: class C extends Chat {
constructor(options: ChatOptions = {}) {
super(options, clientOptions);
}
},
};
}
14 changes: 10 additions & 4 deletions lib/clientOpts.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import { createClient, SupabaseClient } from "@supabase/supabase-js";
import { POLYFACT_ENDPOINT, POLYFACT_TOKEN } from "./utils";

export type ClientOptions = {
endpoint: string;
token: string;
};

export function defaultOptions(opts: Partial<ClientOptions>): ClientOptions {
if (!(opts.token || process?.env?.POLYFACT_TOKEN)) {
export type InputClientOptions = Partial<ClientOptions> | Promise<Partial<ClientOptions>>;

export async function defaultOptions(popts: InputClientOptions): Promise<ClientOptions> {
const opts = await popts;
if (!opts.token && !POLYFACT_TOKEN) {
throw new Error(
"Please put your polyfact token in the POLYFACT_TOKEN environment variable. You can get one at https://app.polyfact.com",
);
}

return {
endpoint: process?.env.POLYFACT_ENDPOINT || "https://api2.polyfact.com",
token: process?.env?.POLYFACT_TOKEN || "",
endpoint: POLYFACT_ENDPOINT || "https://api2.polyfact.com",
token: POLYFACT_TOKEN || "",
...opts,
};
}
28 changes: 16 additions & 12 deletions lib/generate.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import axios, { AxiosError } from "axios";
import * as t from "polyfact-io-ts";
import { Readable } from "stream";
import { Readable } from "readable-stream";
import WebSocket from "isomorphic-ws";
import { ClientOptions, defaultOptions } from "./clientOpts";
import { InputClientOptions, defaultOptions } from "./clientOpts";
import { Memory } from "./memory";
import { ApiError, ErrorData } from "./helpers/error";

Expand All @@ -21,7 +21,7 @@ const Required = t.type({
const GenerationAPIResponse = t.intersection([Required, PartialResultType]);

export type GenerationOptions = {
provider?: "openai" | "cohere";
provider?: "openai" | "cohere" | "llama";
chatId?: string;
memory?: Memory;
memoryId?: string;
Expand Down Expand Up @@ -53,9 +53,9 @@ export type GenerationResult = {
export async function generateWithTokenUsage(
task: string,
options: GenerationOptionsWithInfos = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
): Promise<GenerationResult> {
const { token, endpoint } = defaultOptions(clientOptions);
const { token, endpoint } = await defaultOptions(clientOptions);
const requestBody: {
task: string;
// eslint-disable-next-line camelcase
Expand Down Expand Up @@ -105,7 +105,7 @@ export async function generateWithTokenUsage(
export async function generate(
task: string,
options: GenerationOptions = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
): Promise<string> {
const res = await generateWithTokenUsage(task, options, clientOptions);

Expand All @@ -115,7 +115,7 @@ export async function generate(
export async function generateWithInfo(
task: string,
options: GenerationOptionsWithInfos = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
): Promise<GenerationResult> {
options.infos = true;
const res = await generateWithTokenUsage(task, options, clientOptions);
Expand All @@ -126,7 +126,7 @@ export async function generateWithInfo(
function stream(
task: string,
options: GenerationOptionsWithInfos = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
onMessage: (data: any, resultStream: Readable) => void,
): Readable {
const resultStream = new Readable({
Expand All @@ -151,7 +151,7 @@ function stream(
infos: options?.infos || false,
};

const { token, endpoint } = defaultOptions(clientOptions);
const { token, endpoint } = await defaultOptions(clientOptions);
const ws = new WebSocket(`${endpoint.replace("http", "ws")}/stream?token=${token}`);

ws.onopen = () => ws.send(JSON.stringify(requestBody));
Expand All @@ -163,7 +163,7 @@ function stream(
export function generateStreamWithInfos(
task: string,
options: GenerationOptions = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
): Readable {
return stream(
task,
Expand All @@ -189,7 +189,7 @@ export function generateStreamWithInfos(
export function generateStream(
task: string,
options: GenerationOptions = {},
clientOptions: Partial<ClientOptions> = {},
clientOptions: InputClientOptions = {},
): Readable {
return stream(task, options, clientOptions, (data: any, resultStream: Readable) => {
if (data.data === "") {
Expand All @@ -200,13 +200,17 @@ export function generateStream(
});
}

export default function client(clientOptions: Partial<ClientOptions> = {}) {
export default function client(clientOptions: InputClientOptions = {}) {
return {
generateWithTokenUsage: (task: string, options: GenerationOptions = {}) =>
generateWithTokenUsage(task, options, clientOptions),
generate: (task: string, options: GenerationOptions = {}) =>
generate(task, options, clientOptions),
generateWithInfo: (task: string, options: GenerationOptions = {}) =>
generateWithInfo(task, options, clientOptions),
generateStream: (task: string, options: GenerationOptions = {}) =>
generateStream(task, options, clientOptions),
generateStreamWithInfos: (task: string, options: GenerationOptions = {}) =>
generateStreamWithInfos(task, options, clientOptions),
};
}
9 changes: 0 additions & 9 deletions lib/helpers/ensurePolyfactToken.ts

This file was deleted.

Loading

0 comments on commit 8825bb1

Please sign in to comment.