Skip to content

Commit

Permalink
Merge pull request #36 from polyfact/feat/web-request
Browse files Browse the repository at this point in the history
feat: add web-request
  • Loading branch information
kevin-btc authored Aug 23, 2023
2 parents d583de1 + e0e48c8 commit ca550a0
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 49 deletions.
2 changes: 2 additions & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"prettier"
],

"ignorePatterns": ["examples/**/*"],
"parser": "@typescript-eslint/parser",
"parserOptions": {
Expand All @@ -40,6 +41,7 @@
"import/no-unresolved": "off",
"import/prefer-default-export": "off",
"import/extensions": "off",
"no-else-return": "off",

"no-unused-vars": "off",
"no-restricted-imports": [
Expand Down
58 changes: 58 additions & 0 deletions examples/web.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { Readable } from "readable-stream";
import { generateStream, generateWithTokenUsage } from "../lib/generate";

const config = {
endpoint: "http://localhost:8080",
token: "<YOUR_TOKEN>",
};

function handleStreamData(stream: Readable): Promise<void> {
return new Promise((resolve) => {
let buffer = "";

stream.on("data", (data) => {
buffer += data.toString();

let lastSpaceIndex = buffer.lastIndexOf(" ");
if (lastSpaceIndex !== -1) {
let partToPrint = buffer.substring(0, lastSpaceIndex);
buffer = buffer.substring(lastSpaceIndex + 1);
process.stdout.write(partToPrint + " ");
}
});

stream.on("end", () => {
if (buffer.length) {
process.stdout.write(buffer + "\n");
}
resolve();
});
});
}

(async () => {
const response = await generateWithTokenUsage(
"When is the next Olympics Games ?",
{ web: true, model: "gpt-4", provider: "openai" },
config,
);
console.log(response);

console.log("\n", "-".repeat(80), "\n");

const weatherStream = generateStream(
"what is the weather like today in Paris?",
{ web: true },
config,
);
await handleStreamData(weatherStream);

console.log("\n", "-".repeat(80), "\n");

const websiteSummaryStream = generateStream(
"summarize this website : read:https://www.polyfact.com/",
{ web: true },
config,
);
await handleStreamData(websiteSummaryStream);
})();
3 changes: 2 additions & 1 deletion lib/chats/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
generateWithTokenUsage,
GenerationOptions,
GenerationResult,
SystemPrompt,
} from "../generate";
import { InputClientOptions, ClientOptions, defaultOptions } from "../clientOpts";
import { Memory } from "../memory";
Expand Down Expand Up @@ -52,7 +53,7 @@ type ChatOptions = {
provider?: "openai" | "cohere" | "llama";
model?: string;
autoMemory?: boolean;
} & Exclusive<{ systemPrompt?: string }, { systemPromptId?: UUID }>;
} & SystemPrompt;

export class Chat {
chatId: Promise<string>;
Expand Down
146 changes: 98 additions & 48 deletions lib/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ const Required = t.type({

const GenerationAPIResponse = t.intersection([Required, PartialResultType]);

export type Exclusive<T, U = T> =
export type Exclusive<T, U> =
| (T & Partial<Record<Exclude<keyof U, keyof T>, never>>)
| (U & Partial<Record<Exclude<keyof T, keyof U>, never>>);

type SystemPrompt = Exclusive<{ systemPromptId?: UUID }, { systemPrompt?: string }>;
export type ExclusiveN<T extends Record<string, unknown>[]> = T extends [infer F, ...infer Rest]
? F extends Record<string, unknown>
? Exclusive<F, ExclusiveN<Extract<Rest, Record<string, unknown>[]>>>
: never
: unknown;

type ExclusiveProps = [{ systemPromptId?: UUID }, { systemPrompt?: string }];

export type SystemPrompt = ExclusiveN<ExclusiveProps>;

export type GenerationOptions = {
provider?: "openai" | "cohere" | "llama";
Expand All @@ -39,6 +47,11 @@ export type GenerationOptions = {
infos?: boolean;
} & SystemPrompt;

export type GenerationWithWebOptions = Omit<
GenerationOptions,
"chatId" | "memory" | "memoryId" | "stop" | "systemPromptId" | "systemPrompt"
> & { web: true };

export type TokenUsage = {
input: number;
output: number;
Expand All @@ -59,31 +72,11 @@ export type GenerationResult = {
ressources?: Ressource[];
};

export async function generateWithTokenUsage(
task: string,
options: GenerationOptions = {},
clientOptions: InputClientOptions = {},
async function generateRequest(
requestBody: Record<string, unknown>,
clientOptions: InputClientOptions,
): Promise<GenerationResult> {
const { token, endpoint } = await defaultOptions(clientOptions);
const requestBody: {
task: string;
memory_id?: string;
chat_id?: string;
provider: GenerationOptions["provider"];
model?: string;
stop: GenerationOptions["stop"];
infos: boolean;
system_prompt_id?: UUID;
} = {
task,
provider: options?.provider || "openai",
model: options.model,
memory_id: (await options?.memory?.memoryId) || options?.memoryId,
chat_id: options?.chatId,
stop: options?.stop || [],
infos: options?.infos || false,
system_prompt_id: options?.systemPromptId,
};

try {
const res = await axios.post(`${endpoint}/generate`, requestBody, {
Expand Down Expand Up @@ -113,6 +106,49 @@ export async function generateWithTokenUsage(
}
}

export async function generateWithTokenUsage(
task: string,
options: GenerationOptions,
clientOptions?: InputClientOptions,
): Promise<GenerationResult>;
export async function generateWithTokenUsage(
task: string,
options: GenerationWithWebOptions,
clientOptions?: InputClientOptions,
): Promise<GenerationResult>;

export async function generateWithTokenUsage(
task: string,
options: GenerationOptions | GenerationWithWebOptions = {},
clientOptions: InputClientOptions = {},
): Promise<GenerationResult> {
let requestBody = {};
if ("web" in options) {
requestBody = {
task,
provider: options.provider || "openai",
model: options.model || "gpt-3.5-turbo",
infos: options.infos || false,
web: options.web,
};
} else {
const genOptions = options as GenerationOptions;

requestBody = {
task,
provider: genOptions.provider || "openai",
model: genOptions.model || "gpt-3.5-turbo",
memory_id: (await genOptions.memory?.memoryId) || genOptions.memoryId,
chat_id: genOptions.chatId,
stop: genOptions.stop || [],
infos: genOptions.infos || false,
system_prompt_id: genOptions.systemPromptId,
};
}

return generateRequest(requestBody, clientOptions);
}

/**
* Generates a result based on provided options.
*
Expand Down Expand Up @@ -140,46 +176,60 @@ export async function generate(

function stream(
task: string,
options: GenerationOptions = {},
options: GenerationOptions | GenerationWithWebOptions = {},
clientOptions: InputClientOptions = {},
onMessage: (data: any, resultStream: Readable) => void,
onMessage: (data: unknown, resultStream: Readable) => void,
): Readable {
const resultStream = new Readable({
// eslint-disable-next-line @typescript-eslint/no-empty-function
read() {},
});
(async () => {
const requestBody: {
task: string;
memory_id?: string;
chat_id?: string;
provider: GenerationOptions["provider"];
model?: string;
stop: GenerationOptions["stop"];
infos?: boolean;
system_prompt_id?: UUID;
} = {
task,
provider: options?.provider || "openai",
model: options?.model,
memory_id: (await options?.memory?.memoryId) || options?.memoryId || "",
chat_id: options?.chatId || "",
stop: options?.stop || [],
infos: options?.infos || false,
system_prompt_id: options?.systemPromptId,
};
let requestBody = {};
if ("web" in options) {
requestBody = {
task,
provider: options.provider || "openai",
model: options.model || "gpt-3.5-turbo",
infos: options.infos || false,
web: options.web,
};
} else {
requestBody = {
task,
provider: options?.provider || "openai",
model: options?.model || "gpt-3.5-turbo",
memory_id: (await options?.memory?.memoryId) || options?.memoryId || "",
chat_id: options?.chatId || "",
stop: options?.stop || [],
infos: options?.infos || false,
system_prompt_id: options?.systemPromptId,
};
}

const { token, endpoint } = await defaultOptions(clientOptions);
const ws = new WebSocket(`${endpoint.replace("http", "ws")}/stream?token=${token}`);

ws.onopen = () => ws.send(JSON.stringify(requestBody));
ws.onmessage = (data: any) => onMessage(data, resultStream);
ws.onmessage = (data: unknown) => onMessage(data, resultStream);
})();
return resultStream;
}

export function generateStream(
task: string,
options: GenerationOptions = {},
options: GenerationOptions,
clientOptions: InputClientOptions,
): Readable;
export function generateStream(
task: string,
options: GenerationWithWebOptions,
clientOptions: InputClientOptions,
): Readable;

export function generateStream(
task: string,
options: GenerationOptions | GenerationWithWebOptions = {},
clientOptions: InputClientOptions = {},
): Readable {
if (options.infos) {
Expand Down

0 comments on commit ca550a0

Please sign in to comment.