diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index e2d212817b1..9d7ddb4faaa 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -4,25 +4,17 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "./auth"; import { BedrockRuntimeClient, - InvokeModelCommand, + ConverseStreamOutput, ValidationException, + ModelStreamErrorException, + ThrottlingException, + ServiceUnavailableException, + InternalServerException, } from "@aws-sdk/client-bedrock-runtime"; import { validateModelId } from "./bedrock/utils"; -import { - ConverseRequest, - formatRequestBody, - parseModelResponse, -} from "./bedrock/models"; - -interface ContentItem { - type: string; - text?: string; - image_url?: { - url: string; - }; -} +import { ConverseRequest, createConverseStreamCommand } from "./bedrock/models"; -const ALLOWED_PATH = new Set(["invoke", "converse"]); +const ALLOWED_PATH = new Set(["converse"]); export async function handle( req: NextRequest, @@ -57,29 +49,10 @@ export async function handle( } try { - if (subpath === "converse") { - const response = await handleConverseRequest(req); - return response; - } else { - const response = await handleInvokeRequest(req); - return response; - } + const response = await handleConverseRequest(req); + return response; } catch (e) { console.error("[Bedrock] ", e); - - // Handle specific error cases - if (e instanceof ValidationException) { - return NextResponse.json( - { - error: true, - message: - "Model validation error. If using a Llama model, please provide a valid inference profile ARN.", - details: e.message, - }, - { status: 400 }, - ); - } - return NextResponse.json( { error: true, @@ -92,9 +65,7 @@ export async function handle( } async function handleConverseRequest(req: NextRequest) { - const controller = new AbortController(); - - const region = req.headers.get("X-Region") || "us-east-1"; + const region = req.headers.get("X-Region") || "us-west-2"; const accessKeyId = req.headers.get("X-Access-Key") || ""; const secretAccessKey = req.headers.get("X-Secret-Key") || ""; const sessionToken = req.headers.get("X-Session-Token"); @@ -111,8 +82,6 @@ async function handleConverseRequest(req: NextRequest) { ); } - console.log("[Bedrock] Using region:", region); - const client = new BedrockRuntimeClient({ region, credentials: { @@ -122,167 +91,171 @@ async function handleConverseRequest(req: NextRequest) { }, }); - const timeoutId = setTimeout( - () => { - controller.abort(); - }, - 10 * 60 * 1000, - ); - try { const body = (await req.json()) as ConverseRequest; const { modelId } = body; - // Validate model ID const validationError = validateModelId(modelId); if (validationError) { - throw new ValidationException({ - message: validationError, - $metadata: {}, - }); + throw new Error(validationError); } console.log("[Bedrock] Invoking model:", modelId); - console.log("[Bedrock] Messages:", body.messages); - - const requestBody = formatRequestBody(body); - const jsonString = JSON.stringify(requestBody); - const input = { - modelId, - contentType: "application/json", - accept: "application/json", - body: Uint8Array.from(Buffer.from(jsonString)), - }; - - console.log("[Bedrock] Request input:", { - ...input, - body: requestBody, - }); - - const command = new InvokeModelCommand(input); + const command = createConverseStreamCommand(body); const response = await client.send(command); - console.log("[Bedrock] Got response"); - - // Parse and format the response based on model type - const responseBody = new TextDecoder().decode(response.body); - const formattedResponse = parseModelResponse(responseBody, modelId); - - return NextResponse.json(formattedResponse); - } catch (e) { - console.error("[Bedrock] Request error:", e); - throw e; // Let the main error handler deal with it - } finally { - clearTimeout(timeoutId); - } -} - -async function handleInvokeRequest(req: NextRequest) { - const controller = new AbortController(); - - const region = req.headers.get("X-Region") || "us-east-1"; - const accessKeyId = req.headers.get("X-Access-Key") || ""; - const secretAccessKey = req.headers.get("X-Secret-Key") || ""; - const sessionToken = req.headers.get("X-Session-Token"); - - if (!accessKeyId || !secretAccessKey) { - return NextResponse.json( - { - error: true, - message: "Missing AWS credentials", - }, - { - status: 401, - }, - ); - } - - const client = new BedrockRuntimeClient({ - region, - credentials: { - accessKeyId, - secretAccessKey, - sessionToken: sessionToken || undefined, - }, - }); - - const timeoutId = setTimeout( - () => { - controller.abort(); - }, - 10 * 60 * 1000, - ); - - try { - const body = await req.json(); - const { messages, model } = body; - - // Validate model ID - const validationError = validateModelId(model); - if (validationError) { - throw new ValidationException({ - message: validationError, - $metadata: {}, - }); + if (!response.stream) { + throw new Error("No stream in response"); } - console.log("[Bedrock] Invoking model:", model); - console.log("[Bedrock] Messages:", messages); - - const requestBody = formatRequestBody({ - modelId: model, - messages, - inferenceConfig: { - maxTokens: 2048, - temperature: 0.7, - topP: 0.9, + // Create a ReadableStream for the response + const stream = new ReadableStream({ + async start(controller) { + try { + const responseStream = response.stream; + if (!responseStream) { + throw new Error("No stream in response"); + } + + for await (const event of responseStream) { + const output = event as ConverseStreamOutput; + + if ("messageStart" in output && output.messageStart?.role) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "messageStart", + role: output.messageStart.role, + })}\n\n`, + ); + } else if ( + "contentBlockStart" in output && + output.contentBlockStart + ) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "contentBlockStart", + index: output.contentBlockStart.contentBlockIndex, + start: output.contentBlockStart.start, + })}\n\n`, + ); + } else if ( + "contentBlockDelta" in output && + output.contentBlockDelta?.delta + ) { + if ("text" in output.contentBlockDelta.delta) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "text", + content: output.contentBlockDelta.delta.text, + })}\n\n`, + ); + } else if ("toolUse" in output.contentBlockDelta.delta) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "toolUse", + input: output.contentBlockDelta.delta.toolUse?.input, + })}\n\n`, + ); + } + } else if ( + "contentBlockStop" in output && + output.contentBlockStop + ) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "contentBlockStop", + index: output.contentBlockStop.contentBlockIndex, + })}\n\n`, + ); + } else if ("messageStop" in output && output.messageStop) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "messageStop", + stopReason: output.messageStop.stopReason, + additionalModelResponseFields: + output.messageStop.additionalModelResponseFields, + })}\n\n`, + ); + } else if ("metadata" in output && output.metadata) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "metadata", + usage: output.metadata.usage, + metrics: output.metadata.metrics, + trace: output.metadata.trace, + })}\n\n`, + ); + } + } + controller.close(); + } catch (error) { + if (error instanceof ValidationException) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "ValidationException", + message: error.message, + })}\n\n`, + ); + } else if (error instanceof ModelStreamErrorException) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "ModelStreamErrorException", + message: error.message, + originalStatusCode: error.originalStatusCode, + originalMessage: error.originalMessage, + })}\n\n`, + ); + } else if (error instanceof ThrottlingException) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "ThrottlingException", + message: error.message, + })}\n\n`, + ); + } else if (error instanceof ServiceUnavailableException) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "ServiceUnavailableException", + message: error.message, + })}\n\n`, + ); + } else if (error instanceof InternalServerException) { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "InternalServerException", + message: error.message, + })}\n\n`, + ); + } else { + controller.enqueue( + `data: ${JSON.stringify({ + type: "error", + error: "UnknownError", + message: + error instanceof Error ? error.message : "Unknown error", + })}\n\n`, + ); + } + controller.close(); + } }, }); - const jsonString = JSON.stringify(requestBody); - const input = { - modelId: model, - contentType: "application/json", - accept: "application/json", - body: Uint8Array.from(Buffer.from(jsonString)), - }; - - console.log("[Bedrock] Request input:", { - ...input, - body: requestBody, - }); - - const command = new InvokeModelCommand(input); - const response = await client.send(command); - - console.log("[Bedrock] Got response"); - - // Parse and format the response - const responseBody = new TextDecoder().decode(response.body); - const formattedResponse = parseModelResponse(responseBody, model); - - // Extract text content from the response - let textContent = ""; - if (formattedResponse.content && Array.isArray(formattedResponse.content)) { - textContent = formattedResponse.content - .filter((item: ContentItem) => item.type === "text") - .map((item: ContentItem) => item.text || "") - .join(""); - } else if (typeof formattedResponse.content === "string") { - textContent = formattedResponse.content; - } - - // Return plain text response - return new NextResponse(textContent, { + return new Response(stream, { headers: { - "Content-Type": "text/plain", + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", }, }); - } catch (e) { - console.error("[Bedrock] Request error:", e); - throw e; - } finally { - clearTimeout(timeoutId); + } catch (error) { + console.error("[Bedrock] Request error:", error); + throw error; } } diff --git a/app/api/bedrock/models.ts b/app/api/bedrock/models.ts index b9a0fee5099..f6bb297d268 100644 --- a/app/api/bedrock/models.ts +++ b/app/api/bedrock/models.ts @@ -1,280 +1,405 @@ import { - Message, - validateMessageOrder, - processDocumentContent, - BedrockTextBlock, - BedrockImageBlock, - BedrockDocumentBlock, -} from "./utils"; + ConverseStreamCommand, + type ConverseStreamCommandInput, + type Message, + type ContentBlock, + type SystemContentBlock, + type Tool, + type ToolChoice, + type ToolResultContentBlock, +} from "@aws-sdk/client-bedrock-runtime"; export interface ConverseRequest { modelId: string; - messages: Message[]; + messages: { + role: "user" | "assistant" | "system"; + content: string | ContentItem[]; + }[]; inferenceConfig?: { maxTokens?: number; temperature?: number; topP?: number; + stopSequences?: string[]; + }; + toolConfig?: { + tools: Tool[]; + toolChoice?: ToolChoice; }; - system?: string; - tools?: Array<{ - type: "function"; - function: { - name: string; - description: string; - parameters: { - type: string; - properties: Record; - required: string[]; - }; - }; - }>; } interface ContentItem { - type: string; + type: "text" | "image_url" | "document" | "tool_use" | "tool_result"; text?: string; image_url?: { - url: string; + url: string; // base64 data URL }; document?: { - format: string; + format: + | "pdf" + | "csv" + | "doc" + | "docx" + | "xls" + | "xlsx" + | "html" + | "txt" + | "md"; name: string; source: { - bytes: string; + bytes: string; // base64 }; }; + tool_use?: { + tool_use_id: string; + name: string; + input: any; + }; + tool_result?: { + tool_use_id: string; + content: ToolResultItem[]; + status: "success" | "error"; + }; } -type ProcessedContent = - | ContentItem - | BedrockTextBlock - | BedrockImageBlock - | BedrockDocumentBlock - | { - type: string; - source: { type: string; media_type: string; data: string }; +interface ToolResultItem { + type: "text" | "image" | "document" | "json"; + text?: string; + image?: { + format: "png" | "jpeg" | "gif" | "webp"; + source: { + bytes: string; // base64 }; + }; + document?: { + format: + | "pdf" + | "csv" + | "doc" + | "docx" + | "xls" + | "xlsx" + | "html" + | "txt" + | "md"; + name: string; + source: { + bytes: string; // base64 + }; + }; + json?: any; +} -// Helper function to format request body based on model type -export function formatRequestBody(request: ConverseRequest) { - const baseModel = request.modelId; - const messages = validateMessageOrder(request.messages).map((msg) => ({ - role: msg.role, - content: Array.isArray(msg.content) - ? msg.content.map((item: ContentItem) => { - if (item.type === "image_url" && item.image_url?.url) { - // If it's a base64 image URL - const base64Match = item.image_url.url.match( - /^data:image\/([a-zA-Z]*);base64,([^"]*)$/, - ); - if (base64Match) { - return { - type: "image", - source: { - type: "base64", - media_type: `image/${base64Match[1]}`, - data: base64Match[2], - }, - }; - } - // If it's not a base64 URL, return as is - return item; - } - if ("document" in item) { - try { - return processDocumentContent(item); - } catch (error) { - console.error("Error processing document:", error); - return { - type: "text", - text: `[Document: ${item.document?.name || "Unknown"}]`, - }; - } - } - return { type: "text", text: item.text }; - }) - : [{ type: "text", text: msg.content }], - })); - - const systemPrompt = request.system - ? [{ type: "text", text: request.system }] - : undefined; +function convertContentToAWSBlock(item: ContentItem): ContentBlock | null { + if (item.type === "text" && item.text) { + return { text: item.text }; + } - const baseConfig = { - max_tokens: request.inferenceConfig?.maxTokens || 2048, - temperature: request.inferenceConfig?.temperature || 0.7, - top_p: request.inferenceConfig?.topP || 0.9, - }; + if (item.type === "image_url" && item.image_url?.url) { + const base64Match = item.image_url.url.match( + /^data:image\/([a-zA-Z]*);base64,([^"]*)/, + ); + if (base64Match) { + const format = base64Match[1].toLowerCase(); + if ( + format === "png" || + format === "jpeg" || + format === "gif" || + format === "webp" + ) { + const base64Data = base64Match[2]; + return { + image: { + format: format as "png" | "jpeg" | "gif" | "webp", + source: { + bytes: Uint8Array.from(Buffer.from(base64Data, "base64")), + }, + }, + }; + } + } + } - if (baseModel.startsWith("anthropic.claude")) { + if (item.type === "document" && item.document) { return { - messages, - system: systemPrompt, - anthropic_version: "bedrock-2023-05-31", - ...baseConfig, - ...(request.tools && { tools: request.tools }), + document: { + format: item.document.format, + name: item.document.name, + source: { + bytes: Uint8Array.from( + Buffer.from(item.document.source.bytes, "base64"), + ), + }, + }, }; - } else if ( - baseModel.startsWith("meta.llama") || - baseModel.startsWith("mistral.") - ) { + } + + if (item.type === "tool_use" && item.tool_use) { return { - messages: messages.map((m) => ({ - role: m.role, - content: Array.isArray(m.content) - ? m.content.map((c: ProcessedContent) => { - if ("text" in c) return { type: "text", text: c.text || "" }; - if ("image_url" in c) - return { - type: "text", - text: `[Image: ${c.image_url?.url || "URL not provided"}]`, - }; - if ("document" in c) - return { - type: "text", - text: `[Document: ${c.document?.name || "Unknown"}]`, - }; - return { type: "text", text: "" }; - }) - : [{ type: "text", text: m.content }], - })), - ...baseConfig, - stop_sequences: ["\n\nHuman:", "\n\nAssistant:"], + toolUse: { + toolUseId: item.tool_use.tool_use_id, + name: item.tool_use.name, + input: item.tool_use.input, + }, }; - } else if (baseModel.startsWith("amazon.titan")) { - const formattedText = messages.map((m) => ({ - role: m.role, - content: [ - { - type: "text", - text: `${m.role === "user" ? "Human" : "Assistant"}: ${ - Array.isArray(m.content) - ? m.content - .map((c: ProcessedContent) => { - if ("text" in c) return c.text || ""; - if ("image_url" in c) - return `[Image: ${ - c.image_url?.url || "URL not provided" - }]`; - if ("document" in c) - return `[Document: ${c.document?.name || "Unknown"}]`; - return ""; - }) - .join("") - : m.content - }`, - }, - ], - })); + } + + if (item.type === "tool_result" && item.tool_result) { + const toolResultContent = item.tool_result.content + .map((resultItem) => { + if (resultItem.type === "text" && resultItem.text) { + return { text: resultItem.text } as ToolResultContentBlock; + } + if (resultItem.type === "image" && resultItem.image) { + return { + image: { + format: resultItem.image.format, + source: { + bytes: Uint8Array.from( + Buffer.from(resultItem.image.source.bytes, "base64"), + ), + }, + }, + } as ToolResultContentBlock; + } + if (resultItem.type === "document" && resultItem.document) { + return { + document: { + format: resultItem.document.format, + name: resultItem.document.name, + source: { + bytes: Uint8Array.from( + Buffer.from(resultItem.document.source.bytes, "base64"), + ), + }, + }, + } as ToolResultContentBlock; + } + if (resultItem.type === "json" && resultItem.json) { + return { json: resultItem.json } as ToolResultContentBlock; + } + return null; + }) + .filter((content): content is ToolResultContentBlock => content !== null); + + if (toolResultContent.length === 0) { + return null; + } return { - messages: formattedText, - textGenerationConfig: { - maxTokenCount: baseConfig.max_tokens, - temperature: baseConfig.temperature, - topP: baseConfig.top_p, - stopSequences: ["Human:", "Assistant:"], + toolResult: { + toolUseId: item.tool_result.tool_use_id, + content: toolResultContent, + status: item.tool_result.status, }, }; } - throw new Error(`Unsupported model: ${baseModel}`); + return null; } -// Helper function to parse and format response based on model type -export function parseModelResponse(responseBody: string, modelId: string): any { - const baseModel = modelId; +function convertContentToAWS(content: string | ContentItem[]): ContentBlock[] { + if (typeof content === "string") { + return [{ text: content }]; + } - try { - const response = JSON.parse(responseBody); + // Filter out null blocks and ensure each content block is valid + const blocks = content + .map(convertContentToAWSBlock) + .filter((block): block is ContentBlock => block !== null); - // Common response format for all models - const formatResponse = (content: string | any[]) => ({ - role: "assistant", - content: Array.isArray(content) - ? content.map((item) => { - if (typeof item === "string") { - return { type: "text", text: item }; - } - // Handle different content types - if ("text" in item) { - return { type: "text", text: item.text || "" }; - } - if ("image" in item) { - return { - type: "image_url", - image_url: { - url: `data:image/${ - item.source?.media_type || "image/png" - };base64,${item.source?.data || ""}`, - }, - }; - } - // Document responses are converted to text - if ("document" in item) { - return { - type: "text", - text: `[Document Content]\n${item.text || ""}`, - }; - } - return { type: "text", text: item.text || "" }; - }) - : [{ type: "text", text: content }], - stop_reason: response.stop_reason || response.stopReason || "end_turn", - usage: response.usage || { - input_tokens: 0, - output_tokens: 0, - total_tokens: 0, - }, - }); + // If no valid blocks, provide a default text block + if (blocks.length === 0) { + return [{ text: "" }]; + } - if (baseModel.startsWith("anthropic.claude")) { - // Handle the new Converse API response format - if (response.output?.message) { - return { - role: response.output.message.role, - content: response.output.message.content.map((item: any) => { - if ("text" in item) return { type: "text", text: item.text || "" }; - if ("image" in item) { - return { - type: "image_url", - image_url: { - url: `data:${item.source?.media_type || "image/png"};base64,${ - item.source?.data || "" - }`, - }, - }; - } - return { type: "text", text: item.text || "" }; - }), - stop_reason: response.stopReason, - usage: response.usage, - }; + return blocks; +} + +function formatMessages(messages: ConverseRequest["messages"]): { + messages: Message[]; + systemPrompt?: SystemContentBlock[]; +} { + // Extract system messages + const systemMessages = messages.filter((msg) => msg.role === "system"); + const nonSystemMessages = messages.filter((msg) => msg.role !== "system"); + + // Convert system messages to SystemContentBlock array + const systemPrompt = + systemMessages.length > 0 + ? systemMessages.map((msg) => { + if (typeof msg.content === "string") { + return { text: msg.content } as SystemContentBlock; + } + // For multimodal content, convert each content item + const blocks = convertContentToAWS(msg.content); + return blocks[0] as SystemContentBlock; // Take first block as system content + }) + : undefined; + + // Format remaining messages + const formattedMessages = nonSystemMessages.reduce( + (acc: Message[], curr, idx) => { + // Skip if same role as previous message + if (idx > 0 && curr.role === nonSystemMessages[idx - 1].role) { + return acc; } - // Fallback for older format - return formatResponse( - response.content || - (response.completion - ? [{ type: "text", text: response.completion }] - : []), - ); - } else if (baseModel.startsWith("meta.llama")) { - return formatResponse(response.generation || response.completion || ""); - } else if (baseModel.startsWith("amazon.titan")) { - return formatResponse(response.results?.[0]?.outputText || ""); - } else if (baseModel.startsWith("mistral.")) { - return formatResponse( - response.outputs?.[0]?.text || response.response || "", - ); - } - throw new Error(`Unsupported model: ${baseModel}`); - } catch (e) { - console.error("[Bedrock] Failed to parse response:", e); - // Return raw text as fallback + const content = convertContentToAWS(curr.content); + if (content.length > 0) { + acc.push({ + role: curr.role as "user" | "assistant", + content, + }); + } + return acc; + }, + [], + ); + + // Ensure conversation starts with user + if (formattedMessages.length === 0 || formattedMessages[0].role !== "user") { + formattedMessages.unshift({ + role: "user", + content: [{ text: "Hello" }], + }); + } + + // Ensure conversation ends with user + if (formattedMessages[formattedMessages.length - 1].role !== "user") { + formattedMessages.push({ + role: "user", + content: [{ text: "Continue" }], + }); + } + + return { messages: formattedMessages, systemPrompt }; +} + +export function formatRequestBody( + request: ConverseRequest, +): ConverseStreamCommandInput { + const { messages, systemPrompt } = formatMessages(request.messages); + const input: ConverseStreamCommandInput = { + modelId: request.modelId, + messages, + ...(systemPrompt && { system: systemPrompt }), + }; + + if (request.inferenceConfig) { + input.inferenceConfig = { + maxTokens: request.inferenceConfig.maxTokens, + temperature: request.inferenceConfig.temperature, + topP: request.inferenceConfig.topP, + stopSequences: request.inferenceConfig.stopSequences, + }; + } + + if (request.toolConfig) { + input.toolConfig = { + tools: request.toolConfig.tools, + toolChoice: request.toolConfig.toolChoice, + }; + } + + // Create a clean version of the input for logging + const logInput = { + ...input, + messages: messages.map((msg) => ({ + role: msg.role, + content: msg.content?.map((content) => { + if ("image" in content && content.image) { + return { + image: { + format: content.image.format, + source: { bytes: "[BINARY]" }, + }, + }; + } + if ("document" in content && content.document) { + return { + document: { ...content.document, source: { bytes: "[BINARY]" } }, + }; + } + return content; + }), + })), + }; + + console.log( + "[Bedrock] Formatted request:", + JSON.stringify(logInput, null, 2), + ); + return input; +} + +export function createConverseStreamCommand(request: ConverseRequest) { + const input = formatRequestBody(request); + return new ConverseStreamCommand(input); +} + +export interface StreamResponse { + type: + | "messageStart" + | "contentBlockStart" + | "contentBlockDelta" + | "contentBlockStop" + | "messageStop" + | "metadata" + | "error"; + role?: string; + index?: number; + start?: any; + delta?: any; + stopReason?: string; + additionalModelResponseFields?: any; + usage?: any; + metrics?: any; + trace?: any; + error?: string; + message?: string; + originalStatusCode?: number; + originalMessage?: string; +} + +export function parseStreamResponse(chunk: any): StreamResponse | null { + if (chunk.messageStart) { + return { type: "messageStart", role: chunk.messageStart.role }; + } + if (chunk.contentBlockStart) { + return { + type: "contentBlockStart", + index: chunk.contentBlockStart.contentBlockIndex, + start: chunk.contentBlockStart.start, + }; + } + if (chunk.contentBlockDelta) { + return { + type: "contentBlockDelta", + index: chunk.contentBlockDelta.contentBlockIndex, + delta: chunk.contentBlockDelta.delta, + }; + } + if (chunk.contentBlockStop) { + return { + type: "contentBlockStop", + index: chunk.contentBlockStop.contentBlockIndex, + }; + } + if (chunk.messageStop) { + return { + type: "messageStop", + stopReason: chunk.messageStop.stopReason, + additionalModelResponseFields: + chunk.messageStop.additionalModelResponseFields, + }; + } + if (chunk.metadata) { return { - role: "assistant", - content: [{ type: "text", text: responseBody }], + type: "metadata", + usage: chunk.metadata.usage, + metrics: chunk.metadata.metrics, + trace: chunk.metadata.trace, }; } + return null; } diff --git a/app/api/bedrock/utils.ts b/app/api/bedrock/utils.ts index 85cd517b439..c58808a01cd 100644 --- a/app/api/bedrock/utils.ts +++ b/app/api/bedrock/utils.ts @@ -11,43 +11,149 @@ export interface ImageSource { export interface DocumentSource { bytes: string; // base64 encoded document bytes + media_type?: string; // MIME type of the document } +export type DocumentFormat = + | "pdf" + | "csv" + | "doc" + | "docx" + | "xls" + | "xlsx" + | "html" + | "txt" + | "md"; +export type ImageFormat = "png" | "jpeg" | "gif" | "webp"; + export interface BedrockImageBlock { + type: "image"; image: { - format: "png" | "jpeg" | "gif" | "webp"; - source: ImageSource; + format: ImageFormat; + source: { + bytes: string; + }; }; } export interface BedrockDocumentBlock { + type: "document"; document: { - format: - | "pdf" - | "csv" - | "doc" - | "docx" - | "xls" - | "xlsx" - | "html" - | "txt" - | "md"; + format: string; name: string; - source: DocumentSource; + source: { + bytes: string; + media_type?: string; + }; }; } export interface BedrockTextBlock { + type: "text"; text: string; } -export type BedrockContentBlock = +export interface BedrockToolCallBlock { + type: "tool_calls"; + tool_calls: BedrockToolCall[]; +} + +export interface BedrockToolResultBlock { + type: "tool_result"; + tool_result: BedrockToolResult; +} + +export type BedrockContent = | BedrockTextBlock | BedrockImageBlock - | BedrockDocumentBlock; + | BedrockDocumentBlock + | BedrockToolCallBlock + | BedrockToolResultBlock; + +export interface BedrockToolSpec { + type: string; + function: { + name: string; + description: string; + parameters: Record; + }; +} + +export interface BedrockToolCall { + type: string; + function: { + name: string; + arguments: string; + }; +} + +export interface BedrockToolResult { + type: string; + output: string; +} + +export interface ContentItem { + type: string; + text?: string; + image_url?: { + url: string; + }; + document?: { + format: string; + name: string; + source: { + bytes: string; + media_type?: string; + }; + }; + tool_calls?: BedrockToolCall[]; + tool_result?: BedrockToolResult; +} + +export interface StreamEvent { + messageStart?: { role: string }; + contentBlockStart?: { index: number }; + contentBlockDelta?: { + delta: { + type?: string; + text?: string; + tool_calls?: BedrockToolCall[]; + tool_result?: BedrockToolResult; + }; + contentBlockIndex: number; + }; + contentBlockStop?: { index: number }; + messageStop?: { stopReason: string }; + metadata?: { + usage: { + inputTokens: number; + outputTokens: number; + totalTokens: number; + }; + metrics: { + latencyMs: number; + }; + }; +} + +export interface ConverseRequest { + modelId: string; + messages: Message[]; + inferenceConfig?: { + maxTokens?: number; + temperature?: number; + topP?: number; + stopSequences?: string[]; + stream?: boolean; + }; + system?: { text: string }[]; + tools?: BedrockToolSpec[]; + additionalModelRequestFields?: Record; + additionalModelResponseFieldPaths?: string[]; +} export interface BedrockResponse { - content?: any[]; + content: BedrockContent[]; completion?: string; stop_reason?: string; usage?: { @@ -55,7 +161,7 @@ export interface BedrockResponse { output_tokens: number; total_tokens: number; }; - tool_calls?: any[]; + tool_calls?: BedrockToolCall[]; } // Helper function to get the base model type from modelId @@ -79,8 +185,59 @@ export function validateModelId(modelId: string): string | null { return null; } +// Helper function to validate document name +export function validateDocumentName(name: string): boolean { + const validPattern = /^[a-zA-Z0-9\s\-\(\)\[\]]+$/; + const noMultipleSpaces = !/\s{2,}/.test(name); + return validPattern.test(name) && noMultipleSpaces; +} + +// Helper function to validate document format +export function validateDocumentFormat( + format: string, +): format is DocumentFormat { + const validFormats: DocumentFormat[] = [ + "pdf", + "csv", + "doc", + "docx", + "xls", + "xlsx", + "html", + "txt", + "md", + ]; + return validFormats.includes(format as DocumentFormat); +} + +// Helper function to validate image size and dimensions +export function validateImageSize(base64Data: string): boolean { + // Check size (3.75 MB limit) + const sizeInBytes = (base64Data.length * 3) / 4; // Approximate size of decoded base64 + const maxSize = 3.75 * 1024 * 1024; // 3.75 MB in bytes + + if (sizeInBytes > maxSize) { + throw new Error("Image size exceeds 3.75 MB limit"); + } + + return true; +} + +// Helper function to validate document size +export function validateDocumentSize(base64Data: string): boolean { + // Check size (4.5 MB limit) + const sizeInBytes = (base64Data.length * 3) / 4; // Approximate size of decoded base64 + const maxSize = 4.5 * 1024 * 1024; // 4.5 MB in bytes + + if (sizeInBytes > maxSize) { + throw new Error("Document size exceeds 4.5 MB limit"); + } + + return true; +} + // Helper function to process document content for Bedrock -export function processDocumentContent(content: any): BedrockContentBlock { +export function processDocumentContent(content: any): BedrockDocumentBlock { if ( !content?.document?.format || !content?.document?.name || @@ -90,70 +247,90 @@ export function processDocumentContent(content: any): BedrockContentBlock { } const format = content.document.format.toLowerCase(); - if ( - !["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"].includes( - format, - ) - ) { + if (!validateDocumentFormat(format)) { throw new Error(`Unsupported document format: ${format}`); } + if (!validateDocumentName(content.document.name)) { + throw new Error( + `Invalid document name: ${content.document.name}. Only alphanumeric characters, single spaces, hyphens, parentheses, and square brackets are allowed.`, + ); + } + + // Validate document size + if (!validateDocumentSize(content.document.source.bytes)) { + throw new Error("Document size validation failed"); + } + return { + type: "document", document: { - format: format as BedrockDocumentBlock["document"]["format"], - name: sanitizeDocumentName(content.document.name), + format: format, + name: content.document.name, source: { bytes: content.document.source.bytes, + media_type: content.document.source.media_type, }, }, }; } -// Helper function to format content for Bedrock -export function formatContent( - content: string | MultimodalContent[], -): BedrockContentBlock[] { - if (typeof content === "string") { - return [{ text: content }]; - } +// Helper function to process image content for Bedrock +export function processImageContent(content: any): BedrockImageBlock { + if (content.type === "image_url" && content.image_url?.url) { + const base64Match = content.image_url.url.match( + /^data:image\/([a-zA-Z]*);base64,([^"]*)$/, + ); + if (base64Match) { + const format = base64Match[1].toLowerCase(); + if (["png", "jpeg", "gif", "webp"].includes(format)) { + // Validate image size + if (!validateImageSize(base64Match[2])) { + throw new Error("Image size validation failed"); + } - const formattedContent: BedrockContentBlock[] = []; - - for (const item of content) { - if (item.type === "text" && item.text) { - formattedContent.push({ text: item.text }); - } else if (item.type === "image_url" && item.image_url?.url) { - // Extract base64 data from data URL - const base64Match = item.image_url.url.match( - /^data:image\/([a-zA-Z]*);base64,([^"]*)$/, - ); - if (base64Match) { - const format = base64Match[1].toLowerCase(); - if (["png", "jpeg", "gif", "webp"].includes(format)) { - formattedContent.push({ - image: { - format: format as "png" | "jpeg" | "gif" | "webp", - source: { - bytes: base64Match[2], - }, + return { + type: "image", + image: { + format: format as ImageFormat, + source: { + bytes: base64Match[2], }, - }); - } - } - } else if ("document" in item) { - try { - formattedContent.push(processDocumentContent(item)); - } catch (error) { - console.error("Error processing document:", error); - // Convert document to text as fallback - formattedContent.push({ - text: `[Document: ${(item as any).document?.name || "Unknown"}]`, - }); + }, + }; } } } + throw new Error("Invalid image content format"); +} - return formattedContent; +// Helper function to validate message content restrictions +export function validateMessageContent(message: Message): void { + if (Array.isArray(message.content)) { + // Count images and documents in user messages + if (message.role === "user") { + const imageCount = message.content.filter( + (item) => item.type === "image_url", + ).length; + const documentCount = message.content.filter( + (item) => item.type === "document", + ).length; + + if (imageCount > 20) { + throw new Error("User messages can include up to 20 images"); + } + + if (documentCount > 5) { + throw new Error("User messages can include up to 5 documents"); + } + } else if ( + message.role === "assistant" && + (message.content.some((item) => item.type === "image_url") || + message.content.some((item) => item.type === "document")) + ) { + throw new Error("Assistant messages cannot include images or documents"); + } + } } // Helper function to ensure messages alternate between user and assistant @@ -162,6 +339,9 @@ export function validateMessageOrder(messages: Message[]): Message[] { let lastRole = ""; for (const message of messages) { + // Validate content restrictions for each message + validateMessageContent(message); + if (message.role === lastRole) { // Skip duplicate roles to maintain alternation continue; @@ -173,16 +353,6 @@ export function validateMessageOrder(messages: Message[]): Message[] { return validatedMessages; } -// Helper function to sanitize document names according to Bedrock requirements -function sanitizeDocumentName(name: string): string { - // Remove any characters that aren't alphanumeric, whitespace, hyphens, or parentheses - let sanitized = name.replace(/[^a-zA-Z0-9\s\-\(\)\[\]]/g, ""); - // Replace multiple whitespace characters with a single space - sanitized = sanitized.replace(/\s+/g, " "); - // Trim whitespace from start and end - return sanitized.trim(); -} - // Helper function to convert Bedrock response back to MultimodalContent format export function convertBedrockResponseToMultimodal( response: BedrockResponse, @@ -196,23 +366,35 @@ export function convertBedrockResponseToMultimodal( } return response.content.map((block) => { - if ("text" in block) { + if (block.type === "text") { return { type: "text", text: block.text, }; - } else if ("image" in block) { + } else if (block.type === "image") { return { type: "image_url", image_url: { url: `data:image/${block.image.format};base64,${block.image.source.bytes}`, }, }; + } else if (block.type === "document") { + return { + type: "document", + document: { + format: block.document.format, + name: block.document.name, + source: { + bytes: block.document.source.bytes, + media_type: block.document.source.media_type, + }, + }, + }; } - // Document responses are converted to text content + // Fallback to text content return { type: "text", - text: block.text || "", + text: "", }; }); } diff --git a/app/client/api.ts b/app/client/api.ts index e547bea0a94..05ce8a236dc 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -31,11 +31,19 @@ export const TTSModels = ["tts-1", "tts-1-hd"] as const; export type ChatModel = ModelType; export interface MultimodalContent { - type: "text" | "image_url"; + type: "text" | "image_url" | "document"; text?: string; image_url?: { url: string; }; + document?: { + format: string; + name: string; + source: { + bytes: string; + media_type?: string; + }; + }; } export interface RequestMessage { diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index f8954f9d774..e2197565064 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -1,62 +1,167 @@ import { ApiPath } from "../../constant"; -import { ChatOptions, getHeaders, LLMApi, SpeechOptions } from "../api"; import { - useAccessStore, - useAppConfig, - useChatStore, - usePluginStore, -} from "../../store"; -import { preProcessImageContent, stream } from "../../utils/chat"; + ChatOptions, + getHeaders, + LLMApi, + LLMUsage, + MultimodalContent, + SpeechOptions, +} from "../api"; +import { useAccessStore, useAppConfig } from "../../store"; import Locale from "../../locales"; +import { + getMessageImages, + getMessageTextContent, + isVisionModel, +} from "../../utils"; +import { fetch } from "../../utils/stream"; -export interface BedrockChatRequest { - model: string; - messages: Array<{ - role: string; - content: - | string - | Array<{ - type: string; - text?: string; - image_url?: { url: string }; - document?: { - format: string; - name: string; - source: { - bytes: string; - }; - }; - }>; - }>; - temperature?: number; - top_p?: number; - max_tokens?: number; - stream?: boolean; -} +const MAX_IMAGE_SIZE = 1024 * 1024 * 4; // 4MB limit export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { - throw new Error("Method not implemented."); + throw new Error("Speech not implemented for Bedrock."); } extractMessage(res: any) { console.log("[Response] bedrock response: ", res); + if (Array.isArray(res?.content)) { + return res.content; + } return res; } - async chat(options: ChatOptions): Promise { - const shouldStream = !!options.config.stream; + async processDocument( + file: File, + ): Promise<{ display: string; content: MultimodalContent }> { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = async () => { + try { + const arrayBuffer = reader.result as ArrayBuffer; + const format = file.name.split(".").pop()?.toLowerCase(); + + if (!format) { + throw new Error("Could not determine file format"); + } + + // Format file size + const size = file.size; + let sizeStr = ""; + if (size < 1024) { + sizeStr = size + " B"; + } else if (size < 1024 * 1024) { + sizeStr = (size / 1024).toFixed(2) + " KB"; + } else { + sizeStr = (size / (1024 * 1024)).toFixed(2) + " MB"; + } + // Create display text + const displayText = `Document: ${file.name} (${sizeStr})`; + + // Create actual content + const content: MultimodalContent = { + type: "document", + document: { + format: format as + | "pdf" + | "csv" + | "doc" + | "docx" + | "xls" + | "xlsx" + | "html" + | "txt" + | "md", + name: file.name, + source: { + bytes: Buffer.from(arrayBuffer).toString("base64"), + }, + }, + }; + + resolve({ + display: displayText, + content: content, + }); + } catch (e) { + reject(e); + } + }; + reader.onerror = () => reject(reader.error); + reader.readAsArrayBuffer(file); + }); + } + + async processImage(url: string): Promise { + if (url.startsWith("data:")) { + const base64Match = url.match(/^data:image\/([a-zA-Z]*);base64,([^"]*)/); + if (base64Match) { + const format = base64Match[1].toLowerCase(); + const base64Data = base64Match[2]; + + // Check base64 size + const binarySize = atob(base64Data).length; + if (binarySize > MAX_IMAGE_SIZE) { + throw new Error( + `Image size (${(binarySize / (1024 * 1024)).toFixed( + 2, + )}MB) exceeds maximum allowed size of 4MB`, + ); + } + + return { + type: "image_url", + image_url: { + url: url, + }, + }; + } + throw new Error("Invalid data URL format"); + } + + // For non-data URLs, fetch and convert to base64 + try { + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.statusText}`); + } + + const blob = await response.blob(); + if (blob.size > MAX_IMAGE_SIZE) { + throw new Error( + `Image size (${(blob.size / (1024 * 1024)).toFixed( + 2, + )}MB) exceeds maximum allowed size of 4MB`, + ); + } + + const reader = new FileReader(); + const base64 = await new Promise((resolve, reject) => { + reader.onloadend = () => resolve(reader.result as string); + reader.onerror = () => reject(new Error("Failed to read image data")); + reader.readAsDataURL(blob); + }); + + return { + type: "image_url", + image_url: { + url: base64, + }, + }; + } catch (error) { + console.error("[Bedrock] Image processing error:", error); + throw error; + } + } + + async chat(options: ChatOptions): Promise { + const accessStore = useAccessStore.getState(); const modelConfig = { ...useAppConfig.getState().modelConfig, - ...useChatStore.getState().currentSession().mask.modelConfig, - ...{ - model: options.config.model, - }, + ...options.config, }; - const accessStore = useAccessStore.getState(); - if ( !accessStore.awsRegion || !accessStore.awsAccessKeyId || @@ -71,70 +176,6 @@ export class BedrockApi implements LLMApi { return; } - // Process messages to handle image and document content - const messages = await Promise.all( - options.messages.map(async (v) => { - const content = await preProcessImageContent(v.content); - // If content is an array (multimodal), ensure each item is properly formatted - if (Array.isArray(content)) { - return { - role: v.role, - content: content.map((item) => { - if (item.type === "image_url" && item.image_url?.url) { - // If the URL is a base64 data URL, use it directly - if (item.image_url.url.startsWith("data:image/")) { - return item; - } - // Otherwise, it's a regular URL that needs to be converted to base64 - // The conversion should have been handled by preProcessImageContent - return item; - } - if ("document" in item) { - // Handle document content - const doc = item as any; - if ( - doc?.document?.format && - doc?.document?.name && - doc?.document?.source?.bytes - ) { - return { - type: "document", - document: { - format: doc.document.format, - name: doc.document.name, - source: { - bytes: doc.document.source.bytes, - }, - }, - }; - } - } - return item; - }), - }; - } - // If content is a string, return it as is - return { - role: v.role, - content, - }; - }), - ); - - const requestBody: BedrockChatRequest = { - messages, - stream: shouldStream, - model: modelConfig.model, - max_tokens: modelConfig.max_tokens, - temperature: modelConfig.temperature, - top_p: modelConfig.top_p, - }; - - console.log("[Bedrock] Request:", { - model: modelConfig.model, - messages: messages, - }); - const controller = new AbortController(); options.onController?.(controller); @@ -150,45 +191,196 @@ export class BedrockApi implements LLMApi { } try { - if (shouldStream) { - let responseText = ""; - const pluginStore = usePluginStore.getState(); - const currentSession = useChatStore.getState().currentSession(); - const [tools, funcs] = pluginStore.getAsTools( - currentSession.mask?.plugin || [], - ); + // Process messages to handle multimodal content + const messages = await Promise.all( + options.messages.map(async (msg) => { + if (Array.isArray(msg.content)) { + // For vision models, include both text and images + if (isVisionModel(options.config.model)) { + const images = getMessageImages(msg); + const content: MultimodalContent[] = []; - await stream( - `${ApiPath.Bedrock}/invoke`, - requestBody, - headers, - Array.isArray(tools) ? tools : [], - funcs || {}, - controller, - (chunk: string) => { - try { - responseText += chunk; - return chunk; - } catch (e) { - console.error("[Request] parse error", chunk, e); - return ""; + // Process documents first + for (const item of msg.content) { + // Check for document content + if (item && typeof item === "object") { + if ("file" in item && item.file instanceof File) { + try { + console.log( + "[Bedrock] Processing document:", + item.file.name, + ); + const { content: docContent } = + await this.processDocument(item.file); + content.push(docContent); + } catch (e) { + console.error("[Bedrock] Failed to process document:", e); + } + } else if ("document" in item && item.document) { + // If document content is already processed, include it directly + content.push(item as MultimodalContent); + } + } + } + + // Add text content if it's not a document display text + const text = getMessageTextContent(msg); + if (text && !text.startsWith("Document: ")) { + content.push({ type: "text", text }); + } + + // Process images with size check and error handling + for (const url of images) { + try { + const imageContent = await this.processImage(url); + content.push(imageContent); + } catch (e) { + console.error("[Bedrock] Failed to process image:", e); + // Add error message as text content + content.push({ + type: "text", + text: `Error processing image: ${e.message}`, + }); + } + } + + // Only return content if there is any + if (content.length > 0) { + return { ...msg, content }; + } } + // For non-vision models, only include text + return { ...msg, content: getMessageTextContent(msg) }; + } + return msg; + }), + ); + + // Filter out empty messages + const filteredMessages = messages.filter((msg) => { + if (Array.isArray(msg.content)) { + return msg.content.length > 0; + } + return msg.content !== ""; + }); + + const requestBody = { + messages: filteredMessages, + modelId: options.config.model, + inferenceConfig: { + maxTokens: modelConfig.max_tokens, + temperature: modelConfig.temperature, + topP: modelConfig.top_p, + stopSequences: [], + }, + }; + + console.log( + "[Bedrock] Request body:", + JSON.stringify( + { + ...requestBody, + messages: requestBody.messages.map((msg) => ({ + ...msg, + content: Array.isArray(msg.content) + ? msg.content.map((c) => ({ + type: c.type, + ...(c.document + ? { + document: { + format: c.document.format, + name: c.document.name, + }, + } + : {}), + ...(c.image_url ? { image_url: { url: "[BINARY]" } } : {}), + ...(c.text ? { text: c.text } : {}), + })) + : msg.content, + })), }, - ( - requestPayload: any, - toolCallMessage: any, - toolCallResult: any[], - ) => { - console.log("[Bedrock] processToolMessage", { - requestPayload, - toolCallMessage, - toolCallResult, - }); + null, + 2, + ), + ); + + const shouldStream = !!options.config.stream; + const conversePath = `${ApiPath.Bedrock}/converse`; + + if (shouldStream) { + let response = await fetch(conversePath, { + method: "POST", + headers: { + ...headers, + "X-Stream": "true", }, - options, - ); + body: JSON.stringify(requestBody), + signal: controller.signal, + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Bedrock API error: ${error}`); + } + + let buffer = ""; + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("No response body reader available"); + } + + let currentContent = ""; + let isFirstMessage = true; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + // Convert the chunk to text and add to buffer + const chunk = new TextDecoder().decode(value); + buffer += chunk; + + // Process complete messages from buffer + let newlineIndex; + while ((newlineIndex = buffer.indexOf("\n")) !== -1) { + const line = buffer.slice(0, newlineIndex).trim(); + buffer = buffer.slice(newlineIndex + 1); + + if (line.startsWith("data: ")) { + try { + const event = JSON.parse(line.slice(6)); + + if (event.type === "messageStart") { + if (isFirstMessage) { + isFirstMessage = false; + } + continue; + } + + if (event.type === "text" && event.content) { + currentContent += event.content; + options.onUpdate?.(currentContent, event.content); + } + + if (event.type === "messageStop") { + options.onFinish(currentContent); + return; + } + + if (event.type === "error") { + throw new Error(event.message || "Unknown error"); + } + } catch (e) { + console.error("[Bedrock] Failed to parse stream event:", e); + } + } + } + } + + // If we reach here without a messageStop event, finish with current content + options.onFinish(currentContent); } else { - const response = await fetch(`${ApiPath.Bedrock}/invoke`, { + const response = await fetch(conversePath, { method: "POST", headers, body: JSON.stringify(requestBody), @@ -197,12 +389,12 @@ export class BedrockApi implements LLMApi { if (!response.ok) { const error = await response.text(); - console.error("[Bedrock] Error response:", error); throw new Error(`Bedrock API error: ${error}`); } - const text = await response.text(); - options.onFinish(text); + const responseBody = await response.json(); + const content = this.extractMessage(responseBody); + options.onFinish(content); } } catch (e) { console.error("[Bedrock] Chat error:", e); @@ -210,7 +402,8 @@ export class BedrockApi implements LLMApi { } } - async usage() { + async usage(): Promise { + // Bedrock usage is tracked through AWS billing return { used: 0, total: 0, @@ -218,6 +411,7 @@ export class BedrockApi implements LLMApi { } async models() { + // Return empty array as models are configured through AWS console return []; } } diff --git a/app/components/chat-actions.tsx b/app/components/chat-actions.tsx new file mode 100644 index 00000000000..25cdfe16d8b --- /dev/null +++ b/app/components/chat-actions.tsx @@ -0,0 +1,188 @@ +import { ChatActions as Actions } from "./chat"; +import DocumentIcon from "../icons/document.svg"; +import LoadingButtonIcon from "../icons/loading.svg"; +import { ServiceProvider } from "../constant"; +import { useChatStore } from "../store"; +import { showToast } from "./ui-lib"; +import { MultimodalContent, MessageRole } from "../client/api"; +import { ChatMessage } from "../store/chat"; + +export function ChatActions(props: Parameters[0]) { + const chatStore = useChatStore(); + const currentProviderName = + chatStore.currentSession().mask.modelConfig?.providerName; + const isBedrockProvider = currentProviderName === ServiceProvider.Bedrock; + + async function uploadDocument() { + const fileInput = document.createElement("input"); + fileInput.type = "file"; + fileInput.accept = ".pdf,.csv,.doc,.docx,.xls,.xlsx,.html,.txt,.md"; + fileInput.onchange = async (event: any) => { + const file = event.target.files[0]; + if (!file) return; + + props.setUploading(true); + try { + // Get file extension and MIME type + const format = file.name.split(".").pop()?.toLowerCase() || ""; + const supportedFormats = [ + "pdf", + "csv", + "doc", + "docx", + "xls", + "xlsx", + "html", + "txt", + "md", + ]; + + if (!supportedFormats.includes(format)) { + throw new Error("Unsupported file format"); + } + + // Map file extensions to MIME types + const mimeTypes: { [key: string]: string } = { + pdf: "application/pdf", + csv: "text/csv", + doc: "application/msword", + docx: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + xls: "application/vnd.ms-excel", + xlsx: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + html: "text/html", + txt: "text/plain", + md: "text/markdown", + }; + + // Convert file to base64 + const base64 = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + if (!e.target?.result) return reject("Failed to read file"); + // Get just the base64 data without the data URL prefix + const base64 = (e.target.result as string).split(",")[1]; + resolve(base64); + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + + // Format file size + const size = file.size; + let sizeStr = ""; + if (size < 1024) { + sizeStr = size + " B"; + } else if (size < 1024 * 1024) { + sizeStr = (size / 1024).toFixed(2) + " KB"; + } else { + sizeStr = (size / (1024 * 1024)).toFixed(2) + " MB"; + } + + // Create document content + const content: MultimodalContent[] = [ + { + type: "text", + text: `Document: ${file.name} (${sizeStr})`, + }, + { + type: "document", + document: { + format, + name: file.name, + source: { + bytes: base64, + media_type: mimeTypes[format] || `application/${format}`, + }, + }, + }, + ]; + + // Send content to Bedrock + const session = chatStore.currentSession(); + const modelConfig = session.mask.modelConfig; + const api = await import("../client/api").then((m) => + m.getClientApi(modelConfig.providerName), + ); + + // Create user message + const userMessage: ChatMessage = { + id: Date.now().toString(), + role: "user" as MessageRole, + content, + date: new Date().toLocaleString(), + isError: false, + }; + + // Create bot message + const botMessage: ChatMessage = { + id: (Date.now() + 1).toString(), + role: "assistant" as MessageRole, + content: "", + date: new Date().toLocaleString(), + streaming: true, + isError: false, + }; + + // Add messages to session + chatStore.updateCurrentSession((session) => { + session.messages.push(userMessage, botMessage); + }); + + // Make request + api.llm.chat({ + messages: [userMessage], + config: { ...modelConfig, stream: true }, + onUpdate(message) { + botMessage.streaming = true; + if (message) { + botMessage.content = message; + } + chatStore.updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, + onFinish(message) { + botMessage.streaming = false; + if (message) { + botMessage.content = message; + chatStore.onNewMessage(botMessage); + } + }, + onError(error) { + botMessage.content = error.message; + botMessage.streaming = false; + userMessage.isError = true; + botMessage.isError = true; + chatStore.updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + console.error("[Chat] failed ", error); + }, + }); + } catch (error) { + console.error("Failed to upload document:", error); + showToast("Failed to upload document"); + } finally { + props.setUploading(false); + } + }; + fileInput.click(); + } + + return ( +
+ {/* Original actions */} + + + {/* Document upload button (only for Bedrock) */} + {isBedrockProvider && ( +
+
+ {props.uploading ? : } +
+
Upload Document
+
+ )} +
+ ); +} diff --git a/app/components/chat.module.scss b/app/components/chat.module.scss index 73542fc67f1..7f168942b3b 100644 --- a/app/components/chat.module.scss +++ b/app/components/chat.module.scss @@ -75,6 +75,17 @@ pointer-events: none; } + .icon { + display: flex; + align-items: center; + justify-content: center; + + svg { + width: 16px; + height: 16px; + } + } + &:hover { --delay: 0.5s; width: var(--full-width); @@ -393,8 +404,8 @@ button { padding: 7px; - } } +} /* Specific styles for iOS devices */ @media screen and (max-device-width: 812px) and (-webkit-min-device-pixel-ratio: 2) { diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 3d5b6a4f2c4..c4d11e31876 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -8,7 +8,7 @@ import React, { Fragment, RefObject, } from "react"; - +import DocumentIcon from "../icons/document.svg"; import SendWhiteIcon from "../icons/send-white.svg"; import BrainIcon from "../icons/brain.svg"; import RenameIcon from "../icons/rename.svg"; @@ -548,6 +548,91 @@ export function ChatActions(props: { ); } }, [chatStore, currentModel, models]); + const isBedrockProvider = currentProviderName === ServiceProvider.Bedrock; + + // ... (rest of the existing state and functions) + + async function uploadDocument() { + const fileInput = document.createElement("input"); + fileInput.type = "file"; + fileInput.accept = ".pdf,.csv,.doc,.docx,.xls,.xlsx,.html,.txt,.md"; + fileInput.onchange = async (event: any) => { + const file = event.target.files[0]; + if (!file) return; + + props.setUploading(true); + try { + // Convert file to base64 + const base64 = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + if (!e.target?.result) return reject("Failed to read file"); + const base64 = (e.target.result as string).split(",")[1]; + resolve(base64); + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); + + // Get file extension + const format = file.name.split(".").pop()?.toLowerCase() || ""; + const supportedFormats = [ + "pdf", + "csv", + "doc", + "docx", + "xls", + "xlsx", + "html", + "txt", + "md", + ]; + + if (!supportedFormats.includes(format)) { + throw new Error("Unsupported file format"); + } + + // Format file size + const size = file.size; + let sizeStr = ""; + if (size < 1024) { + sizeStr = size + " B"; + } else if (size < 1024 * 1024) { + sizeStr = (size / 1024).toFixed(2) + " KB"; + } else { + sizeStr = (size / (1024 * 1024)).toFixed(2) + " MB"; + } + + // Create document content with only filename and size + const documentContent = { + type: "document", + document: { + format, + name: file.name, + size: sizeStr, + source: { + bytes: base64, + }, + }, + }; + + // Submit the document content as a JSON string but only display filename and size + const displayContent = `Document: ${file.name} (${sizeStr})`; + chatStore.onUserInput(displayContent); + + // Store the actual document content separately if needed + // chatStore.updateCurrentSession((session) => { + // session.lastDocument = documentContent; + // }); + } catch (error) { + console.error("Failed to upload document:", error); + showToast("Failed to upload document"); + } finally { + props.setUploading(false); + } + }; + fileInput.click(); + } return (
@@ -580,6 +665,14 @@ export function ChatActions(props: { icon={props.uploading ? : } /> )} + {/* Add document upload button for Bedrock */} + {isBedrockProvider && ( + : } + /> + )} + + + + + + diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 573969be7b1..cdf99f19791 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -81,6 +81,7 @@ const cn = { Clear: "清除聊天", Settings: "对话设置", UploadImage: "上传图片", + UploadDocument: "上传文档", }, Rename: "重命名对话", Typing: "正在输入…", diff --git a/app/locales/en.ts b/app/locales/en.ts index 9d3097ef822..12107ac7cda 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -82,6 +82,7 @@ const en: LocaleType = { Clear: "Clear Context", Settings: "Settings", UploadImage: "Upload Images", + UploadDocument: "Upload Documents", }, Rename: "Rename Chat", Typing: "Typing…",