Skip to content

Commit

Permalink
feature: add score threshold, author, increase width, and other searc…
Browse files Browse the repository at this point in the history
…h component improvements
  • Loading branch information
skeptrunedev authored and cdxker committed Dec 11, 2024
1 parent cf7c198 commit 2cbac8f
Show file tree
Hide file tree
Showing 16 changed files with 383 additions and 103 deletions.
2 changes: 1 addition & 1 deletion clients/search-component/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"import": "./dist/vanilla/index.js"
}
},
"version": "0.2.12",
"version": "0.2.15",
"license": "MIT",
"homepage": "https://github.com/devflowinc/trieve/tree/main/clients/search-component",
"scripts": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export const Message = ({
const [positive, setPositive] = React.useState<boolean | null>(null);
const [copied, setCopied] = React.useState<boolean>(false);
const { props } = useModalState();

const ecommerceItems = message.additional
?.filter(
(chunk) =>
Expand Down
6 changes: 5 additions & 1 deletion clients/search-component/src/TrieveModal/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ body {
}

.trieve-footer.ecommerce {
@apply py-0.5 -mx-4;
@apply py-0.5;
}

.commands.ecommerce {
Expand Down Expand Up @@ -574,6 +574,10 @@ body {
&.loading {
@apply animate-pulse;
}

&.empty-state-loading {
@apply animate-pulse;
}
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions clients/search-component/src/utils/hooks/chat-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
const [messages, setMessages] = useState<Messages>([]);
const [isLoading, setIsLoading] = useState(false);
const chatMessageAbortController = useRef<AbortController>(
new AbortController(),
new AbortController()
);
const isDoneReading = useRef<boolean>(true);
const createTopic = async ({ question }: { question: string }) => {
Expand All @@ -90,7 +90,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
setMessages([]);
};

const { currentTag, currentGroup } = useModalState();
const { currentTag, currentGroup, props } = useModalState();

useEffect(() => {
if (currentTag) {
Expand All @@ -100,7 +100,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {

const handleReader = async (
reader: ReadableStreamDefaultReader<Uint8Array>,
queryId: string | null,
queryId: string | null
) => {
setIsLoading(true);
isDoneReading.current = false;
Expand Down Expand Up @@ -192,7 +192,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
],
stream_response: true,
},
chatMessageAbortController.current.signal,
chatMessageAbortController.current.signal
);
handleReader(reader, queryId);
} else {
Expand All @@ -204,15 +204,17 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
llm_options: {
completion_first: false,
},
page_size: 5,
page_size: props.searchOptions?.page_size ?? 5,
score_threshold: props.searchOptions?.score_threshold || null,
use_group_search: props.useGroupSearch,
filters:
currentTag !== "all"
? {
must: [{ field: "tag_set", match_any: [currentTag] }], // Apply tag filter
must: [{ field: "tag_set", match_any: [currentTag] }],
}
: null,
},
chatMessageAbortController.current.signal,
chatMessageAbortController.current.signal
);
handleReader(reader, queryId);
}
Expand Down Expand Up @@ -244,7 +246,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
[
...messages.slice(0, -1),
messages[messages.length - 1]?.slice(0, -1),
].filter((a) => a.length),
].filter((a) => a.length)
);
}
};
Expand Down Expand Up @@ -292,7 +294,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {

const rateChatCompletion = async (
isPositive: boolean,
queryId: string | null,
queryId: string | null
) => {
if (queryId) {
trieveSDK.rateRagQuery({
Expand Down
73 changes: 39 additions & 34 deletions clients/search-component/src/utils/trieve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export const omit = (obj: object | null | undefined, keys: string[]) => {
if (!obj) return obj;

return Object.fromEntries(
Object.entries(obj).filter(([key]) => !keys.includes(key)),
Object.entries(obj).filter(([key]) => !keys.includes(key))
);
};

Expand All @@ -33,6 +33,13 @@ export const searchWithTrieve = async ({
tag?: string;
type?: ModalTypes;
}) => {
const scoreThreshold =
searchOptions.score_threshold ??
((searchOptions.search_type ?? "fulltext") === "fulltext" ||
searchOptions.search_type == "bm25"
? 2
: 0.3);

let results;
if (searchOptions.use_autocomplete === true) {
results = (await trieve.autocomplete(
Expand All @@ -44,12 +51,8 @@ export const searchWithTrieve = async ({
highlight_window: type === "ecommerce" ? 5 : 10,
},
extend_results: true,
score_threshold:
(searchOptions.search_type ?? "fulltext") === "fulltext" ||
searchOptions.search_type == "bm25"
? 2
: 0.3,
page_size: 20,
score_threshold: scoreThreshold,
page_size: searchOptions.page_size ?? 15,
...(tag && {
filters: {
must: [{ field: "tag_set", match_any: [tag] }],
Expand All @@ -61,7 +64,7 @@ export const searchWithTrieve = async ({
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
abortController?.signal,
abortController?.signal
)) as SearchResponseBody;
} else {
results = (await trieve.search(
Expand All @@ -72,12 +75,8 @@ export const searchWithTrieve = async ({
highlight_delimiters: ["?", ",", ".", "!", "\n"],
highlight_window: type === "ecommerce" ? 5 : 10,
},
score_threshold:
(searchOptions.search_type ?? "fulltext") === "fulltext" ||
searchOptions.search_type == "bm25"
? 2
: 0.3,
page_size: 20,
score_threshold: scoreThreshold,
page_size: searchOptions.page_size ?? 15,
...(tag && {
filters: {
must: [{ field: "tag_set", match_any: [tag] }],
Expand All @@ -89,7 +88,7 @@ export const searchWithTrieve = async ({
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
abortController?.signal,
abortController?.signal
)) as SearchResponseBody;
}

Expand Down Expand Up @@ -127,6 +126,13 @@ export const groupSearchWithTrieve = async ({
tag?: string;
type?: ModalTypes;
}) => {
const scoreThreshold =
searchOptions.score_threshold ??
((searchOptions.search_type ?? "fulltext") === "fulltext" ||
searchOptions.search_type == "bm25"
? 2
: 0.3);

const results = await trieve.searchOverGroups(
{
query,
Expand All @@ -135,22 +141,18 @@ export const groupSearchWithTrieve = async ({
highlight_delimiters: ["?", ",", ".", "!", "\n"],
highlight_window: type === "ecommerce" ? 5 : 10,
},
score_threshold:
(searchOptions.search_type ?? "fulltext") === "fulltext" ||
searchOptions.search_type == "bm25"
? 2
: 0.3,
page_size: 20,
score_threshold: scoreThreshold,
page_size: searchOptions.page_size ?? 15,
...(tag && {
filters: {
must: [{ field: "tag_set", match_any: [tag] }],
},
}),
group_size: 3,
group_size: 1,
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
abortController?.signal,
abortController?.signal
);

const resultsWithHighlight = results.results.map((group) => {
Expand Down Expand Up @@ -186,15 +188,18 @@ export const countChunks = async ({
tag?: string;
searchOptions?: Props["searchOptions"];
}) => {
const scoreThreshold =
searchOptions?.score_threshold ??
((searchOptions?.search_type ?? "fulltext") === "fulltext" ||
searchOptions?.search_type == "bm25"
? 2
: 0.3);

const results = await trieve.countChunksAboveThreshold(
{
query,
score_threshold:
(searchOptions?.search_type ?? "fulltext") === "fulltext" ||
searchOptions?.search_type == "bm25"
? 2
: 0.3,
limit: 10000,
score_threshold: scoreThreshold,
limit: 100,
...(tag && {
filters: {
must: [{ field: "tag_set", match_any: [tag] }],
Expand All @@ -203,7 +208,7 @@ export const countChunks = async ({
search_type: "fulltext",
...omit(searchOptions, ["search_type"]),
},
abortController?.signal,
abortController?.signal
);
return results;
};
Expand Down Expand Up @@ -245,7 +250,7 @@ export const getSuggestedQueries = async ({
search_type: "semantic",
context: "You are a user searching through a docs website",
},
abortController?.signal,
abortController?.signal
);
};

Expand All @@ -262,7 +267,7 @@ export const getSuggestedQuestions = async ({
search_type: "semantic",
context: "You are a user searching through a docs website",
},
abortController?.signal,
abortController?.signal
);
};

Expand All @@ -274,7 +279,7 @@ export type SimpleChunk = ChunkMetadata | ChunkMetadataStringTagSet;

export const getAllChunksForGroup = async (
groupId: string,
trieve: TrieveSDK,
trieve: TrieveSDK
): Promise<SimpleChunk[]> => {
let moreToFind = true;
let page = 1;
Expand All @@ -287,7 +292,7 @@ export const getAllChunksForGroup = async (
datasetId: trieve.datasetId,
groupId,
page,
},
}
);
if (results.chunks.length === 0) {
moreToFind = false;
Expand Down
17 changes: 14 additions & 3 deletions clients/ts-sdk/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -5705,11 +5705,10 @@
{
"name": "dataset_id",
"in": "path",
"description": "The id of the organization you want to fetch.",
"description": "The id or tracking_id of the dataset you want to get the demo page for.",
"required": true,
"schema": {
"type": "string",
"format": "uuid"
"type": "string"
}
}
],
Expand Down Expand Up @@ -12012,6 +12011,14 @@
"type": "boolean",
"nullable": true
},
"creatorLinkedInUrl": {
"type": "string",
"nullable": true
},
"creatorName": {
"type": "string",
"nullable": true
},
"currencyPosition": {
"type": "string",
"nullable": true
Expand Down Expand Up @@ -12048,6 +12055,10 @@
},
"nullable": true
},
"forBrandName": {
"type": "string",
"nullable": true
},
"heroPattern": {
"allOf": [
{
Expand Down
5 changes: 4 additions & 1 deletion clients/ts-sdk/src/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1968,13 +1968,16 @@ export type PublicPageParameters = {
brandLogoImgSrcUrl?: (string) | null;
brandName?: (string) | null;
chat?: (boolean) | null;
creatorLinkedInUrl?: (string) | null;
creatorName?: (string) | null;
currencyPosition?: (string) | null;
datasetId?: (string) | null;
debounceMs?: (number) | null;
defaultAiQuestions?: Array<(string)> | null;
defaultCurrency?: (string) | null;
defaultSearchMode?: (string) | null;
defaultSearchQueries?: Array<(string)> | null;
forBrandName?: (string) | null;
heroPattern?: ((HeroPattern) | null);
navLogoImgSrcUrl?: (string) | null;
openGraphMetadata?: ((OpenGraphMetadata) | null);
Expand Down Expand Up @@ -4507,7 +4510,7 @@ export type DeleteOrganizationResponse = (void);

export type PublicPageData = {
/**
* The id of the organization you want to fetch.
* The id or tracking_id of the dataset you want to get the demo page for.
*/
datasetId: string;
};
Expand Down
6 changes: 4 additions & 2 deletions frontends/dashboard/src/hooks/usePublicPageSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export const { use: usePublicPage, provider: PublicPageProvider } =
const [isPublic, setisPublic] = createSignal<boolean>(false);
const [hasLoaded, setHasLoaded] = createSignal(false);

const { datasetId } = useContext(DatasetContext);
const { dataset, datasetId } = useContext(DatasetContext);
const { selectedOrg } = useContext(UserContext);

const trieve = useTrieve();
Expand Down Expand Up @@ -208,7 +208,9 @@ export const { use: usePublicPage, provider: PublicPageProvider } =

const apiHost = import.meta.env.VITE_API_HOST as unknown as string;
const publicUrl = createMemo(() => {
return `${apiHost.slice(0, -4)}/public_page/${datasetId()}`;
return `${apiHost.slice(0, -4)}/demos/${
dataset()?.dataset.tracking_id ?? datasetId()
}`;
});

return {
Expand Down
Loading

0 comments on commit 2cbac8f

Please sign in to comment.