diff --git a/src/aws_bedrock_llms.ts b/src/aws_bedrock_llms.ts index 6741ec2..7b79ab1 100644 --- a/src/aws_bedrock_llms.ts +++ b/src/aws_bedrock_llms.ts @@ -30,6 +30,7 @@ import { ModelResponseData, ModelAction, modelRef, + ToolDefinition, //ToolDefinition, } from "genkit/model"; @@ -48,6 +49,7 @@ import { ContentBlockDelta, ImageFormat, StopReason, + Tool, } from "@aws-sdk/client-bedrock-runtime"; export const amazonNovaProV1 = modelRef({ @@ -119,16 +121,17 @@ function toAwsBedrockbRole(role: Role): string { } } -// function toAwsBedrockTool(tool: ToolDefinition): ToolUseBlock { -// return { -// toolUseId: tool.name, // or any appropriate value for toolUseId -// name: tool.name, -// input: tool.inputSchema, -// }; -// } -const regex = /data:.*base64,/ -const getDataPart = (dataUrl: string) => dataUrl.replace(regex,""); - +function toAwsBedrockTool(tool: ToolDefinition): Tool { + return { + toolSpec: { + name: tool.name, + description: tool.description, + inputSchema: tool.inputSchema ? { json: tool.inputSchema } : undefined, + }, + }; +} +const regex = /data:.*base64,/; +const getDataPart = (dataUrl: string) => dataUrl.replace(regex, ""); export function toAwsBedrockTextAndMedia( part: Part, @@ -139,7 +142,9 @@ export function toAwsBedrockTextAndMedia( text: part.text, }; } else if (part.media) { - const imageBuffer = new Uint8Array(Buffer.from(getDataPart(part.media.url), "base64")); + const imageBuffer = new Uint8Array( + Buffer.from(getDataPart(part.media.url), "base64"), + ); return { image: { @@ -202,6 +207,8 @@ export function toAwsBedrockMessages( break; } case "assistant": { + // Request to call the tool + const toolCalls: ToolUseBlock[] = msg.content .filter((part) => part.toolRequest) .map((part) => { @@ -213,7 +220,7 @@ export function toAwsBedrockMessages( return { toolUseId: part.toolRequest.ref || "", name: part.toolRequest.name, - input: JSON.stringify(part.toolRequest.input), + input: part.toolRequest.input as any, }; }); if (toolCalls?.length > 0) { @@ -233,20 +240,34 @@ export function toAwsBedrockMessages( } break; } - // case "tool": { - // const toolResponseParts = msg.toolResponseParts(); - // toolResponseParts.map((part) => { - // awsBedrockMsgs.push({ - // role: role, - // tool_call_id: part.toolResponse.ref || "", - // content: - // typeof part.toolResponse.output === "string" - // ? part.toolResponse.output - // : JSON.stringify(part.toolResponse.output), - // }); - // }); - // break; - // } + case "tool": { + // result of the tool + const toolResponseParts = msg.toolResponseParts(); + + toolResponseParts.map((part) => { + const toolresult: AwsMessge = { + role: "user", + content: [ + { + toolResult: { + toolUseId: part.toolResponse.ref, + content: [ + { + json: { + result: part.toolResponse.output as { + [key: string]: any; + }, + }, + }, + ], + }, + }, + ], + }; + awsBedrockMsgs.push(toolresult); + }); + break; + } default: throw new Error("unrecognized role"); } @@ -273,13 +294,15 @@ function fromAwsBedrockToolCall(toolCall: ToolUseBlock) { ); } const f = toolCall; - return { - toolRequest: { - name: f.name, - ref: toolCall.toolUseId, - input: f.input ? JSON.parse(f.input as string) : f.input, + return [ + { + toolRequest: { + name: f.name, + ref: toolCall.toolUseId, + input: f.input, + }, }, - }; + ]; } function fromAwsBedrockChoice( @@ -373,7 +396,9 @@ export function toAwsBedrockRequestBody( const body: ConverseCommandInput | ConverseStreamCommandInput = { messages: awsBedrockMessages, system: awsBedrockSystemMessage as SystemContentBlock[] | undefined, - //toolConfig: request.tools?.map(toAwsBedrockTool), + toolConfig: request.tools + ? { tools: request.tools.map(toAwsBedrockTool) } + : undefined, modelId: modelString, inferenceConfig: { maxTokens: request.config?.maxOutputTokens,