Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Removes the explicit “ASM mask type” argument from MHA backward callsites and computes the ASM-specific mask selector internally, simplifying the public mha_bwd_args surface.
Changes:
- Removes
get_mask_type()plumbing from C++ benchmark and Python/CUDA interfaces, and adjustsmha_bwd_argsinitialization accordingly. - Reworks
mha_bwd_argsto keep a single mask type field and derives the ASM kernel mask selector insidefmha_v3_bwd. - Updates the backward smoke test script’s executed test set and reduces one SWA test invocation.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/cpp/mha/smoke_test_bwd_v3.sh | Changes which test suites run by default and alters SWA test invocation coverage. |
| op_tests/cpp/mha/benchmark_mha_bwd.cpp | Removes mask-type derivation at callsite and updates mha_bwd_args construction. |
| csrc/py_itfs_cu/asm_mha_varlen_bwd.cu | Removes callsite mask-type derivation and updates args packing for varlen ASM path. |
| csrc/py_itfs_cu/asm_mha_bwd.cu | Removes callsite mask-type derivation and updates args packing for ASM path. |
| csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu | Removes callsite mask-type derivation and updates args packing for CK varlen path. |
| csrc/py_itfs_ck/mha_bwd_kernels.cu | Removes callsite mask-type derivation and updates args packing for CK path. |
| csrc/include/mha_bwd.h | Removes the old ASM mask_type field and renames/repurposes the remaining mask field. |
| csrc/cpp_itfs/mha_bwd.cu | Uses unified mask type for CK and computes ASM-specific mask selector internally. |
Comments suppressed due to low confidence (1)
csrc/include/mha_bwd.h:1
mask_typeis now used as the CK/mask_enumvalue (and the ASM-specific selector is computed separately infmha_v3_bwd). This is easy to misread given the previousmask_typemeaning. Consider renaming this field to something explicit likeck_mask_type/mask_enum_value(and updating callsites), and add a short comment documenting the expected enum values.
#pragma once
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Merged
13 tasks
valarLip
approved these changes
Feb 12, 2026
valarLip
pushed a commit
that referenced
this pull request
Mar 18, 2026
* remove asm mask type in api * refine quit condition
AMD-yanfeiwang
pushed a commit
to AMD-yanfeiwang/aiter
that referenced
this pull request
Mar 18, 2026
* remove asm mask type in api * refine quit condition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Turn it to an inner parameter
Technical Details
Test Plan
Test Result
Submission Checklist