-
Notifications
You must be signed in to change notification settings - Fork 29
Update AITER subcommit and refactor internal AITER/CK FA API usage #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
f02c48e
0b0ad93
c198cbd
a52bb32
77f0a05
1637266
568e9b5
2cb6d82
cf4aa9e
e25cea8
2122479
4817e72
837b827
68ca0fe
ae688ab
762b91b
0a7187d
39b27bc
29878cf
6846a27
47592ac
357b5ce
1f080c1
a657bdd
c225448
4193158
dbb6106
899162e
1081c5e
9514855
78f1d69
f935956
b90da33
0475f85
5db08ea
d5e5ec6
151e9ca
5a18d16
db34177
9cd2833
a6b831e
d1bd569
686e6a2
52e1a1a
48c5839
9b2166e
3dacb2a
638c9e6
c786c3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
|
|
||
| from importlib import metadata | ||
| import os | ||
| import sys | ||
|
ipanfilo marked this conversation as resolved.
Outdated
|
||
| import time | ||
| from pathlib import Path | ||
| from typing import List, Tuple | ||
|
|
@@ -88,6 +89,17 @@ def setup_common_extension() -> CMakeExtension: | |
| cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") | ||
| elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"): | ||
| cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON") | ||
| # Dynamically set AITER_LOG_MORE based on PIP_VERBOSE to avoid excessive logging during build. | ||
| os.environ["AITER_LOG_MORE"] = str(max(int(os.environ.get('PIP_VERBOSE', '0')) - 1, 0)) | ||
| # Explicitly checks the AITER API usage | ||
| try: | ||
| subprocess.run( | ||
|
ipanfilo marked this conversation as resolved.
Outdated
|
||
| sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", | ||
| shell=True, check=True | ||
| ) | ||
| except subprocess.CalledProcessError: | ||
| print("Error checking AITER mha_args usage.") | ||
| sys.exit(1) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly checks the AITER API usage
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually you can put your comment in PR to the comments in the source codes |
||
|
|
||
| if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: | ||
| os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This helper script scans semi-hard-coded files wrt TE source-code in order to directly compare AITER's internal API and our attempt at utilizing it. This script is run during setup through
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can put this in the comment of this file. Also don't forget to add the copyright |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import argparse | ||
|
wangye805 marked this conversation as resolved.
|
||
| import re | ||
| from pathlib import Path | ||
| from typing import List, Set | ||
| import sys | ||
|
|
||
| def parse_with_skip_comments(buffer, line, regex, outputs): | ||
| # skip comments | ||
| stripped = line.strip() | ||
| if not stripped or stripped.startswith("//"): | ||
| return | ||
| line_no_comment = re.sub(r"//.*", "", line) | ||
| buffer[0] += " " + line_no_comment.strip() | ||
| if ";" not in line_no_comment: | ||
| return | ||
| match = regex.search(buffer[0]) | ||
| if match: | ||
| outputs.append(match.group(1)) | ||
| buffer[0] = "" | ||
|
|
||
|
|
||
| def extract_fields_from_header(text: str, struct_name: str) -> List[str]: | ||
| struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") | ||
| struct_end_re = re.compile(r"^\s*};\s*$") | ||
|
|
||
| struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") | ||
| lines = text.splitlines() | ||
| in_struct = False | ||
| fields: List[str] = [] | ||
| buffer = [""] | ||
| for line in lines: | ||
| if not in_struct: | ||
| if struct_start_re.search(line): | ||
| in_struct = True | ||
| continue | ||
| if struct_end_re.search(line): | ||
| break | ||
| parse_with_skip_comments(buffer, line, struct_field_re, fields) | ||
| return fields | ||
|
|
||
|
|
||
| def extract_usage_from_source(text: str, var_name: str) -> Set[str]: | ||
| assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=") | ||
| assignments = [] | ||
| lines = text.splitlines() | ||
| buffer = [""] | ||
| for line in lines: | ||
| parse_with_skip_comments(buffer, line, assign_re, assignments) | ||
| return set(assignments) | ||
|
|
||
|
|
||
| def main() -> int: | ||
| parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") | ||
| parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both") | ||
|
ipanfilo marked this conversation as resolved.
Outdated
|
||
| args = parser.parse_args() | ||
| modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] | ||
| mismatch = 0 | ||
| for mode in modes: | ||
| header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h") | ||
|
ipanfilo marked this conversation as resolved.
Outdated
|
||
| source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp") | ||
| header_text = header_path.read_text(encoding="utf-8") | ||
| source_text = source_path.read_text(encoding="utf-8") | ||
|
|
||
| header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args") | ||
| header_set = set(header_fields) | ||
| used_fields = extract_usage_from_source(source_text, f"fmha_args") | ||
|
|
||
| missing_in_usage = sorted(header_set - used_fields) | ||
| unknown_in_header = sorted(used_fields - header_set) | ||
| mismatch += len(missing_in_usage) + len(unknown_in_header) | ||
|
|
||
| print(f"\nAnalyzing mha_{mode}_args\n") | ||
| print(f"mha_{mode}_args fields in header:", len(header_set)) | ||
| print(f"mha_{mode}_args fields referenced in source:", len(used_fields)) | ||
|
|
||
| if missing_in_usage: | ||
| print("\nFields present in header but not referenced in source:") | ||
| for name in missing_in_usage: | ||
| print(f" - {name}") | ||
| else: | ||
| print("\nAll header fields are referenced in source.") | ||
|
|
||
| if unknown_in_header: | ||
| print("\nFields referenced in source but not in header:") | ||
| for name in unknown_in_header: | ||
| print(f" - {name}") | ||
| else: | ||
| print("\nNo unknown fields referenced in source.") | ||
|
|
||
| if mismatch: | ||
| print(f"\nTotal mismatched fields: {mismatch}") | ||
| return 1 | ||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) | ||
Uh oh!
There was an error while loading. Please reload this page.