diff --git a/src/aws_bedrock_embedders.ts b/src/aws_bedrock_embedders.ts index 50fbfcc..5e8874a 100644 --- a/src/aws_bedrock_embedders.ts +++ b/src/aws_bedrock_embedders.ts @@ -1,98 +1,92 @@ -// /** -// * Copyright 2024 The Fire Company -// * -// * Licensed under the Apache License, Version 2.0 (the "License"); -// * you may not use this file except in compliance with the License. -// * You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// /* eslint-disable @typescript-eslint/no-explicit-any */ +/** + * Copyright 2024 The Fire Company + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* eslint-disable @typescript-eslint/no-explicit-any */ -// import { embedderRef, Genkit } from "genkit"; -// import ModelClient, { -// GetEmbeddings200Response, -// GetEmbeddingsParameters, -// } from "@azure-rest/ai-inference"; -// import { z } from "zod"; -// import { type PluginOptions } from "./index.js"; -// import { AzureKeyCredential } from "@azure/core-auth"; +import { embedderRef, Genkit } from "genkit"; -// export const TextEmbeddingConfigSchema = z.object({ -// dimensions: z.number().optional(), -// encodingFormat: z.union([z.literal("float"), z.literal("base64")]).optional(), -// }); +import { z } from "zod"; +import { + BedrockRuntimeClient, + InvokeModelCommand, + InvokeModelCommandInput, + InvokeModelCommandOutput, +} from "@aws-sdk/client-bedrock-runtime"; -// export type TextEmbeddingGeckoConfig = z.infer< -// typeof TextEmbeddingConfigSchema -// >; +export const TextEmbeddingConfigSchema = z.object({ + dimensions: z.number().optional(), +}); -// export const TextEmbeddingInputSchema = z.string(); +export type TextEmbeddingGeckoConfig = z.infer< + typeof TextEmbeddingConfigSchema +>; -// export const openAITextEmbedding3Small = embedderRef({ -// name: "github/text-embedding-3-small", -// configSchema: TextEmbeddingConfigSchema, -// info: { -// dimensions: 1536, -// label: "OpenAI - Text-embedding-3-small", -// supports: { -// input: ["text"], -// }, -// }, -// }); +export const TextEmbeddingInputSchema = z.string(); -// export const SUPPORTED_EMBEDDING_MODELS: Record = { -// "text-embedding-3-small": openAITextEmbedding3Small, -// }; +export const amazonTitanEmbedTextV2 = embedderRef({ + name: "aws-bedrock/amazon.titan-embed-text-v2:0", + configSchema: TextEmbeddingConfigSchema, + info: { + dimensions: 512, + label: "Amazon - titan-embed-text-v2:0", + supports: { + input: ["text"], + }, + }, +}); -// export function awsBedrockEmbedder( -// name: string, -// ai: Genkit, -// options?: PluginOptions, -// ) { -// const token = options?.githubToken || process.env.GITHUB_TOKEN; -// let endpoint = options?.endpoint || process.env.GITHUB_ENDPOINT; -// if (!token) { -// throw new Error( -// "Please pass in the TOKEN key or set the GITHUB_TOKEN environment variable", -// ); -// } -// if (!endpoint) { -// endpoint = "https://models.inference.ai.azure.com"; -// } +export const SUPPORTED_EMBEDDING_MODELS: Record = { + "amazon.titan-embed-text-v2:0": amazonTitanEmbedTextV2, +}; -// const client = ModelClient(endpoint, new AzureKeyCredential(token)); -// const model = SUPPORTED_EMBEDDING_MODELS[name]; +export function awsBedrockEmbedder( + name: string, + ai: Genkit, + client: BedrockRuntimeClient, +) { + const model = SUPPORTED_EMBEDDING_MODELS[name]; -// return ai.defineEmbedder( -// { -// info: model.info!, -// configSchema: TextEmbeddingConfigSchema, -// name: model.name, -// }, -// async (input, options) => { -// const body = { -// body: { -// model: name, -// input: input.map((d) => d.text), -// dimensions: options?.dimensions, -// encoding_format: options?.encodingFormat, -// }, -// } as GetEmbeddingsParameters; -// const embeddings = (await client -// .path("/embeddings") -// .post(body)) as GetEmbeddings200Response; -// return { -// embeddings: embeddings.body.data.map((d) => ({ -// embedding: Array.isArray(d.embedding) ? d.embedding : [], -// })), -// }; -// }, -// ); -// } + return ai.defineEmbedder( + { + info: model.info!, + configSchema: TextEmbeddingConfigSchema, + name: model.name, + }, + async (input, options) => { + const body: InvokeModelCommandInput = { + modelId: name, + contentType: "application/json", + body: JSON.stringify({ + inputText: input.map((d) => d.text).join(","), + dimensions: options?.dimensions, + }), + }; + + const command = new InvokeModelCommand(body); + + const response = (await client.send(command)) as InvokeModelCommandOutput; + const embeddings = new TextDecoder().decode(response.body) + ? JSON.parse(new TextDecoder().decode(response.body)) + : []; + return { + embeddings: [ + { + embedding: embeddings.embedding as number[], + }, + ], + }; + }, + ); +} diff --git a/src/aws_bedrock_llms.ts b/src/aws_bedrock_llms.ts index cbda432..0cf68f4 100644 --- a/src/aws_bedrock_llms.ts +++ b/src/aws_bedrock_llms.ts @@ -14,7 +14,6 @@ * limitations under the License. */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import * as fs from "fs"; import { Message, @@ -103,7 +102,7 @@ export function toAwsBedrockTextAndMedia( text: part.text, }; } else if (part.media) { - const imageBuffer = new Uint8Array(fs.readFileSync(part.media.url).buffer); + const imageBuffer = new Uint8Array(Buffer.from(part.media.url, "base64")); return { image: { diff --git a/src/index.ts b/src/index.ts index 48f1d29..f8db113 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,15 +10,15 @@ import { amazonNovaProV1, SUPPORTED_AWS_BEDROCK_MODELS, } from "./aws_bedrock_llms.js"; -// import { -// awsBedrockEmbedder, -// openAITextEmbedding3Small, -// SUPPORTED_EMBEDDING_MODELS, -// } from "./aws_bedrock_embedders.js"; +import { + awsBedrockEmbedder, + amazonTitanEmbedTextV2, + SUPPORTED_EMBEDDING_MODELS, +} from "./aws_bedrock_embedders.js"; export { amazonNovaProV1 }; -// export { openAITextEmbedding3Small }; +export { amazonTitanEmbedTextV2 }; export type PluginOptions = BedrockRuntimeClientConfig; @@ -30,9 +30,9 @@ export function awsBedrock(options?: PluginOptions) { awsBedrockModel(name, client, ai); }); - // Object.keys(SUPPORTED_EMBEDDING_MODELS).forEach((name) => - // awsBedrockEmbedder(name, ai, options), - // ); + Object.keys(SUPPORTED_EMBEDDING_MODELS).forEach((name) => + awsBedrockEmbedder(name, ai, client), + ); }); }