Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,22 +592,39 @@ def forward_fused_infer_attention(self, query: torch.Tensor,
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_tokens]
value = value[:num_tokens]
# Get workspace from cache or calculate it if not present.
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
if not attn_metadata.causal:
# for dflash
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=0,
)
else:
# Get workspace from cache or calculate it if not present.
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
Comment on lines +595 to +627
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's significant code duplication between the if and else branches. Most arguments to torch_npu.npu_fused_infer_attention_score are the same. This can be refactored to improve readability and maintainability by creating a common dictionary of arguments.

        common_kwargs = {
            "query": query,
            "key": key,
            "value": value,
            "block_table": block_table,
            "input_layout": "TND",
            "block_size": block_size,
            "actual_seq_lengths": attn_metadata.actual_seq_lengths_q,
            "actual_seq_lengths_kv": actual_seq_lengths_kv,
            "num_key_value_heads": self.num_kv_heads,
            "num_heads": self.num_heads,
            "scale": self.scale,
        }
        if not attn_metadata.causal:
            # for dflash
            attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                **common_kwargs,
                sparse_mode=0,
            )
        else:
            # Get workspace from cache or calculate it if not present.
            attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                **common_kwargs,
                atten_mask=attn_metadata.attn_mask,
                sparse_mode=3,
            )


attn_output = attn_output.view(num_tokens, self.num_heads,
self.head_size)
Expand Down
28 changes: 16 additions & 12 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,22 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
_cos = torch.ones(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
if vllm_config.speculative_config.method == "dflash":
_cos = torch.ones(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device)
_sin = torch.zeros(1, max_num_batched_tokens * 2, 1, rope_dim, dtype=dtype, device=device)
else:
_cos = torch.ones(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(1,
max_num_batched_tokens,
1,
rope_dim,
dtype=dtype,
device=device)
Comment on lines +93 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The if/else block for creating _cos and _sin tensors contains duplicated code. This can be simplified by using a multiplier to determine the size, making the code more concise and easier to maintain.

        size_multiplier = 2 if vllm_config.speculative_config and vllm_config.speculative_config.method == "dflash" else 1
        shape = (1, max_num_batched_tokens * size_multiplier, 1, rope_dim)
        _cos = torch.ones(*shape, dtype=dtype, device=device)
        _sin = torch.zeros(*shape, dtype=dtype, device=device)



def get_cos_and_sin_mla(positions, use_cache=False):
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ("eagle", "eagle3"):
elif method in ("eagle", "eagle3", "dflash"):
return EagleProposer(vllm_config, device, runner)
elif method == "mtp":
return MtpProposer(vllm_config, device, runner)
Expand Down
206 changes: 205 additions & 1 deletion vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def load_model(self, model: nn.Module) -> None:
draft_attn_layer_names = draft_attn_layer - target_attn_layer_names
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_names = list(sorted(draft_attn_layer_names))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert len(draft_attn_layer_names) == 1 was removed. While this might be necessary for DFlash, it removes a safety check for other methods like 'eagle' and 'eagle3'. This could lead to unexpected behavior if the draft model for those methods changes. It would be safer to make this check conditional to avoid potential regressions.

        if self.method != "dflash":
            assert len(draft_attn_layer_names) == 1, (
                f"Expected 1 draft attention layer for method '{self.method}', "
                f"but found {len(draft_attn_layer_names)}."
            )
        self.attn_layer_names = list(sorted(draft_attn_layer_names))


if supports_multimodal(model):
Expand Down Expand Up @@ -276,6 +275,8 @@ def load_model(self, model: nn.Module) -> None:
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.method == "dflash":
self.model.lm_head = target_language_model.lm_head
if self.method == "eagle" and hasattr(model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(model):
Expand Down Expand Up @@ -325,6 +326,16 @@ def dummy_run(self,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False):
# DFlash uses a different dummy_run path
if self.method == "dflash":
self._dummy_run_dflash(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_profile=is_profile,
)
return

# update global cos, sin
update_cos_sin(self.positions[:num_tokens])

Expand Down Expand Up @@ -416,6 +427,199 @@ def dummy_run(self,
self._update_full_graph_params(forward_context, num_tokens,
multi_steps_attn_metadata)

@torch.inference_mode()
def _dummy_run_dflash(
self,
num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
is_profile: bool = False
):
num_tokens = num_tokens // 2
if not self.use_cuda_graph:
aclgraph_runtime_mode = CUDAGraphMode.NONE
num_iters_to_capture = 1
for _ in range(num_iters_to_capture):
(
num_input_tokens,
num_tokens_across_dp,
_,
_,
) = self.runner._sync_metadata_across_dp(num_tokens,
is_draft_model=True)
if self.use_cuda_graph and \
num_input_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_input_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE

positions_len = 2 * num_input_tokens
update_cos_sin(self.positions[:positions_len])

with set_ascend_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
num_actual_tokens=0,
in_profile_run=is_profile,
aclgraph_runtime_mode=cudagraph_runtime_mode,
is_draft_model=True):

input_ids = self.input_ids[:num_input_tokens]
positions = self.positions[:positions_len]
hidden_states = self.hidden_states[:num_input_tokens]

self.model(
input_ids=input_ids,
positions=positions,
hidden_states=hidden_states,
inputs_embeds=None,
)

def _dflash_propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
) -> torch.Tensor:
batch_size = common_attn_metadata.num_reqs
assert self.runner is not None
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size

if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
query_start_loc = common_attn_metadata.query_start_loc[:batch_size + 1]
num_context_tokens = target_token_ids.shape[0]
num_query_tokens = 1 + self.num_speculative_tokens
total_num_query_tokens = batch_size * num_query_tokens
num_kv_tokens = num_context_tokens + total_num_query_tokens

MASK_TOKEN_ID = 151669
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The magic number 151669 is used for MASK_TOKEN_ID. This should be defined as a named constant at the top of the file or in a configuration file to improve readability and maintainability.

        MASK_TOKEN_ID = 151669  # TODO: Refactor to a named constant


query_positions_list = []
for i in range(batch_size):
last_position = target_positions[last_token_indices[i]]
query_positions = (
torch.arange(
num_query_tokens,
device=target_positions.device,
dtype=target_positions.dtype,
)
+ 1
+ last_position
)
query_positions_list.append(query_positions)
position_ids = torch.cat([target_positions] + query_positions_list)
assert position_ids.shape[0] == num_kv_tokens

if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
block_size = attn_metadata_builder.kv_cache_spec.block_size
block_table_tensor = common_attn_metadata.block_table_tensor
query_slot_mapping_list = []
for i in range(batch_size):
query_position_ids_i = query_positions_list[i]
block_numbers_i = query_position_ids_i // block_size
block_ids_i = block_table_tensor[i].gather(
dim=0, index=block_numbers_i.long()
)
slot_mapping_i = (
block_ids_i * block_size + query_position_ids_i % block_size
).to(torch.int32)
query_slot_mapping_list.append(slot_mapping_i)
common_attn_metadata.slot_mapping = torch.cat(
[common_attn_metadata.slot_mapping[:target_hidden_states.shape[0]]] + query_slot_mapping_list)
common_attn_metadata.num_actual_tokens = num_kv_tokens
common_attn_metadata.max_query_len = num_kv_tokens
common_attn_metadata.query_start_loc = (
self.arange[:batch_size + 1] * num_query_tokens
)
common_attn_metadata.query_start_loc_cpu = (
torch.from_numpy(self.token_arange_np[:batch_size + 1]).clone()
* num_query_tokens
)
common_attn_metadata.max_seq_len += num_query_tokens
common_attn_metadata.seq_lens[:batch_size] = (
common_attn_metadata.seq_lens[:batch_size] + num_query_tokens
)
common_attn_metadata.seq_lens_cpu = (
common_attn_metadata.seq_lens.cpu()
)
common_attn_metadata.causal = False
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill

builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
self.runner.get_model())
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata

self.positions[:num_kv_tokens] = position_ids
self.input_ids[:total_num_query_tokens] = MASK_TOKEN_ID
for i in range(batch_size):
self.input_ids[i * num_query_tokens] = next_token_ids[i]
self.hidden_states[:num_context_tokens] = target_hidden_states

update_cos_sin(position_ids)

if self.use_cuda_graph and total_num_query_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_query_tokens
)
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = total_num_query_tokens
aclgraph_runtime_mode = CUDAGraphMode.NONE

(
num_input_tokens,
num_tokens_across_dp,
_,
_,
) = self.runner._sync_metadata_across_dp(
num_input_tokens, is_draft_model=True
)

with set_ascend_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_actual_tokens=num_kv_tokens,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
is_draft_model=True):

ret_hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_kv_tokens],
hidden_states=self.hidden_states[:num_context_tokens],
inputs_embeds=None,
)

valid_hidden_list = []
for i in range(batch_size):
start = i * num_query_tokens + 1
end = (i + 1) * num_query_tokens
valid_hidden_list.append(ret_hidden_states[start:end])
valid_hidden_states = torch.cat(valid_hidden_list, dim=0)
logits = self.model.compute_logits(valid_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(batch_size, self.num_speculative_tokens)

def _propose(
self,
# [num_tokens]
Expand Down
44 changes: 27 additions & 17 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _set_up_drafter(self):
self.decode_token_per_req = 1 + spec_token_num
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
if self.speculative_config.method == "eagle3":
if self.speculative_config.method in ("eagle3", "dflash"):
assert isinstance(self.drafter, EagleProposer)
self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state)
Expand Down Expand Up @@ -1450,21 +1450,31 @@ def propose_draft_token_ids(
else:
target_hidden_states = hidden_states[token_indices]
assert self.drafter is not None
draft_token_ids = self.drafter._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
if self.speculative_config.method == "dflash":
draft_token_ids = self.drafter._dflash_propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata
)
else:
draft_token_ids = self.drafter._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
Comment on lines +1453 to +1477
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication in the if/else block for calling _dflash_propose and _propose. Many arguments are shared. This can be refactored to improve readability and maintainability by using a dictionary for common arguments.

                propose_kwargs = {
                    "target_token_ids": target_token_ids,
                    "target_positions": target_positions,
                    "target_hidden_states": target_hidden_states,
                    "next_token_ids": next_token_ids,
                    "last_token_indices": token_indices_to_sample,
                    "common_attn_metadata": common_attn_metadata,
                }
                if self.speculative_config.method == "dflash":
                    draft_token_ids = self.drafter._dflash_propose(**propose_kwargs)
                else:
                    propose_kwargs.update({
                        "sampling_metadata": sampling_metadata,
                        "req_scheduled_tokens": req_scheduled_tokens,
                        "long_seq_metadata": long_seq_metadata,
                        "num_prefill_reqs": num_prefill_reqs,
                        "num_decode_reqs": num_decode_reqs,
                        "scheduler_output": scheduler_output,
                        "num_scheduled_tokens": num_scheduled_tokens,
                    })
                    draft_token_ids = self.drafter._propose(**propose_kwargs)


else:
raise ValueError("Unknown speculative decoding method: "
Expand Down Expand Up @@ -2387,7 +2397,7 @@ def load_model(self) -> None:
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
self.model.get_eagle3_aux_hidden_state_layers(self.drafter.method))

if self.lora_config:
self.model = self.load_lora_model(self.model, self.vllm_config,
Expand Down