diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 3a22092334b..397da921b1b 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -19,8 +19,8 @@ use pyo3::{exceptions::PyException, prelude::*}; use crate::{engine::*, to_pyerr, CancellationToken}; +pub use dynamo_llm::endpoint_type::EndpointType; pub use dynamo_llm::http::service::{error as http_error, service_v2}; - pub use dynamo_runtime::{ error, pipeline::{async_trait, AsyncEngine, Data, ManyOut, SingleIn}, @@ -92,6 +92,27 @@ impl HttpService { Ok(()) }) } + + fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> { + let endpoint_type = EndpointType::all() + .iter() + .find(|&&ep_type| ep_type.as_str().to_lowercase() == endpoint_type.to_lowercase()) + .copied() + .ok_or_else(|| { + let valid_types = EndpointType::all() + .iter() + .map(|&ep_type| ep_type.as_str().to_string()) + .collect::>() + .join(", "); + to_pyerr(format!( + "Invalid endpoint type: '{}'. Valid types are: {}", + endpoint_type, valid_types + )) + })?; + + self.inner.enable_model_endpoint(endpoint_type, enabled); + Ok(()) + } } /// Python Exception for HTTP errors diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 87b131d069e..0437ee902a0 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -125,9 +125,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul manager.add_completions_model(local_model.display_name(), completions_engine)?; for endpoint_type in EndpointType::all() { - http_service - .enable_model_endpoint(endpoint_type, true) - .await; + http_service.enable_model_endpoint(endpoint_type, true); } http_service @@ -141,9 +139,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul // Enable all endpoints for endpoint_type in EndpointType::all() { - http_service - .enable_model_endpoint(endpoint_type, true) - .await; + http_service.enable_model_endpoint(endpoint_type, true); } http_service } @@ -170,9 +166,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul manager.add_completions_model(model.service_name(), cmpl_pipeline)?; // Enable all endpoints for endpoint_type in EndpointType::all() { - http_service - .enable_model_endpoint(endpoint_type, true) - .await; + http_service.enable_model_endpoint(endpoint_type, true); } http_service } @@ -223,7 +217,7 @@ async fn run_watcher( let _endpoint_enabler_task = tokio::spawn(async move { while let Some(model_type) = rx.recv().await { tracing::debug!("Received model type update: {:?}", model_type); - update_http_endpoints(http_service.clone(), model_type).await; + update_http_endpoints(http_service.clone(), model_type); } }); @@ -236,7 +230,7 @@ async fn run_watcher( } /// Updates HTTP service endpoints based on available model types -async fn update_http_endpoints(service: Arc, model_type: ModelUpdate) { +fn update_http_endpoints(service: Arc, model_type: ModelUpdate) { tracing::debug!( "Updating HTTP service endpoints for model type: {:?}", model_type @@ -244,32 +238,20 @@ async fn update_http_endpoints(service: Arc, model_type: ModelUpdat match model_type { ModelUpdate::Added(model_type) => match model_type { ModelType::Backend => { - service - .enable_model_endpoint(EndpointType::Chat, true) - .await; - service - .enable_model_endpoint(EndpointType::Completion, true) - .await; + service.enable_model_endpoint(EndpointType::Chat, true); + service.enable_model_endpoint(EndpointType::Completion, true); } _ => { - service - .enable_model_endpoint(model_type.as_endpoint_type(), true) - .await; + service.enable_model_endpoint(model_type.as_endpoint_type(), true); } }, ModelUpdate::Removed(model_type) => match model_type { ModelType::Backend => { - service - .enable_model_endpoint(EndpointType::Chat, false) - .await; - service - .enable_model_endpoint(EndpointType::Completion, false) - .await; + service.enable_model_endpoint(EndpointType::Chat, false); + service.enable_model_endpoint(EndpointType::Completion, false); } _ => { - service - .enable_model_endpoint(model_type.as_endpoint_type(), false) - .await; + service.enable_model_endpoint(model_type.as_endpoint_type(), false); } }, } diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index e7d710332e5..c2173612991 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -262,7 +262,7 @@ impl HttpService { &self.route_docs } - pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) { + pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) { self.state.flags.set(&endpoint_type, enable); tracing::info!( "{} endpoints {}",