-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[0.13.0][Bugfix] Add synced_cudagraph_mode to limit mixed graph modes in dp ranks
#6011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -426,23 +426,26 @@ def needs_mc2(num_tokens: int) -> bool: | |
| or self.ascend_config.recompute_scheduler_enable) | ||
|
|
||
| def _sync_metadata_across_dp( | ||
| self, num_tokens: int, | ||
| with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]: | ||
| self, | ||
| num_tokens: int, | ||
| with_prefill: bool, | ||
| cudagraph_mode: int = 0, | ||
| ) -> tuple[int, Optional[torch.Tensor], bool, int]: | ||
| # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in | ||
| # our case, we still need to sync the other two flags as well. So we need to | ||
| # include them in the all_reduce operation, and more over, we CANNOT skip it | ||
| # even if we are running in eager mode, which harms performance. | ||
| # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here | ||
| # immediately once the other two flags are no longer needed. | ||
| if self.dp_size == 1: | ||
| return num_tokens, None, with_prefill | ||
| return num_tokens, None, with_prefill, cudagraph_mode | ||
|
|
||
| if self._skip_all_reduce_across_dp_group(): | ||
| num_tokens_after_padding = torch.tensor([num_tokens] * | ||
| self.dp_size, | ||
| device="cpu", | ||
| dtype=torch.int32) | ||
| return num_tokens, num_tokens_after_padding, with_prefill | ||
| return num_tokens, num_tokens_after_padding, with_prefill, cudagraph_mode | ||
|
Comment on lines
443
to
+448
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jianzs Hi, as what you mentioned in #5979, when we skip |
||
|
|
||
| # Sync num_tokens, with_prefill across dp ranks | ||
| num_tokens_tensor = torch.tensor([ | ||
|
|
@@ -455,24 +458,34 @@ def _sync_metadata_across_dp( | |
| dtype=torch.int32, | ||
| device="cpu") | ||
|
|
||
| packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) | ||
| cudagraph_mode_tensor = torch.tensor([ | ||
| cudagraph_mode if i == self.dp_rank else 0 | ||
| for i in range(self.dp_size) | ||
| ], | ||
| dtype=torch.int32, | ||
| device="cpu") | ||
|
|
||
| packed_tensor = torch.cat( | ||
| [num_tokens_tensor, flags_tensor, cudagraph_mode_tensor]) | ||
| # use cpu_group to avoid cpu synchronization issue. | ||
| # it can be overlapped with main moell execution on npu. | ||
| dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group) | ||
|
|
||
| # Unpack the results | ||
| num_tokens_across_dp = packed_tensor[:-1] | ||
| synced_flags = packed_tensor[-1:] | ||
| num_tokens_across_dp = packed_tensor[:self.dp_size] | ||
| synced_flags = packed_tensor[self.dp_size:self.dp_size + 1] | ||
| cudagraph_mode_across_dp = packed_tensor[self.dp_size + 1:] | ||
| max_tokens_across_dp = torch.max(num_tokens_across_dp).item() | ||
| global_with_prefill = bool(synced_flags[0]) | ||
| synced_cudagraph_mode = torch.min(cudagraph_mode_across_dp).item() | ||
|
|
||
| # Create a tensor for num_tokens_after_padding | ||
| num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * | ||
| self.dp_size, | ||
| device="cpu", | ||
| dtype=torch.int32) | ||
|
|
||
| return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill | ||
| return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, synced_cudagraph_mode | ||
|
|
||
| def get_model(self) -> nn.Module: | ||
| # get raw model out of the aclgraph wrapper. | ||
|
|
@@ -486,8 +499,8 @@ def _prepare_inputs( | |
| intermediate_tensors: Optional[IntermediateTensors] = None, | ||
| ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, | ||
| int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], | ||
| Optional[torch.Tensor], Optional[torch.Tensor], int, dict[str, | ||
| Any]]: | ||
| Optional[torch.Tensor], Optional[torch.Tensor], int, int, dict[ | ||
| str, Any]]: | ||
| total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | ||
| assert total_num_scheduled_tokens > 0 | ||
| num_reqs = self.input_batch.num_reqs | ||
|
|
@@ -567,30 +580,31 @@ def _prepare_inputs( | |
| out=positions_np, | ||
| ) | ||
| max_num_scheduled_tokens = max(tokens) | ||
| if (self.use_aclgraph and total_num_scheduled_tokens | ||
| <= self.cudagraph_batch_sizes[-1]): | ||
| # Add padding to the batch size. | ||
| num_input_tokens = self.vllm_config.pad_for_cudagraph( | ||
| total_num_scheduled_tokens) | ||
| elif enable_sp(self.vllm_config): | ||
| # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, | ||
| # the model will fall back to running its FX graph in eager mode. | ||
| # In this case, when sequence parallelism is enabled, we need to pad tokens to align | ||
| # with tp_size because pad_size cannot be captured by the FX graph | ||
| uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \ | ||
| and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs) | ||
| has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 | ||
| # the following process is corresponding to _pad_for_sequence_parallelism | ||
| # in gpu_model_runner | ||
| if enable_sp(self.vllm_config): | ||
| tp_size = self.vllm_config.parallel_config.tensor_parallel_size | ||
| num_input_tokens = math.ceil( | ||
| total_num_scheduled_tokens / tp_size) * tp_size | ||
| else: | ||
| # Eager mode. | ||
| num_input_tokens = total_num_scheduled_tokens | ||
| cudagraph_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( | ||
| num_tokens=num_input_tokens, | ||
| uniform_decode=uniform_decode, | ||
| has_lora=has_lora, | ||
| ) | ||
| num_input_tokens = batch_descriptor.num_tokens | ||
| self.query_lens = torch.from_numpy(num_scheduled_tokens) | ||
|
|
||
| # Get info across DP ranks. | ||
| # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, | ||
| # Otherwise, it's just max_tokens_across_dp_cpu | ||
| (maybe_padded_num_tokens, num_tokens_across_dp, | ||
| with_prefill) = self._sync_metadata_across_dp(num_input_tokens, | ||
| with_prefill) | ||
| (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, | ||
| synced_cudagraph_mode) = self._sync_metadata_across_dp( | ||
| num_input_tokens, with_prefill, cudagraph_mode.value) | ||
| self.with_prefill = with_prefill | ||
| # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens | ||
| # We should consider removing maybe_padded_num_tokens later | ||
|
|
@@ -953,7 +967,7 @@ def _prepare_inputs( | |
| # TODO: We should make this official ASAP. Also note that if we pad here, | ||
| # the builders won’t need to add any extra padding. | ||
| if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ | ||
| uniform_decode: | ||
| uniform_decode and synced_cudagraph_mode == CUDAGraphMode.FULL.value: | ||
| max_decode_tokens = min( | ||
| self.scheduler_config.max_num_seqs * | ||
| self.uniform_decode_query_len, | ||
|
|
@@ -1078,7 +1092,7 @@ def _prepare_inputs( | |
| num_input_tokens, num_tokens_across_dp, | ||
| maybe_padded_num_tokens, logits_indices, spec_decode_metadata, | ||
| input_ids, inputs_embeds, intermediate_tensors, | ||
| max_num_scheduled_tokens, model_kwargs) | ||
| max_num_scheduled_tokens, synced_cudagraph_mode, model_kwargs) | ||
|
|
||
| # all-gather one hidden-states in sp scene | ||
| @staticmethod | ||
|
|
@@ -1475,7 +1489,7 @@ def execute_model( | |
| (attn_metadata, positions, num_scheduled_tokens_np, | ||
| num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, | ||
| logits_indices, spec_decode_metadata, input_ids, inputs_embeds, | ||
| intermediate_tensors, max_query_len, | ||
| intermediate_tensors, max_query_len, synced_cudagraph_mode, | ||
| model_kwargs) = (self._prepare_inputs(scheduler_output, | ||
| intermediate_tensors)) | ||
|
|
||
|
|
@@ -1498,7 +1512,9 @@ def execute_model( | |
| == self.input_batch.num_reqs * max_query_len) | ||
| has_lora = len(self.input_batch.lora_id_to_lora_request) > 0 | ||
| aclgraph_runtime_mode, batch_descriptor = \ | ||
| self.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) | ||
| self.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora, | ||
| disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value) | ||
| num_input_tokens = batch_descriptor.num_tokens | ||
|
|
||
| if self.ascend_config.enable_async_exponential: | ||
| self.sampler.do_async_exponential( | ||
|
|
@@ -2078,9 +2094,9 @@ def _dummy_run( | |
| self.cudagraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora) | ||
|
|
||
| # Padding for DP | ||
| (num_tokens, num_tokens_across_dp, | ||
| with_prefill) = self._sync_metadata_across_dp( | ||
| batch_descriptor.num_tokens, with_prefill) | ||
| (num_tokens, num_tokens_across_dp, with_prefill, | ||
| synced_cudagraph_mode) = self._sync_metadata_across_dp( | ||
| batch_descriptor.num_tokens, with_prefill, _ag_mode.value) | ||
|
|
||
| # If cudagraph_mode.decode_mode() == FULL and | ||
| # cudagraph_mode.seperate_routine(). This means that we are using | ||
|
|
@@ -2127,11 +2143,13 @@ def _dummy_run( | |
| if not is_profile and self.dynamic_eplb: | ||
| self.eplb_updator.forward_before() | ||
|
|
||
| if num_tokens != batch_descriptor.num_tokens: | ||
| if num_tokens_across_dp is not None: | ||
| _ag_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( | ||
| num_tokens=num_tokens, | ||
| uniform_decode=uniform_decode, | ||
| has_lora=has_lora) | ||
| has_lora=has_lora, | ||
| disable_full=synced_cudagraph_mode | ||
| <= CUDAGraphMode.PIECEWISE.value) | ||
|
|
||
| num_tokens_padded = batch_descriptor.num_tokens | ||
| num_reqs_padded = (batch_descriptor.num_reqs if | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
_skip_all_reduce_across_dp_group()is true, theall_reduceoperation for syncing metadata is skipped. However, this also skips syncingcudagraph_mode, returning the localcudagraph_modeinstead. This could lead to different ranks operating in different CUDAGraph modes, which is the exact issue this pull request aims to fix and could cause hangs.The
cudagraph_modeshould be synced across all DP ranks regardless of whether other metadata syncing is skipped.