From 83059cb96b6a9e91ab02ed55965ffea6754b9f08 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 8 Jul 2025 16:42:17 -0700 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=94=A5=20remove=20new=5Ftoken=5Fids?= =?UTF-8?q?=20from=20warmup=20decode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 95c60f249..90fc6a84e 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -359,21 +359,15 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # one decode iteration across all sequences req_ids = [] - new_token_ids = [] new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token - new_block_ids.append([req.block_ids]) num_computed_tokens.append(prompt_len) cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=new_token_ids, + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) From a9f34c8b61b56b9c2c18688f7361b230953c4486 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 9 Jul 2025 11:06:12 -0700 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=94=A5=20remove=20new=5Ftoken=5Fids?= =?UTF-8?q?=20from=20sb=20warmup=20too?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 90fc6a84e..ef716a6f3 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -540,22 +540,17 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # Set up dummy cached_requests for decode steps req_ids = [] - new_token_ids = [] new_block_ids = [] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token new_block_ids.append([req.block_ids]) num_computed_tokens.append(req.num_computed_tokens) cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=new_token_ids, + new_token_ids=[], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) From 99aad58e3b4234d7689c497035641c9abb7b77f6 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 9 Jul 2025 11:08:24 -0700 Subject: [PATCH 3/6] =?UTF-8?q?=F0=9F=99=88=20ignore=20the=20random=20type?= =?UTF-8?q?=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index ef716a6f3..fc9df11fe 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -359,7 +359,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # one decode iteration across all sequences req_ids = [] - new_block_ids = [] + new_block_ids = [] # type: ignore[var-annotated] num_computed_tokens = [] for req in dummy_requests: req_ids.append(req.req_id) From dcb8f0cefb62bfe4966a8bd8259df4e14d495885 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 9 Jul 2025 11:19:24 -0700 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=8E=A8=20make=20it=20consistent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index fc9df11fe..91bcda231 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -550,7 +550,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=[], + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) From a4a188f45d7ade47688bef98ff8d3784968b6e9c Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 30 Jul 2025 10:30:23 -0700 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=9A=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 266 +++++++++++++-------------- 1 file changed, 129 insertions(+), 137 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 91bcda231..57e32e82e 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -5,7 +5,7 @@ import platform import signal import time -from typing import Any, Optional, Union, cast +from typing import Optional, Union, cast import torch import torch.distributed as dist @@ -16,6 +16,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -29,25 +30,44 @@ from vllm_spyre.model_executor.model_loader import spyre_setup from vllm_spyre.platform import SpyrePlatform from vllm_spyre.v1.worker.spyre_model_runner import ( - ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner) + ContinuousBatchingSpyreModelRunner, SpyrePoolingModelRunner, + StaticBatchingSpyreModelRunner, SupportedTask) logger = init_logger(__name__) +# var to make sure we always warmup with the right context +_inside_warmup_mode = False + @contextlib.contextmanager def _maybe_warmup_context(): + global _inside_warmup_mode warmup_context = contextlib.nullcontext if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn": from torch_sendnn import warmup_mode warmup_context = warmup_mode with warmup_context(): + _inside_warmup_mode = True yield + _inside_warmup_mode = False class SpyreWorker(WorkerBaseV1): """A worker class that executes the model on a group of Spyre cores. """ + @property + def is_pooling(self) -> bool: + return self.model_config.task == "embed" \ + if self.model_config.task else \ + "embed" in self.model_config.supported_tasks + + @property + def is_decoder(self) -> bool: + return self.model_config.task == "generate" \ + if self.model_config.task else \ + "generate" in self.model_config.supported_tasks + def get_kv_cache_spec(self) -> KVCacheSpec: """Get specifications for KV cache implementation. @@ -76,12 +96,12 @@ def compile_or_warm_up_model(self) -> None: (s["prompt_length"], s["new_tokens"], s["batch_size"]) for s in self.spyre_warmup_shapes ]): - if self.model_config.task != "embed": + if not self.is_pooling: # TODO: remove if spyre supports # lower number of output tokens - assert num_decode_tokens >= 3, ( + assert num_decode_tokens >= 2, ( "VLLM_SPYRE_WARMUP_NEW_TOKENS must be " - "at least 3 (spyre requirement).") + "at least 2 (spyre requirement).") # warmup individual combination logger.info( "[WARMUP] (%d/%d) for prompt length %d, decoding %d tokens " @@ -89,6 +109,9 @@ def compile_or_warm_up_model(self) -> None: prompt_len, num_decode_tokens, batch_size) self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, self.restricted_tokens, batch_size) + + self.model_runner.complete_warmup() + all_warmup_end_t = time.time() all_warmup_total_t = all_warmup_end_t - all_warmup_start_t self.perf_metrics.log("total warmup time", all_warmup_total_t) @@ -98,7 +121,6 @@ def compile_or_warm_up_model(self) -> None: "[WARMUP] All %d prompt/decode/batchsize-shape " "combinations finished in %.3fs", num_shape_combinations, all_warmup_total_t) - self.model_runner.complete_warmup() def check_health(self) -> None: """Basic health check (override for device-specific checks).""" @@ -158,9 +180,12 @@ def __init__( init_cached_hf_modules() self.model_runner: \ Union[StaticBatchingSpyreModelRunner, - ContinuousBatchingSpyreModelRunner] - if self.model_config.task == "embed": - raise NotImplementedError + ContinuousBatchingSpyreModelRunner, SpyrePoolingModelRunner] + if self.is_pooling: + self.model_runner = SpyrePoolingModelRunner( + self.vllm_config, self.is_driver_worker) + self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( + self.vllm_config.scheduler_config) else: if envs_spyre.VLLM_SPYRE_USE_CB: self.model_runner = ContinuousBatchingSpyreModelRunner( @@ -319,12 +344,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))] - # TODO temporary until 'pooling_params' makes it to a release version - # in vllm - extra_kwargs: dict[str, Any] = {} - if "pooling_params" in NewRequestData.__dataclass_fields__: - extra_kwargs["pooling_params"] = None - dummy_requests = [ + dummy_requests: list[NewRequestData] = [ NewRequestData( req_id="warmup-%d" % (i), prompt_token_ids=warmup_tokens_tensor[i].tolist(), @@ -332,65 +352,18 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(max_tokens=num_decode_tokens), + pooling_params=None, block_ids=[0], # not actually used num_computed_tokens=0, lora_request=None, - **extra_kwargs) for i in range(batch_size + 1) + ) for i in range(batch_size + 1) ] add_dummy_request = dummy_requests.pop(-1) with _maybe_warmup_context(): - for i, req in enumerate(dummy_requests): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req.req_id: prompt_len}, - total_num_scheduled_tokens=prompt_len, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size) - self.execute_model(scheduler_output) - - # one decode iteration across all sequences - req_ids = [] - new_block_ids = [] # type: ignore[var-annotated] - num_computed_tokens = [] - for req in dummy_requests: - req_ids.append(req.req_id) - num_computed_tokens.append(prompt_len) - cached_request_data = CachedRequestData( - req_ids=req_ids, - resumed_from_preemption=False, - new_token_ids=[[] for _ in range(len(dummy_requests))], - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=cached_request_data, - num_scheduled_tokens={ - f"warmup-{i}": 1 - for i in range(batch_size) - }, - total_num_scheduled_tokens=batch_size, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - logger.info("[WARMUP] Decode...") - self.execute_model(scheduler_output) - self._cleanup_model_runner(request=dummy_requests) + self._dynamic_warmup(dummy_requests=dummy_requests, + prompt_len=prompt_len, + batch_size=batch_size) # warmup_mode completes the graph compilation, but we need to do # one additional prefill to deploy the compiled program to the device, @@ -413,16 +386,14 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): self.execute_model(scheduler_output) self._cleanup_model_runner(request=[add_dummy_request]) - # get the number or pages from the actual Spyre card after the warmup - # and set it accordingly in the model runner and the kv cache size - n_blocks_avail = self._get_num_blocks_available() - model_runner._set_free_blocks(num_blocks=n_blocks_avail) - model_runner.model.model._set_past_key_value_states( - num_blocks=n_blocks_avail) + model_runner.complete_warmup() warmup_end_t = time.time() warmup_total_t = warmup_end_t - warmup_start_t - logger.info("[WARMUP] Finished in %.3fs", warmup_total_t) + compile_cache_str = 'enabled' if int( + os.getenv("TORCH_SENDNN_CACHE_ENABLE", "0")) else 'disabled' + logger.info("[WARMUP] Finished in %.3fs (compilation cache %s)", + warmup_total_t, compile_cache_str) maybe_override_signals_handler() @@ -449,51 +420,6 @@ def _cleanup_model_runner(self, request) -> None: cast(ContinuousBatchingSpyreModelRunner, self.model_runner) model_runner.tkv = 0 - def _get_num_blocks_available(self) -> int: - """Function returns the number of available blocks/pages. - Will eventually contain a function in torch_sendnn which reads - the actual value provided by the compiler for backend sendnn""" - - max_batch_size = \ - self.model_runner.vllm_config.scheduler_config.max_num_seqs - max_model_len = \ - self.model_runner.vllm_config.scheduler_config.max_model_len - block_size = self.model_runner.block_size # type: ignore[union-attr] - - min_req_num_blocks = max_model_len // block_size - # min_req_num_blocks is not enough blocks for the following test: - # tests/e2e/test_spyre_cb.py::test_scheduler_cb_steps_tkv - # [seqs_max_tokens4-prompts_lengths4-steps_add_reqs4- - # checked_steps4-256-False-2-eager-llama-194m] - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn': - # TODO: replace num_blocks_spyre by calling a function in - # torch_sendnn which returns the value set by the Spyre compiler - num_blocks_spyre = max_batch_size * min_req_num_blocks - assert num_blocks_spyre >= min_req_num_blocks, ( - "Number of pages available on Spyre (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_spyre, min_req_num_blocks)) - max_concurrency_spyre = num_blocks_spyre * block_size \ - / max_model_len - logger.info("Spyre KV cache size: %s tokens", - num_blocks_spyre * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_spyre) - return num_blocks_spyre - else: # dynamo backend 'eager' - num_blocks_cpu = max_batch_size * min_req_num_blocks - assert num_blocks_cpu >= min_req_num_blocks, ( - "Number of pages available on CPU (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_cpu, min_req_num_blocks)) - max_concurrency_cpu = num_blocks_cpu * block_size / max_model_len - logger.info("CPU KV cache size: %s tokens", - num_blocks_cpu * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_cpu) - return num_blocks_cpu - def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, special_token_ids, batch_size): @@ -517,25 +443,24 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (batch_size, prompt_len))] - # TODO temporary until 'pooling_params' makes it to a release version - # in vllm - extra_kwargs: dict[str, Any] = {} - if "pooling_params" in NewRequestData.__dataclass_fields__: - extra_kwargs["pooling_params"] = None + sampling_params, pooling_params = None, None + if not self.is_pooling: + sampling_params = SamplingParams(max_tokens=num_decode_tokens) + else: + pooling_params = PoolingParams() # Set up dummy requests for prefill steps dummy_requests = [ - NewRequestData( - req_id="warmup", - prompt_token_ids=warmup_tokens_tensor[i].tolist(), - mm_inputs=[], - mm_hashes=[], - mm_positions=[], - sampling_params=SamplingParams(max_tokens=num_decode_tokens), - block_ids=[0], - num_computed_tokens=0, - lora_request=None, - **extra_kwargs) for i in range(batch_size) + NewRequestData(req_id="warmup", + prompt_token_ids=warmup_tokens_tensor[i].tolist(), + mm_inputs=[], + mm_hashes=[], + mm_positions=[], + sampling_params=sampling_params, + pooling_params=pooling_params, + block_ids=[0], + num_computed_tokens=0, + lora_request=None) for i in range(batch_size) ] # Set up dummy cached_requests for decode steps @@ -550,7 +475,6 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, - new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) @@ -598,11 +522,76 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, batch_size=batch_size, max_tokens=num_decode_tokens, prompt_len=prompt_len) + compile_cache_str = 'enabled' if int( + os.getenv("TORCH_SENDNN_CACHE_ENABLE", "0")) else 'disabled' logger.info( "[WARMUP] Prompt length %d and max output tokens %d " - "finished in %.3fs", warmup_total_t, prompt_len, num_decode_tokens) + "finished in %.3fs (compilation cache %s)", prompt_len, + num_decode_tokens, warmup_total_t, compile_cache_str) maybe_override_signals_handler() + def _dynamic_warmup( + self, + dummy_requests: list[NewRequestData], + prompt_len: int, + batch_size: int, + ) -> None: + + assert ( + _inside_warmup_mode + ), "it looks like you are outside the warmup context for warmup" + + for i, req in enumerate(dummy_requests): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[req], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={req.req_id: prompt_len}, + total_num_scheduled_tokens=prompt_len, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size) + + self.execute_model(scheduler_output) + + # one decode iteration across all sequences + req_ids = [] + new_block_ids = [] + num_computed_tokens = [] + for req in dummy_requests: + req_ids.append(req.req_id) + new_block_ids.append([req.block_ids]) + num_computed_tokens.append(prompt_len) + cached_request_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=False, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_request_data, + num_scheduled_tokens={f"warmup-{i}": 1 + for i in range(batch_size)}, + total_num_scheduled_tokens=batch_size, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + logger.info("[WARMUP] Decode...") + self.execute_model(scheduler_output) + self._cleanup_model_runner(request=dummy_requests) + def _warmup_model_forward_pass( self, scheduler_output: SchedulerOutput, @@ -637,6 +626,9 @@ def do_metadata_broadcast(self) -> bool: def kv_cache(self) -> Optional[list[list[torch.Tensor]]]: return None + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.model_runner.get_supported_tasks() + @SpyrePlatform.inference_mode() def execute_model( self, From fd699ef8b7c7e27f715393db4311c2245570ed57 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 30 Jul 2025 10:40:53 -0700 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=90=9B=20whoops?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm_spyre/v1/worker/spyre_worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 57e32e82e..313cce320 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -475,6 +475,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) @@ -570,6 +571,7 @@ def _dynamic_warmup( cached_request_data = CachedRequestData( req_ids=req_ids, resumed_from_preemption=False, + new_token_ids=[[] for _ in range(len(dummy_requests))], new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, )