Skip to content

Commit

Permalink
feat: tools working
Browse files Browse the repository at this point in the history
  • Loading branch information
xavidop committed Dec 27, 2024
1 parent 6df7391 commit 9087bdf
Showing 1 changed file with 58 additions and 33 deletions.
91 changes: 58 additions & 33 deletions src/aws_bedrock_llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
ModelResponseData,
ModelAction,
modelRef,
ToolDefinition,
//ToolDefinition,
} from "genkit/model";

Expand All @@ -48,6 +49,7 @@ import {
ContentBlockDelta,
ImageFormat,
StopReason,
Tool,
} from "@aws-sdk/client-bedrock-runtime";

export const amazonNovaProV1 = modelRef({
Expand Down Expand Up @@ -119,16 +121,17 @@ function toAwsBedrockbRole(role: Role): string {
}
}

// function toAwsBedrockTool(tool: ToolDefinition): ToolUseBlock {
// return {
// toolUseId: tool.name, // or any appropriate value for toolUseId
// name: tool.name,
// input: tool.inputSchema,
// };
// }
const regex = /data:.*base64,/
const getDataPart = (dataUrl: string) => dataUrl.replace(regex,"");

function toAwsBedrockTool(tool: ToolDefinition): Tool {
return {
toolSpec: {
name: tool.name,
description: tool.description,
inputSchema: tool.inputSchema ? { json: tool.inputSchema } : undefined,
},
};
}
const regex = /data:.*base64,/;
const getDataPart = (dataUrl: string) => dataUrl.replace(regex, "");

export function toAwsBedrockTextAndMedia(
part: Part,
Expand All @@ -139,7 +142,9 @@ export function toAwsBedrockTextAndMedia(
text: part.text,
};
} else if (part.media) {
const imageBuffer = new Uint8Array(Buffer.from(getDataPart(part.media.url), "base64"));
const imageBuffer = new Uint8Array(
Buffer.from(getDataPart(part.media.url), "base64"),
);

return {
image: {
Expand Down Expand Up @@ -202,6 +207,8 @@ export function toAwsBedrockMessages(
break;
}
case "assistant": {
// Request to call the tool

const toolCalls: ToolUseBlock[] = msg.content
.filter((part) => part.toolRequest)
.map((part) => {
Expand All @@ -213,7 +220,7 @@ export function toAwsBedrockMessages(
return {
toolUseId: part.toolRequest.ref || "",
name: part.toolRequest.name,
input: JSON.stringify(part.toolRequest.input),
input: part.toolRequest.input as any,
};
});
if (toolCalls?.length > 0) {
Expand All @@ -233,20 +240,34 @@ export function toAwsBedrockMessages(
}
break;
}
// case "tool": {
// const toolResponseParts = msg.toolResponseParts();
// toolResponseParts.map((part) => {
// awsBedrockMsgs.push({
// role: role,
// tool_call_id: part.toolResponse.ref || "",
// content:
// typeof part.toolResponse.output === "string"
// ? part.toolResponse.output
// : JSON.stringify(part.toolResponse.output),
// });
// });
// break;
// }
case "tool": {
// result of the tool
const toolResponseParts = msg.toolResponseParts();

toolResponseParts.map((part) => {
const toolresult: AwsMessge = {
role: "user",
content: [
{
toolResult: {
toolUseId: part.toolResponse.ref,
content: [
{
json: {
result: part.toolResponse.output as {
[key: string]: any;
},
},
},
],
},
},
],
};
awsBedrockMsgs.push(toolresult);
});
break;
}
default:
throw new Error("unrecognized role");
}
Expand All @@ -273,13 +294,15 @@ function fromAwsBedrockToolCall(toolCall: ToolUseBlock) {
);
}
const f = toolCall;
return {
toolRequest: {
name: f.name,
ref: toolCall.toolUseId,
input: f.input ? JSON.parse(f.input as string) : f.input,
return [
{
toolRequest: {
name: f.name,
ref: toolCall.toolUseId,
input: f.input,
},
},
};
];
}

function fromAwsBedrockChoice(
Expand Down Expand Up @@ -373,7 +396,9 @@ export function toAwsBedrockRequestBody(
const body: ConverseCommandInput | ConverseStreamCommandInput = {
messages: awsBedrockMessages,
system: awsBedrockSystemMessage as SystemContentBlock[] | undefined,
//toolConfig: request.tools?.map(toAwsBedrockTool),
toolConfig: request.tools
? { tools: request.tools.map(toAwsBedrockTool) }
: undefined,
modelId: modelString,
inferenceConfig: {
maxTokens: request.config?.maxOutputTokens,
Expand Down

0 comments on commit 9087bdf

Please sign in to comment.