Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/bindings/python/examples/openai_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime):
host: str = "localhost"
port: int = 8000
service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine)
service.add_chat_completions_model(served_model_name, "mdcsum", engine)

print("Starting service...")
shutdown_signal = service.run(runtime.child_token())
Expand Down
12 changes: 9 additions & 3 deletions lib/bindings/python/rust/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,29 @@ impl HttpService {
Ok(Self { inner })
}

pub fn add_completions_model(&self, model: String, engine: HttpAsyncEngine) -> PyResult<()> {
pub fn add_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_completions_model(&model, engine)
.add_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}

pub fn add_chat_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_chat_completions_model(&model, engine)
.add_chat_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/bindings/python/tests/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime):
model_name = "test_model"
start_done = asyncio.Event()
child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model

async def worker():
"""The server worker task."""
Expand All @@ -94,7 +95,7 @@ async def worker():
engine = HttpAsyncEngine(python_engine.generate, loop)

service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine)
service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True)

shutdown_signal = service.run(child_token)
Expand Down
11 changes: 10 additions & 1 deletion lib/llm/src/common/checked_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct Checksum {
algorithm: CryptographicHashMethods,
}

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
pub enum CryptographicHashMethods {
#[serde(rename = "blake3")]
BLAKE3,
Expand Down Expand Up @@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum {
}
}

impl Default for Checksum {
fn default() -> Self {
Self {
hash: "".to_string(),
algorithm: CryptographicHashMethods::BLAKE3,
}
}
}

impl FromStr for CryptographicHashMethods {
type Err = String;

Expand Down
6 changes: 0 additions & 6 deletions lib/llm/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};

mod model_entry;
pub use model_entry::ModelEntry;

mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};

/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";

/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
30 changes: 0 additions & 30 deletions lib/llm/src/discovery/model_entry.rs

This file was deleted.

80 changes: 69 additions & 11 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;

use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{
kv_router::KvRouter,
Expand All @@ -21,6 +20,10 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
model_type::ModelType,
};

#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
Expand All @@ -39,7 +42,7 @@ pub struct ModelManager {
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,

// These two are Mutex because we read and write rarely and equally
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}
Expand All @@ -62,6 +65,43 @@ impl ModelManager {
}
}

pub fn is_valid_checksum(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
}

pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect()
}
Expand Down Expand Up @@ -99,37 +139,41 @@ impl ModelManager {
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}

pub fn add_chat_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}

pub fn add_embeddings_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}

pub fn add_tensor_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}

pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
Expand Down Expand Up @@ -196,10 +240,11 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}

/// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd.
pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) {
self.cards.lock().insert(key.to_string(), entry);
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card);
Ok(())
}

/// Remove and return model card for this instance's etcd key. We do this when the instance stops.
Expand Down Expand Up @@ -291,13 +336,17 @@ pub struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
}

impl<E> Default for ModelEngines<E> {
fn default() -> Self {
Self {
default: None,
engines: HashMap::new(),
checksums: HashMap::new(),
}
}
}
Expand All @@ -313,18 +362,21 @@ impl<E> ModelEngines<E> {
self.default = None;
}

fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(())
}

fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string()));
}
let _ = self.checksums.remove(model);
Ok(())
}

Expand All @@ -339,4 +391,10 @@ impl<E> ModelEngines<E> {
pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
}

/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
}
}
Loading
Loading