-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat (ai/core): add embed function (#1575)
- Loading branch information
Showing
26 changed files
with
963 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
--- | ||
'@ai-sdk/provider': patch | ||
'@ai-sdk/mistral': patch | ||
'@ai-sdk/openai': patch | ||
'ai': patch | ||
--- | ||
|
||
feat (ai/core): add embed function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
--- | ||
title: Embeddings | ||
description: Learn how to embed values with the Vercel AI SDK. | ||
--- | ||
|
||
# Embeddings | ||
|
||
Embeddings are a way to represent words, phrases, or images as vectors in a high-dimensional space. | ||
In this space, similar words are close to each other, and the distance between words can be used to measure their similarity. | ||
|
||
## Embedding a Single Value | ||
|
||
The Vercel AI SDK provides the `embed` function to embed single values, which is useful for tasks such as finding similar words | ||
or phrases or clustering text. You can use it with embeddings models, e.g. `openai.embedding('text-embedding-3-large')` or `mistral.embedding('mistral-embed')`. | ||
|
||
```tsx | ||
import { embed } from 'ai'; | ||
import { openai } from '@ai-sdk/openai'; | ||
|
||
const { embedding } = await embed({ | ||
model: openai.embedding('text-embedding-3-small'), | ||
value: 'sunny day at the beach', | ||
}); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { mistral } from '@ai-sdk/mistral'; | ||
import { embed } from 'ai'; | ||
import dotenv from 'dotenv'; | ||
|
||
dotenv.config(); | ||
|
||
async function main() { | ||
const { embedding } = await embed({ | ||
model: mistral.embedding('mistral-embed'), | ||
value: 'sunny day at the beach', | ||
}); | ||
|
||
console.log(embedding); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { openai } from '@ai-sdk/openai'; | ||
import { embed } from 'ai'; | ||
import dotenv from 'dotenv'; | ||
|
||
dotenv.config(); | ||
|
||
async function main() { | ||
const { embedding } = await embed({ | ||
model: openai.embedding('text-embedding-3-small'), | ||
value: 'sunny day at the beach', | ||
}); | ||
|
||
console.log(embedding); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import assert from 'node:assert'; | ||
import { MockEmbeddingModelV1 } from '../test/mock-embedding-model-v1'; | ||
import { embed } from './embed'; | ||
|
||
const dummyEmbedding = [0.1, 0.2, 0.3]; | ||
const testValue = 'sunny day at the beach'; | ||
|
||
describe('result.embedding', () => { | ||
it('should generate embedding', async () => { | ||
const result = await embed({ | ||
model: new MockEmbeddingModelV1({ | ||
doEmbed: async ({ values }) => { | ||
assert.deepStrictEqual(values, [testValue]); | ||
|
||
return { | ||
embeddings: [dummyEmbedding], | ||
}; | ||
}, | ||
}), | ||
value: testValue, | ||
}); | ||
|
||
assert.deepStrictEqual(result.embedding, dummyEmbedding); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import { Embedding, EmbeddingModel } from '../types'; | ||
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff'; | ||
|
||
/** | ||
Embed a value using an embedding model. The type of the value is defined by the embedding model. | ||
@param model - The embedding model to use. | ||
@param value - The value that should be embedded. | ||
@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2. | ||
@param abortSignal - An optional abort signal that can be used to cancel the call. | ||
@returns A result object that contains the embedding, the value, and additional information. | ||
*/ | ||
export async function embed<VALUE>({ | ||
model, | ||
value, | ||
maxRetries, | ||
abortSignal, | ||
}: { | ||
/** | ||
The embedding model to use. | ||
*/ | ||
model: EmbeddingModel<VALUE>; | ||
|
||
/** | ||
The value that should be embedded. | ||
*/ | ||
value: VALUE; | ||
|
||
/** | ||
Maximum number of retries per embedding model call. Set to 0 to disable retries. | ||
@default 2 | ||
*/ | ||
maxRetries?: number; | ||
|
||
/** | ||
Abort signal. | ||
*/ | ||
abortSignal?: AbortSignal; | ||
}): Promise<EmbedResult<VALUE>> { | ||
const retry = retryWithExponentialBackoff({ maxRetries }); | ||
|
||
const modelResponse = await retry(() => | ||
model.doEmbed({ | ||
values: [value], | ||
abortSignal, | ||
}), | ||
); | ||
|
||
return new EmbedResult({ | ||
value, | ||
embedding: modelResponse.embeddings[0], | ||
rawResponse: modelResponse.rawResponse, | ||
}); | ||
} | ||
|
||
/** | ||
The result of a `embed` call. | ||
It contains the embedding, the value, and additional information. | ||
*/ | ||
export class EmbedResult<VALUE> { | ||
/** | ||
The value that was embedded. | ||
*/ | ||
readonly value: VALUE; | ||
|
||
/** | ||
The embedding of the value. | ||
*/ | ||
readonly embedding: Embedding; | ||
|
||
/** | ||
Optional raw response data. | ||
*/ | ||
readonly rawResponse?: { | ||
/** | ||
Response headers. | ||
*/ | ||
headers?: Record<string, string>; | ||
}; | ||
|
||
constructor(options: { | ||
value: VALUE; | ||
embedding: Embedding; | ||
rawResponse?: { | ||
headers?: Record<string, string>; | ||
}; | ||
}) { | ||
this.value = options.value; | ||
this.embedding = options.embedding; | ||
this.rawResponse = options.rawResponse; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export * from './embed'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import { EmbeddingModelV1 } from '@ai-sdk/provider'; | ||
|
||
export class MockEmbeddingModelV1<VALUE> implements EmbeddingModelV1<VALUE> { | ||
readonly specificationVersion = 'v1'; | ||
|
||
readonly provider: EmbeddingModelV1<VALUE>['provider']; | ||
readonly modelId: EmbeddingModelV1<VALUE>['modelId']; | ||
readonly maxEmbeddingsPerCall: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall']; | ||
readonly supportsParallelCalls: EmbeddingModelV1<VALUE>['supportsParallelCalls']; | ||
|
||
doEmbed: EmbeddingModelV1<VALUE>['doEmbed']; | ||
|
||
constructor({ | ||
provider = 'mock-provider', | ||
modelId = 'mock-model-id', | ||
maxEmbeddingsPerCall = 1, | ||
supportsParallelCalls = false, | ||
doEmbed = notImplemented, | ||
}: { | ||
provider?: EmbeddingModelV1<VALUE>['provider']; | ||
modelId?: EmbeddingModelV1<VALUE>['modelId']; | ||
maxEmbeddingsPerCall?: EmbeddingModelV1<VALUE>['maxEmbeddingsPerCall']; | ||
supportsParallelCalls?: EmbeddingModelV1<VALUE>['supportsParallelCalls']; | ||
doEmbed?: EmbeddingModelV1<VALUE>['doEmbed']; | ||
}) { | ||
this.provider = provider; | ||
this.modelId = modelId; | ||
this.maxEmbeddingsPerCall = maxEmbeddingsPerCall; | ||
this.supportsParallelCalls = supportsParallelCalls; | ||
this.doEmbed = doEmbed; | ||
} | ||
} | ||
|
||
function notImplemented(): never { | ||
throw new Error('Not implemented'); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import { EmbeddingModelV1, EmbeddingModelV1Embedding } from '@ai-sdk/provider'; | ||
|
||
/** | ||
Embedding model that is used by the AI SDK Core functions. | ||
*/ | ||
export type EmbeddingModel<VALUE> = EmbeddingModelV1<VALUE>; | ||
|
||
/** | ||
Embedding. | ||
*/ | ||
export type Embedding = EmbeddingModelV1Embedding; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export * from './embedding-model'; | ||
export * from './errors'; | ||
export * from './language-model'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import { EmbeddingModelV1Embedding } from '@ai-sdk/provider'; | ||
import { JsonTestServer } from '@ai-sdk/provider-utils/test'; | ||
import { createMistral } from './mistral-provider'; | ||
|
||
const dummyEmbeddings = [ | ||
[0.1, 0.2, 0.3, 0.4, 0.5], | ||
[0.6, 0.7, 0.8, 0.9, 1.0], | ||
]; | ||
const testValues = ['sunny day at the beach', 'rainy day in the city']; | ||
|
||
const provider = createMistral({ apiKey: 'test-api-key' }); | ||
const model = provider.embedding('mistral-embed'); | ||
|
||
describe('doEmbed', () => { | ||
const server = new JsonTestServer('https://api.mistral.ai/v1/embeddings'); | ||
|
||
server.setupTestEnvironment(); | ||
|
||
function prepareJsonResponse({ | ||
embeddings = dummyEmbeddings, | ||
}: { | ||
embeddings?: EmbeddingModelV1Embedding[]; | ||
} = {}) { | ||
server.responseBodyJson = { | ||
id: 'b322cfc2b9d34e2f8e14fc99874faee5', | ||
object: 'list', | ||
data: embeddings.map((embedding, i) => ({ | ||
object: 'embedding', | ||
embedding, | ||
index: i, | ||
})), | ||
model: 'mistral-embed', | ||
usage: { prompt_tokens: 8, total_tokens: 8, completion_tokens: 0 }, | ||
}; | ||
} | ||
|
||
it('should extract embedding', async () => { | ||
prepareJsonResponse(); | ||
|
||
const { embeddings } = await model.doEmbed({ values: testValues }); | ||
|
||
expect(embeddings).toStrictEqual(dummyEmbeddings); | ||
}); | ||
|
||
it('should expose the raw response headers', async () => { | ||
prepareJsonResponse(); | ||
|
||
server.responseHeaders = { | ||
'test-header': 'test-value', | ||
}; | ||
|
||
const { rawResponse } = await model.doEmbed({ values: testValues }); | ||
|
||
expect(rawResponse?.headers).toStrictEqual({ | ||
// default headers: | ||
'content-type': 'application/json', | ||
|
||
// custom header | ||
'test-header': 'test-value', | ||
}); | ||
}); | ||
|
||
it('should pass the model and the values', async () => { | ||
prepareJsonResponse(); | ||
|
||
await model.doEmbed({ values: testValues }); | ||
|
||
expect(await server.getRequestBodyJson()).toStrictEqual({ | ||
model: 'mistral-embed', | ||
input: testValues, | ||
encoding_format: 'float', | ||
}); | ||
}); | ||
|
||
it('should pass custom headers', async () => { | ||
prepareJsonResponse(); | ||
|
||
const provider = createMistral({ | ||
apiKey: 'test-api-key', | ||
headers: { | ||
'Custom-Header': 'test-header', | ||
}, | ||
}); | ||
|
||
await provider.embedding('mistral-embed').doEmbed({ | ||
values: testValues, | ||
}); | ||
|
||
const requestHeaders = await server.getRequestHeaders(); | ||
expect(requestHeaders.get('Custom-Header')).toStrictEqual('test-header'); | ||
}); | ||
|
||
it('should pass the api key as Authorization header', async () => { | ||
prepareJsonResponse(); | ||
|
||
const provider = createMistral({ apiKey: 'test-api-key' }); | ||
|
||
await provider.embedding('mistral-embed').doEmbed({ | ||
values: testValues, | ||
}); | ||
|
||
expect( | ||
(await server.getRequestHeaders()).get('Authorization'), | ||
).toStrictEqual('Bearer test-api-key'); | ||
}); | ||
}); |
Oops, something went wrong.