Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apps/web/client/src/app/api/chat/helpers/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ export async function getModelFromType(chatType: ChatType) {
switch (chatType) {
case ChatType.CREATE:
case ChatType.FIX:
model = await initModel({
model = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.OPEN_AI_GPT_5,
});
break;
case ChatType.ASK:
case ChatType.EDIT:
default:
model = await initModel({
model = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.CLAUDE_4_SONNET,
});
Expand Down Expand Up @@ -60,7 +60,7 @@ export const repairToolCall = async ({ toolCall, tools, error }: { toolCall: Too
`Invalid parameter for tool ${toolCall.toolName} with args ${JSON.stringify(toolCall.input)}, attempting to fix`,
);

const { model } = await initModel({
const { model } = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.OPEN_AI_GPT_5_NANO,
});
Expand Down
62 changes: 48 additions & 14 deletions apps/web/client/src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { api } from '@/trpc/server';
import { trackEvent } from '@/utils/analytics/server';
import { AgentStreamer, RootAgent } from '@onlook/ai';
import { AgentStreamer, BaseAgent, RootAgent, UserAgent } from '@onlook/ai';
import { toDbMessage } from '@onlook/db';
import { ChatType, type ChatMessage } from '@onlook/models';
import { AgentType, ChatType } from '@onlook/models';
import { type NextRequest } from 'next/server';
import { v4 as uuidv4 } from 'uuid';
import { checkMessageLimit, decrementUsage, errorHandler, getSupabaseUser, incrementUsage, repairToolCall } from './helpers';
import { z } from 'zod';

export async function POST(req: NextRequest) {
try {
Expand Down Expand Up @@ -51,14 +52,38 @@ export async function POST(req: NextRequest) {
}
}

const streamResponseSchema = z.object({
agentType: z.enum(AgentType).optional().default(AgentType.ROOT),
messages: z.array(z.any()),
chatType: z.enum(ChatType).optional(),
conversationId: z.string(),
projectId: z.string(),
}).superRefine((data, ctx) => {
if (data.agentType === AgentType.ROOT) {
// chatType is required for ROOT agents
if (data.chatType === undefined) {
ctx.addIssue({
code: 'custom',
message: "chatType is required when agentType is ROOT",
path: ['chatType']
});
}
} else {
// chatType must be undefined for non-ROOT agents
if (data.chatType !== undefined) {
ctx.addIssue({
code: 'custom',
message: "chatType is forbidden when agentType is not ROOT",
path: ['chatType']
});
}
}
});

export const streamResponse = async (req: NextRequest, userId: string) => {
const body = await req.json();
const { messages, chatType, conversationId, projectId } = body as {
messages: ChatMessage[],
chatType: ChatType,
conversationId: string,
projectId: string,
};
const { agentType, messages, chatType, conversationId, projectId } = streamResponseSchema.parse(body);

// Updating the usage record and rate limit is done here to avoid
// abuse in the case where a single user sends many concurrent requests.
// If the call below fails, the user will not be penalized.
Expand All @@ -71,12 +96,20 @@ export const streamResponse = async (req: NextRequest, userId: string) => {
const lastUserMessage = messages.findLast((message) => message.role === 'user');
const traceId = lastUserMessage?.id ?? uuidv4();

if (chatType === ChatType.EDIT) {
usageRecord = await incrementUsage(req, traceId);
}

// Create RootAgent instance
const agent = await RootAgent.create(chatType);
let agent: BaseAgent;
if (agentType === AgentType.ROOT) {
if (chatType === ChatType.EDIT) {
usageRecord = await incrementUsage(req, traceId);
}

agent = new RootAgent(chatType!);
} else if (agentType === AgentType.USER) {
agent = new UserAgent();
} else {
// agent = new WeatherAgent();
throw new Error('Agent type not supported');
}
const streamer = new AgentStreamer(agent, conversationId);

return streamer.streamText(messages, {
Expand All @@ -87,7 +120,8 @@ export const streamResponse = async (req: NextRequest, userId: string) => {
conversationId,
projectId,
userId,
chatType: chatType,
agentType: agentType ?? AgentType.ROOT,
chatType: chatType ?? "null",
tags: ['chat'],
langfuseTraceId: traceId,
sessionId: conversationId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { useEditorEngine } from '@/components/store/editor';
import { handleToolCall } from '@/components/tools';
import { api } from '@/trpc/client';
import { useChat as useAiChat } from '@ai-sdk/react';
import { ChatType, type ChatMessage, type ChatSuggestion } from '@onlook/models';
import { AgentType, ChatType, type ChatMessage, type ChatSuggestion } from '@onlook/models';
import { jsonClone } from '@onlook/utility';
import { DefaultChatTransport, lastAssistantMessageIsCompleteWithToolCalls } from 'ai';
import { usePostHog } from 'posthog-js/react';
Expand Down Expand Up @@ -32,6 +32,7 @@ interface UseChatProps {
projectId: string;
initialMessages: ChatMessage[];
}
const agentType = AgentType.ROOT;

export function useChat({ conversationId, projectId, initialMessages }: UseChatProps) {
const editorEngine = useEditorEngine();
Expand All @@ -41,6 +42,7 @@ export function useChat({ conversationId, projectId, initialMessages }: UseChatP
const [finishReason, setFinishReason] = useState<string | null>(null);
const [isExecutingToolCall, setIsExecutingToolCall] = useState(false);


const { addToolResult, messages, error, stop, setMessages, regenerate, status } =
useAiChat<ChatMessage>({
id: 'user-chat',
Expand All @@ -51,11 +53,12 @@ export function useChat({ conversationId, projectId, initialMessages }: UseChatP
body: {
conversationId,
projectId,
agentType,
},
}),
onToolCall: async (toolCall) => {
setIsExecutingToolCall(true);
void handleToolCall(toolCall.toolCall, editorEngine, addToolResult).then(() => {
void handleToolCall(agentType, toolCall.toolCall, editorEngine, addToolResult).then(() => {
setIsExecutingToolCall(false);
});
},
Expand Down Expand Up @@ -89,6 +92,7 @@ export function useChat({ conversationId, projectId, initialMessages }: UseChatP
chatType: type,
conversationId,
context,
agentType,
},
});
void editorEngine.chat.conversation.generateTitle(content);
Expand Down Expand Up @@ -137,6 +141,7 @@ export function useChat({ conversationId, projectId, initialMessages }: UseChatP
body: {
chatType,
conversationId,
agentType,
},
});

Expand Down
17 changes: 11 additions & 6 deletions apps/web/client/src/components/tools/tools.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import type { EditorEngine } from '@/components/store/editor/engine';
import type { ToolCall } from '@ai-sdk/provider-utils';
import { getToolClassesFromType } from '@onlook/ai';
import type { AbstractChat } from 'ai';
import { getAvailableTools, type OnToolCallHandler } from '@onlook/ai';
import { toast } from '@onlook/ui/sonner';
import type { AgentType } from '@onlook/models';

export async function handleToolCall(toolCall: ToolCall<string, unknown>, editorEngine: EditorEngine, addToolResult: (toolResult: { tool: string, toolCallId: string, output: any }) => Promise<void>) {
export async function handleToolCall(agentType: AgentType, toolCall: ToolCall<string, unknown>, editorEngine: EditorEngine, addToolResult: typeof AbstractChat.prototype.addToolResult) {
const toolName = toolCall.toolName;
const currentChatMode = editorEngine.state.chatMode;
const availableTools = getToolClassesFromType(currentChatMode);
const availableTools = getAvailableTools(agentType, currentChatMode) as any[];
let output: any = null;

try {
const tool = availableTools.find(tool => tool.toolName === toolName);
const tool = availableTools.find((tool: any) => tool.toolName === toolName);
if (!tool) {
toast.error(`Tool "${toolName}" not available in ask mode`, {
description: `Switch to build mode to use this tool.`,
duration: 2000,
});

throw new Error(`Tool "${toolName}" is not available in ${currentChatMode} mode`);
throw new Error(`Tool "${toolName}" is not available in ${currentChatMode} mode!!!!`);
}

if (!tool) {
Expand All @@ -26,8 +28,11 @@ export async function handleToolCall(toolCall: ToolCall<string, unknown>, editor
// Parse the input to the tool parameters. Throws if invalid.
const validatedInput = tool.parameters.parse(toolCall.input);
const toolInstance = new tool();
const getOnToolCall: OnToolCallHandler = (subAgentType, addSubAgentToolResult) => (toolCall) =>
void handleToolCall(subAgentType, toolCall.toolCall, editorEngine, addSubAgentToolResult);

// Can force type with as any because we know the input is valid.
output = await toolInstance.handle(validatedInput as any, editorEngine);
output = await toolInstance.handle(validatedInput as any, editorEngine, getOnToolCall);
} catch (error) {
output = 'error handling tool call ' + error;
} finally {
Expand Down
19 changes: 12 additions & 7 deletions apps/web/client/src/server/api/routers/chat/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
conversationUpdateSchema,
fromDbConversation
} from '@onlook/db';
import { LLMProvider, OPENROUTER_MODELS } from '@onlook/models';
import { AgentType, LLMProvider, OPENROUTER_MODELS } from '@onlook/models';
import { generateText } from 'ai';
import { eq } from 'drizzle-orm';
import { v4 as uuidv4 } from 'uuid';
Expand Down Expand Up @@ -36,7 +36,10 @@ export const conversationRouter = createTRPCRouter({
upsert: protectedProcedure
.input(conversationInsertSchema)
.mutation(async ({ ctx, input }) => {
const [conversation] = await ctx.db.insert(conversations).values(input).returning();
const [conversation] = await ctx.db.insert(conversations).values({
...input,
agentType: input.agentType as AgentType,
}).returning();
if (!conversation) {
throw new Error('Conversation not created');
}
Expand All @@ -45,10 +48,12 @@ export const conversationRouter = createTRPCRouter({
update: protectedProcedure
.input(conversationUpdateSchema)
.mutation(async ({ ctx, input }) => {
const [conversation] = await ctx.db.update({
...conversations,
updatedAt: new Date(),
}).set(input)
const [conversation] = await ctx.db.update(conversations)
.set({
...input,
agentType: input.agentType as AgentType,
updatedAt: new Date(),
})
.where(eq(conversations.id, input.id)).returning();
if (!conversation) {
throw new Error('Conversation not updated');
Expand All @@ -68,7 +73,7 @@ export const conversationRouter = createTRPCRouter({
content: z.string(),
}))
.mutation(async ({ ctx, input }) => {
const { model, providerOptions, headers } = await initModel({
const { model, providerOptions, headers } = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.CLAUDE_3_5_HAIKU,
});
Expand Down
2 changes: 1 addition & 1 deletion apps/web/client/src/server/api/routers/chat/suggestion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export const suggestionsRouter = createTRPCRouter({
})),
}))
.mutation(async ({ ctx, input }) => {
const { model, headers } = await initModel({
const { model, headers } = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.OPEN_AI_GPT_5_NANO,
});
Expand Down
2 changes: 1 addition & 1 deletion apps/web/client/src/server/api/routers/project/project.ts
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ export const projectRouter = createTRPCRouter({
}))
.mutation(async ({ ctx, input }): Promise<string> => {
try {
const { model, providerOptions, headers } = await initModel({
const { model, providerOptions, headers } = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.OPEN_AI_GPT_5_NANO,
});
Expand Down
3 changes: 2 additions & 1 deletion packages/ai/src/agents/classes/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export { RootAgent } from './root';
export { RootAgent } from './root';
export { UserAgent } from './user';
35 changes: 18 additions & 17 deletions packages/ai/src/agents/classes/root.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import { ChatType, LLMProvider, OPENROUTER_MODELS, type ModelConfig } from '@onlook/models';
import { initModel } from '../../chat/providers';
import { getAskModeSystemPrompt, getCreatePageSystemPrompt, getSystemPrompt } from '../../prompt';
import { getToolSetFromType } from '../../tools/toolset';
import { BaseAgent } from '../models/base';
import { getAskModeSystemPrompt, getCreatePageSystemPrompt, getSystemPrompt, initModel, UserAgentTool } from '../../index';
import { AgentType, ChatType, LLMProvider, OPENROUTER_MODELS, type ModelConfig } from '@onlook/models';
import { readOnlyRootTools, rootTools } from '../tool-lookup';

export function getToolFromType(chatType: ChatType) {
return chatType === ChatType.ASK ? readOnlyRootTools : rootTools;
}

export class RootAgent extends BaseAgent {
readonly id = 'root-agent';
readonly agentType = AgentType.ROOT;
private readonly chatType: ChatType;
readonly modelConfig: ModelConfig;

constructor(chatType: ChatType, modelConfig: ModelConfig) {
super(getToolSetFromType(chatType));

constructor(chatType: ChatType) {
super();
this.chatType = chatType;
this.modelConfig = modelConfig;
this.modelConfig = this.getModelFromType(chatType);
}

get systemPrompt(): string {
return this.getSystemPromptFromType(this.chatType);
}

get tools() {
return getToolFromType(this.chatType);
}

private getSystemPromptFromType(chatType: ChatType): string {
switch (chatType) {
case ChatType.CREATE:
Expand All @@ -32,23 +38,18 @@ export class RootAgent extends BaseAgent {
}
}

static async create(chatType: ChatType): Promise<RootAgent> {
const modelConfig = await RootAgent.getModelFromType(chatType);
return new RootAgent(chatType, modelConfig);
}

private static async getModelFromType(chatType: ChatType): Promise<ModelConfig> {
private getModelFromType(chatType: ChatType): ModelConfig {
switch (chatType) {
case ChatType.CREATE:
case ChatType.FIX:
return await initModel({
return initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.OPEN_AI_GPT_5,
});
case ChatType.ASK:
case ChatType.EDIT:
default:
return await initModel({
return initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.CLAUDE_4_SONNET,
});
Expand Down
17 changes: 17 additions & 0 deletions packages/ai/src/agents/classes/user.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { BaseAgent } from '../models/base';
import { initModel } from '../../index';
import { AgentType, LLMProvider, OPENROUTER_MODELS, type ModelConfig } from '@onlook/models';
import { userTools } from '../tool-lookup';

export class UserAgent extends BaseAgent {
readonly agentType = AgentType.USER;
readonly modelConfig: ModelConfig = initModel({
provider: LLMProvider.OPENROUTER,
model: OPENROUTER_MODELS.CLAUDE_3_5_HAIKU,
});
readonly tools = userTools;

get systemPrompt(): string {
return ``;
}
}
1 change: 1 addition & 0 deletions packages/ai/src/agents/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './models';
export * from './classes';
export * from './tools';
export { AgentStreamer } from './streamer';
Loading