Skip to content

Commit

Permalink
feat (ai/core): add embed function (#1575)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored May 14, 2024
1 parent 1009594 commit 0f6bc4e
Show file tree
Hide file tree
Showing 26 changed files with 963 additions and 22 deletions.
8 changes: 8 additions & 0 deletions .changeset/witty-beds-sell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@ai-sdk/provider': patch
'@ai-sdk/mistral': patch
'@ai-sdk/openai': patch
'ai': patch
---

feat (ai/core): add embed function
24 changes: 24 additions & 0 deletions content/docs/03-ai-sdk-core/30-embeddings.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
title: Embeddings
description: Learn how to embed values with the Vercel AI SDK.
---

# Embeddings

Embeddings are a way to represent words, phrases, or images as vectors in a high-dimensional space.
In this space, similar words are close to each other, and the distance between words can be used to measure their similarity.

## Embedding a Single Value

The Vercel AI SDK provides the `embed` function to embed single values, which is useful for tasks such as finding similar words
or phrases or clustering text. You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`.

```tsx
import { embed } from 'ai';
import { openai } from '@ai-sdk/openai';

const { embedding } = await embed({
model: openai.embedding('text-embedding-3-small'),
value: 'sunny day at the beach',
});
```
16 changes: 16 additions & 0 deletions examples/ai-core/src/embed/mistral.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { mistral } from '@ai-sdk/mistral';
import { embed } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const { embedding } = await embed({
model: mistral.embedding('mistral-embed'),
value: 'sunny day at the beach',
});

console.log(embedding);
}

main().catch(console.error);
16 changes: 16 additions & 0 deletions examples/ai-core/src/embed/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { openai } from '@ai-sdk/openai';
import { embed } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const { embedding } = await embed({
model: openai.embedding('text-embedding-3-small'),
value: 'sunny day at the beach',
});

console.log(embedding);
}

main().catch(console.error);
25 changes: 25 additions & 0 deletions packages/core/core/embed/embed.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import assert from 'node:assert';
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1';
import { embed } from './embed';

const dummyEmbedding = [0.1, 0.2, 0.3];
const testValue = 'sunny day at the beach';

describe('result.embedding', () => {
it('should generate embedding', async () => {
const result = await embed({
model: new MockEmbeddingModelV1({
doEmbed: async ({ values }) => {
assert.deepStrictEqual(values, [testValue]);

return {
embeddings: [dummyEmbedding],
};
},
}),
value: testValue,
});

assert.deepStrictEqual(result.embedding, dummyEmbedding);
});
});
95 changes: 95 additions & 0 deletions packages/core/core/embed/embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { Embedding, EmbeddingModel } from '../types';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';

/**
Embed a value using an embedding model. The type of the value is defined by the embedding model.
@param model - The embedding model to use.
@param value - The value that should be embedded.
@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2.
@param abortSignal - An optional abort signal that can be used to cancel the call.
@returns A result object that contains the embedding, the value, and additional information.
*/
export async function embed<VALUE>({
model,
value,
maxRetries,
abortSignal,
}: {
/**
The embedding model to use.
*/
model: EmbeddingModel<VALUE>;

/**
The value that should be embedded.
*/
value: VALUE;

/**
Maximum number of retries per embedding model call. Set to 0 to disable retries.
@default 2
*/
maxRetries?: number;

/**
Abort signal.
*/
abortSignal?: AbortSignal;
}): Promise<EmbedResult<VALUE>> {
const retry = retryWithExponentialBackoff({ maxRetries });

const modelResponse = await retry(() =>
model.doEmbed({
values: [value],
abortSignal,
}),
);

return new EmbedResult({
value,
embedding: modelResponse.embeddings[0],
rawResponse: modelResponse.rawResponse,
});
}

/**
The result of a `embed` call.
It contains the embedding, the value, and additional information.
*/
export class EmbedResult<VALUE> {
/**
The value that was embedded.
*/
readonly value: VALUE;

/**
The embedding of the value.
*/
readonly embedding: Embedding;

/**
Optional raw response data.
*/
readonly rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

constructor(options: {
value: VALUE;
embedding: Embedding;
rawResponse?: {
headers?: Record<string, string>;
};
}) {
this.value = options.value;
this.embedding = options.embedding;
this.rawResponse = options.rawResponse;
}
}
1 change: 1 addition & 0 deletions packages/core/core/embed/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from './embed';
1 change: 1 addition & 0 deletions packages/core/core/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './embed';
export * from './generate-object';
export * from './generate-text';
export * from './prompt';
Expand Down
36 changes: 36 additions & 0 deletions packages/core/core/test/mock-embedding-model-v1.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { EmbeddingModelV1 } from '@ai-sdk/provider';

export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> {
readonly specificationVersion = 'v1';

readonly provider: EmbeddingModelV1<VALUE>['provider'];
readonly modelId: EmbeddingModelV1<VALUE>['modelId'];
readonly maxEmbeddingsPerCall: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
readonly supportsParallelCalls: EmbeddingModelV1<VALUE>['supportsParallelCalls'];

doEmbed: EmbeddingModelV1<VALUE>['doEmbed'];

constructor({
provider = 'mock-provider',
modelId = 'mock-model-id',
maxEmbeddingsPerCall = 1,
supportsParallelCalls = false,
doEmbed = notImplemented,
}: {
provider?: EmbeddingModelV1<VALUE>['provider'];
modelId?: EmbeddingModelV1<VALUE>['modelId'];
maxEmbeddingsPerCall?: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall'];
supportsParallelCalls?: EmbeddingModelV1<VALUE>['supportsParallelCalls'];
doEmbed?: EmbeddingModelV1<VALUE>['doEmbed'];
}) {
this.provider = provider;
this.modelId = modelId;
this.maxEmbeddingsPerCall = maxEmbeddingsPerCall;
this.supportsParallelCalls = supportsParallelCalls;
this.doEmbed = doEmbed;
}
}

function notImplemented(): never {
throw new Error('Not implemented');
}
11 changes: 11 additions & 0 deletions packages/core/core/types/embedding-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { EmbeddingModelV1, EmbeddingModelV1Embedding } from '@ai-sdk/provider';

/**
Embedding model that is used by the AI SDK Core functions.
*/
export type EmbeddingModel<VALUE> = EmbeddingModelV1<VALUE>;

/**
Embedding.
*/
export type Embedding = EmbeddingModelV1Embedding;
1 change: 1 addition & 0 deletions packages/core/core/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './embedding-model';
export * from './errors';
export * from './language-model';
106 changes: 106 additions & 0 deletions packages/mistral/src/mistral-embedding-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import { EmbeddingModelV1Embedding } from '@ai-sdk/provider';
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
import { createMistral } from './mistral-provider';

const dummyEmbeddings = [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0],
];
const testValues = ['sunny day at the beach', 'rainy day in the city'];

const provider = createMistral({ apiKey: 'test-api-key' });
const model = provider.embedding('mistral-embed');

describe('doEmbed', () => {
const server = new JsonTestServer('https://api.mistral.ai/v1/embeddings');

server.setupTestEnvironment();

function prepareJsonResponse({
embeddings = dummyEmbeddings,
}: {
embeddings?: EmbeddingModelV1Embedding[];
} = {}) {
server.responseBodyJson = {
id: 'b322cfc2b9d34e2f8e14fc99874faee5',
object: 'list',
data: embeddings.map((embedding, i) => ({
object: 'embedding',
embedding,
index: i,
})),
model: 'mistral-embed',
usage: { prompt_tokens: 8, total_tokens: 8, completion_tokens: 0 },
};
}

it('should extract embedding', async () => {
prepareJsonResponse();

const { embeddings } = await model.doEmbed({ values: testValues });

expect(embeddings).toStrictEqual(dummyEmbeddings);
});

it('should expose the raw response headers', async () => {
prepareJsonResponse();

server.responseHeaders = {
'test-header': 'test-value',
};

const { rawResponse } = await model.doEmbed({ values: testValues });

expect(rawResponse?.headers).toStrictEqual({
// default headers:
'content-type': 'application/json',

// custom header
'test-header': 'test-value',
});
});

it('should pass the model and the values', async () => {
prepareJsonResponse();

await model.doEmbed({ values: testValues });

expect(await server.getRequestBodyJson()).toStrictEqual({
model: 'mistral-embed',
input: testValues,
encoding_format: 'float',
});
});

it('should pass custom headers', async () => {
prepareJsonResponse();

const provider = createMistral({
apiKey: 'test-api-key',
headers: {
'Custom-Header': 'test-header',
},
});

await provider.embedding('mistral-embed').doEmbed({
values: testValues,
});

const requestHeaders = await server.getRequestHeaders();
expect(requestHeaders.get('Custom-Header')).toStrictEqual('test-header');
});

it('should pass the api key as Authorization header', async () => {
prepareJsonResponse();

const provider = createMistral({ apiKey: 'test-api-key' });

await provider.embedding('mistral-embed').doEmbed({
values: testValues,
});

expect(
(await server.getRequestHeaders()).get('Authorization'),
).toStrictEqual('Bearer test-api-key');
});
});
Loading

0 comments on commit 0f6bc4e

Please sign in to comment.