diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 35dd617eeba0..34ddd17f8620 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -57,10 +57,7 @@ from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager -from vllm.v1.worker.gpu.dp_utils import ( - get_cudagraph_and_dp_padding, - make_num_tokens_across_dp, -) +from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, @@ -265,7 +262,7 @@ def load_model(self, *args, **kwargs) -> None: prepare_communication_buffer_for_model(self.model) if self.speculator is not None: - prepare_communication_buffer_for_model(self.speculator) + prepare_communication_buffer_for_model(self.speculator.model) # Initialize the components that require the model. self.model_state = init_model_state( @@ -382,8 +379,41 @@ def _dummy_run( return None, None assert self.execute_model_state is not None - input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state + ( + input_batch, + model_inputs, + attn_metadata, + slot_mappings_by_layer, + hidden_states, + aux_hidden_states, + kv_connector_output, + num_tokens_across_dp, + ) = self.execute_model_state self.execute_model_state = None + + # dummy run the eagle speculator's propose to ensure DP/EP sync. + if self.speculator is not None: + self.speculator.propose( + input_batch=input_batch, + attn_metadata=attn_metadata, + slot_mappings=slot_mappings_by_layer, + last_hidden_states=hidden_states, + aux_hidden_states=aux_hidden_states, + num_sampled=torch.ones( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ), + num_rejected=torch.zeros( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ), + last_sampled=self.req_states.last_sampled_tokens, + next_prefill_tokens=self.req_states.next_prefill_tokens, + temperature=self.sampler.sampling_states.temperature.gpu, + seeds=self.sampler.sampling_states.seeds.gpu, + num_tokens_across_dp=num_tokens_across_dp, + dummy_run=True, + skip_attn_for_dummy_run=skip_attn, + ) + assert hidden_states is not None # Last PP rank always has hidden_states sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -431,17 +461,6 @@ def profile_run(self) -> None: else: self._dummy_pooler_run(hidden_states) - if self.speculator is not None: - num_tokens_across_dp = make_num_tokens_across_dp( - self.parallel_config.data_parallel_size, self.max_num_tokens - ) - self.speculator.run_model( - self.max_num_tokens, - attn_metadata=None, - slot_mappings=None, - num_tokens_across_dp=num_tokens_across_dp, - ) - torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() @@ -977,6 +996,7 @@ def execute_model( hidden_states, aux_hidden_states, kv_connector_output, + num_tokens_across_dp, ) if not self.is_last_pp_rank: @@ -1003,6 +1023,7 @@ def sample_tokens( hidden_states, aux_hidden_states, kv_connector_output, + num_tokens_across_dp, ) = self.execute_model_state self.execute_model_state = None @@ -1076,6 +1097,7 @@ def sample_tokens( self.req_states.next_prefill_tokens, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, + num_tokens_across_dp=num_tokens_across_dp, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py index 77dddf3ada1c..157ed1182485 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py @@ -55,6 +55,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): def get_cudagraph_size(self, num_tokens: int) -> int | None: return self.cudagraph_sizes.get(num_tokens) + def get_cudagraph_runtime_mode( + self, num_tokens: int + ) -> tuple[CUDAGraphMode, int | None]: + cudagraph_size = self.get_cudagraph_size(num_tokens) + if cudagraph_size is None: + cudagraph_mode = CUDAGraphMode.NONE + else: + cudagraph_mode = self.cudagraph_mode + + if ( + cudagraph_mode == CUDAGraphMode.FULL + and cudagraph_size is not None + and cudagraph_size not in self.graphs + ): + # If graph wasn't captured yet, fall back to eager. + # This might happen when the dummy run is called before capture. + cudagraph_mode = CUDAGraphMode.NONE + cudagraph_size = None + return cudagraph_mode, cudagraph_size + def capture_graph( self, num_tokens: int, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 9ea84386bdce..9185850dcb62 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -16,6 +16,7 @@ build_slot_mappings_by_layer, ) from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample @@ -48,6 +49,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vocab_size = self.draft_model_config.get_vocab_size() self.dtype = vllm_config.model_config.dtype + # DP configuration + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, @@ -122,8 +127,8 @@ def generate_draft( self, num_reqs: int, num_tokens_padded: int, - attn_metadata: dict[str, Any], - slot_mappings: dict[str, torch.Tensor], + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, ) -> None: @@ -164,9 +169,10 @@ def generate_draft( self.hidden_states, self.max_model_len, ) - self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos - ) + if attn_metadata is not None: + self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) def capture_model(self) -> None: if self.num_speculative_steps == 1: @@ -203,6 +209,9 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, + num_tokens_across_dp: torch.Tensor | None = None, + dummy_run: bool = False, + skip_attn_for_dummy_run: bool = False, ) -> torch.Tensor: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and @@ -236,7 +245,7 @@ def propose( num_tokens, attn_metadata, slot_mappings, - num_tokens_across_dp=None, # FIXME + num_tokens_across_dp=num_tokens_across_dp, ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -282,48 +291,64 @@ def propose( self.max_model_len, self.max_num_reqs, ) - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos - ) - cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) - cudagraph_mode = self.cudagraph_manager.cudagraph_mode - if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: + if not (dummy_run and skip_attn_for_dummy_run): + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos + ) + + cudagraph_mode, cudagraph_size = ( + self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs) + ) + num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = ( + get_cudagraph_and_dp_padding( + num_reqs, + cudagraph_size, + cudagraph_mode.value, + self.dp_size, + self.dp_rank, + ) + ) + cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode) + if cudagraph_mode == CUDAGraphMode.FULL: # Run full CUDA graph. - self.cudagraph_manager.run_fullgraph(cudagraph_size) + self.cudagraph_manager.run_fullgraph(num_tokens_padded) return self.draft_tokens[:num_reqs] # Run eager or piecewise CUDA graph. - num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs - query_start_loc_cpu = torch.arange( - num_reqs + 1, dtype=torch.int32, device="cpu" - ) - block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] - - # FIXME(woosuk): This is UNSAFE!! - attn_metadata = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=num_reqs, - num_tokens=num_reqs, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=1, - seq_lens=self.input_buffers.seq_lens[:num_reqs], - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, - ) - slot_mappings_by_layer = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) + attn_metadata_updated = None + slot_mappings_updated = None + if not (dummy_run and skip_attn_for_dummy_run): + query_start_loc_cpu = torch.arange( + num_reqs + 1, dtype=torch.int32, device="cpu" + ) + block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] + + # FIXME(woosuk): This is UNSAFE!! + attn_metadata_updated = build_attn_metadata( + attn_groups=self.attn_groups, + num_reqs=num_reqs, + num_tokens=num_reqs, + query_start_loc_gpu=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=1, + seq_lens=self.input_buffers.seq_lens[:num_reqs], + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + slot_mappings_updated = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) + self.generate_draft( num_reqs, num_tokens_padded, - attn_metadata, - slot_mappings_by_layer, - num_tokens_across_dp=None, # FIXME + attn_metadata_updated, + slot_mappings_updated, + num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, ) return self.draft_tokens[:num_reqs]