Skip to content

[Perf] Fix slow hasattr in CUDAGraphWrapper.__getattr__#37425

Merged
Isotr0py merged 5 commits intovllm-project:mainfrom
ZeldaHuang:optimize_cudagraph_wrapper
Mar 19, 2026
Merged

[Perf] Fix slow hasattr in CUDAGraphWrapper.__getattr__#37425
Isotr0py merged 5 commits intovllm-project:mainfrom
ZeldaHuang:optimize_cudagraph_wrapper

Conversation

@ZeldaHuang
Copy link
Copy Markdown
Contributor

@ZeldaHuang ZeldaHuang commented Mar 18, 2026

Purpose

Ref vllm-project/vllm-omni#1982

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: 智鸣 <hzm414167@alibaba-inc.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request addresses a performance issue in CUDAGraphWrapper.__getattr__ caused by slow exception string formatting when hasattr is used on non-existent attributes. The fix is correct in its goal, but I have provided a suggestion for a more idiomatic and performant approach using getattr with a default sentinel value. This alternative avoids the performance pitfall of hasattr more directly and is a common pattern for implementing efficient proxy objects.

Comment on lines 212 to 216
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
f"Attribute {key} not exists in the runnable of cudagraph wrapper"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While removing the object representation from the error string fixes the performance issue with hasattr, a more idiomatic and robust way to implement a performant proxy __getattr__ is to use getattr with a default value. This avoids relying on exception handling for control flow, which is what makes hasattr slow on failure. This approach is generally faster and results in cleaner code.

To implement this, you would need to define a sentinel object at the module level, for example:
_sentinel = object()

Suggested change
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
f"Attribute {key} not exists in the runnable of cudagraph wrapper"
)
value = getattr(self.runnable, key, _sentinel)
if value is not _sentinel:
return value
raise AttributeError(
f"Attribute {key} not exists in the runnable of cudagraph wrapper"
)

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 18, 2026

@ProExpertProg Could you please take a look at this?

Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but can we cache __repr__(self.runnable) to keep the message friendly for debugging?

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 18, 2026
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Signed-off-by: 智鸣 <hzm414167@alibaba-inc.com>
@ZeldaHuang
Copy link
Copy Markdown
Contributor Author

LGTM, but can we cache __repr__(self.runnable) to keep the message friendly for debugging?

Add self._runnable_str = str(runnable)

@Isotr0py Isotr0py enabled auto-merge (squash) March 18, 2026 14:48
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
f"cudagraph wrapper: {self._runnable_str}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more performant to add some sort of hasattr override? (does Python support those?) instead of relying on raising and catching an exception

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like python does not have a custom hasattr override

@zou3519
Copy link
Copy Markdown
Collaborator

zou3519 commented Mar 18, 2026

Sorry I want to take a closer look at this. This is going to regress the cold start time (but I agree it fixes the runtime overhead issue). Trying to see if we can avoid that as well

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 18, 2026

This is going to regress the cold start time

@zou3519 I was wondering why this affects cold start time?

@zou3519
Copy link
Copy Markdown
Collaborator

zou3519 commented Mar 18, 2026

This PR moves the str(self.runnable) to the CUDAGraphWrapper constructor. If we're deploying with piecewise cudagraphs, an extreme case is 100 piecewise graphs. If str(self.runnable) is expensive (10ms), then this is one additional second of startup time. It would be better to just not do str(self.runnable) at all, unless we think the debuggability benefit it gives is worth it.

cudagraph_options: CUDAGraphOptions | None = None,
) -> None:
self.runnable = runnable
self._runnable_str = str(runnable)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to ship this PR now and think later, then my vote is to remove this line and avoid putting str(runnable) into the AttributeError message.

Otherwise I want to learn why a hasattr call is in the hot-path

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 18, 2026

This PR moves the str(self.runnable) to the CUDAGraphWrapper constructor. If we're deploying with piecewise cudagraphs, an extreme case is 100 piecewise graphs. If str(self.runnable) is expensive (10ms), then this is one additional second of startup time. It would be better to just not do str(self.runnable) at all, unless we think the debuggability benefit it gives is worth it.

I see. It would be great if we could do str(self.runnable) only when the log level is set to debug.

@ZeldaHuang
Copy link
Copy Markdown
Contributor Author

This PR moves the str(self.runnable) to the CUDAGraphWrapper constructor. If we're deploying with piecewise cudagraphs, an extreme case is 100 piecewise graphs. If str(self.runnable) is expensive (10ms), then this is one additional second of startup time. It would be better to just not do str(self.runnable) at all, unless we think the debuggability benefit it gives is worth it.

I think in piecewise CUDA graph mode there is still only one CUDAGraphWrapper, so this should still correspond to only one str(self.runnable) call rather than one per piecewise graph.

@zou3519
Copy link
Copy Markdown
Collaborator

zou3519 commented Mar 18, 2026

Sorry to be clear: let's say we have llama-3-70b

  • There are 80 layers
  • We make 81 total split graphs out of these
  • Then, we wrap each of the 81 in a CUDAGraphWrapper
  • Then we record separate cudagraphs for each cudagraph size sometime later.

At the very least, we are creating 81 CUDAGraphWrappers for this model

Signed-off-by: 智鸣 <hzm414167@alibaba-inc.com>
@ZeldaHuang
Copy link
Copy Markdown
Contributor Author

Sorry to be clear: let's say we have llama-3-70b

  • There are 80 layers
  • We make 81 total split graphs out of these
  • Then, we wrap each of the 81 in a CUDAGraphWrapper
  • Then we record separate cudagraphs for each cudagraph size sometime later.

At the very least, we are creating 81 CUDAGraphWrappers for this model

Thanks for clarifying.
I change the implementation, now str(self.runnable) will be created and rasied only in debug mode

f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self._runnable_str}"
)
raise AttributeError
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be AttributeError()? (I don't know how this works)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref to https://docs.python.org/3/library/functions.html#getattr, the default implementation will raise AttributeError() when attribute not exists.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if it's "raise AttributeError" vs "raise AttributeError()" ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are nearly the same when no error message is provided.

@zou3519 zou3519 enabled auto-merge (squash) March 18, 2026 16:26
auto-merge was automatically disabled March 19, 2026 03:33

Head branch was pushed to by a user without write access

@ZeldaHuang ZeldaHuang requested a review from njhill as a code owner March 19, 2026 03:33
@mergify mergify bot added the v1 label Mar 19, 2026
@Isotr0py Isotr0py merged commit d3cc379 into vllm-project:main Mar 19, 2026
60 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 19, 2026
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 23, 2026
### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 25, 2026
…t#7442)

### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…#37425)

Signed-off-by: 智鸣 <hzm414167@alibaba-inc.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
lihaokun-2026 pushed a commit to lihaokun-2026/vllm-ascend that referenced this pull request Mar 29, 2026
…t#7442)

### What this PR does / why we need it?

Follow vllm-project/vllm#37425,
vllm-project/vllm-omni#1982

Copied from them:

Notice that `hasattr(self.model, "flush_pending_metadata")` cost 6ms per
decode step when profiling Qwen3 Omni.

The original `CUDAGraphWrapper.__getattr__` raises:
```python
  raise AttributeError(f"... cudagraph wrapper: {self.runnable}")
  ```
When hasattr() is called for a non-existent attribute, Python internally
calls __getattr__ which constructs this AttributeError. The
{self.runnable} triggers `__repr__()` on the underlying model (e.g.,
`Qwen3OmniMoeForConditionalGeneration`), which recursivelytraverses the
entire nn.Module tree to generate an 18,000+ character string. This
takes ~6-7ms per call.
Since `hasattr(self.model, "flush_pending_metadata") ` is called every
decode step in the Talker forward path, this adds ~6ms overhead per
step, severely impacting audio inter-chunk latency (ICL).

```Python
hasattr(self.model, "flush_pending_metadata")
  → getattr(self.model, "flush_pending_metadata")
    → not found in CUDAGraphWrapper.__dict__
    → not found in the CUDAGraphWrapper class hierarchy
    → triggers CUDAGraphWrapper.__getattr__("flush_pending_metadata")
      → hasattr(self.runnable, "flush_pending_metadata")  # runnable also doesn't have it
      → executes raise AttributeError(f"... {self.runnable}")
        → Python needs to construct the exception object
        → the f-string triggers self.runnable.__repr__()
        → Qwen3OmniMoeForConditionalGeneration.__repr__()
          → recursively traverses the entire nn.Module tree
          → generates a 18,000+ character string
          → takes ~6 ms
        → AttributeError object is created
    → hasattr catches the AttributeError and returns False
    → the 18,000-character string is immediately discarded (no one ever sees it)
```

### Does this PR introduce _any_ user-facing change?

NO.

### How was this patch tested?

See vllm-project/vllm-omni#1982


- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4497431

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…#37425)

Signed-off-by: 智鸣 <hzm414167@alibaba-inc.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants