Skip to content

[AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility#20890

Closed
michaelzhang-ai wants to merge 1 commit intomainfrom
fix/aiter-oom-and-torchao-compat
Closed

[AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility#20890
michaelzhang-ai wants to merge 1 commit intomainfrom
fix/aiter-oom-and-torchao-compat

Conversation

@michaelzhang-ai
Copy link
Copy Markdown
Collaborator

Motivation

Two aiter backend failures surfaced by PR #20392 which defaults AMD HIP GPUs to the aiter backend:

Shard 8 (CI log): test_no_overlap_scheduler.py OOM during AiterAttnBackend.__init__:

torch.OutOfMemoryError: HIP out of memory. Tried to allocate 16.25 GiB.
GPU 0 has a total capacity of 255.98 GiB of which 4.23 GiB is free.

Shard 10 (CI log): test_torchao.py crash on quantized weights:

NotImplementedError: AffineQuantizedTensor dispatch: attempting to run
  unimplemented operator/function: func=<OpOverload(op='aiter.gemm_a16w16')>

Modifications

1. aiter_backend.py — workspace buffer OOM

The workspace buffer for paged_attention_ragged was sized using max_context_len (e.g. 131K for Llama 3.1 → 512 partitions → 16.25 GiB), but the CI GPU's KV cache only held 25K tokens. Since no single sequence can exceed max_total_num_tokens, we cap max_num_partitions accordingly:

  • Before: max_num_partitions = ceil(131072 / 256) = 512 → workspace ~16.25 GiB
  • After: max_num_partitions = ceil(25432 / 256) = 100 → workspace ~3.2 GiB

2. unquant.py — tuned GEMM with torchao guard

Add aiter's tgemm.mm fast path for unquantized linear ops on AMD, guarded by type(layer.weight.data) is torch.Tensor. Torchao-quantized weights (AffineQuantizedTensor, a torch.Tensor subclass) fail the strict type() check and correctly fall through to F.linear.

Additional Note for PR #20392

The PR also has a logic bug in its workspace buffer guard condition:

# Current (incorrect): skips allocation only when BOTH are true
if not (self.use_mla and self.use_triton_unified_attention):

# Should be: skips allocation when EITHER is true
if not (self.use_mla or self.use_triton_unified_attention):

workspace_buffer is only used by paged_attention_ragged, which is only called when both use_mla=False AND use_triton_unified_attention=False.

Checklist

Two fixes for aiter backend failures surfaced by PR #20392:

1. aiter_backend.py: Cap max_num_partitions by min(max_context_len,
   max_total_num_tokens). The workspace buffer was sized for the model's
   theoretical max context (e.g. 131K = 512 partitions = 16 GiB) when
   the KV cache only held 25K tokens (100 partitions = 3 GiB), causing
   OOM on memory-constrained CI GPUs.

2. unquant.py: Add aiter tgemm.mm fast path for unquantized linear ops,
   guarded by type(layer.weight.data) is torch.Tensor. Torchao-quantized
   weights (AffineQuantizedTensor) fail the strict type() check and fall
   through to F.linear, preventing NotImplementedError on gemm_a16w16.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the quant LLM Quantization label Mar 18, 2026
@michaelzhang-ai michaelzhang-ai changed the title fix(aiter): workspace buffer OOM and tuned GEMM torchao compatibility [AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility Mar 18, 2026
@michaelzhang-ai
Copy link
Copy Markdown
Collaborator Author

@amd-bot review

@michaelzhang-ai
Copy link
Copy Markdown
Collaborator Author

Note: This is a simpler fix for the same root cause described in #18262 / #18263. Rather than jointly solving the memory partition equation upfront (as in #18263), this caps max_num_partitions by min(max_context_len, max_total_num_tokens) after KV cache allocation — same correctness guarantee with a minimal 7-line change in aiter_backend.py.

@michaelzhang-ai
Copy link
Copy Markdown
Collaborator Author

@amd-bot review

@bingxche
Copy link
Copy Markdown
Collaborator

@michaelzhang-ai requested a review

Claude Code Review

PR #20890: [AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility
Reviewed at 2026-03-19 01:41 UTC

Code Review: [AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility

1. Summary

This PR fixes two CI failures on AMD (ROCm) GPUs when using the aiter backend: (1) an OOM error caused by over-allocating the workspace buffer for paged attention based on max_context_len (model's max sequence length) rather than the actual KV cache capacity, and (2) a NotImplementedError when torchao-quantized weights (tensor subclasses) are dispatched through aiter's tuned GEMM path. The fixes cap the workspace buffer size and add a strict type guard to route only plain torch.Tensor weights through the tuned GEMM.

2. Code Quality

Bugs / Logic Errors

aiter_backend.py — Workspace buffer fix:

  • ✅ The logic is sound. Using min(max_context_len, max_total_num_tokens) correctly caps the partition count since no single sequence can exceed the total token budget.
  • ✅ The getattr fallback to self.max_context_len is defensive and correct for cases where model_runner doesn't have the attribute.
  • ⚠️ Minor concern: max_total_num_tokens represents the total KV cache across all sequences, not the max for a single sequence. In theory, if req_to_token_pool allows a single sequence to consume the entire pool (which it does in practice since there's no per-sequence cap enforced at the pool level), this is correct. However, if there were ever a scenario where max_total_num_tokens is less than the longest possible single sequence but the system still accepts it, the workspace buffer could be undersized. In practice, this seems safe because the scheduler won't schedule a sequence longer than available KV cache slots.

unquant.py — Tuned GEMM guard:

  • type(layer.weight.data) is torch.Tensor is the correct strict identity check — it will be False for AffineQuantizedTensor and other torch.Tensor subclasses, which is the desired behavior.
  • ⚠️ Potential issue with .data: Accessing .data bypasses autograd, which is fine for inference. However, is layer.weight.data always guaranteed to return the same type as layer.weight? For quantized tensors, .data might unwrap the subclass in some PyTorch versions. If so, the guard could incorrectly pass. Recommendation: Consider checking type(layer.weight) is torch.Tensor instead (without .data), or at minimum verify the behavior with the specific torchao version used:
# Potentially safer:
if _use_aiter and type(layer.weight) is torch.Tensor:
    return tgemm.mm(x, layer.weight, bias, otype=x.dtype)

Code Style and Readability

  • ✅ Comments are clear and explain the why behind the changes.
  • ✅ The changes are minimal and focused.
  • ✅ Follows existing code patterns in the codebase.

Error Handling

  • The getattr with fallback is good defensive coding.
  • No new exception paths introduced.

3. Performance

Workspace buffer change:

  • ✅ This is purely a memory optimization — reduces allocation without affecting computation. The workspace is still large enough for any sequence the system can actually process.
  • Net positive: avoids OOM on memory-constrained GPUs while maintaining correctness.

Tuned GEMM addition:

  • tgemm.mm is expected to be faster than F.linear for unquantized weights on AMD GPUs. This is a performance improvement for the common case.
  • ⚠️ The type() check adds negligible overhead (single pointer comparison).
  • Question: Has tgemm.mm been benchmarked against F.linear for various matrix sizes typical in LLM inference (small batch prefill, large batch decode, etc.)? If tgemm.mm is slower for some shapes, a more nuanced gating strategy might be needed. However, given that aiter is AMD's optimized kernel library, this is likely a net win.

4. Security

No security concerns. Changes are internal computation/allocation logic with no user-facing input handling changes.

5. Testing

  • The PR description indicates these fixes address specific CI failures in test_no_overlap_scheduler.py (shard 8) and test_torchao.py (shard 10).
  • ✅ Existing CI tests should validate both fixes when run on AMD hardware.
  • Suggested additional tests:
    1. A unit test that verifies effective_max_seq_len is correctly computed when max_total_num_tokens < max_context_len and when max_total_num_tokens >= max_context_len.
    2. A test that confirms UnquantizedLinearMethod.apply falls through to F.linear when layer.weight is a tensor subclass (mock AffineQuantizedTensor).

6. Suggestions

Suggestion 1: Fix the logic bug mentioned in PR description

The PR description identifies a logic bug in #20392's guard condition but doesn't fix it in this PR. If this PR is meant to land first or simultaneously, consider including the fix:

# In aiter_backend.py, the workspace allocation guard should be:
if not (self.use_mla or self.use_triton_unified_attention):
    # allocate workspace_buffer

If this is intentionally left for a follow-up, a TODO comment would be helpful.

Suggestion 2: Use type(layer.weight) instead of type(layer.weight.data)

# More robust against .data unwrapping tensor subclasses
if _use_aiter and type(layer.weight) is torch.Tensor:
    return tgemm.mm(x, layer.weight, bias, otype=x.dtype)

Suggestion 3: Add a log message for the effective cap

This would help with debugging future memory issues:

effective_max_seq_len = min(
    self.max_context_len,
    getattr(model_runner, "max_total_num_tokens", self.max_context_len),
)
if effective_max_seq_len < self.max_context_len:
    logger.info(
        f"Capping workspace buffer sequence length from {self.max_context_len} "
        f"to {effective_max_seq_len} (max_total_num_tokens)"
    )

Suggestion 4: Consider guarding the tgemm import

The import at module level (from aiter.tuned_gemm import tgemm) could fail if aiter is installed but tuned_gemm is not available in older versions. Since it's already inside a _is_hip / _use_aiter guarded block, this should be fine, but verify that tuned_gemm is available in all supported aiter versions.

7. Overall Assessment

Approve

The changes are well-motivated, minimal, and correctly address two real CI failures. Both fixes are logically sound:

  1. The workspace buffer cap is a safe optimization that prevents OOM without risking correctness.
  2. The type() is torch.Tensor guard is the textbook way to exclude tensor subclasses from a fast path.

The only actionable concern is the .data vs no-.data question in the type check (Suggestion 2), which I'd recommend addressing before merge. The logic bug in the guard condition mentioned in the PR description (Suggestion 1) should also be tracked, even if fixed in a separate PR.


Automated review by amd-bot using Claude. This is an AI-generated review — please use your judgment.

@yctseng0211
Copy link
Copy Markdown
Collaborator

@amd-bot review

@bingxche
Copy link
Copy Markdown
Collaborator

@yctseng0211 requested a review

Claude Code Review

PR #20890: [AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility
Reviewed at 2026-03-19 02:20 UTC

Code Review: [AMD] fix CI: workspace buffer OOM and tuned GEMM torchao compatibility

1. Summary

This PR fixes two AMD (ROCm) CI failures: (1) an OOM crash caused by oversized workspace buffer allocation in the aiter attention backend, where the buffer was sized based on max_context_len (model config, e.g., 131K) rather than actual KV cache capacity; (2) a NotImplementedError when torchao-quantized weights were dispatched through aiter's tuned GEMM path, which doesn't support AffineQuantizedTensor subclasses. The fixes are minimal and well-targeted.

2. Code Quality

Workspace buffer fix (aiter_backend.py)

Good:

  • The use of getattr(model_runner, "max_total_num_tokens", self.max_context_len) is defensive — it gracefully falls back if the attribute doesn't exist.
  • The comment clearly explains the rationale.
  • The min() approach is correct: no single sequence can exceed max_total_num_tokens.

Potential concerns:

  • max_total_num_tokens semantics: This is the total KV cache capacity across all sequences in a batch, not a per-sequence limit. In theory, a single very long sequence could consume the entire KV cache, so using it as a cap is technically correct (it's an upper bound on any single sequence length). However, it's worth noting that this is a practical upper bound, not a theoretical one — the actual limit is min(max_context_len, max_total_num_tokens), which is exactly what's computed. ✅

  • The and/or bug in PR gpt-oss decode performance optimization #20392: The PR description mentions a logic bug in the guard condition for workspace buffer allocation in the parent PR. This fix does not address that bug — it should either be fixed here or tracked as a follow-up issue. This is important because with the and condition, the workspace buffer could be needlessly allocated when only one of use_mla or use_triton_unified_attention is true, wasting memory.

Tuned GEMM guard (unquant.py)

Good:

  • Using type(layer.weight.data) is torch.Tensor (strict identity check) rather than isinstance() is the correct choice here. AffineQuantizedTensor is a torch.Tensor subclass, so isinstance would pass True and route quantized weights into the unquantized GEMM path. The strict type() check correctly excludes subclasses. ✅

Potential concerns:

  • layer.weight.data vs layer.weight: Accessing .data strips autograd tracking, but more importantly, for AffineQuantizedTensor, does .data return the subclass or the underlying tensor? This needs verification. If .data unwraps to a plain torch.Tensor, the guard would incorrectly pass. However, based on the PR description and torchao behavior, .data on an AffineQuantizedTensor returns itself (still an AffineQuantizedTensor), so the check should work.

  • Import at module level: from aiter.tuned_gemm import tgemm is imported at module scope inside the if _is_hip block. This is consistent with the existing import style for aiter modules. ✅

  • tgemm.mm API: The call tgemm.mm(x, layer.weight, bias, otype=x.dtype) — is otype the correct kwarg name? If the API changes, this would break silently or crash. A brief comment referencing the aiter API version or a try/except might be warranted.

3. Performance

  • Workspace buffer: The fix reduces memory allocation, which is strictly beneficial. No performance regression.

  • Tuned GEMM: Adding tgemm.mm for unquantized linear ops is a performance improvement on AMD GPUs — aiter's tuned GEMM kernels are optimized for specific GPU architectures. The type() check is essentially free (pointer comparison). The fallback to F.linear for quantized weights maintains correctness. ✅

  • Minor note: The type() is torch.Tensor check runs on every apply() call. This is negligible but could be hoisted to __init__ time by storing a flag like self._use_tgemm. However, weight types could theoretically change (e.g., lazy quantization), so runtime checking is safer.

4. Security

No security concerns. The changes are internal to tensor allocation and kernel dispatch.

5. Testing

  • The PR fixes are validated by existing CI tests (test_no_overlap_scheduler.py for shard 8, test_torchao.py for shard 10) that were previously failing.
  • Missing explicit tests: There's no unit test that specifically validates:
    1. Workspace buffer size is capped correctly when max_total_num_tokens < max_context_len
    2. tgemm.mm is called for plain tensors and F.linear is called for quantized tensors

Suggested tests:

# Test workspace buffer capping
def test_workspace_buffer_capped():
    """Verify max_num_partitions uses min(max_context_len, max_total_num_tokens)."""
    # Mock model_runner with max_total_num_tokens < max_context_len
    # Assert max_num_partitions == ceil(max_total_num_tokens / partition_size)
    pass

# Test GEMM dispatch
def test_unquant_dispatch_plain_tensor():
    """Plain torch.Tensor should use tgemm.mm on AMD."""
    pass

def test_unquant_dispatch_quantized_tensor():
    """AffineQuantizedTensor should fall through to F.linear."""
    pass

6. Suggestions

1. Address the and/or bug mentioned in the description

Either fix it in this PR or file a tracking issue. The current code in PR #20392:

# Incorrect
if not (self.use_mla and self.use_triton_unified_attention):

Should be:

# Correct
if not (self.use_mla or self.use_triton_unified_attention):

This is a real bug that wastes memory when only one flag is true.

2. Add a log message for the effective capping

effective_max_seq_len = min(
    self.max_context_len,
    getattr(model_runner, "max_total_num_tokens", self.max_context_len),
)
if effective_max_seq_len < self.max_context_len:
    logger.info(
        f"Capping workspace buffer: effective_max_seq_len={effective_max_seq_len} "
        f"(max_context_len={self.max_context_len}, "
        f"max_total_num_tokens={getattr(model_runner, 'max_total_num_tokens', 'N/A')})"
    )

This aids debugging on memory-constrained GPUs.

3. Consider robustness of the type() check

Add a comment explaining why type() is used instead of isinstance():

# Use strict type() check (not isinstance) to exclude torch.Tensor
# subclasses like torchao's AffineQuantizedTensor, which don't support
# aiter's tuned GEMM operators.
if _use_aiter and type(layer.weight.data) is torch.Tensor:
    return tgemm.mm(x, layer.weight, bias, otype=x.dtype)

4. Verify tgemm singleton behavior

The import from aiter.tuned_gemm import tgemm imports what appears to be a module-level singleton. Verify that tgemm.mm is thread-safe and doesn't require initialization (e.g., tuning database loading) that could fail silently.

7. Overall Assessment

Approve

The changes are minimal, well-motivated, and correctly fix two real CI failures. The workspace buffer fix uses a sound upper bound, and the tuned GEMM guard uses the correct type-checking pattern to distinguish plain tensors from quantized subclasses. Both changes are low-risk.

Minor action items before merge:

  1. Add a comment explaining the type() vs isinstance() rationale (suggestion Add install with pip #3)
  2. Consider filing a follow-up issue for the and/or logic bug in PR gpt-oss decode performance optimization #20392
  3. Adding a debug/info log for the workspace capping would be helpful but is not blocking

Automated review by amd-bot using Claude. This is an AI-generated review — please use your judgment.

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 19, 2026

Open again if needed

@HaiShaw HaiShaw closed this Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants