-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Misc] Tidy up some spec decode logic in GPUModelRunner #31591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b225a73
f6229e0
99922e5
d3fd6d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -452,6 +452,11 @@ def __init__( | |
| self.num_spec_tokens = 0 | ||
| if self.speculative_config: | ||
| self.num_spec_tokens = self.speculative_config.num_speculative_tokens | ||
| draft_config = self.speculative_config.draft_model_config | ||
| if draft_config is not None and draft_config.max_model_len is not None: | ||
| self.effective_drafter_max_model_len = draft_config.max_model_len | ||
| else: | ||
| self.effective_drafter_max_model_len = self.max_model_len | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts on making this min(self.max model len, draft max model len)? We have been seeing some logs where the drafter has a very high max model len even when the base model doesn't. Also if you do this clamping you can move it into a helper fn to share the logic with the update function below
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just aiming to keep the existing logic, I'm not sure what makes the most sense, would defer to your judgement. |
||
|
|
||
| # Request states. | ||
| self.requests: dict[str, CachedRequestState] = {} | ||
|
|
@@ -674,6 +679,13 @@ def __init__( | |
| self.kv_connector_output: KVConnectorOutput | None = None | ||
| self.layerwise_nvtx_hooks_registered = False | ||
|
|
||
| def update_max_model_len(self, max_model_len: int) -> None: | ||
| self.max_model_len = max_model_len | ||
| if self.speculative_config: | ||
| draft_config = self.speculative_config.draft_model_config | ||
| if draft_config is None or draft_config.max_model_len is None: | ||
| self.effective_drafter_max_model_len = self.max_model_len | ||
|
|
||
| def reset_mm_cache(self) -> None: | ||
| if self.mm_budget: | ||
| self.mm_budget.reset_cache() | ||
|
|
@@ -3423,54 +3435,41 @@ def propose_draft_token_ids(sampled_token_ids): | |
| self._copy_draft_token_ids_to_cpu(scheduler_output) | ||
|
|
||
| spec_config = self.speculative_config | ||
| use_padded_batch_for_eagle = ( | ||
| spec_config is not None | ||
| and spec_config.use_eagle() | ||
| and not spec_config.disable_padded_drafter_batch | ||
| ) | ||
| effective_drafter_max_model_len = self.max_model_len | ||
| if effective_drafter_max_model_len is None: | ||
| effective_drafter_max_model_len = self.model_config.max_model_len | ||
| if ( | ||
| spec_config is not None | ||
| and spec_config.draft_model_config is not None | ||
| and spec_config.draft_model_config.max_model_len is not None | ||
| ): | ||
| effective_drafter_max_model_len = ( | ||
| spec_config.draft_model_config.max_model_len | ||
| propose_drafts_after_bookkeeping = False | ||
| if spec_config is not None: | ||
| input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( | ||
| spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens | ||
| <= self.effective_drafter_max_model_len | ||
| ) | ||
| input_fits_in_drafter = spec_decode_common_attn_metadata and ( | ||
| spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens | ||
| <= effective_drafter_max_model_len | ||
| ) | ||
| if use_padded_batch_for_eagle: | ||
| assert self.speculative_config is not None | ||
| assert isinstance(self.drafter, EagleProposer) | ||
| sampled_token_ids = sampler_output.sampled_token_ids | ||
| if input_fits_in_drafter: | ||
| if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch: | ||
| # EAGLE speculative decoding can use the GPU sampled tokens | ||
| # as inputs, and does not need to wait for bookkeeping to finish. | ||
| propose_draft_token_ids(sampled_token_ids) | ||
| elif self.valid_sampled_token_count_event is not None: | ||
| assert spec_decode_common_attn_metadata is not None | ||
| next_token_ids, valid_sampled_tokens_count = ( | ||
| self.drafter.prepare_next_token_ids_padded( | ||
| spec_decode_common_attn_metadata, | ||
| sampled_token_ids, | ||
| self.requests, | ||
| self.input_batch, | ||
| self.discard_request_mask.gpu, | ||
| assert isinstance(self.drafter, EagleProposer) | ||
| sampled_token_ids = sampler_output.sampled_token_ids | ||
| if input_fits_in_drafter: | ||
| propose_draft_token_ids(sampled_token_ids) | ||
| elif self.valid_sampled_token_count_event is not None: | ||
| assert spec_decode_common_attn_metadata is not None | ||
| next_token_ids, valid_sampled_tokens_count = ( | ||
| self.drafter.prepare_next_token_ids_padded( | ||
| spec_decode_common_attn_metadata, | ||
| sampled_token_ids, | ||
| self.requests, | ||
| self.input_batch, | ||
| self.discard_request_mask.gpu, | ||
| ) | ||
| ) | ||
| ) | ||
| self._copy_valid_sampled_token_count( | ||
| next_token_ids, valid_sampled_tokens_count | ||
| ) | ||
| # Since we couldn't run the drafter, | ||
| # just use zeros for the draft tokens. | ||
| self._draft_token_ids = torch.zeros( | ||
| 1, device=self.device, dtype=torch.int32 | ||
| ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) | ||
| self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) | ||
| self._copy_valid_sampled_token_count( | ||
| next_token_ids, valid_sampled_tokens_count | ||
| ) | ||
| # Since we couldn't run the drafter, | ||
| # just use zeros for the draft tokens. | ||
| self._draft_token_ids = torch.zeros( | ||
| 1, device=self.device, dtype=torch.int32 | ||
| ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) | ||
| self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) | ||
| else: | ||
| propose_drafts_after_bookkeeping = input_fits_in_drafter | ||
|
|
||
| with record_function_or_nullcontext("gpu_model_runner: bookkeep"): | ||
| ( | ||
|
|
@@ -3490,17 +3489,14 @@ def propose_draft_token_ids(sampled_token_ids): | |
| spec_decode_metadata, | ||
| ) | ||
|
|
||
| if ( | ||
| self.speculative_config | ||
| and not use_padded_batch_for_eagle | ||
| and input_fits_in_drafter | ||
| ): | ||
| if propose_drafts_after_bookkeeping: | ||
| # ngram and other speculative decoding methods use the sampled | ||
| # tokens on the CPU, so they are run after bookkeeping. | ||
| propose_draft_token_ids(valid_sampled_token_ids) | ||
|
|
||
| with record_function_or_nullcontext("gpu_model_runner: eplb"): | ||
| self.eplb_step() | ||
|
|
||
| with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): | ||
| output = ModelRunnerOutput( | ||
| req_ids=req_ids_output_copy, | ||
|
|
@@ -3518,6 +3514,7 @@ def propose_draft_token_ids(sampled_token_ids): | |
|
|
||
| if not self.use_async_scheduling: | ||
| return output | ||
|
|
||
| with record_function_or_nullcontext( | ||
| "gpu_model_runner: AsyncGPUModelRunnerOutput" | ||
| ): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just set
self.draft_configand shortcut the multiple None checks when we need to access it?