Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions deploy/sdk/src/dynamo/sdk/lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ServiceConfig(dict):
"""Configuration store that inherits from dict for simpler access patterns"""

_instance = None
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"

@classmethod
def get_instance(cls):
Expand All @@ -49,6 +51,33 @@ def require(self, service_name, key):
raise ValueError(f"{service_name}.{key} must be specified in configuration")
return self[service_name][key]

@classmethod
def get_parsed_config(cls, service_name):
"""Get parsed config for a service with common configs applied, returned as dict"""
instance = cls.get_instance()

if service_name not in instance:
return {}

# Get service config excluding ServiceArgs if it exists
service_config = instance[service_name].copy()
if "ServiceArgs" in service_config:
del service_config["ServiceArgs"]

# Apply common configs if they exist
if (common := instance.get(cls.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(cls.COMMON_CONFIG_KEY)
) is not None:
for key in common_config_keys:
if key in common and key not in service_config:
service_config[key] = common[key]

# Remove the common-configs key itself from the final config
if cls.COMMON_CONFIG_KEY in service_config:
del service_config[cls.COMMON_CONFIG_KEY]

return service_config

def as_args(self, service_name, prefix=""):
"""Extract configs as CLI args for a service, with optional prefix filtering.

Expand All @@ -57,8 +86,6 @@ def as_args(self, service_name, prefix=""):
the component's `common-configs` setting, and that key has not been overriden by the
component's config.
"""
COMMON_CONFIG_SERVICE = "Common"
COMMON_CONFIG_KEY = "common-configs"

if service_name not in self:
return []
Expand All @@ -69,7 +96,7 @@ def add_to_args(args: list[str], key: str, value):
if prefix and not key.startswith(prefix):
return

if key.endswith(COMMON_CONFIG_KEY):
if key.endswith(self.COMMON_CONFIG_KEY):
return

# Strip prefix if needed
Expand All @@ -90,8 +117,8 @@ def add_to_args(args: list[str], key: str, value):
if "ServiceArgs" in service_config:
del service_config["ServiceArgs"]

if (common := self.get(COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(COMMON_CONFIG_KEY)
if (common := self.get(self.COMMON_CONFIG_SERVICE)) is not None and (
common_config_keys := service_config.get(self.COMMON_CONFIG_KEY)
) is not None:
for key in common_config_keys:
if key in common and key not in service_config:
Expand Down
3 changes: 1 addition & 2 deletions examples/llm/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ class Frontend:

def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None
self.setup_model()
Expand Down
3 changes: 1 addition & 2 deletions examples/sglang/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ class Frontend:

def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None

Expand Down
3 changes: 1 addition & 2 deletions examples/tensorrt_llm/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class Frontend:
processor = depends(Processor)

def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))

# Chat/completions Endpoint
subprocess.run(
Expand Down
9 changes: 7 additions & 2 deletions examples/vllm_v0/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ class Frontend:

def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
self.frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = FrontendConfig(
**ServiceConfig.get_parsed_config("Frontend")
)
self.process = None

logger.warning(f"Frontend config: {self.frontend_config}")

self.start_ingress_and_processor()

def start_ingress_and_processor(self):
Expand All @@ -87,6 +90,8 @@ def start_ingress_and_processor(self):
self.frontend_config.router,
]

logger.info(f"Frontend cmd: {cmd}")

self.process = subprocess.Popen(
cmd,
stdout=None,
Expand Down
6 changes: 3 additions & 3 deletions examples/vllm_v0/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def generate(self, request: PreprocessedRequest):
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.token_ids),
0, # TODO: return prefix hit rate from dynamo-run router
request.estimated_prefix_hit_num_blocks * self.engine_args.block_size,
prefill_queue_size,
)
else:
Expand All @@ -225,12 +225,12 @@ async def generate(self, request: PreprocessedRequest):
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
logger.info(
f"Prefilling remotely for request {request_id} with length {len(request.token_ids)}"
f"Prefilling remotely for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {request.estimated_prefix_hit_num_blocks * self.engine_args.block_size})"
)
else:
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request_id} with length {len(request.token_ids)}"
f"Prefilling locally for request {request_id} with length {len(request.token_ids)} (estimated prefix hit length {request.estimated_prefix_hit_num_blocks * self.engine_args.block_size})"
)

sampling_params = SamplingParams(**self.default_sampling_params)
Expand Down
6 changes: 3 additions & 3 deletions examples/vllm_v0/configs/agg_kv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate
port: 8000
router: kv
common-configs: [block-size]
common-configs: [block-size, router]

VllmWorker:
enforce-eager: true
Expand All @@ -32,4 +32,4 @@ VllmWorker:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
common-configs: [model, block-size, max-model-len, router]
6 changes: 3 additions & 3 deletions examples/vllm_v0/configs/disagg_kv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ Common:
block-size: 64
max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
router: kv

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.VllmWorker.generate
port: 8000
router: kv
common-configs: [block-size]
common-configs: [block-size, router]

VllmWorker:
remote-prefill: true
Expand All @@ -35,7 +35,7 @@ VllmWorker:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config]
common-configs: [model, block-size, max-model-len, kv-transfer-config, router]

PrefillWorker:
max-num-batched-tokens: 16384
Expand Down
1 change: 1 addition & 0 deletions examples/vllm_v0/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class PreprocessedRequest(BaseModel):
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
estimated_prefix_hit_num_blocks: Optional[int] = None


class DisaggPreprocessedRequest(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions examples/vllm_v1/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class Frontend:

def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
frontend_config = FrontendConfig(**ServiceConfig.get_parsed_config("Frontend"))
self.frontend_config = frontend_config
self.process = None

Expand Down
21 changes: 15 additions & 6 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub mod recorder;
pub mod scheduler;
pub mod scoring;

use tracing;

use crate::{
kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
Expand Down Expand Up @@ -129,7 +131,8 @@ impl KvRouter {
}

/// Give these tokens, find the worker with the best match in it's KV cache.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<i64> {
/// Returned overlap amount is in number of blocks.
async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len();
let block_size = self.block_size;

Expand All @@ -141,8 +144,9 @@ impl KvRouter {
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
let worker_id = self.scheduler.schedule(overlap_scores.clone(), isl_tokens).await?;
let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
Ok((worker_id, overlap_amount))
}
}

Expand All @@ -153,7 +157,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let worker_id = self.find_best_match(&request.tokens).await?;
let (worker_id, _) = self.find_best_match(&request.tokens).await?;

let response = RouterResponse { worker_id };
let response = Annotated::from_data(response);
Expand Down Expand Up @@ -187,8 +191,13 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
match self.inner.client.instance_source.as_ref() {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let instance_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, instance_id).await
let (instance_id, overlap_amount) =
self.chooser.find_best_match(&request.token_ids).await?;
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
self.inner.direct(updated_request, instance_id).await
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl OpenAIPreprocessor {
builder.stop_conditions(stop_conditions);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);

Ok((builder.build()?, annotations))
}
Expand Down
4 changes: 4 additions & 0 deletions lib/llm/src/protocols/common/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub struct PreprocessedRequest {
/// User requested annotations for the request
#[builder(default)]
pub annotations: Vec<String>,

/// Estimated number of prefix hit tokens (only used in kv aware routing)
#[builder(default)]
pub estimated_prefix_hit_num_blocks: Option<u32>,
}

impl PreprocessedRequest {
Expand Down
Loading