Skip to content

Commit 57d6e12

Browse files
authored
Web search improvements (bm25, web chat) (#1420)
* Fix web search blocking case * Web search support in web chat * Tweak ui * Support fallback to bm25 * Clippy * Reinject descriptions
1 parent 6547156 commit 57d6e12

File tree

12 files changed

+448
-93
lines changed

12 files changed

+448
-93
lines changed

Cargo.lock

Lines changed: 87 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mistralrs-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ parking_lot = "0.12.3"
9595
ahash = "0.8.12"
9696
num-traits = "0.2.19"
9797
libc = "0.2.172"
98+
bm25 = "2.2.1"
9899

99100
[features]
100101
pyo3_macros = ["pyo3"]

mistralrs-core/src/engine/add_request.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ impl Engine {
3333
request.messages,
3434
RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. }
3535
) && request.web_search_options.is_some()
36-
&& get_mut_arcmutex!(self.bert_pipeline).is_some()
3736
{
3837
search_request::search_request(self.clone(), *request).await;
3938
} else {

mistralrs-core/src/engine/search_request.rs

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{borrow::Cow, sync::Arc, time::Instant};
22

3+
use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer};
34
use either::Either;
45
use indexmap::IndexMap;
56
use tokenizers::InputSequence;
@@ -72,52 +73,102 @@ async fn do_search(
7273
SearchContextSize::Medium => 8192_usize,
7374
SearchContextSize::Low => 4096_usize,
7475
};
75-
let mut results = tracing::dispatcher::with_default(&dispatch, || {
76-
search::run_search_tool(&tool_call_params)
77-
.unwrap()
78-
.into_iter()
79-
.map(|mut result| {
80-
result = result
81-
.cap_content_len(&tokenizer, max_results_budget_toks)
82-
.unwrap();
83-
let len = {
84-
let inp = InputSequence::Raw(Cow::from(&result.content));
85-
tokenizer
86-
.encode_fast(inp, false)
87-
.map(|x| x.len())
88-
.unwrap_or(usize::MAX)
89-
};
90-
(result, len)
91-
})
92-
.collect::<Vec<_>>()
76+
let mut results = tokio::task::block_in_place(|| {
77+
tracing::dispatcher::with_default(&dispatch, || {
78+
search::run_search_tool(&tool_call_params)
79+
.unwrap()
80+
.into_iter()
81+
.map(|mut result| {
82+
result = result
83+
.cap_content_len(&tokenizer, max_results_budget_toks)
84+
.unwrap();
85+
let len = {
86+
let inp = InputSequence::Raw(Cow::from(&result.content));
87+
tokenizer
88+
.encode_fast(inp, false)
89+
.map(|x| x.len())
90+
.unwrap_or(usize::MAX)
91+
};
92+
(result, len)
93+
})
94+
.collect::<Vec<_>>()
95+
})
9396
});
9497

9598
// Sort increasing by tokenized length, if it fails, put it at the end.
9699
results.sort_by_key(|(_, len)| *len);
97100

98101
{
99-
let device = get_mut_arcmutex!(this.pipeline).device();
102+
// Determine ranking: use embedding model if available, otherwise fallback to BM25
103+
let decreasing_indexes: Vec<usize> = if let Some(bert_pipeline) =
104+
&mut *get_mut_arcmutex!(this.bert_pipeline)
105+
{
106+
// Semantic reranking with embeddings
107+
let device = get_mut_arcmutex!(this.pipeline).device();
108+
search::rag::compute_most_similar(
109+
&device,
110+
&tool_call_params.query,
111+
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
112+
bert_pipeline,
113+
)
114+
.unwrap()
115+
} else {
116+
tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results.");
100117

101-
let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else {
102-
unreachable!()
103-
};
118+
// Build an Embedder over the corpus, fitting to the entire set of documents.
119+
// - Language::English is chosen here
120+
// - This computes an in‑memory sparse embedding for each document.
121+
122+
let docs: Vec<String> =
123+
results.iter().map(|(res, _)| res.content.clone()).collect();
124+
let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect();
125+
126+
let embedder: Embedder =
127+
EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build();
104128

105-
let decreasing_indexes = search::rag::compute_most_similar(
106-
&device,
107-
&tool_call_params.query,
108-
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
109-
bert_pipeline,
110-
)
111-
.unwrap();
112-
113-
// Rerank the results
114-
let mut results_old = Vec::new();
115-
std::mem::swap(&mut results_old, &mut results);
116-
for &index in &decreasing_indexes {
117-
let mut current_result: (SearchResult, usize) = Default::default();
118-
std::mem::swap(&mut current_result, &mut results_old[index]);
119-
120-
results.push(current_result);
129+
// Initialize a Scorer keyed by usize (document index type).
130+
let mut scorer = Scorer::<usize>::new();
131+
132+
// For each document, compute its embedding and upsert into the scorer.
133+
for (i, doc_text) in docs.iter().enumerate() {
134+
let doc_embedding = embedder.embed(doc_text);
135+
scorer.upsert(&i, doc_embedding);
136+
}
137+
138+
// Embed the query string into the same sparse embedding space.
139+
let query_embedding = embedder.embed(&tool_call_params.query);
140+
141+
// Score all documents individually
142+
let mut scored_docs: Vec<ScoredDocument<usize>> = docs
143+
.iter()
144+
.enumerate()
145+
.filter_map(|(i, _)| {
146+
scorer
147+
.score(&i, &query_embedding)
148+
.map(|score| ScoredDocument { id: i, score })
149+
})
150+
.collect();
151+
152+
// Sort the scored documents by descending `score` (f32).
153+
scored_docs.sort_by(|a, b| {
154+
b.score
155+
.partial_cmp(&a.score)
156+
.unwrap_or(std::cmp::Ordering::Equal)
157+
});
158+
159+
// Extract only the document indices (usize) in ranked order.
160+
let decreasing_indexes: Vec<usize> =
161+
scored_docs.into_iter().map(|d| d.id).collect();
162+
163+
decreasing_indexes
164+
};
165+
// Reorder results according to ranking
166+
let mut old = Vec::new();
167+
std::mem::swap(&mut old, &mut results);
168+
for &idx in &decreasing_indexes {
169+
let mut item: (SearchResult, usize) = Default::default();
170+
std::mem::swap(&mut item, &mut old[idx]);
171+
results.push(item);
121172
}
122173
}
123174

@@ -148,7 +199,15 @@ async fn do_search(
148199
message.insert("role".to_string(), Either::Left("tool".to_string()));
149200
message.insert(
150201
"content".to_string(),
151-
Either::Left(format!("{{\"output\": \"{tool_result}\"}}")),
202+
Either::Left(
203+
// Format the tool output JSON and append the search tool description for context
204+
format!(
205+
"{{\"output\": \"{}\"}}\n\n{}\n\n{}",
206+
tool_result,
207+
search::SEARCH_DESCRIPTION,
208+
search::EXTRACT_DESCRIPTION,
209+
),
210+
),
152211
);
153212
messages.push(message);
154213
}
@@ -214,12 +273,16 @@ async fn do_extraction(
214273
SearchContextSize::Low => 4096_usize,
215274
};
216275

217-
let res = tracing::dispatcher::with_default(&dispatch, || {
218-
search::run_extract_tool(&tool_call_params)
219-
.unwrap()
276+
let res = {
277+
let extract_result = tokio::task::block_in_place(|| {
278+
tracing::dispatcher::with_default(&dispatch, || {
279+
search::run_extract_tool(&tool_call_params).unwrap()
280+
})
281+
});
282+
extract_result
220283
.cap_content_len(&tokenizer, max_results_budget_toks)
221284
.unwrap()
222-
});
285+
};
223286

224287
let tool_result = serde_json::to_string(&res)
225288
.unwrap()

mistralrs-core/src/search/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
2222
pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";
2323

2424
const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
25-
const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
25+
pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
2626
If the user wants up-to-date information or you want to retrieve new information, call this tool.
2727
If you call this tool, then you MUST complete your answer using the output.
2828
The input can be a query. It should not be a URL. Either is fine.
@@ -41,7 +41,7 @@ You should expect output like this:
4141
]
4242
}
4343
"#;
44-
const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
44+
pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
4545
If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
4646
The input must be a URL.
4747
Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.

0 commit comments

Comments
 (0)