-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Feature] add DFlash Support #7162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,18 +90,22 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, | |
| if hasattr(model_config.hf_text_config, "partial_rotary_factor"): | ||
| rope_dim = int(rope_dim * | ||
| model_config.hf_text_config.partial_rotary_factor) | ||
| _cos = torch.ones(1, | ||
| max_num_batched_tokens, | ||
| 1, | ||
| rope_dim, | ||
| dtype=dtype, | ||
| device=device) | ||
| _sin = torch.zeros(1, | ||
| max_num_batched_tokens, | ||
| 1, | ||
| rope_dim, | ||
| dtype=dtype, | ||
| device=device) | ||
| if vllm_config.speculative_config.method == "dflash": | ||
| _cos = torch.ones(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device) | ||
| _sin = torch.zeros(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device) | ||
| else: | ||
| _cos = torch.ones(1, | ||
| max_num_batched_tokens, | ||
| 1, | ||
| rope_dim, | ||
| dtype=dtype, | ||
| device=device) | ||
| _sin = torch.zeros(1, | ||
| max_num_batched_tokens, | ||
| 1, | ||
| rope_dim, | ||
| dtype=dtype, | ||
| device=device) | ||
|
Comment on lines
+93
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The size_multiplier = 2 if vllm_config.speculative_config and vllm_config.speculative_config.method == "dflash" else 1
shape = (1, max_num_batched_tokens * size_multiplier, 1, rope_dim)
_cos = torch.ones(*shape, dtype=dtype, device=device)
_sin = torch.zeros(*shape, dtype=dtype, device=device) |
||
|
|
||
|
|
||
| def get_cos_and_sin_mla(positions, use_cache=False): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,7 +196,6 @@ def load_model(self, model: nn.Module) -> None: | |
| draft_attn_layer_names = draft_attn_layer - target_attn_layer_names | ||
| draft_indexer_layer_names = indexer_layers - target_indexer_layer_names | ||
| draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names | ||
| assert len(draft_attn_layer_names) == 1 | ||
| self.attn_layer_names = list(sorted(draft_attn_layer_names)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion if self.method != "dflash":
assert len(draft_attn_layer_names) == 1, (
f"Expected 1 draft attention layer for method '{self.method}', "
f"but found {len(draft_attn_layer_names)}."
)
self.attn_layer_names = list(sorted(draft_attn_layer_names)) |
||
|
|
||
| if supports_multimodal(model): | ||
|
|
@@ -276,6 +275,8 @@ def load_model(self, model: nn.Module) -> None: | |
| # share lm_head with the target model if needed | ||
| # some model definition do not define lm_head explicitly | ||
| # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM | ||
| if self.method == "dflash": | ||
| self.model.lm_head = target_language_model.lm_head | ||
| if self.method == "eagle" and hasattr(model, "lm_head"): | ||
| logger.info("Loading EAGLE LM head weights from the target model.") | ||
| if supports_multimodal(model): | ||
|
|
@@ -325,6 +326,16 @@ def dummy_run(self, | |
| batch_descriptor=None, | ||
| dummy_compute_logits=lambda hidden_states: None, | ||
| is_profile=False): | ||
| # DFlash uses a different dummy_run path | ||
| if self.method == "dflash": | ||
| self._dummy_run_dflash( | ||
| num_tokens=num_tokens, | ||
| num_tokens_across_dp=num_tokens_across_dp, | ||
| aclgraph_runtime_mode=aclgraph_runtime_mode, | ||
| is_profile=is_profile, | ||
| ) | ||
| return | ||
|
|
||
| # update global cos, sin | ||
| update_cos_sin(self.positions[:num_tokens]) | ||
|
|
||
|
|
@@ -416,6 +427,199 @@ def dummy_run(self, | |
| self._update_full_graph_params(forward_context, num_tokens, | ||
| multi_steps_attn_metadata) | ||
|
|
||
| @torch.inference_mode() | ||
| def _dummy_run_dflash( | ||
| self, | ||
| num_tokens: int, | ||
| num_tokens_across_dp: Optional[torch.Tensor] = None, | ||
| aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, | ||
| is_profile: bool = False | ||
| ): | ||
| num_tokens = num_tokens // 2 | ||
| if not self.use_cuda_graph: | ||
| aclgraph_runtime_mode = CUDAGraphMode.NONE | ||
| num_iters_to_capture = 1 | ||
| for _ in range(num_iters_to_capture): | ||
| ( | ||
| num_input_tokens, | ||
| num_tokens_across_dp, | ||
| _, | ||
| _, | ||
| ) = self.runner._sync_metadata_across_dp(num_tokens, | ||
| is_draft_model=True) | ||
| if self.use_cuda_graph and \ | ||
| num_input_tokens <= self.runner.cudagraph_batch_sizes[-1]: | ||
| num_input_tokens = self.vllm_config.pad_for_cudagraph(num_input_tokens) | ||
| cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE | ||
| else: | ||
| cudagraph_runtime_mode = CUDAGraphMode.NONE | ||
|
|
||
| positions_len = 2 * num_input_tokens | ||
| update_cos_sin(self.positions[:positions_len]) | ||
|
|
||
| with set_ascend_forward_context( | ||
| None, | ||
| self.vllm_config, | ||
| num_tokens=num_input_tokens, | ||
| num_tokens_across_dp=num_tokens_across_dp, | ||
| num_actual_tokens=0, | ||
| in_profile_run=is_profile, | ||
| aclgraph_runtime_mode=cudagraph_runtime_mode, | ||
| is_draft_model=True): | ||
|
|
||
| input_ids = self.input_ids[:num_input_tokens] | ||
| positions = self.positions[:positions_len] | ||
| hidden_states = self.hidden_states[:num_input_tokens] | ||
|
|
||
| self.model( | ||
| input_ids=input_ids, | ||
| positions=positions, | ||
| hidden_states=hidden_states, | ||
| inputs_embeds=None, | ||
| ) | ||
|
|
||
| def _dflash_propose( | ||
| self, | ||
| # [num_tokens] | ||
| target_token_ids: torch.Tensor, | ||
| # [num_tokens] or [3, num_tokens] when M-RoPE is enabled | ||
| target_positions: torch.Tensor, | ||
| # [num_tokens, hidden_size] | ||
| target_hidden_states: torch.Tensor, | ||
| # [batch_size] | ||
| next_token_ids: torch.Tensor, | ||
| last_token_indices: torch.Tensor | None, | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| ) -> torch.Tensor: | ||
| batch_size = common_attn_metadata.num_reqs | ||
| assert self.runner is not None | ||
| target_hidden_states = self.model.combine_hidden_states(target_hidden_states) | ||
| assert target_hidden_states.shape[-1] == self.hidden_size | ||
|
|
||
| if last_token_indices is None: | ||
| last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 | ||
| query_start_loc = common_attn_metadata.query_start_loc[:batch_size + 1] | ||
| num_context_tokens = target_token_ids.shape[0] | ||
| num_query_tokens = 1 + self.num_speculative_tokens | ||
| total_num_query_tokens = batch_size * num_query_tokens | ||
| num_kv_tokens = num_context_tokens + total_num_query_tokens | ||
|
|
||
| MASK_TOKEN_ID = 151669 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| query_positions_list = [] | ||
| for i in range(batch_size): | ||
| last_position = target_positions[last_token_indices[i]] | ||
| query_positions = ( | ||
| torch.arange( | ||
| num_query_tokens, | ||
| device=target_positions.device, | ||
| dtype=target_positions.dtype, | ||
| ) | ||
| + 1 | ||
| + last_position | ||
| ) | ||
| query_positions_list.append(query_positions) | ||
| position_ids = torch.cat([target_positions] + query_positions_list) | ||
| assert position_ids.shape[0] == num_kv_tokens | ||
|
|
||
| if self.attn_metadata_builder is None: | ||
| attn_metadata_builder = self._get_attention_metadata_builder() | ||
| else: | ||
| attn_metadata_builder = self.attn_metadata_builder | ||
| block_size = attn_metadata_builder.kv_cache_spec.block_size | ||
| block_table_tensor = common_attn_metadata.block_table_tensor | ||
| query_slot_mapping_list = [] | ||
| for i in range(batch_size): | ||
| query_position_ids_i = query_positions_list[i] | ||
| block_numbers_i = query_position_ids_i // block_size | ||
| block_ids_i = block_table_tensor[i].gather( | ||
| dim=0, index=block_numbers_i.long() | ||
| ) | ||
| slot_mapping_i = ( | ||
| block_ids_i * block_size + query_position_ids_i % block_size | ||
| ).to(torch.int32) | ||
| query_slot_mapping_list.append(slot_mapping_i) | ||
| common_attn_metadata.slot_mapping = torch.cat( | ||
| [common_attn_metadata.slot_mapping[:target_hidden_states.shape[0]]] + query_slot_mapping_list) | ||
| common_attn_metadata.num_actual_tokens = num_kv_tokens | ||
| common_attn_metadata.max_query_len = num_kv_tokens | ||
| common_attn_metadata.query_start_loc = ( | ||
| self.arange[:batch_size + 1] * num_query_tokens | ||
| ) | ||
| common_attn_metadata.query_start_loc_cpu = ( | ||
| torch.from_numpy(self.token_arange_np[:batch_size + 1]).clone() | ||
| * num_query_tokens | ||
| ) | ||
| common_attn_metadata.max_seq_len += num_query_tokens | ||
| common_attn_metadata.seq_lens[:batch_size] = ( | ||
| common_attn_metadata.seq_lens[:batch_size] + num_query_tokens | ||
| ) | ||
| common_attn_metadata.seq_lens_cpu = ( | ||
| common_attn_metadata.seq_lens.cpu() | ||
| ) | ||
| common_attn_metadata.causal = False | ||
| common_attn_metadata.attn_mask = None | ||
| common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill | ||
|
|
||
| builder = self.runner.attn_groups[0][0].get_metadata_builder() | ||
| attn_metadata = builder.build(0, common_attn_metadata, | ||
| self.runner.get_model()) | ||
| per_layer_attn_metadata = {} | ||
| for layer_name in self.attn_layer_names: | ||
| per_layer_attn_metadata[layer_name] = attn_metadata | ||
|
|
||
| self.positions[:num_kv_tokens] = position_ids | ||
| self.input_ids[:total_num_query_tokens] = MASK_TOKEN_ID | ||
| for i in range(batch_size): | ||
| self.input_ids[i * num_query_tokens] = next_token_ids[i] | ||
| self.hidden_states[:num_context_tokens] = target_hidden_states | ||
|
|
||
| update_cos_sin(position_ids) | ||
|
|
||
| if self.use_cuda_graph and total_num_query_tokens <= self.runner.cudagraph_batch_sizes[-1]: | ||
| num_input_tokens = self.vllm_config.pad_for_cudagraph( | ||
| total_num_query_tokens | ||
| ) | ||
| aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE | ||
| else: | ||
| num_input_tokens = total_num_query_tokens | ||
| aclgraph_runtime_mode = CUDAGraphMode.NONE | ||
|
|
||
| ( | ||
| num_input_tokens, | ||
| num_tokens_across_dp, | ||
| _, | ||
| _, | ||
| ) = self.runner._sync_metadata_across_dp( | ||
| num_input_tokens, is_draft_model=True | ||
| ) | ||
|
|
||
| with set_ascend_forward_context( | ||
| per_layer_attn_metadata, | ||
| self.vllm_config, | ||
| num_tokens=num_input_tokens, | ||
| num_actual_tokens=num_kv_tokens, | ||
| num_tokens_across_dp=num_tokens_across_dp, | ||
| aclgraph_runtime_mode=CUDAGraphMode.NONE, | ||
| is_draft_model=True): | ||
|
|
||
| ret_hidden_states = self.model( | ||
| input_ids=self.input_ids[:num_input_tokens], | ||
| positions=self.positions[:num_kv_tokens], | ||
| hidden_states=self.hidden_states[:num_context_tokens], | ||
| inputs_embeds=None, | ||
| ) | ||
|
|
||
| valid_hidden_list = [] | ||
| for i in range(batch_size): | ||
| start = i * num_query_tokens + 1 | ||
| end = (i + 1) * num_query_tokens | ||
| valid_hidden_list.append(ret_hidden_states[start:end]) | ||
| valid_hidden_states = torch.cat(valid_hidden_list, dim=0) | ||
| logits = self.model.compute_logits(valid_hidden_states) | ||
| draft_token_ids = logits.argmax(dim=-1) | ||
| return draft_token_ids.view(batch_size, self.num_speculative_tokens) | ||
|
|
||
| def _propose( | ||
| self, | ||
| # [num_tokens] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -383,7 +383,7 @@ def _set_up_drafter(self): | |
| self.decode_token_per_req = 1 + spec_token_num | ||
| if get_pp_group().is_last_rank: | ||
| self.drafter = self._get_drafter() | ||
| if self.speculative_config.method == "eagle3": | ||
| if self.speculative_config.method in ("eagle3", "dflash"): | ||
| assert isinstance(self.drafter, EagleProposer) | ||
| self.use_aux_hidden_state_outputs = ( | ||
| self.drafter.eagle3_use_aux_hidden_state) | ||
|
|
@@ -1450,21 +1450,31 @@ def propose_draft_token_ids( | |
| else: | ||
| target_hidden_states = hidden_states[token_indices] | ||
| assert self.drafter is not None | ||
| draft_token_ids = self.drafter._propose( | ||
| target_token_ids=target_token_ids, | ||
| target_positions=target_positions, | ||
| target_hidden_states=target_hidden_states, | ||
| next_token_ids=next_token_ids, | ||
| last_token_indices=token_indices_to_sample, | ||
| common_attn_metadata=common_attn_metadata, | ||
| sampling_metadata=sampling_metadata, | ||
| req_scheduled_tokens=req_scheduled_tokens, | ||
| long_seq_metadata=long_seq_metadata, | ||
| num_prefill_reqs=num_prefill_reqs, | ||
| num_decode_reqs=num_decode_reqs, | ||
| scheduler_output=scheduler_output, | ||
| num_scheduled_tokens=num_scheduled_tokens, | ||
| ) | ||
| if self.speculative_config.method == "dflash": | ||
| draft_token_ids = self.drafter._dflash_propose( | ||
| target_token_ids=target_token_ids, | ||
| target_positions=target_positions, | ||
| target_hidden_states=target_hidden_states, | ||
| next_token_ids=next_token_ids, | ||
| last_token_indices=token_indices_to_sample, | ||
| common_attn_metadata=common_attn_metadata | ||
| ) | ||
| else: | ||
| draft_token_ids = self.drafter._propose( | ||
| target_token_ids=target_token_ids, | ||
| target_positions=target_positions, | ||
| target_hidden_states=target_hidden_states, | ||
| next_token_ids=next_token_ids, | ||
| last_token_indices=token_indices_to_sample, | ||
| common_attn_metadata=common_attn_metadata, | ||
| sampling_metadata=sampling_metadata, | ||
| req_scheduled_tokens=req_scheduled_tokens, | ||
| long_seq_metadata=long_seq_metadata, | ||
| num_prefill_reqs=num_prefill_reqs, | ||
| num_decode_reqs=num_decode_reqs, | ||
| scheduler_output=scheduler_output, | ||
| num_scheduled_tokens=num_scheduled_tokens, | ||
| ) | ||
|
Comment on lines
+1453
to
+1477
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication in the propose_kwargs = {
"target_token_ids": target_token_ids,
"target_positions": target_positions,
"target_hidden_states": target_hidden_states,
"next_token_ids": next_token_ids,
"last_token_indices": token_indices_to_sample,
"common_attn_metadata": common_attn_metadata,
}
if self.speculative_config.method == "dflash":
draft_token_ids = self.drafter._dflash_propose(**propose_kwargs)
else:
propose_kwargs.update({
"sampling_metadata": sampling_metadata,
"req_scheduled_tokens": req_scheduled_tokens,
"long_seq_metadata": long_seq_metadata,
"num_prefill_reqs": num_prefill_reqs,
"num_decode_reqs": num_decode_reqs,
"scheduler_output": scheduler_output,
"num_scheduled_tokens": num_scheduled_tokens,
})
draft_token_ids = self.drafter._propose(**propose_kwargs) |
||
|
|
||
| else: | ||
| raise ValueError("Unknown speculative decoding method: " | ||
|
|
@@ -2387,7 +2397,7 @@ def load_model(self) -> None: | |
| self.drafter.load_model(self.model) | ||
| if self.use_aux_hidden_state_outputs: | ||
| self.model.set_aux_hidden_state_layers( | ||
| self.model.get_eagle3_aux_hidden_state_layers()) | ||
| self.model.get_eagle3_aux_hidden_state_layers(self.drafter.method)) | ||
|
|
||
| if self.lora_config: | ||
| self.model = self.load_lora_model(self.model, self.vllm_config, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the
ifandelsebranches. Most arguments totorch_npu.npu_fused_infer_attention_scoreare the same. This can be refactored to improve readability and maintainability by creating a common dictionary of arguments.