Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SpyreAttentionMetadata,
SpyreCausalLM,
)
from vllm_spyre.perf_metrics import create_perf_metric_logger
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.utils import exact_div
from vllm_spyre.v1.sample.spyre_logits_processor import build_logitsprocs_for_cb
Expand Down Expand Up @@ -716,6 +717,9 @@ def __init__(

self.prefix_cache_stats = None

# Initialize performance metric logger for tracking embedding times
self.perf_logger = create_perf_metric_logger(rank=rank)

def load_model(self) -> None:
self._model = SpyreCausalLM(
vllm_config=self.vllm_config,
Expand Down Expand Up @@ -1040,9 +1044,16 @@ def _prepare_chunked_prefill(self, req_id: str) -> SamplingForwardInputs:
is_decode=False,
)

t1 = time.time() - t0
t_elapsed = time.time() - t0

logger.info("maybe_mm_embedding processing time: %.2fms", (t1 * 1000))
logger.info("maybe_mm_embedding processing time: %.2fms", (t_elapsed * 1000))
self.perf_logger.log(
"get_mm_embeddings_time_ms",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to be publishing a non-standard vllm metric here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to find a similar metric from vllm that we could reuse, or push the numbers there. Haven't found one yet..

The idea here is to have a way to measure impact of MM processing.

Open to suggestions, if there is a better and more standard way to deal with this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joerunde actually usage of create_perf_metric_logger already handles the optional enablement of this metric. So this will only get printed if we pass VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED env variable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

t_elapsed * 1000,
phase="prefill",
has_mm_features=True,
req_id=req_id,
)

# Cache the full embeddings for subsequent chunks
request.cached_mm_embeddings = full_embeds
Expand Down Expand Up @@ -1078,11 +1089,20 @@ def _prepare_chunked_prefill(self, req_id: str) -> SamplingForwardInputs:
)
else:
# Non-multimodal or decode: use standard token embedding
t0 = time.time()
input_embeds = self.model.get_maybe_mm_embeddings(
input_tokens,
mm_features=None,
is_decode=False,
)
t_elapsed = time.time() - t0
self.perf_logger.log(
"get_mm_embeddings_time_ms",
t_elapsed * 1000,
phase="prefill",
has_mm_features=False,
req_id=req_id,
)

model_inputs = SamplingForwardInputs(
input_tokens=input_tokens,
Expand Down Expand Up @@ -1175,11 +1195,22 @@ def _prepare_decode(

# None unless this model is multimodal; no mm_features since
# all multimodal features are merged in prefill.
t0 = time.time()
input_embeds = self.model.get_maybe_mm_embeddings(
input_tokens,
mm_features=None,
is_decode=True,
)
t_elapsed = time.time() - t0
# Log timing for each request in the decode batch
for req_id in cached_request_data.req_ids:
self.perf_logger.log(
"get_mm_embeddings_time_ms",
t_elapsed * 1000,
phase="decode",
has_mm_features=False,
req_id=req_id,
)

model_inputs = SamplingForwardInputs(
input_tokens=input_tokens,
Expand Down
Loading