Skip to content

remove asm mask type#2026

Merged
valarLip merged 2 commits intomainfrom
jim/dev/fa_bwd_merge_mask
Feb 12, 2026
Merged

remove asm mask type#2026
valarLip merged 2 commits intomainfrom
jim/dev/fa_bwd_merge_mask

Conversation

@slippedJim
Copy link
Copy Markdown
Contributor

Motivation

Turn it to an inner parameter

Technical Details

Test Plan

Test Result

Submission Checklist

@slippedJim slippedJim requested review from a team and Copilot February 11, 2026 07:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Removes the explicit “ASM mask type” argument from MHA backward callsites and computes the ASM-specific mask selector internally, simplifying the public mha_bwd_args surface.

Changes:

  • Removes get_mask_type() plumbing from C++ benchmark and Python/CUDA interfaces, and adjusts mha_bwd_args initialization accordingly.
  • Reworks mha_bwd_args to keep a single mask type field and derives the ASM kernel mask selector inside fmha_v3_bwd.
  • Updates the backward smoke test script’s executed test set and reduces one SWA test invocation.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/cpp/mha/smoke_test_bwd_v3.sh Changes which test suites run by default and alters SWA test invocation coverage.
op_tests/cpp/mha/benchmark_mha_bwd.cpp Removes mask-type derivation at callsite and updates mha_bwd_args construction.
csrc/py_itfs_cu/asm_mha_varlen_bwd.cu Removes callsite mask-type derivation and updates args packing for varlen ASM path.
csrc/py_itfs_cu/asm_mha_bwd.cu Removes callsite mask-type derivation and updates args packing for ASM path.
csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu Removes callsite mask-type derivation and updates args packing for CK varlen path.
csrc/py_itfs_ck/mha_bwd_kernels.cu Removes callsite mask-type derivation and updates args packing for CK path.
csrc/include/mha_bwd.h Removes the old ASM mask_type field and renames/repurposes the remaining mask field.
csrc/cpp_itfs/mha_bwd.cu Uses unified mask type for CK and computes ASM-specific mask selector internally.
Comments suppressed due to low confidence (1)

csrc/include/mha_bwd.h:1

  • mask_type is now used as the CK/mask_enum value (and the ASM-specific selector is computed separately in fmha_v3_bwd). This is easy to misread given the previous mask_type meaning. Consider renaming this field to something explicit like ck_mask_type/mask_enum_value (and updating callsites), and add a short comment documenting the expected enum values.
#pragma once

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/cpp_itfs/mha_bwd.cu
Comment thread csrc/cpp_itfs/mha_bwd.cu Outdated
Comment thread op_tests/cpp/mha/smoke_test_bwd_v3.sh
Comment thread op_tests/cpp/mha/smoke_test_bwd_v3.sh
@valarLip valarLip merged commit 035f5f3 into main Feb 12, 2026
23 of 26 checks passed
@valarLip valarLip deleted the jim/dev/fa_bwd_merge_mask branch February 12, 2026 03:41
valarLip pushed a commit that referenced this pull request Mar 18, 2026
* remove asm mask type in api

* refine quit condition
AMD-yanfeiwang pushed a commit to AMD-yanfeiwang/aiter that referenced this pull request Mar 18, 2026
* remove asm mask type in api

* refine quit condition
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants