Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 87 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ parking_lot = "0.12.3"
ahash = "0.8.12"
num-traits = "0.2.19"
libc = "0.2.172"
bm25 = "2.2.1"

[features]
pyo3_macros = ["pyo3"]
Expand Down
1 change: 0 additions & 1 deletion mistralrs-core/src/engine/add_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ impl Engine {
request.messages,
RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. }
) && request.web_search_options.is_some()
&& get_mut_arcmutex!(self.bert_pipeline).is_some()
{
search_request::search_request(self.clone(), *request).await;
} else {
Expand Down
149 changes: 106 additions & 43 deletions mistralrs-core/src/engine/search_request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::Cow, sync::Arc, time::Instant};

use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer};
use either::Either;
use indexmap::IndexMap;
use tokenizers::InputSequence;
Expand Down Expand Up @@ -72,52 +73,102 @@ async fn do_search(
SearchContextSize::Medium => 8192_usize,
SearchContextSize::Low => 4096_usize,
};
let mut results = tracing::dispatcher::with_default(&dispatch, || {
search::run_search_tool(&tool_call_params)
.unwrap()
.into_iter()
.map(|mut result| {
result = result
.cap_content_len(&tokenizer, max_results_budget_toks)
.unwrap();
let len = {
let inp = InputSequence::Raw(Cow::from(&result.content));
tokenizer
.encode_fast(inp, false)
.map(|x| x.len())
.unwrap_or(usize::MAX)
};
(result, len)
})
.collect::<Vec<_>>()
let mut results = tokio::task::block_in_place(|| {
tracing::dispatcher::with_default(&dispatch, || {
search::run_search_tool(&tool_call_params)
.unwrap()
.into_iter()
.map(|mut result| {
result = result
.cap_content_len(&tokenizer, max_results_budget_toks)
.unwrap();
let len = {
let inp = InputSequence::Raw(Cow::from(&result.content));
tokenizer
.encode_fast(inp, false)
.map(|x| x.len())
.unwrap_or(usize::MAX)
};
(result, len)
})
.collect::<Vec<_>>()
})
});

// Sort increasing by tokenized length, if it fails, put it at the end.
results.sort_by_key(|(_, len)| *len);

{
let device = get_mut_arcmutex!(this.pipeline).device();
// Determine ranking: use embedding model if available, otherwise fallback to BM25
let decreasing_indexes: Vec<usize> = if let Some(bert_pipeline) =
&mut *get_mut_arcmutex!(this.bert_pipeline)
{
// Semantic reranking with embeddings
let device = get_mut_arcmutex!(this.pipeline).device();
search::rag::compute_most_similar(
&device,
&tool_call_params.query,
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
bert_pipeline,
)
.unwrap()
} else {
tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results.");

let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else {
unreachable!()
};
// Build an Embedder over the corpus, fitting to the entire set of documents.
// - Language::English is chosen here
// - This computes an in‑memory sparse embedding for each document.

let docs: Vec<String> =
results.iter().map(|(res, _)| res.content.clone()).collect();
let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect();

let embedder: Embedder =
EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build();

let decreasing_indexes = search::rag::compute_most_similar(
&device,
&tool_call_params.query,
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
bert_pipeline,
)
.unwrap();

// Rerank the results
let mut results_old = Vec::new();
std::mem::swap(&mut results_old, &mut results);
for &index in &decreasing_indexes {
let mut current_result: (SearchResult, usize) = Default::default();
std::mem::swap(&mut current_result, &mut results_old[index]);

results.push(current_result);
// Initialize a Scorer keyed by usize (document index type).
let mut scorer = Scorer::<usize>::new();

// For each document, compute its embedding and upsert into the scorer.
for (i, doc_text) in docs.iter().enumerate() {
let doc_embedding = embedder.embed(doc_text);
scorer.upsert(&i, doc_embedding);
}

// Embed the query string into the same sparse embedding space.
let query_embedding = embedder.embed(&tool_call_params.query);

// Score all documents individually
let mut scored_docs: Vec<ScoredDocument<usize>> = docs
.iter()
.enumerate()
.filter_map(|(i, _)| {
scorer
.score(&i, &query_embedding)
.map(|score| ScoredDocument { id: i, score })
})
.collect();

// Sort the scored documents by descending `score` (f32).
scored_docs.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});

// Extract only the document indices (usize) in ranked order.
let decreasing_indexes: Vec<usize> =
scored_docs.into_iter().map(|d| d.id).collect();

decreasing_indexes
};
// Reorder results according to ranking
let mut old = Vec::new();
std::mem::swap(&mut old, &mut results);
for &idx in &decreasing_indexes {
let mut item: (SearchResult, usize) = Default::default();
std::mem::swap(&mut item, &mut old[idx]);
results.push(item);
}
}

Expand Down Expand Up @@ -148,7 +199,15 @@ async fn do_search(
message.insert("role".to_string(), Either::Left("tool".to_string()));
message.insert(
"content".to_string(),
Either::Left(format!("{{\"output\": \"{tool_result}\"}}")),
Either::Left(
// Format the tool output JSON and append the search tool description for context
format!(
"{{\"output\": \"{}\"}}\n\n{}\n\n{}",
tool_result,
search::SEARCH_DESCRIPTION,
search::EXTRACT_DESCRIPTION,
),
),
);
messages.push(message);
}
Expand Down Expand Up @@ -214,12 +273,16 @@ async fn do_extraction(
SearchContextSize::Low => 4096_usize,
};

let res = tracing::dispatcher::with_default(&dispatch, || {
search::run_extract_tool(&tool_call_params)
.unwrap()
let res = {
let extract_result = tokio::task::block_in_place(|| {
tracing::dispatcher::with_default(&dispatch, || {
search::run_extract_tool(&tool_call_params).unwrap()
})
});
extract_result
.cap_content_len(&tokenizer, max_results_budget_toks)
.unwrap()
});
};

let tool_result = serde_json::to_string(&res)
.unwrap()
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";

const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
If the user wants up-to-date information or you want to retrieve new information, call this tool.
If you call this tool, then you MUST complete your answer using the output.
The input can be a query. It should not be a URL. Either is fine.
Expand All @@ -41,7 +41,7 @@ You should expect output like this:
]
}
"#;
const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
The input must be a URL.
Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
Expand Down
Loading
Loading