|
1 | 1 | use std::{borrow::Cow, sync::Arc, time::Instant}; |
2 | 2 |
|
| 3 | +use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer}; |
3 | 4 | use either::Either; |
4 | 5 | use indexmap::IndexMap; |
5 | 6 | use tokenizers::InputSequence; |
@@ -72,52 +73,102 @@ async fn do_search( |
72 | 73 | SearchContextSize::Medium => 8192_usize, |
73 | 74 | SearchContextSize::Low => 4096_usize, |
74 | 75 | }; |
75 | | - let mut results = tracing::dispatcher::with_default(&dispatch, || { |
76 | | - search::run_search_tool(&tool_call_params) |
77 | | - .unwrap() |
78 | | - .into_iter() |
79 | | - .map(|mut result| { |
80 | | - result = result |
81 | | - .cap_content_len(&tokenizer, max_results_budget_toks) |
82 | | - .unwrap(); |
83 | | - let len = { |
84 | | - let inp = InputSequence::Raw(Cow::from(&result.content)); |
85 | | - tokenizer |
86 | | - .encode_fast(inp, false) |
87 | | - .map(|x| x.len()) |
88 | | - .unwrap_or(usize::MAX) |
89 | | - }; |
90 | | - (result, len) |
91 | | - }) |
92 | | - .collect::<Vec<_>>() |
| 76 | + let mut results = tokio::task::block_in_place(|| { |
| 77 | + tracing::dispatcher::with_default(&dispatch, || { |
| 78 | + search::run_search_tool(&tool_call_params) |
| 79 | + .unwrap() |
| 80 | + .into_iter() |
| 81 | + .map(|mut result| { |
| 82 | + result = result |
| 83 | + .cap_content_len(&tokenizer, max_results_budget_toks) |
| 84 | + .unwrap(); |
| 85 | + let len = { |
| 86 | + let inp = InputSequence::Raw(Cow::from(&result.content)); |
| 87 | + tokenizer |
| 88 | + .encode_fast(inp, false) |
| 89 | + .map(|x| x.len()) |
| 90 | + .unwrap_or(usize::MAX) |
| 91 | + }; |
| 92 | + (result, len) |
| 93 | + }) |
| 94 | + .collect::<Vec<_>>() |
| 95 | + }) |
93 | 96 | }); |
94 | 97 |
|
95 | 98 | // Sort increasing by tokenized length, if it fails, put it at the end. |
96 | 99 | results.sort_by_key(|(_, len)| *len); |
97 | 100 |
|
98 | 101 | { |
99 | | - let device = get_mut_arcmutex!(this.pipeline).device(); |
| 102 | + // Determine ranking: use embedding model if available, otherwise fallback to BM25 |
| 103 | + let decreasing_indexes: Vec<usize> = if let Some(bert_pipeline) = |
| 104 | + &mut *get_mut_arcmutex!(this.bert_pipeline) |
| 105 | + { |
| 106 | + // Semantic reranking with embeddings |
| 107 | + let device = get_mut_arcmutex!(this.pipeline).device(); |
| 108 | + search::rag::compute_most_similar( |
| 109 | + &device, |
| 110 | + &tool_call_params.query, |
| 111 | + results.iter().map(|(res, _)| res).collect::<Vec<_>>(), |
| 112 | + bert_pipeline, |
| 113 | + ) |
| 114 | + .unwrap() |
| 115 | + } else { |
| 116 | + tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results."); |
100 | 117 |
|
101 | | - let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else { |
102 | | - unreachable!() |
103 | | - }; |
| 118 | + // Build an Embedder over the corpus, fitting to the entire set of documents. |
| 119 | + // - Language::English is chosen here |
| 120 | + // - This computes an in‑memory sparse embedding for each document. |
| 121 | + |
| 122 | + let docs: Vec<String> = |
| 123 | + results.iter().map(|(res, _)| res.content.clone()).collect(); |
| 124 | + let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect(); |
| 125 | + |
| 126 | + let embedder: Embedder = |
| 127 | + EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build(); |
104 | 128 |
|
105 | | - let decreasing_indexes = search::rag::compute_most_similar( |
106 | | - &device, |
107 | | - &tool_call_params.query, |
108 | | - results.iter().map(|(res, _)| res).collect::<Vec<_>>(), |
109 | | - bert_pipeline, |
110 | | - ) |
111 | | - .unwrap(); |
112 | | - |
113 | | - // Rerank the results |
114 | | - let mut results_old = Vec::new(); |
115 | | - std::mem::swap(&mut results_old, &mut results); |
116 | | - for &index in &decreasing_indexes { |
117 | | - let mut current_result: (SearchResult, usize) = Default::default(); |
118 | | - std::mem::swap(&mut current_result, &mut results_old[index]); |
119 | | - |
120 | | - results.push(current_result); |
| 129 | + // Initialize a Scorer keyed by usize (document index type). |
| 130 | + let mut scorer = Scorer::<usize>::new(); |
| 131 | + |
| 132 | + // For each document, compute its embedding and upsert into the scorer. |
| 133 | + for (i, doc_text) in docs.iter().enumerate() { |
| 134 | + let doc_embedding = embedder.embed(doc_text); |
| 135 | + scorer.upsert(&i, doc_embedding); |
| 136 | + } |
| 137 | + |
| 138 | + // Embed the query string into the same sparse embedding space. |
| 139 | + let query_embedding = embedder.embed(&tool_call_params.query); |
| 140 | + |
| 141 | + // Score all documents individually |
| 142 | + let mut scored_docs: Vec<ScoredDocument<usize>> = docs |
| 143 | + .iter() |
| 144 | + .enumerate() |
| 145 | + .filter_map(|(i, _)| { |
| 146 | + scorer |
| 147 | + .score(&i, &query_embedding) |
| 148 | + .map(|score| ScoredDocument { id: i, score }) |
| 149 | + }) |
| 150 | + .collect(); |
| 151 | + |
| 152 | + // Sort the scored documents by descending `score` (f32). |
| 153 | + scored_docs.sort_by(|a, b| { |
| 154 | + b.score |
| 155 | + .partial_cmp(&a.score) |
| 156 | + .unwrap_or(std::cmp::Ordering::Equal) |
| 157 | + }); |
| 158 | + |
| 159 | + // Extract only the document indices (usize) in ranked order. |
| 160 | + let decreasing_indexes: Vec<usize> = |
| 161 | + scored_docs.into_iter().map(|d| d.id).collect(); |
| 162 | + |
| 163 | + decreasing_indexes |
| 164 | + }; |
| 165 | + // Reorder results according to ranking |
| 166 | + let mut old = Vec::new(); |
| 167 | + std::mem::swap(&mut old, &mut results); |
| 168 | + for &idx in &decreasing_indexes { |
| 169 | + let mut item: (SearchResult, usize) = Default::default(); |
| 170 | + std::mem::swap(&mut item, &mut old[idx]); |
| 171 | + results.push(item); |
121 | 172 | } |
122 | 173 | } |
123 | 174 |
|
@@ -148,7 +199,15 @@ async fn do_search( |
148 | 199 | message.insert("role".to_string(), Either::Left("tool".to_string())); |
149 | 200 | message.insert( |
150 | 201 | "content".to_string(), |
151 | | - Either::Left(format!("{{\"output\": \"{tool_result}\"}}")), |
| 202 | + Either::Left( |
| 203 | + // Format the tool output JSON and append the search tool description for context |
| 204 | + format!( |
| 205 | + "{{\"output\": \"{}\"}}\n\n{}\n\n{}", |
| 206 | + tool_result, |
| 207 | + search::SEARCH_DESCRIPTION, |
| 208 | + search::EXTRACT_DESCRIPTION, |
| 209 | + ), |
| 210 | + ), |
152 | 211 | ); |
153 | 212 | messages.push(message); |
154 | 213 | } |
@@ -214,12 +273,16 @@ async fn do_extraction( |
214 | 273 | SearchContextSize::Low => 4096_usize, |
215 | 274 | }; |
216 | 275 |
|
217 | | - let res = tracing::dispatcher::with_default(&dispatch, || { |
218 | | - search::run_extract_tool(&tool_call_params) |
219 | | - .unwrap() |
| 276 | + let res = { |
| 277 | + let extract_result = tokio::task::block_in_place(|| { |
| 278 | + tracing::dispatcher::with_default(&dispatch, || { |
| 279 | + search::run_extract_tool(&tool_call_params).unwrap() |
| 280 | + }) |
| 281 | + }); |
| 282 | + extract_result |
220 | 283 | .cap_content_len(&tokenizer, max_results_budget_toks) |
221 | 284 | .unwrap() |
222 | | - }); |
| 285 | + }; |
223 | 286 |
|
224 | 287 | let tool_result = serde_json::to_string(&res) |
225 | 288 | .unwrap() |
|
0 commit comments