Skip to content

Commit

Permalink
do generation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasavila00 committed Mar 17, 2024
1 parent be925c0 commit b7a4b92
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 53 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ jobs:
- run: npm config set "@jsr:registry" https://npm.jsr.io
- run: npm config set "//registry.tiptap.dev/:_authToken" ${{ secrets.TIPTAP_TOKEN }}
- run: cd gui && npm ci && npm run static
- run: cd gui && npm ci && npm run test -- --run
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
TODO.md
TODO.md
node_modules/
25 changes: 25 additions & 0 deletions gui/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions gui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"react-dom": "^18.2.0",
"react-hook-form": "^7.51.0",
"react-select": "^5.8.0",
"recoil": "^0.7.7",
"tailwind-merge": "^2.2.1",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4"
Expand Down
38 changes: 31 additions & 7 deletions gui/src/editor/components/LeftSidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,32 @@ import {
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { useForm } from "react-hook-form";
import { ALL_CHAT_TEMPLATES } from "@lmscript/client/chat-template";
const RunpodSglangConfigSchema = z.object({
url: z.string().min(4),
token: z.string(),
template: z.enum(ALL_CHAT_TEMPLATES),
});

const RunpodSglangConfig: FC<{
setBackend: (tag: Backend) => void;
}> = ({ setBackend }) => {
const UrlTokenTemplateConfig: FC<{
setBackend: (backend: Backend) => void;
tag: "runpod-serverless-sglang" | "sglang";
}> = ({ setBackend, tag }) => {
const form = useForm<z.infer<typeof RunpodSglangConfigSchema>>({
resolver: zodResolver(RunpodSglangConfigSchema),
defaultValues: {
url: "http://localhost:8000",
token: "",
template: "llama-2-chat",
},
});

function onSubmit(values: z.infer<typeof RunpodSglangConfigSchema>) {
setBackend({
tag: "runpod-serverless-sglang",
tag: tag,
url: values.url,
token: values.token,
template: values.template,
});
}

Expand Down Expand Up @@ -80,6 +85,20 @@ const RunpodSglangConfig: FC<{
</FormItem>
)}
/>
<FormField
control={form.control}
name="template"
render={({ field }) => (
<FormItem>
<FormLabel>Template</FormLabel>
<FormControl>
<Input placeholder="" {...field} />
</FormControl>
<FormDescription>TODO template desc.</FormDescription>
<FormMessage />
</FormItem>
)}
/>
<Button variant="outline" type="submit" className="w-full">
Save
</Button>
Expand All @@ -94,13 +113,18 @@ const BackendConfig: FC<{
if (backendTag == null) return <></>;
switch (backendTag) {
case "runpod-serverless-sglang": {
return <RunpodSglangConfig setBackend={setBackend} />;
return (
<UrlTokenTemplateConfig
setBackend={setBackend}
tag="runpod-serverless-sglang"
/>
);
}
case "runpod-serverless-vllm": {
return "todo";
return <>TODO vllm</>;
}
case "sglang": {
return "todo";
return <UrlTokenTemplateConfig setBackend={setBackend} tag="sglang" />;
}
default: {
return assertIsNever(backendTag);
Expand Down
164 changes: 128 additions & 36 deletions gui/src/editor/components/Play/Play.tsx
Original file line number Diff line number Diff line change
@@ -1,51 +1,143 @@
import { Backend } from "@/editor/hooks/useRunner";
import { getMessagesOfAuthor } from "@/editor/lib/playMessages";
import { EditorState } from "@/editor/lib/types";
import {
MessageOfAuthor,
getMessagesOfAuthor,
} from "@/editor/lib/playMessages";
import { EditorState, NamedVariable, SamplingParams } from "@/editor/lib/types";
import { assertIsNever } from "@/lib/utils";
import { FC } from "react";
import { atomFamily, useRecoilValueLoadable } from "recoil";
import { SGLangBackend } from "@lmscript/client/backends/sglang";
import { AbstractBackend } from "@lmscript/client/backends/abstract";
import { messagesToTasks } from "@/editor/lib/messageToTasks";
import { VllmBackend } from "@lmscript/client/backends/vllm";
import { RunpodServerlessBackend } from "@lmscript/client/backends/runpod-serverless-sglang";
const getBackendInstance = (backend: Backend): AbstractBackend => {
switch (backend.tag) {
case "runpod-serverless-sglang": {
return new RunpodServerlessBackend(backend.url, backend.token);
}
case "runpod-serverless-vllm": {
return new VllmBackend({
url: backend.url,
auth: backend.token,
model: backend.model,
});
}
case "sglang": {
return new SGLangBackend(backend.url);
}
default: {
return assertIsNever(backend);
}
}
};

const PlayStream: FC<{
editorState: EditorState;
captures: Record<string, string>;
}> = ({ editorState }) => {
const msgs = getMessagesOfAuthor(editorState);
if (msgs.tag === "success") {
return (
<>
{msgs.value.map((msg, i) => {
return (
<div key={i}>
<pre>{JSON.stringify(msg, null, 2)}</pre>
</div>
);
})}
</>
);
type Captures = Record<string, string>;
const generateAsyncAtom = atomFamily<
{
captures: Captures;
finalText: string | undefined;
},
{
backend: Backend;
messages: MessageOfAuthor[];
samplingParams: SamplingParams;
variables: NamedVariable[];
}
>({
key: "generateAsyncAtom",
default: (_param) => {
return {
captures: {},
finalText: undefined,
};
},
effects: (param) => [
(opts) => {
const instance = getBackendInstance(param.backend);
const tasks = messagesToTasks(
param.messages,
param.backend.template,
param.variables,
);

instance
.executeJSON(
{
tasks,
sampling_params: param.samplingParams,
initial_state: {
text: "",
captured: {},
},
},
{
onCapture: (cap) => {
opts.setSelf((prev) => {
if ("captures" in prev) {
return {
captures: {
...prev.captures,
[cap.name]: cap.value,
},
finalText: undefined,
};
}
return prev;
});
},
},
)
.then((out) => {
opts.setSelf((prev) => {
if ("captures" in prev) {
return {
captures: out.captured,
finalText: out.text,
};
}
return prev;
});
});
},
],
});

return (
<>
ERROR!!!!!!
<pre>{JSON.stringify(msgs, null, 2)}</pre>
</>
const PlayStream: FC<{
backend: Backend;
messages: MessageOfAuthor[];
samplingParams: SamplingParams;
variables: NamedVariable[];
}> = ({ variables, backend, messages, samplingParams }) => {
const generationAtom = useRecoilValueLoadable(
generateAsyncAtom({
samplingParams,
backend,
messages,
variables,
}),
);
};

const useCaptures = (_backend: Backend, _editorState: EditorState) => {
const captures: Record<string, string> = {};
console.log(generationAtom);

return {
captures,
};
return <></>;
};

export const Play: FC<{
backend: Backend;
editorState: EditorState;
}> = ({ editorState, backend }) => {
const { captures } = useCaptures(backend, editorState);
return (
<>
<PlayStream editorState={editorState} captures={captures} />
</>
);
const msgs = getMessagesOfAuthor(editorState);
if (msgs.tag === "success") {
return (
<PlayStream
backend={backend}
messages={msgs.value}
samplingParams={editorState.samplingParams}
variables={editorState.variables}
/>
);
}
return <>TODO: error {JSON.stringify(msgs)}</>;
};
9 changes: 9 additions & 0 deletions gui/src/editor/hooks/useRunner.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ChatTemplate } from "@lmscript/client/chat-template";
import { useState } from "react";

export const ALL_BACKENDS_TAGS: BackendTag[] = [
Expand All @@ -14,16 +15,24 @@ export const BackendLabels: Record<BackendTag, string> = {

export type SGLangBackend = {
tag: "sglang";
url: string;
token: string;
template: ChatTemplate;
};

export type RunpodServerlessSGLangBackend = {
tag: "runpod-serverless-sglang";
url: string;
token: string;
template: ChatTemplate;
};

export type RunpodServerlessVLLMBackend = {
tag: "runpod-serverless-vllm";
url: string;
token: string;
model: string;
template: ChatTemplate;
};

export type Backend =
Expand Down
Loading

0 comments on commit b7a4b92

Please sign in to comment.