Skip to content

Commit

Permalink
Reapply streaming (#434)
Browse files Browse the repository at this point in the history
* Stream LLM response when editing apps (#430)

* Streaming WIP

* More wip

* Fix up the streaming parser and the tests

* Remove extra debug log

* Remove unused file

* Fix more parser bugs

* Hook up frontend

---------

Co-authored-by: Nicholas Charriere <[email protected]>

* Add changes to bugfix

* Saving progress. Some things work but tags are wonky

* Fix streaming bugs

* Remove debug log

---------

Co-authored-by: Ben Reinhart <[email protected]>
  • Loading branch information
nichochar and benjreinhart authored Oct 31, 2024
1 parent 8b2f55d commit 51d5265
Show file tree
Hide file tree
Showing 13 changed files with 1,024 additions and 90 deletions.
42 changes: 26 additions & 16 deletions packages/api/ai/generate.mts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { generateText, type GenerateTextResult } from 'ai';
import { streamText, generateText, type GenerateTextResult } from 'ai';
import { getModel } from './config.mjs';
import {
type CodeLanguageType,
Expand All @@ -13,7 +13,7 @@ import Path from 'node:path';
import { PROMPTS_DIR } from '../constants.mjs';
import { encode, decodeCells } from '../srcmd.mjs';
import { buildProjectXml, type FileContent } from '../ai/app-parser.mjs';
import { type AppGenerationLog, logAppGeneration } from './logger.mjs';
import { logAppGeneration } from './logger.mjs';

const makeGenerateSrcbookSystemPrompt = () => {
return readFileSync(Path.join(PROMPTS_DIR, 'srcbook-generator.txt'), 'utf-8');
Expand Down Expand Up @@ -259,30 +259,40 @@ export async function generateApp(
return result.text;
}

export async function editApp(
export async function streamEditApp(
projectId: string,
files: FileContent[],
query: string,
appId: string,
planId: string,
): Promise<string> {
) {
const model = await getModel();

const systemPrompt = makeAppEditorSystemPrompt();
const userPrompt = makeAppEditorUserPrompt(projectId, files, query);
const result = await generateText({

let response = '';

const result = await streamText({
model,
system: systemPrompt,
prompt: userPrompt,
onChunk: (chunk) => {
if (chunk.chunk.type === 'text-delta') {
response += chunk.chunk.textDelta;
}
},
onFinish: () => {
if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration({
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: response,
});
}
},
});
const log: AppGenerationLog = {
appId,
planId,
llm_request: { model, system: systemPrompt, prompt: userPrompt },
llm_response: result,
};

if (process.env.SRCBOOK_DISABLE_ANALYTICS !== 'true') {
logAppGeneration(log);
}
return result.text;

return result.textStream;
}
110 changes: 110 additions & 0 deletions packages/api/ai/plan-parser.mts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { XMLParser } from 'fast-xml-parser';
import Path from 'node:path';
import { type App as DBAppType } from '../db/schema.mjs';
import { loadFile } from '../apps/disk.mjs';
import { StreamingXMLParser, TagType } from './stream-xml-parser.mjs';
import { ActionChunkType, DescriptionChunkType } from '@srcbook/shared';

// The ai proposes a plan that we expect to contain both files and commands
// Here is an example of a plan:
Expand Down Expand Up @@ -167,3 +169,111 @@ export function getPackagesToInstall(plan: Plan): string[] {
)
.flatMap((action) => action.packages);
}

export async function streamParsePlan(
stream: AsyncIterable<string>,
app: DBAppType,
_query: string,
planId: string,
) {
let parser: StreamingXMLParser;

return new ReadableStream({
async pull(controller) {
if (parser === undefined) {
parser = new StreamingXMLParser({
async onTag(tag) {
if (tag.name === 'planDescription' || tag.name === 'action') {
const chunk = await toStreamingChunk(app, tag, planId);
if (chunk) {
controller.enqueue(JSON.stringify(chunk) + '\n');
}
}
},
});
}

try {
for await (const chunk of stream) {
parser.parse(chunk);
}
controller.close();
} catch (error) {
console.error(error);
controller.enqueue(
JSON.stringify({
type: 'error',
data: { content: 'Error while parsing streaming response' },
}) + '\n',
);
controller.error(error);
}
},
});
}

async function toStreamingChunk(
app: DBAppType,
tag: TagType,
planId: string,
): Promise<DescriptionChunkType | ActionChunkType | null> {
switch (tag.name) {
case 'planDescription':
return {
type: 'description',
planId: planId,
data: { content: tag.content },
} as DescriptionChunkType;
case 'action': {
const descriptionTag = tag.children.find((t) => t.name === 'description');
const description = descriptionTag?.content ?? '';
const type = tag.attributes.type;

if (type === 'file') {
const fileTag = tag.children.find((t) => t.name === 'file')!;

const filePath = fileTag.attributes.filename as string;
let originalContent = null;

try {
const fileContent = await loadFile(app, filePath);
originalContent = fileContent.source;
} catch (error) {
// If the file doesn't exist, it's likely that it's a new file.
}

return {
type: 'action',
planId: planId,
data: {
type: 'file',
description,
path: filePath,
dirname: Path.dirname(filePath),
basename: Path.basename(filePath),
modified: fileTag.content,
original: originalContent,
},
} as ActionChunkType;
} else if (type === 'command') {
const commandTag = tag.children.find((t) => t.name === 'commandType')!;
const packageTags = tag.children.filter((t) => t.name === 'package');

return {
type: 'action',
planId: planId,
data: {
type: 'command',
description,
command: commandTag.content,
packages: packageTags.map((t) => t.content),
},
} as ActionChunkType;
} else {
return null;
}
}
default:
return null;
}
}
207 changes: 207 additions & 0 deletions packages/api/ai/stream-xml-parser.mts
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
export type NodeSchema = {
isContentNode?: boolean;
hasCdata?: boolean;
allowedChildren?: string[];
};

export const xmlSchema: Record<string, NodeSchema> = {
plan: { isContentNode: false, hasCdata: false },
action: { isContentNode: false, hasCdata: false },
description: { isContentNode: true, hasCdata: true },
file: { isContentNode: false, hasCdata: true },
commandType: { isContentNode: true, hasCdata: false },
package: { isContentNode: true, hasCdata: false },
planDescription: { isContentNode: true, hasCdata: true },
};

export type TagType = {
name: string;
attributes: Record<string, string>;
content: string;
children: TagType[];
};

export type TagCallbackType = (tag: TagType) => void;

export class StreamingXMLParser {
private buffer = '';
private currentTag: TagType | null = null;
private tagStack: TagType[] = [];
private isInCDATA = false;
private cdataBuffer = '';
private textBuffer = '';
private onTag: TagCallbackType;

constructor({ onTag }: { onTag: TagCallbackType }) {
this.onTag = onTag;
}

private parseAttributes(attributeString: string): Record<string, string> {
const attributes: Record<string, string> = {};
const matches = attributeString.match(/(\w+)="([^"]*?)"/g);

if (matches) {
matches.forEach((match) => {
const [key, value] = match.split('=') as [string, string];
attributes[key] = value.replace(/"/g, '');
});
}

return attributes;
}

private handleOpenTag(tagContent: string) {
// First, save any accumulated text content to the current tag
if (this.currentTag && this.textBuffer.trim()) {
this.currentTag.content = this.textBuffer.trim();
}
this.textBuffer = '';

const spaceIndex = tagContent.indexOf(' ');
const tagName = spaceIndex === -1 ? tagContent : tagContent.substring(0, spaceIndex);
const attributeString = spaceIndex === -1 ? '' : tagContent.substring(spaceIndex + 1);

const newTag: TagType = {
name: tagName,
attributes: this.parseAttributes(attributeString),
content: '',
children: [],
};

if (this.currentTag) {
// Push current tag to stack before moving to new tag
this.tagStack.push(this.currentTag);
this.currentTag.children.push(newTag);
}

this.currentTag = newTag;
}

private handleCloseTag(tagName: string) {
if (!this.currentTag) {
console.warn('Attempted to handle close tag with no current tag');
return;
}

// Save any remaining text content before closing
// Don't overwrite CDATA content, it's already been written
const schema = xmlSchema[this.currentTag.name];
const isCdataNode = schema ? schema.hasCdata : false;
if (!isCdataNode) {
this.currentTag.content = this.textBuffer.trim();
}
this.textBuffer = '';

if (this.currentTag.name !== tagName) {
return;
}

// Clean and emit the completed tag
this.currentTag = this.cleanNode(this.currentTag);
this.onTag(this.currentTag);

// Pop the parent tag from the stack
if (this.tagStack.length > 0) {
this.currentTag = this.tagStack.pop()!;
} else {
this.currentTag = null;
}
}

private cleanNode(node: TagType): TagType {
const schema = xmlSchema[node.name];

// If it's not in the schema, default to treating it as a content node
const isContentNode = schema ? schema.isContentNode : true;

// If it's not a content node and has children, remove its content
if (!isContentNode && node.children.length > 0) {
node.content = '';
}

// Recursively clean children
node.children = node.children.map((child) => this.cleanNode(child));

return node;
}

parse(chunk: string) {
this.buffer += chunk;

while (this.buffer.length > 0) {
// Handle CDATA sections
if (this.isInCDATA) {
const cdataEndIndex = this.cdataBuffer.indexOf(']]>');
if (cdataEndIndex === -1) {
this.cdataBuffer += this.buffer;
// Sometimes ]]> is in the next chunk, and we don't want to lose what's behind it
const nextCdataEnd = this.cdataBuffer.indexOf(']]>');
if (nextCdataEnd !== -1) {
this.buffer = this.cdataBuffer.substring(nextCdataEnd);
} else {
this.buffer = '';
}
return;
}

this.cdataBuffer = this.cdataBuffer.substring(0, cdataEndIndex);
if (this.currentTag) {
this.currentTag.content = this.cdataBuffer.trim();
}
this.isInCDATA = false;
this.buffer = this.cdataBuffer.substring(cdataEndIndex + 3) + this.buffer;
this.cdataBuffer = '';
continue;
}

// Look for the next tag
const openTagStartIdx = this.buffer.indexOf('<');
if (openTagStartIdx === -1) {
// No more tags in this chunk, save the rest as potential content
this.textBuffer += this.buffer;
this.buffer = '';
return;
}

// Save any text content before this tag
if (openTagStartIdx > 0) {
this.textBuffer += this.buffer.substring(0, openTagStartIdx);
this.buffer = this.buffer.substring(openTagStartIdx);
}

// Check for CDATA
if (this.sequenceExistsAt('<![CDATA[', 0)) {
this.isInCDATA = true;
const cdataStart = this.buffer.substring(9);
this.cdataBuffer = cdataStart;
this.buffer = '';
return;
}

const openTagEndIdx = this.buffer.indexOf('>');
if (openTagEndIdx === -1) {
return;
}

const tagContent = this.buffer.substring(1, openTagEndIdx);
this.buffer = this.buffer.substring(openTagEndIdx + 1);

if (tagContent.startsWith('/')) {
// Closing tag
this.handleCloseTag(tagContent.substring(1));
} else {
// Opening tag
this.handleOpenTag(tagContent);
}
}
}

private sequenceExistsAt(sequence: string, idx: number, buffer: string = this.buffer) {
for (let i = 0; i < sequence.length; i++) {
if (buffer[idx + i] !== sequence[i]) {
return false;
}
}
return true;
}
}
Loading

0 comments on commit 51d5265

Please sign in to comment.