Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import { useState } from "react";

// We dont support all vectorDBs yet for reranking due to complexities of how each provider
// returns information. We need to normalize the response data so Reranker can be used for each provider.
const supportedVectorDBs = ["lancedb"];
const hint = {
default: {
title: "Default",
Expand All @@ -20,8 +17,7 @@ export default function VectorSearchMode({ workspace, setHasChanges }) {
const [selection, setSelection] = useState(
workspace?.vectorSearchMode ?? "default"
);
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB))
return null;
if (!workspace?.vectorDB) return null;

return (
<div>
Expand Down
29 changes: 29 additions & 0 deletions server/utils/EmbeddingRerankers/rerank.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
const { getRerankerProvider } = require("../helpers");

async function rerank(query, documents, topN = 4) {
const reranker = getRerankerProvider();
return await reranker.rerank(query, documents, { topK: topN });
}

/**
* For reranking, we want to work with a larger number of results than the topN.
* This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
* We want to give the reranker a larger number of results to work with.
*
* However, we cannot make this boundless as reranking is expensive and time consuming.
* So we limit the number of results to a maximum of 50 and a minimum of 10.
* This is a good balance between the number of results to rerank and the cost of reranking
* and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
*
* Benchmarks:
* On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
*/

function getSearchLimit(totalEmbeddings = 0) {
return Math.max(10, Math.min(50, Math.ceil(totalEmbeddings * 0.1)));
}

module.exports = {
rerank,
getSearchLimit,
};
29 changes: 29 additions & 0 deletions server/utils/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
* @property {Function} embedChunks - Embeds multiple chunks of text.
*/

/**
* @typedef {Object} BaseRerankerProvider
* @property {function(string, {text: string}[], {topK: number}): Promise<any[]>} rerank - Reranks a list of documents.
*/

/**
* Gets the systems current vector database provider.
* @param {('pinecone' | 'chroma' | 'chromacloud' | 'lancedb' | 'weaviate' | 'qdrant' | 'milvus' | 'zilliz' | 'astra') | null} getExactly - If provided, this will return an explit provider.
Expand Down Expand Up @@ -471,6 +476,29 @@ function toChunks(arr, size) {
);
}

/**
* Returns the Reranker provider.
* @returns {BaseRerankerProvider}
*/
function getRerankerProvider() {
const rerankerSelection = process.env.RERANKING_PROVIDER ?? "native";
switch (rerankerSelection) {
case "native":
const {
NativeEmbeddingReranker,
} = require("../EmbeddingRerankers/native");
return new NativeEmbeddingReranker();
default:
console.log(
`[RERANKING] Reranker provider ${rerankerSelection} is not supported. Using native reranker as fallback.`
);
const {
NativeEmbeddingReranker: Native,
} = require("../EmbeddingRerankers/native");
return new Native();
}
}

module.exports = {
getEmbeddingEngineSelection,
maximumChunkLength,
Expand All @@ -479,4 +507,5 @@ module.exports = {
getBaseLLMProviderModel,
getLLMProvider,
toChunks,
getRerankerProvider,
};
85 changes: 70 additions & 15 deletions server/utils/vectorDbProviders/astra/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");

const sanitizeNamespace = (namespace) => {
// If namespace already starts with ns_, don't add it again
Expand Down Expand Up @@ -301,6 +302,7 @@ const AstraDB = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -319,17 +321,27 @@ const AstraDB = {
}

const queryVector = await LLMConnector.embedTextInput(input);
const { contextTexts, sourceDocuments } = await this.similarityResponse({
client,
namespace: sanitizedNamespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});
const { contextTexts, sourceDocuments } = rerank
? await this.rerankedSimilarityResponse({
client,
namespace: sanitizedNamespace,
query: input,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
})
: await this.similarityResponse({
client,
namespace: sanitizedNamespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});

const sources = sourceDocuments.map((metadata, i) => {
return { ...metadata, text: contextTexts[i] };
const sources = sourceDocuments.map((doc, i) => {
return { metadata: doc, text: contextTexts[i] };
});
return {
contextTexts,
Expand Down Expand Up @@ -373,11 +385,55 @@ const AstraDB = {
return;
}
result.contextTexts.push(response.metadata.text);
result.sourceDocuments.push(response);
result.sourceDocuments.push({
...response.metadata,
score: response.$similarity,
});
result.scores.push(response.$similarity);
});
return result;
},
rerankedSimilarityResponse: async function ({
client,
namespace,
query,
queryVector,
topN = 4,
similarityThreshold = 0.25,
filterIdentifiers = [],
}) {
const totalEmbeddings = await this.namespaceCount(namespace);
const searchLimit = getSearchLimit(totalEmbeddings);
const { sourceDocuments } = await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN: searchLimit,
filterIdentifiers,
});

const rerankedResults = await rerank(query, sourceDocuments, topN);
const result = {
contextTexts: [],
sourceDocuments: [],
scores: [],
};

rerankedResults.forEach((item) => {
if (item.rerank_score < similarityThreshold) return;
const { rerank_score, ...rest } = item;
if (filterIdentifiers.includes(sourceIdentifier(rest))) return;

result.contextTexts.push(rest.text);
result.sourceDocuments.push({
...rest,
score: rerank_score,
});
result.scores.push(rerank_score);
});
return result;
},
allNamespaces: async function (client) {
try {
let header = new Headers();
Expand Down Expand Up @@ -432,12 +488,11 @@ const AstraDB = {
curateSources: function (sources = []) {
const documents = [];
for (const source of sources) {
if (Object.keys(source).length > 0) {
const metadata = source.hasOwnProperty("metadata")
? source.metadata
: source;
const { metadata = {} } = source;
if (Object.keys(metadata).length > 0) {
documents.push({
...metadata,
...(source.text ? { text: source.text } : {}),
});
}
}
Expand Down
78 changes: 68 additions & 10 deletions server/utils/vectorDbProviders/chroma/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { parseAuthHeader } = require("../../http");
const { sourceIdentifier } = require("../../chats");
const { rerank, getSearchLimit } = require("../../EmbeddingRerankers/rerank");
const COLLECTION_REGEX = new RegExp(
/^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/
);
Expand Down Expand Up @@ -150,6 +151,51 @@ const Chroma = {

return result;
},
rerankedSimilarityResponse: async function ({
client,
namespace,
query,
queryVector,
topN = 4,
similarityThreshold = 0.25,
filterIdentifiers = [],
}) {
const totalEmbeddings = await this.namespaceCount(namespace);
const searchLimit = getSearchLimit(totalEmbeddings);
const { sourceDocuments, contextTexts } = await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN: searchLimit,
filterIdentifiers,
});
const documentsForReranking = sourceDocuments.map((metadata, i) => ({
...metadata,
text: contextTexts[i],
}));

const rerankedResults = await rerank(query, documentsForReranking, topN);
const result = {
contextTexts: [],
sourceDocuments: [],
scores: [],
};

rerankedResults.forEach((item) => {
if (item.rerank_score < similarityThreshold) return;
const { vector: _, rerank_score, ...rest } = item;
if (filterIdentifiers.includes(sourceIdentifier(rest))) return;

result.contextTexts.push(rest.text);
result.sourceDocuments.push({
...rest,
score: rerank_score,
});
result.scores.push(rerank_score);
});
return result;
},
namespace: async function (client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const collection = await client
Expand Down Expand Up @@ -348,12 +394,14 @@ const Chroma = {
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
rerank = false,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");

const { client } = await this.connect();
if (!(await this.namespaceExists(client, this.normalize(namespace)))) {
const collectionName = this.normalize(namespace);
if (!(await this.namespaceExists(client, collectionName))) {
return {
contextTexts: [],
sources: [],
Expand All @@ -362,16 +410,26 @@ const Chroma = {
}

const queryVector = await LLMConnector.embedTextInput(input);
const { contextTexts, sourceDocuments, scores } =
await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});
const result = rerank
? await this.rerankedSimilarityResponse({
client,
namespace,
query: input,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
})
: await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});

const { contextTexts, sourceDocuments, scores } = result;
const sources = sourceDocuments.map((metadata, i) => ({
metadata: {
...metadata,
Expand Down
Loading