Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f02c48e
Initial inclusion of new API in fwd as well as part 1 of refactor
Micky774 Feb 2, 2026
0b0ad93
Initial implementation of refactor/API update across ALL CK funcs
Micky774 Feb 3, 2026
c198cbd
Updated logging
Micky774 Feb 6, 2026
a52bb32
Add script for comparing AITER/TE API
Micky774 Feb 6, 2026
77f0a05
Reconcile new AITER mask type
Micky774 Feb 9, 2026
1637266
Updated API helper tool
Micky774 Feb 9, 2026
568e9b5
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 9, 2026
2cb6d82
Formatting
Micky774 Feb 9, 2026
cf4aa9e
Added sys exit to script
Micky774 Feb 9, 2026
e25cea8
Slightly better error message
Micky774 Feb 9, 2026
2122479
Updated AITER_ASM_DIR implementation
Micky774 Feb 11, 2026
4817e72
Update AITER
Micky774 Feb 11, 2026
837b827
Updated AITER_ASM_DIR logic to allow for hip-free use
Micky774 Feb 12, 2026
68ca0fe
Re-introduce setup AITER API check
Micky774 Feb 16, 2026
ae688ab
Update AITER to custom feature branch
Micky774 Feb 16, 2026
762b91b
Reduce AITER build verbosity
Micky774 Feb 16, 2026
0a7187d
Updated API
Micky774 Feb 16, 2026
39b27bc
Address PR comments
Micky774 Feb 17, 2026
29878cf
Updated bias stride calculations
Micky774 Feb 17, 2026
6846a27
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 18, 2026
47592ac
Reverted AITER feature branch use due to verbosity changes
Micky774 Feb 18, 2026
357b5ce
PR review comments
Micky774 Feb 18, 2026
1f080c1
Reintroduced warning suppression in AITER
Micky774 Feb 18, 2026
a657bdd
Removes auto-setting of AITER_LOG_MORE, corrects batch stride impl
Micky774 Feb 18, 2026
c225448
Removes AITER_LOG_MORE from CI runs
Micky774 Feb 18, 2026
4193158
Minor corrections
Micky774 Feb 18, 2026
dbb6106
PR feedback
Micky774 Feb 19, 2026
899162e
Formatting
Micky774 Feb 19, 2026
1081c5e
Copyright
Micky774 Feb 19, 2026
9514855
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 19, 2026
78f1d69
Updated ck_fused_attn lib build to include copying HSA
Micky774 Feb 20, 2026
f935956
Corrected AITER bug and moved to TE feature branch
Micky774 Mar 3, 2026
b90da33
Merge branch 'dev' into zain/aiter-api
Micky774 Mar 4, 2026
0475f85
Added back dropped code from merge conflict
Micky774 Mar 4, 2026
5db08ea
Downgrade to more conservative AITER commit for compat
Micky774 Mar 4, 2026
d5e5ec6
Removed python-level args check
Micky774 Mar 4, 2026
151e9ca
Removed old tools
Micky774 Mar 4, 2026
5a18d16
Corrected arg_size types manually
Micky774 Mar 5, 2026
db34177
Updated AITER commit and fixed API mismatch in group gemm
Micky774 Mar 5, 2026
9cd2833
Added build-time AITER API usage check
Micky774 Mar 5, 2026
a6b831e
PR review comments
Micky774 Mar 5, 2026
d1bd569
Undo extra import removal
Micky774 Mar 5, 2026
686e6a2
Adjusted python requirement in cmakelist
Micky774 Mar 5, 2026
52e1a1a
Updated group-gemm dispatch
Micky774 Mar 6, 2026
48c5839
Made AITER API check earlier
Micky774 Mar 6, 2026
9b2166e
Update AITER w/ Xinya's patch
Micky774 Mar 16, 2026
3dacb2a
Merge branch 'dev' into zain/aiter-api
Micky774 Mar 16, 2026
638c9e6
Revert and cherry-pick aiter subcommit
Micky774 Mar 18, 2026
c786c3b
Patched unordered map mgpu kernel collision
Micky774 Mar 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ artifacts/
**/times.csv
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
transformer_engine/lib/aiter
Comment thread
ipanfilo marked this conversation as resolved.
Outdated
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1180 files
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from importlib import metadata
import os
import sys
Comment thread
ipanfilo marked this conversation as resolved.
Outdated
import time
from pathlib import Path
from typing import List, Tuple
Expand Down Expand Up @@ -88,6 +89,17 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF")
elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"):
cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON")
# Dynamically set AITER_LOG_MORE based on PIP_VERBOSE to avoid excessive logging during build.
os.environ["AITER_LOG_MORE"] = str(max(int(os.environ.get('PIP_VERBOSE', '0')) - 1, 0))
# Explicitly checks the AITER API usage
try:
subprocess.run(
Comment thread
ipanfilo marked this conversation as resolved.
Outdated
sys.executable + " tools/check_aiter_mha_args_usage.py --mode both",
shell=True, check=True
)
except subprocess.CalledProcessError:
print("Error checking AITER mha_args usage.")
sys.exit(1)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly checks the AITER API usage

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you can put your comment in PR to the comments in the source codes


if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None:
os.environ["NVTE_ENABLE_ROCSHMEM"] = '1'
Expand Down
97 changes: 97 additions & 0 deletions tools/check_aiter_mha_args_usage.py
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through setup.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put this in the comment of this file.

Also don't forget to add the copyright

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
Comment thread
wangye805 marked this conversation as resolved.
import re
from pathlib import Path
from typing import List, Set
import sys

def parse_with_skip_comments(buffer, line, regex, outputs):
# skip comments
stripped = line.strip()
if not stripped or stripped.startswith("//"):
return
line_no_comment = re.sub(r"//.*", "", line)
buffer[0] += " " + line_no_comment.strip()
if ";" not in line_no_comment:
return
match = regex.search(buffer[0])
if match:
outputs.append(match.group(1))
buffer[0] = ""


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
lines = text.splitlines()
in_struct = False
fields: List[str] = []
buffer = [""]
for line in lines:
if not in_struct:
if struct_start_re.search(line):
in_struct = True
continue
if struct_end_re.search(line):
break
parse_with_skip_comments(buffer, line, struct_field_re, fields)
return fields


def extract_usage_from_source(text: str, var_name: str) -> Set[str]:
assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=")
assignments = []
lines = text.splitlines()
buffer = [""]
for line in lines:
parse_with_skip_comments(buffer, line, assign_re, assignments)
return set(assignments)


def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both")
Comment thread
ipanfilo marked this conversation as resolved.
Outdated
args = parser.parse_args()
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h")
Comment thread
ipanfilo marked this conversation as resolved.
Outdated
source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp")
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")

header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args")
header_set = set(header_fields)
used_fields = extract_usage_from_source(source_text, f"fmha_args")

missing_in_usage = sorted(header_set - used_fields)
unknown_in_header = sorted(used_fields - header_set)
mismatch += len(missing_in_usage) + len(unknown_in_header)

print(f"\nAnalyzing mha_{mode}_args\n")
print(f"mha_{mode}_args fields in header:", len(header_set))
print(f"mha_{mode}_args fields referenced in source:", len(used_fields))

if missing_in_usage:
print("\nFields present in header but not referenced in source:")
for name in missing_in_usage:
print(f" - {name}")
else:
print("\nAll header fields are referenced in source.")

if unknown_in_header:
print("\nFields referenced in source but not in header:")
for name in unknown_in_header:
print(f" - {name}")
else:
print("\nNo unknown fields referenced in source.")

if mismatch:
print(f"\nTotal mismatched fields: {mismatch}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
1 change: 0 additions & 1 deletion transformer_engine/common/ck_fused_attn/aiter_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ if [[ -z "${AITER_DIR}" || -z "${AITER_TEST_DIR}" || -z "${GPU_ARCHS_VAL}" ]]; t
fi

rm -rf "${AITER_DIR}/aiter/jit/build"
AITER_LOG_MORE=1 \
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \
GPU_ARCHS="${GPU_ARCHS_VAL}" \
python3 "${AITER_TEST_DIR}/compile.py"
Expand Down
Loading
Loading