Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/llm/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod model_entry;
pub use model_entry::ModelEntry;

mod watcher;
pub use watcher::ModelWatcher;
pub use watcher::{ModelUpdate, ModelWatcher};

/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
106 changes: 103 additions & 3 deletions lib/llm/src/discovery/watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

use std::sync::Arc;
use tokio::sync::mpsc::Sender;

use anyhow::Context as _;
use tokio::sync::{mpsc::Receiver, Notify};
Expand Down Expand Up @@ -36,14 +37,24 @@ use crate::{

use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelUpdate {
Added(ModelType),
Removed(ModelType),
}

pub struct ModelWatcher {
manager: Arc<ModelManager>,
drt: DistributedRuntime,
router_mode: RouterMode,
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>,
}

const ALL_MODEL_TYPES: &[ModelType] =
&[ModelType::Chat, ModelType::Completion, ModelType::Embedding];

impl ModelWatcher {
pub fn new(
runtime: DistributedRuntime,
Expand All @@ -56,10 +67,15 @@ impl ModelWatcher {
drt: runtime,
router_mode,
notify_on_model: Notify::new(),
model_update_tx: None,
kv_router_config,
}
}

pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
self.model_update_tx = Some(tx);
}

/// Wait until we have at least one chat completions model and return it's name.
pub async fn wait_for_chat_model(&self) -> String {
// Loop in case it gets added and immediately deleted
Expand Down Expand Up @@ -100,6 +116,12 @@ impl ModelWatcher {
};
self.manager.save_model_entry(key, model_entry.clone());

if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(model_entry.model_type))
.await
.ok();
}

if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(name = model_entry.name, "New endpoint for existing model");
self.notify_on_model.notify_waiters();
Expand Down Expand Up @@ -151,13 +173,91 @@ impl ModelWatcher {
.await
.with_context(|| model_name.clone())?;
if !active_instances.is_empty() {
let mut update_tx = true;
let mut model_type: ModelType = model_entry.model_type;
if model_entry.model_type == ModelType::Chat
&& self.manager.list_chat_completions_models().is_empty()
{
self.manager.remove_chat_completions_model(&model_name).ok();
Comment on lines +179 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abrarshivani sorry for commenting on old PR, can you help me understand the conditional here? remove_chat_completions_model will be run if list_chat_completions_models().is_empty(), but in that case, wouldn't the remove do nothing as the manger already has no chat completions models?

model_type = ModelType::Chat;
} else if model_entry.model_type == ModelType::Completion
&& self.manager.list_completions_models().is_empty()
{
self.manager.remove_completions_model(&model_name).ok();
model_type = ModelType::Completion;
} else if model_entry.model_type == ModelType::Embedding
&& self.manager.list_embeddings_models().is_empty()
{
self.manager.remove_embeddings_model(&model_name).ok();
model_type = ModelType::Embedding;
} else if model_entry.model_type == ModelType::Backend {
if self.manager.list_chat_completions_models().is_empty() {
self.manager.remove_chat_completions_model(&model_name).ok();
model_type = ModelType::Chat;
}
if self.manager.list_completions_models().is_empty() {
self.manager.remove_completions_model(&model_name).ok();
if model_type == ModelType::Chat {
model_type = ModelType::Backend;
} else {
model_type = ModelType::Completion;
}
}
} else {
tracing::debug!(
"Model {} is still active in other instances, not removing",
model_name
);
update_tx = false;
}
if update_tx {
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(model_type)).await.ok();
}
}
return Ok(None);
}

// Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(&model_name);
let _ = self.manager.remove_completions_model(&model_name);
let _ = self.manager.remove_embeddings_model(&model_name);
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);

let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;

if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
}
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
{
completions_model_removed = true;
}
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}

if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES {
if (chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completion)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
{
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(*model_type)).await.ok();
}
}
}
}

Ok(Some(model_name))
}
Expand Down
49 changes: 49 additions & 0 deletions lib/llm/src/endpoint_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use serde::{Deserialize, Serialize};
use strum::Display;

#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum EndpointType {
// Chat Completions API
Chat,
/// Older completions API
Completion,
/// Embeddings API
Embedding,
/// Responses API
Responses,
}

impl EndpointType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Responses => "responses",
}
}

pub fn all() -> Vec<Self> {
vec![
Self::Chat,
Self::Completion,
Self::Embedding,
Self::Responses,
]
}
}
88 changes: 82 additions & 6 deletions lib/llm/src/entrypoint/input/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
use std::sync::Arc;

use crate::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
discovery::{ModelManager, ModelUpdate, ModelWatcher, MODEL_ROOT_PATH},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{self, input::common, EngineConfig},
http::service::service_v2,
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_type::ModelType,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
Expand All @@ -22,9 +24,6 @@ use dynamo_runtime::{DistributedRuntime, Runtime};
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut http_service_builder = service_v2::HttpService::builder()
.port(engine_config.local_model().http_port())
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(engine_config.local_model().request_template());

let http_service = match engine_config {
Expand All @@ -45,6 +44,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
Arc::new(http_service.clone()),
)
.await?;
}
Expand Down Expand Up @@ -98,6 +98,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;

for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}

http_service
}
EngineConfig::StaticFull { engine, model, .. } => {
Expand All @@ -106,6 +112,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;

// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service
}
EngineConfig::StaticCore {
Expand All @@ -129,6 +142,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
>(model.card(), inner_engine)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service
}
};
Expand All @@ -154,13 +173,70 @@ async fn run_watcher(
network_prefix: &str,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
http_service: Arc<HttpService>,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();

// Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32);

watch_obj.set_notify_on_model_update(tx);

// Spawn a task to watch for model type changes and update HTTP service endpoints
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type);
update_http_endpoints(http_service.clone(), model_type).await;
}
});

// Pass the sender to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await;
});

Ok(())
}

/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}",
model_type
);
match model_type {
ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, true)
.await;
service
.enable_model_endpoint(EndpointType::Completion, true)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), true)
.await;
}
},
ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, false)
.await;
service
.enable_model_endpoint(EndpointType::Completion, false)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), false)
.await;
}
},
}
}
Loading
Loading