diff --git a/Cargo.lock b/Cargo.lock index 669e99292c..0c9ec0dfc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,6 +356,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bm25" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9874599901ae2aaa19b1485145be2fa4e9af42d1b127672a03a7099ab6350bac" +dependencies = [ + "cached", + "deunicode", + "fxhash", + "rust-stemmers", + "stop-words", + "unicode-segmentation", +] + [[package]] name = "bumpalo" version = "3.17.0" @@ -400,6 +414,39 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "cached" +version = "0.55.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0839c297f8783316fcca9d90344424e968395413f0662a5481f79c6648bbc14" +dependencies = [ + "ahash", + "cached_proc_macro", + "cached_proc_macro_types", + "hashbrown 0.14.5", + "once_cell", + "thiserror 2.0.12", + "web-time", +] + +[[package]] +name = "cached_proc_macro" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673992d934f0711b68ebb3e1b79cdc4be31634b37c98f26867ced0438ca5c603" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "cached_proc_macro_types" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade8366b8bd5ba243f0a58f036cc0ca8a2f069cff1a2351ef1cac6b083e16fc0" + [[package]] name = "candle-core" version = "0.8.0" @@ -992,11 +1039,17 @@ dependencies = [ "anyhow", "bytemuck", "bytemuck_derive", - "hashbrown", + "hashbrown 0.15.3", "regex-syntax 0.8.5", "strum 0.27.1", ] +[[package]] +name = "deunicode" +version = "1.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04" + [[package]] name = "digest" version = "0.10.7" @@ -1670,6 +1723,16 @@ dependencies = [ "rand_distr", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + [[package]] name = "hashbrown" version = "0.15.3" @@ -2067,7 +2130,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.3", "serde", ] @@ -2539,6 +2602,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bindgen_cuda 0.1.5", + "bm25", "bytemuck", "bytemuck_derive", "candle-core", @@ -2557,7 +2621,7 @@ dependencies = [ "futures", "galil-seiferas", "half", - "hashbrown", + "hashbrown 0.15.3", "hf-hub", "html2text", "image", @@ -3311,7 +3375,7 @@ dependencies = [ "chrono-tz", "either", "eyre", - "hashbrown", + "hashbrown 0.15.3", "indexmap", "indoc", "libc", @@ -3781,6 +3845,16 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust-stemmers" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54" +dependencies = [ + "serde", + "serde_derive", +] + [[package]] name = "rust_decimal" version = "1.37.1" @@ -4239,6 +4313,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stop-words" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6a86be9f7fa4559b7339669e72026eb437f5e9c5a85c207fe1033079033a17" +dependencies = [ + "serde_json", +] + [[package]] name = "string_cache" version = "0.8.9" diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 8840c2fcf0..ca76f18c25 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -95,6 +95,7 @@ parking_lot = "0.12.3" ahash = "0.8.12" num-traits = "0.2.19" libc = "0.2.172" +bm25 = "2.2.1" [features] pyo3_macros = ["pyo3"] diff --git a/mistralrs-core/src/engine/add_request.rs b/mistralrs-core/src/engine/add_request.rs index 820c7e6227..9fee96da2e 100644 --- a/mistralrs-core/src/engine/add_request.rs +++ b/mistralrs-core/src/engine/add_request.rs @@ -33,7 +33,6 @@ impl Engine { request.messages, RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. } ) && request.web_search_options.is_some() - && get_mut_arcmutex!(self.bert_pipeline).is_some() { search_request::search_request(self.clone(), *request).await; } else { diff --git a/mistralrs-core/src/engine/search_request.rs b/mistralrs-core/src/engine/search_request.rs index b1f7309784..a3bba21a07 100644 --- a/mistralrs-core/src/engine/search_request.rs +++ b/mistralrs-core/src/engine/search_request.rs @@ -1,5 +1,6 @@ use std::{borrow::Cow, sync::Arc, time::Instant}; +use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer}; use either::Either; use indexmap::IndexMap; use tokenizers::InputSequence; @@ -72,52 +73,102 @@ async fn do_search( SearchContextSize::Medium => 8192_usize, SearchContextSize::Low => 4096_usize, }; - let mut results = tracing::dispatcher::with_default(&dispatch, || { - search::run_search_tool(&tool_call_params) - .unwrap() - .into_iter() - .map(|mut result| { - result = result - .cap_content_len(&tokenizer, max_results_budget_toks) - .unwrap(); - let len = { - let inp = InputSequence::Raw(Cow::from(&result.content)); - tokenizer - .encode_fast(inp, false) - .map(|x| x.len()) - .unwrap_or(usize::MAX) - }; - (result, len) - }) - .collect::>() + let mut results = tokio::task::block_in_place(|| { + tracing::dispatcher::with_default(&dispatch, || { + search::run_search_tool(&tool_call_params) + .unwrap() + .into_iter() + .map(|mut result| { + result = result + .cap_content_len(&tokenizer, max_results_budget_toks) + .unwrap(); + let len = { + let inp = InputSequence::Raw(Cow::from(&result.content)); + tokenizer + .encode_fast(inp, false) + .map(|x| x.len()) + .unwrap_or(usize::MAX) + }; + (result, len) + }) + .collect::>() + }) }); // Sort increasing by tokenized length, if it fails, put it at the end. results.sort_by_key(|(_, len)| *len); { - let device = get_mut_arcmutex!(this.pipeline).device(); + // Determine ranking: use embedding model if available, otherwise fallback to BM25 + let decreasing_indexes: Vec = if let Some(bert_pipeline) = + &mut *get_mut_arcmutex!(this.bert_pipeline) + { + // Semantic reranking with embeddings + let device = get_mut_arcmutex!(this.pipeline).device(); + search::rag::compute_most_similar( + &device, + &tool_call_params.query, + results.iter().map(|(res, _)| res).collect::>(), + bert_pipeline, + ) + .unwrap() + } else { + tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results."); - let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else { - unreachable!() - }; + // Build an Embedder over the corpus, fitting to the entire set of documents. + // - Language::English is chosen here + // - This computes an in‑memory sparse embedding for each document. + + let docs: Vec = + results.iter().map(|(res, _)| res.content.clone()).collect(); + let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect(); + + let embedder: Embedder = + EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build(); - let decreasing_indexes = search::rag::compute_most_similar( - &device, - &tool_call_params.query, - results.iter().map(|(res, _)| res).collect::>(), - bert_pipeline, - ) - .unwrap(); - - // Rerank the results - let mut results_old = Vec::new(); - std::mem::swap(&mut results_old, &mut results); - for &index in &decreasing_indexes { - let mut current_result: (SearchResult, usize) = Default::default(); - std::mem::swap(&mut current_result, &mut results_old[index]); - - results.push(current_result); + // Initialize a Scorer keyed by usize (document index type). + let mut scorer = Scorer::::new(); + + // For each document, compute its embedding and upsert into the scorer. + for (i, doc_text) in docs.iter().enumerate() { + let doc_embedding = embedder.embed(doc_text); + scorer.upsert(&i, doc_embedding); + } + + // Embed the query string into the same sparse embedding space. + let query_embedding = embedder.embed(&tool_call_params.query); + + // Score all documents individually + let mut scored_docs: Vec> = docs + .iter() + .enumerate() + .filter_map(|(i, _)| { + scorer + .score(&i, &query_embedding) + .map(|score| ScoredDocument { id: i, score }) + }) + .collect(); + + // Sort the scored documents by descending `score` (f32). + scored_docs.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Extract only the document indices (usize) in ranked order. + let decreasing_indexes: Vec = + scored_docs.into_iter().map(|d| d.id).collect(); + + decreasing_indexes + }; + // Reorder results according to ranking + let mut old = Vec::new(); + std::mem::swap(&mut old, &mut results); + for &idx in &decreasing_indexes { + let mut item: (SearchResult, usize) = Default::default(); + std::mem::swap(&mut item, &mut old[idx]); + results.push(item); } } @@ -148,7 +199,15 @@ async fn do_search( message.insert("role".to_string(), Either::Left("tool".to_string())); message.insert( "content".to_string(), - Either::Left(format!("{{\"output\": \"{tool_result}\"}}")), + Either::Left( + // Format the tool output JSON and append the search tool description for context + format!( + "{{\"output\": \"{}\"}}\n\n{}\n\n{}", + tool_result, + search::SEARCH_DESCRIPTION, + search::EXTRACT_DESCRIPTION, + ), + ), ); messages.push(message); } @@ -214,12 +273,16 @@ async fn do_extraction( SearchContextSize::Low => 4096_usize, }; - let res = tracing::dispatcher::with_default(&dispatch, || { - search::run_extract_tool(&tool_call_params) - .unwrap() + let res = { + let extract_result = tokio::task::block_in_place(|| { + tracing::dispatcher::with_default(&dispatch, || { + search::run_extract_tool(&tool_call_params).unwrap() + }) + }); + extract_result .cap_content_len(&tokenizer, max_results_budget_toks) .unwrap() - }); + }; let tool_result = serde_json::to_string(&res) .unwrap() diff --git a/mistralrs-core/src/search/mod.rs b/mistralrs-core/src/search/mod.rs index 42ea8f7768..fbdeddb004 100644 --- a/mistralrs-core/src/search/mod.rs +++ b/mistralrs-core/src/search/mod.rs @@ -22,7 +22,7 @@ pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web"; pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor"; const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); -const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query. +pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query. If the user wants up-to-date information or you want to retrieve new information, call this tool. If you call this tool, then you MUST complete your answer using the output. The input can be a query. It should not be a URL. Either is fine. @@ -41,7 +41,7 @@ You should expect output like this: ] } "#; -const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website. +pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website. If the user wants information about a specific site or you want to extract the content of a specific site, call this tool. The input must be a URL. Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly. diff --git a/mistralrs-web-chat/src/handlers/websocket.rs b/mistralrs-web-chat/src/handlers/websocket.rs index 201de8a15d..789f13317f 100644 --- a/mistralrs-web-chat/src/handlers/websocket.rs +++ b/mistralrs-web-chat/src/handlers/websocket.rs @@ -5,7 +5,9 @@ use axum::{ }; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use futures_util::stream::StreamExt; -use mistralrs::{Model, TextMessageRole, TextMessages, VisionMessages}; +use mistralrs::{ + Model, RequestBuilder, TextMessageRole, TextMessages, VisionMessages, WebSearchOptions, +}; use serde_json::Value; use std::io::Cursor; use std::mem; @@ -25,6 +27,15 @@ pub struct VisionContext<'a> { pub image_buffer: &'a mut Vec, } +/// Aggregates frequently used parameters so that helper functions stay below +/// Clippy’s `too_many_arguments` threshold. +pub struct HandlerParams<'a> { + pub socket: &'a mut WebSocket, + pub app: &'a Arc, + pub streaming: &'a mut bool, + pub active_chat_id: &'a Option, +} + /// Upgrades an HTTP request to a WebSocket connection. pub async fn ws_handler( ws: WebSocketUpgrade, @@ -141,7 +152,7 @@ pub async fn handle_socket(mut socket: WebSocket, app: Arc) { } continue; } - // Handle front‑end replay helper messages without triggering inference + // Handle front-end replay helper messages without triggering inference if user_msg.trim_start().starts_with("{\"restore\":") { handle_restore_message( &user_msg, @@ -153,6 +164,90 @@ pub async fn handle_socket(mut socket: WebSocket, app: Arc) { .await; continue; } + // Handle chat messages with optional web search options provided as JSON + if let Ok(val) = serde_json::from_str::(&user_msg) { + if let Some(content) = val.get("content").and_then(|v| v.as_str()) { + // Extract web search options if provided + let web_search_opts = if let Some(opts_val) = val.get("web_search_options") { + match serde_json::from_value::(opts_val.clone()) { + Ok(opts) => Some(opts), + Err(e) => { + let _ = socket + .send(Message::Text(format!( + "Error parsing web_search_options: {}", + e + ))) + .await; + None + } + } + } else { + None + }; + // Determine selected model + let model_name_opt = { app.current.read().await.clone() }; + let Some(model_name) = model_name_opt else { + let _ = socket + .send(Message::Text( + "No model selected. Choose one in the sidebar.".into(), + )) + .await; + continue; + }; + let Some(model_loaded) = app.models.get(&model_name).cloned() else { + let _ = socket + .send(Message::Text("Selected model not found.".into())) + .await; + continue; + }; + match model_loaded { + LoadedModel::Text(model) => { + let mut params = HandlerParams { + socket: &mut socket, + app: &app, + streaming: &mut streaming, + active_chat_id: &active_chat_id, + }; + handle_text_model( + &model, + content, + web_search_opts.clone(), + &mut text_msgs, + &mut params, + ) + .await; + } + LoadedModel::Vision(model) => { + let mut vision_ctx = VisionContext { + msgs: &mut vision_msgs, + image_buffer: &mut image_buffer, + }; + let mut params = HandlerParams { + socket: &mut socket, + app: &app, + streaming: &mut streaming, + active_chat_id: &active_chat_id, + }; + handle_vision_model( + &model, + content, + web_search_opts.clone(), + &mut vision_ctx, + &mut params, + ) + .await; + } + LoadedModel::Speech(_) => { + let _ = socket + .send(Message::Text( + "Speech models are not supported over WebSocket".into(), + )) + .await; + } + } + continue; + } + } let model_name_opt = { app.current.read().await.clone() }; let Some(model_name) = model_name_opt else { @@ -172,32 +267,26 @@ pub async fn handle_socket(mut socket: WebSocket, app: Arc) { match model_loaded { LoadedModel::Text(model) => { - handle_text_model( - &model, - &user_msg, - &mut text_msgs, - &mut socket, - &app, - &mut streaming, - &active_chat_id, - ) - .await; + let mut params = HandlerParams { + socket: &mut socket, + app: &app, + streaming: &mut streaming, + active_chat_id: &active_chat_id, + }; + handle_text_model(&model, &user_msg, None, &mut text_msgs, &mut params).await; } LoadedModel::Vision(model) => { let mut vision_ctx = VisionContext { msgs: &mut vision_msgs, image_buffer: &mut image_buffer, }; - handle_vision_model( - &model, - &user_msg, - &mut vision_ctx, - &mut socket, - &app, - &mut streaming, - &active_chat_id, - ) - .await; + let mut params = HandlerParams { + socket: &mut socket, + app: &app, + streaming: &mut streaming, + active_chat_id: &active_chat_id, + }; + handle_vision_model(&model, &user_msg, None, &mut vision_ctx, &mut params).await; } // Speech models should use HTTP endpoint; not handled here LoadedModel::Speech(_) => { @@ -305,12 +394,15 @@ async fn handle_restore_message( async fn handle_text_model( model: &Arc, user_msg: &str, + web_search_opts: Option, text_msgs: &mut TextMessages, - socket: &mut WebSocket, - app: &Arc, - streaming: &mut bool, - active_chat_id: &Option, + params: &mut HandlerParams<'_>, ) { + // Local aliases keep the original body unchanged. + let socket = &mut *params.socket; + let app = params.app; + let streaming = &mut *params.streaming; + let active_chat_id = params.active_chat_id; *text_msgs = text_msgs .clone() .add_message(TextMessageRole::User, user_msg); @@ -321,11 +413,15 @@ async fn handle_text_model( } let mut assistant_content = String::new(); let msgs_snapshot = text_msgs.clone(); + let mut request_builder = RequestBuilder::from(msgs_snapshot); + if let Some(opts) = web_search_opts { + request_builder = request_builder.with_web_search_options(opts); + } *streaming = true; let stream_res = stream_and_forward( model, - msgs_snapshot, + request_builder, socket, |tok| { assistant_content = tok.to_string(); @@ -348,12 +444,14 @@ async fn handle_text_model( async fn handle_vision_model( model: &Arc, user_msg: &str, + web_search_opts: Option, vision_ctx: &mut VisionContext<'_>, - socket: &mut WebSocket, - app: &Arc, - streaming: &mut bool, - active_chat_id: &Option, + params: &mut HandlerParams<'_>, ) { + let socket = &mut *params.socket; + let app = params.app; + let streaming = &mut *params.streaming; + let active_chat_id = params.active_chat_id; // Track the exact set of messages that will be sent *this* turn. let mut msgs_for_stream: Option = None; // --- Vision input routing --- @@ -460,11 +558,16 @@ async fn handle_vision_model( } } + let msgs = msgs_for_stream.expect("msgs_for_stream must be set"); + let mut request_builder = RequestBuilder::from(msgs); + if let Some(opts) = web_search_opts { + request_builder = request_builder.with_web_search_options(opts); + } *streaming = true; let mut assistant_content = String::new(); let stream_res = stream_and_forward( model, - msgs_for_stream.expect("msgs_for_stream must be set"), + request_builder, socket, |tok| { assistant_content = tok.to_string(); diff --git a/mistralrs-web-chat/src/main.rs b/mistralrs-web-chat/src/main.rs index 16c2e05160..6654654e9a 100644 --- a/mistralrs-web-chat/src/main.rs +++ b/mistralrs-web-chat/src/main.rs @@ -12,8 +12,8 @@ use hyper::Uri; use include_dir::{include_dir, Dir}; use indexmap::IndexMap; use mistralrs::{ - best_device, parse_isq_value, IsqType, SpeechLoaderType, SpeechModelBuilder, TextModelBuilder, - VisionModelBuilder, + best_device, parse_isq_value, BertEmbeddingModel, IsqType, SpeechLoaderType, + SpeechModelBuilder, TextModelBuilder, VisionModelBuilder, }; use std::{net::SocketAddr, sync::Arc}; use tokio::{fs, net::TcpListener}; @@ -71,6 +71,15 @@ async fn main() -> Result<()> { .as_ref() .and_then(|isq| parse_isq_value(isq, Some(&device)).ok()); + // Determine embedding model for web search if enabled + let search_embedding_model: Option = if cli.enable_search { + Some(match &cli.search_bert_model { + Some(model_id) => BertEmbeddingModel::Custom(model_id.clone()), + None => BertEmbeddingModel::default(), + }) + } else { + None + }; let mut models: IndexMap = IndexMap::new(); // Insert text models first @@ -81,12 +90,14 @@ async fn main() -> Result<()> { .unwrap_or("text-model") .to_string(); println!("📝 Loading text model: {name}"); - let m = TextModelBuilder::new(path) + let mut builder = TextModelBuilder::new(path) .with_isq(isq.unwrap_or(default_isq)) .with_logging() - .with_throughput_logging() - .build() - .await?; + .with_throughput_logging(); + if let Some(ref bert_model) = search_embedding_model { + builder = builder.with_search(bert_model.clone()); + } + let m = builder.build().await?; models.insert(name, LoadedModel::Text(Arc::new(m))); } @@ -98,12 +109,14 @@ async fn main() -> Result<()> { .unwrap_or("vision-model") .to_string(); println!("🖼️ Loading vision model: {name}"); - let m = VisionModelBuilder::new(path) + let mut builder = VisionModelBuilder::new(path) .with_isq(isq.unwrap_or(default_isq)) .with_logging() - .with_throughput_logging() - .build() - .await?; + .with_throughput_logging(); + if let Some(ref bert_model) = search_embedding_model { + builder = builder.with_search(bert_model.clone()); + } + let m = builder.build().await?; models.insert(name, LoadedModel::Vision(Arc::new(m))); } diff --git a/mistralrs-web-chat/src/types.rs b/mistralrs-web-chat/src/types.rs index 53c136e813..385867dffc 100644 --- a/mistralrs-web-chat/src/types.rs +++ b/mistralrs-web-chat/src/types.rs @@ -23,6 +23,12 @@ pub struct Cli { /// Repeated flag for speech models #[arg(long = "speech-model")] pub speech_models: Vec, + /// Enable web search tool (requires embedding model) + #[arg(long)] + pub enable_search: bool, + /// Hugging Face model ID for search embeddings (default: SnowflakeArcticEmbedL if --enable-search) + #[arg(long = "search-bert-model")] + pub search_bert_model: Option, /// Port to listen on (default: 8080) #[arg(long = "port")] diff --git a/mistralrs-web-chat/static/index.html b/mistralrs-web-chat/static/index.html index d3c3e79def..c62d8d3b75 100644 --- a/mistralrs-web-chat/static/index.html +++ b/mistralrs-web-chat/static/index.html @@ -14,6 +14,21 @@

Control panel

+
+
    diff --git a/mistralrs-web-chat/static/js/ui.js b/mistralrs-web-chat/static/js/ui.js index 69e033f41d..f74d49c4a0 100644 --- a/mistralrs-web-chat/static/js/ui.js +++ b/mistralrs-web-chat/static/js/ui.js @@ -294,6 +294,18 @@ function initStopButton() { }); } +/** + * Initialize web search controls: toggle visibility of search options + */ +function initWebSearchControls() { + const checkbox = document.getElementById('enableSearch'); + const options = document.getElementById('webSearchOptions'); + if (!checkbox || !options) return; + checkbox.addEventListener('change', () => { + options.hidden = !checkbox.checked; + }); +} + /** * Initialize all UI interactions */ @@ -302,5 +314,6 @@ function initUI() { initImageUpload(); initTextUpload(); initDragAndDrop(); + initWebSearchControls(); initStopButton(); } diff --git a/mistralrs-web-chat/static/js/websocket.js b/mistralrs-web-chat/static/js/websocket.js index c13fd1d8a0..fa45329ef1 100644 --- a/mistralrs-web-chat/static/js/websocket.js +++ b/mistralrs-web-chat/static/js/websocket.js @@ -211,7 +211,21 @@ ${content} showSpinner(); - ws.send(msg); + // Send message, optionally with web search options + const enableSearch = document.getElementById('enableSearch')?.checked; + if (enableSearch) { + const opts = {}; + // Include selected context size (default medium) + const sizeSelect = document.getElementById('searchContextSize'); + if (sizeSelect) { + const sizeValue = sizeSelect.value; + if (sizeValue) opts.search_context_size = sizeValue; + } + const payload = { content: msg, web_search_options: opts }; + ws.send(JSON.stringify(payload)); + } else { + ws.send(msg); + } input.value = ''; // Clear uploaded files after sending diff --git a/mistralrs-web-chat/static/styles.css b/mistralrs-web-chat/static/styles.css index 719794c3af..7c62f03c2f 100644 --- a/mistralrs-web-chat/static/styles.css +++ b/mistralrs-web-chat/static/styles.css @@ -426,3 +426,48 @@ pre:hover .copy-btn { @keyframes spin { to { transform: rotate(360deg); } } + +/* ----- sidebar card for Web Search ----- */ +.sidebar-card { + background: var(--chat-bg); + border: 1px solid var(--border-color); + border-radius: var(--radius); + padding: 0.75rem; +} +.sidebar-card h3 { + margin: 0 0 0.5rem; + font-size: 1rem; +} +.sidebar-card-content { + display: flex; + flex-direction: column; + gap: 0.75rem; +} +.toggle-label { + display: flex; + align-items: center; + gap: 0.5rem; + cursor: pointer; + font-size: 0.95rem; +} +/* Web Search options within card */ +#webSearchOptions { + display: none; + flex-direction: column; + gap: 0.5rem; +} +#webSearchOptions:not([hidden]) { + display: flex; +} +#webSearchOptions label { + font-size: 0.9rem; +} +#webSearchOptions select { + width: 100%; + padding: 0.45rem; + border: 1px solid var(--border-color); + border-radius: var(--radius); + background: var(--chat-bg); + color: var(--text-color); + font-family: inherit; +}