Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3298,12 +3298,19 @@ def _check_and_update_cudagraph_mode(
with update_pass_config(self):
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)

capture_descs = self.cudagraph_dispatcher.get_capture_descs()
capture_sizes = sorted({
desc.num_tokens
for _, descs in capture_descs
for desc in descs
})
Comment on lines +3301 to +3306
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to extract capture_sizes from capture_descs is redundant and can be simplified using a set comprehension directly. This improves readability and maintainability.

capture_sizes = sorted({desc.num_tokens for descs in capture_descs.values() for desc in descs})

Copy link
Copy Markdown
Collaborator

@yiz-liu yiz-liu Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It got a point, but not so important.


# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
# we set the graph params right before initializing the keys.
if self.use_aclgraph:
set_graph_params(self.cudagraph_batch_sizes)
set_graph_params(capture_sizes)
if self.speculative_config:
set_draft_graph_params(self.cudagraph_batch_sizes)
set_draft_graph_params(capture_sizes)

def capture_model(self) -> None:
gpu_model_runner_cls = next((cls for cls in self.__class__.__mro__ if cls.__name__ == "GPUModelRunner"), None)
Expand Down
Loading