Skip to content

[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__#7442

Merged
wangxiyuan merged 13 commits intovllm-project:mainfrom
gcanlin:aclgraph-wrapper-perf
Mar 23, 2026
Merged

[Bugfix] Fix slow hasattr in ACLGraphWrapper.__getattr__#7442
wangxiyuan merged 13 commits intovllm-project:mainfrom
gcanlin:aclgraph-wrapper-perf

Conversation

@gcanlin
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin commented Mar 18, 2026

What this PR does / why we need it?

Follow vllm-project/vllm#37425, vllm-project/vllm-omni#1982

Copied from them:

Notice that hasattr(self.model, "flush_pending_metadata") cost 6ms per decode step when profiling Qwen3 Omni.

The original CUDAGraphWrapper.__getattr__ raises:

raise AttributeError(f"... cudagraph wrapper: {self.runnable}")

When hasattr() is called for a non-existent attribute, Python internally calls getattr which constructs this AttributeError. The {self.runnable} triggers __repr__() on the underlying model (e.g., Qwen3OmniMoeForConditionalGeneration), which recursivelytraverses the entire nn.Module tree to generate an 18,000+ character string. This takes ~6-7ms per call.

Since hasattr(self.model, "flush_pending_metadata") is called every decode step in the Talker forward path, this adds ~6ms overhead per step, severely impacting audio inter-chunk latency (ICL).

hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__not found in the CUDAGraphWrapper class hierarchytriggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have itexecutes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception objectthe f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module treegenerates a 18,000+ character stringtakes ~6 msAttributeError object is createdhasattr catches the AttributeError and returns Falsethe 18,000-character string is immediately discarded (no one ever sees it)

Does this PR introduce any user-facing change?

NO.

How was this patch tested?

See vllm-project/vllm-omni#1982

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin requested a review from yiz-liu as a code owner March 18, 2026 14:59
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a performance bottleneck within the ACLGraphWrapper class. The __getattr__ method, particularly when hasattr was used, was inadvertently triggering expensive str() calls on the wrapped runnable object. By pre-calculating and storing this string representation once during object initialization, the change eliminates these repeated, costly operations, leading to a more efficient attribute access mechanism.

Highlights

  • Performance Optimization: Optimized the ACLGraphWrapper.__getattr__ method to prevent slow str() calls on the runnable object, which were triggered by hasattr checks.
  • Attribute Caching: Introduced a cached string representation of the runnable object (_runnable_str) during initialization to improve performance when constructing AttributeError messages.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gcanlin gcanlin added ready read for review ready-for-test start test by label for PR labels Mar 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the __getattr__ method in ACLGraphWrapper by caching the string representation of the runnable object. This avoids potentially expensive str() calls when an attribute is not found, which improves performance in error-handling paths. The change is correct and beneficial. I've added a suggestion to further improve the performance and readability of __getattr__ by replacing the hasattr/getattr pattern with a try/except block, which is more idiomatic and avoids a double attribute lookup.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@gcanlin gcanlin requested a review from wangxiyuan as a code owner March 20, 2026 15:24
@wangxiyuan wangxiyuan merged commit e68464a into vllm-project:main Mar 23, 2026
38 checks passed
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 25, 2026
…t#7442)

### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
lihaokun-2026 pushed a commit to lihaokun-2026/vllm-ascend that referenced this pull request Mar 29, 2026
…t#7442)

### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…t#7442)

### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants