diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 20dbd6fdd..edd3b7849 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -33,6 +33,7 @@ SpyreAttentionMetadata, SpyreCausalLM, ) +from vllm_spyre.perf_metrics import create_perf_metric_logger from vllm_spyre.platform import SpyrePlatform from vllm_spyre.utils import exact_div from vllm_spyre.v1.sample.spyre_logits_processor import build_logitsprocs_for_cb @@ -716,6 +717,9 @@ def __init__( self.prefix_cache_stats = None + # Initialize performance metric logger for tracking embedding times + self.perf_logger = create_perf_metric_logger(rank=rank) + def load_model(self) -> None: self._model = SpyreCausalLM( vllm_config=self.vllm_config, @@ -1040,9 +1044,16 @@ def _prepare_chunked_prefill(self, req_id: str) -> SamplingForwardInputs: is_decode=False, ) - t1 = time.time() - t0 + t_elapsed = time.time() - t0 - logger.info("maybe_mm_embedding processing time: %.2fms", (t1 * 1000)) + logger.info("maybe_mm_embedding processing time: %.2fms", (t_elapsed * 1000)) + self.perf_logger.log( + "get_mm_embeddings_time_ms", + t_elapsed * 1000, + phase="prefill", + has_mm_features=True, + req_id=req_id, + ) # Cache the full embeddings for subsequent chunks request.cached_mm_embeddings = full_embeds @@ -1078,11 +1089,20 @@ def _prepare_chunked_prefill(self, req_id: str) -> SamplingForwardInputs: ) else: # Non-multimodal or decode: use standard token embedding + t0 = time.time() input_embeds = self.model.get_maybe_mm_embeddings( input_tokens, mm_features=None, is_decode=False, ) + t_elapsed = time.time() - t0 + self.perf_logger.log( + "get_mm_embeddings_time_ms", + t_elapsed * 1000, + phase="prefill", + has_mm_features=False, + req_id=req_id, + ) model_inputs = SamplingForwardInputs( input_tokens=input_tokens, @@ -1175,11 +1195,22 @@ def _prepare_decode( # None unless this model is multimodal; no mm_features since # all multimodal features are merged in prefill. + t0 = time.time() input_embeds = self.model.get_maybe_mm_embeddings( input_tokens, mm_features=None, is_decode=True, ) + t_elapsed = time.time() - t0 + # Log timing for each request in the decode batch + for req_id in cached_request_data.req_ids: + self.perf_logger.log( + "get_mm_embeddings_time_ms", + t_elapsed * 1000, + phase="decode", + has_mm_features=False, + req_id=req_id, + ) model_inputs = SamplingForwardInputs( input_tokens=input_tokens,