Skip to content

Commit 7b91c10

Browse files
skeptrunedevcdxker
authored andcommitted
feature: add scoring options req param to search chunks route
1 parent 2b143b3 commit 7b91c10

File tree

6 files changed

+177
-53
lines changed

6 files changed

+177
-53
lines changed

server/src/data/models.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::errors::ServiceError;
55
use crate::get_env;
66
use crate::handlers::analytics_handler::CTRDataRequestBody;
77
use crate::handlers::chunk_handler::{
8-
AutocompleteReqPayload, ChunkFilter, FullTextBoost, ParsedQuery, SearchChunksReqPayload,
9-
SemanticBoost,
8+
AutocompleteReqPayload, ChunkFilter, FullTextBoost, ParsedQuery, ScoringOptions,
9+
SearchChunksReqPayload, SemanticBoost,
1010
};
1111
use crate::handlers::file_handler::UploadFileReqPayload;
1212
use crate::handlers::group_handler::{SearchOverGroupsReqPayload, SearchWithinGroupReqPayload};
@@ -5455,6 +5455,7 @@ impl<'de> Deserialize<'de> for SearchChunksReqPayload {
54555455
get_total_pages: Option<bool>,
54565456
filters: Option<ChunkFilter>,
54575457
sort_options: Option<SortOptions>,
5458+
scoring_options: Option<ScoringOptions>,
54585459
highlight_options: Option<HighlightOptions>,
54595460
score_threshold: Option<f32>,
54605461
slim_chunks: Option<bool>,
@@ -5486,6 +5487,7 @@ impl<'de> Deserialize<'de> for SearchChunksReqPayload {
54865487
get_total_pages: helper.get_total_pages,
54875488
filters: helper.filters,
54885489
sort_options,
5490+
scoring_options: helper.scoring_options,
54895491
highlight_options,
54905492
score_threshold: helper.score_threshold,
54915493
slim_chunks: helper.slim_chunks,
@@ -5511,6 +5513,7 @@ impl<'de> Deserialize<'de> for AutocompleteReqPayload {
55115513
page_size: Option<u64>,
55125514
filters: Option<ChunkFilter>,
55135515
sort_options: Option<SortOptions>,
5516+
scoring_options: Option<ScoringOptions>,
55145517
highlight_options: Option<HighlightOptions>,
55155518
score_threshold: Option<f32>,
55165519
slim_chunks: Option<bool>,
@@ -5541,6 +5544,7 @@ impl<'de> Deserialize<'de> for AutocompleteReqPayload {
55415544
page_size: helper.page_size,
55425545
filters: helper.filters,
55435546
sort_options,
5547+
scoring_options: helper.scoring_options,
55445548
highlight_options,
55455549
score_threshold: helper.score_threshold,
55465550
slim_chunks: helper.slim_chunks,

server/src/handlers/chunk_handler.rs

+17-4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ pub struct SemanticBoost {
5959
pub distance_factor: f32,
6060
}
6161

62+
/// Scoring options provides ways to modify the sparse or dense vector created for the query in order to change how potential matches are scored. If not specified, this defaults to no modifications.
63+
#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
64+
pub struct ScoringOptions {
65+
/// Full text boost is useful for when you want to boost certain phrases in the fulltext (SPLADE) and BM25 search results. I.e. making sure that the listing for AirBNB itself ranks higher than companies who make software for AirBNB hosts by boosting the in-document-frequency of the AirBNB token (AKA word) for its official listing. Conceptually it multiples the in-document-importance second value in the tuples of the SPLADE or BM25 sparse vector of the chunk_html innerText for all tokens present in the boost phrase by the boost factor like so: (token, in-document-importance) -> (token, in-document-importance*boost_factor).
66+
pub fulltext_boost: Option<FullTextBoost>,
67+
/// Semantic boost is useful for moving the embedding vector of the chunk in the direction of the distance phrase. I.e. you can push a chunk with a chunk_html of "iphone" 25% closer to the term "flagship" by using the distance phrase "flagship" and a distance factor of 0.25. Conceptually it's drawing a line (euclidean/L2 distance) between the vector for the innerText of the chunk_html and distance_phrase then moving the vector of the chunk_html distance_factor*L2Distance closer to or away from the distance_phrase point along the line between the two points.
68+
pub semantic_boost: Option<SemanticBoost>,
69+
}
70+
6271
#[derive(Serialize, Deserialize, Debug, ToSchema, Clone)]
6372
#[schema(example = json!({
6473
"chunk_html": "<p>Some HTML content</p>",
@@ -946,12 +955,11 @@ pub struct SearchChunksReqPayload {
946955
pub filters: Option<ChunkFilter>,
947956
/// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
948957
pub sort_options: Option<SortOptions>,
958+
/// Scoring options provides ways to modify the sparse or dense vector created for the query in order to change how potential matches are scored. If not specified, this defaults to no modifications.
959+
pub scoring_options: Option<ScoringOptions>,
949960
/// Highlight Options lets you specify different methods to highlight the chunks in the result set. If not specified, this defaults to the score of the chunks.
950961
pub highlight_options: Option<HighlightOptions>,
951-
/// Set score_threshold to a float to filter out chunks with a score below the threshold for cosine distance metric
952-
/// For Manhattan Distance, Euclidean Distance, and Dot Product, it will filter out scores above the threshold distance
953-
/// This threshold applies before weight and bias modifications. If not specified, this defaults to no threshold
954-
/// A threshold of 0 will default to no threshold
962+
/// Set score_threshold to a float to filter out chunks with a score below the threshold for cosine distance metric. For Manhattan Distance, Euclidean Distance, and Dot Product, it will filter out scores above the threshold distance. This threshold applies before weight and bias modifications. If not specified, this defaults to no threshold. A threshold of 0 will default to no threshold.
955963
pub score_threshold: Option<f32>,
956964
/// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false.
957965
pub slim_chunks: Option<bool>,
@@ -977,6 +985,7 @@ impl Default for SearchChunksReqPayload {
977985
page_size: Some(10),
978986
filters: None,
979987
sort_options: None,
988+
scoring_options: None,
980989
highlight_options: None,
981990
score_threshold: None,
982991
slim_chunks: None,
@@ -1329,6 +1338,8 @@ pub struct AutocompleteReqPayload {
13291338
pub filters: Option<ChunkFilter>,
13301339
/// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
13311340
pub sort_options: Option<SortOptions>,
1341+
/// Scoring options provides ways to modify the sparse or dense vector created for the query in order to change how potential matches are scored. If not specified, this defaults to no modifications.
1342+
pub scoring_options: Option<ScoringOptions>,
13321343
/// Highlight Options lets you specify different methods to highlight the chunks in the result set. If not specified, this defaults to the score of the chunks.
13331344
pub highlight_options: Option<HighlightOptions>,
13341345
/// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0.
@@ -1356,6 +1367,7 @@ impl From<AutocompleteReqPayload> for SearchChunksReqPayload {
13561367
page_size: autocomplete_data.page_size,
13571368
filters: autocomplete_data.filters,
13581369
sort_options: autocomplete_data.sort_options,
1370+
scoring_options: autocomplete_data.scoring_options,
13591371
highlight_options: autocomplete_data.highlight_options,
13601372
score_threshold: autocomplete_data.score_threshold,
13611373
slim_chunks: autocomplete_data.slim_chunks,
@@ -1653,6 +1665,7 @@ impl From<CountChunksReqPayload> for SearchChunksReqPayload {
16531665
page_size: count_data.limit,
16541666
filters: count_data.filters,
16551667
sort_options: None,
1668+
scoring_options: None,
16561669
highlight_options: None,
16571670
score_threshold: count_data.score_threshold,
16581671
slim_chunks: None,

server/src/handlers/group_handler.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,7 @@ impl From<SearchWithinGroupReqPayload> for SearchChunksReqPayload {
14261426
filters: search_within_group_data.filters,
14271427
search_type: search_within_group_data.search_type,
14281428
sort_options: search_within_group_data.sort_options,
1429+
scoring_options: None,
14291430
highlight_options: search_within_group_data.highlight_options,
14301431
score_threshold: search_within_group_data.score_threshold,
14311432
slim_chunks: search_within_group_data.slim_chunks,

server/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ impl Modify for SecurityAddon {
140140
name = "BSL",
141141
url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt",
142142
),
143-
version = "0.11.7",
143+
version = "0.11.8",
144144
),
145145
servers(
146146
(url = "https://api.trieve.ai",
@@ -277,6 +277,7 @@ impl Modify for SecurityAddon {
277277
handlers::chunk_handler::GetChunksData,
278278
handlers::chunk_handler::GetTrackingChunksData,
279279
handlers::chunk_handler::SemanticBoost,
280+
handlers::chunk_handler::ScoringOptions,
280281
handlers::chunk_handler::ChunkReturnTypes,
281282
handlers::chunk_handler::ScrollChunksReqPayload,
282283
handlers::chunk_handler::ScrollChunksResponseBody,

server/src/operators/model_operator.rs

+93-32
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub struct EmbeddingParameters {
2828
#[tracing::instrument]
2929
pub async fn get_dense_vector(
3030
message: String,
31-
distance_phrase: Option<SemanticBoost>,
31+
semantic_boost: Option<SemanticBoost>,
3232
embed_type: &str,
3333
dataset_config: DatasetConfiguration,
3434
) -> Result<Vec<f32>, ServiceError> {
@@ -91,32 +91,23 @@ pub async fn get_dense_vector(
9191
};
9292

9393
let clipped_message: String = message.chars().take(20000).collect();
94+
let mut messages = vec![format!(
95+
"{}{}",
96+
dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_message
97+
)
98+
.to_string()];
99+
if let Some(semantic_boost) = semantic_boost.as_ref() {
100+
if semantic_boost.distance_factor == 0.0 || semantic_boost.phrase.is_empty() {
101+
return Err(ServiceError::BadRequest(
102+
"Semantic boost phrase is empty or distance factor is 0. Boost phrase must not be empty and distance factor must be greater than 0".to_string(),
103+
));
104+
}
94105

95-
let mut messages = vec![clipped_message.clone()];
96-
97-
if distance_phrase.is_some() {
98-
let clipped_boost: String = distance_phrase
99-
.as_ref()
100-
.unwrap()
101-
.phrase
102-
.chars()
103-
.take(20000)
104-
.collect();
106+
let clipped_boost: String = semantic_boost.phrase.chars().take(20000).collect();
105107
messages.push(clipped_boost);
106108
}
107109

108-
let input = match embed_type {
109-
"doc" => EmbeddingInput::StringArray(messages),
110-
"query" => EmbeddingInput::String(
111-
format!(
112-
"{}{}",
113-
dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_message
114-
)
115-
.to_string(),
116-
),
117-
_ => EmbeddingInput::StringArray(messages),
118-
};
119-
110+
let input = EmbeddingInput::StringArray(messages);
120111
let parameters = EmbeddingParameters {
121112
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
122113
input,
@@ -139,8 +130,8 @@ pub async fn get_dense_vector(
139130
))
140131
})?;
141132

142-
let embeddings: EmbeddingResponse = format_response(embeddings_resp.into_string().unwrap())
143-
.map_err(|e| {
133+
let embeddings: EmbeddingResponse =
134+
format_response(embeddings_resp.into_string().unwrap_or("".to_string())).map_err(|e| {
144135
log::error!("Failed to format response from embeddings server {:?}", e);
145136
ServiceError::InternalServerError(
146137
"Failed to format response from embeddings server".to_owned(),
@@ -165,10 +156,24 @@ pub async fn get_dense_vector(
165156
));
166157
}
167158

168-
if distance_phrase.is_some() {
169-
let distance_factor = distance_phrase.unwrap().distance_factor;
170-
let boost_vector = vectors.pop().unwrap();
171-
let embedding_vector = vectors.pop().unwrap();
159+
if let Some(semantic_boost) = semantic_boost {
160+
let distance_factor = semantic_boost.distance_factor;
161+
let boost_vector = match vectors.pop() {
162+
Some(v) => v,
163+
None => {
164+
return Err(ServiceError::InternalServerError(
165+
"No dense embedding returned from server for boost_vector".to_owned(),
166+
))
167+
}
168+
};
169+
let embedding_vector = match vectors.pop() {
170+
Some(v) => v,
171+
None => {
172+
return Err(ServiceError::InternalServerError(
173+
"No dense embedding returned from server for embedding_vector".to_owned(),
174+
))
175+
}
176+
};
172177

173178
return Ok(embedding_vector
174179
.iter()
@@ -190,6 +195,7 @@ pub async fn get_dense_vector(
190195
#[tracing::instrument]
191196
pub async fn get_sparse_vector(
192197
message: String,
198+
fulltext_boost: Option<FullTextBoost>,
193199
embed_type: &str,
194200
) -> Result<Vec<(u32, f32)>, ServiceError> {
195201
let origin_key = match embed_type {
@@ -206,11 +212,22 @@ pub async fn get_sparse_vector(
206212
origin_key
207213
)))?;
208214

209-
let clipped_message = message.chars().take(128000).collect();
215+
let clipped_message: String = message.chars().take(20000).collect();
216+
let mut inputs = vec![clipped_message.clone()];
217+
if let Some(fulltext_boost) = fulltext_boost.as_ref() {
218+
if fulltext_boost.phrase.is_empty() {
219+
return Err(ServiceError::BadRequest(
220+
"Fulltext boost phrase is empty. Non-empty phrase must be specified.".to_string(),
221+
));
222+
}
223+
224+
let clipped_boost: String = fulltext_boost.phrase.chars().take(20000).collect();
225+
inputs.push(clipped_boost);
226+
}
210227

211228
let embedding_server_call = format!("{}/embed_sparse", server_origin);
212229

213-
let sparse_vectors = ureq::post(&embedding_server_call)
230+
let mut sparse_vectors = ureq::post(&embedding_server_call)
214231
.set("Content-Type", "application/json")
215232
.set(
216233
"Authorization",
@@ -220,7 +237,7 @@ pub async fn get_sparse_vector(
220237
),
221238
)
222239
.send_json(CustomSparseEmbedData {
223-
inputs: vec![clipped_message],
240+
inputs,
224241
encode_type: embed_type.to_string(),
225242
truncate: true,
226243
})
@@ -242,6 +259,50 @@ pub async fn get_sparse_vector(
242259
)
243260
})?;
244261

262+
if let Some(fulltext_boost) = fulltext_boost {
263+
let boost_amt = fulltext_boost.boost_factor;
264+
let boost_vector = match sparse_vectors.pop() {
265+
Some(v) => v,
266+
None => {
267+
return Err(ServiceError::InternalServerError(
268+
"No sparse vector returned from server for boost_vector".to_owned(),
269+
))
270+
}
271+
};
272+
let query_vector = match sparse_vectors.pop() {
273+
Some(v) => v,
274+
None => {
275+
return Err(ServiceError::InternalServerError(
276+
"No sparse vector returned from server for embedding_vector".to_owned(),
277+
))
278+
}
279+
};
280+
281+
let boosted_query_vector = query_vector
282+
.iter()
283+
.map(|splade_indice| {
284+
if boost_vector
285+
.iter()
286+
.any(|boost_splade_indice| boost_splade_indice.index == splade_indice.index)
287+
{
288+
SpladeIndicies {
289+
index: splade_indice.index,
290+
value: splade_indice.value * (boost_amt as f32),
291+
}
292+
.into_tuple()
293+
} else {
294+
SpladeIndicies {
295+
index: splade_indice.index,
296+
value: splade_indice.value,
297+
}
298+
.into_tuple()
299+
}
300+
})
301+
.collect();
302+
303+
return Ok(boosted_query_vector);
304+
}
305+
245306
match sparse_vectors.first() {
246307
Some(v) => Ok(v
247308
.iter()

0 commit comments

Comments
 (0)