-
Notifications
You must be signed in to change notification settings - Fork 450
[Refactor][CI] Reduce sparse related test time #1637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…pt parameters directly (tile-ai#1630) * Modify main functions in example_custom_compress.py, example_gemm_sp.py, and example_vertical_slash_sparse_attn.py to accept parameters directly instead of using argparse for improved flexibility. * Update corresponding calls to main functions in the script execution section. * Ensure consistency in matrix dimensions and argument handling across examples.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughRefactor example scripts to convert parameterless Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @examples/gemm_sp/example_custom_compress.py:
- Around line 330-332: The argparse entry for accum_dtype is inconsistent:
change parser.add_argument("--accum_dtype", type=str, choices=[T.float,
T.float16], default=T.float, ...) to use string choices and default (e.g.,
type=str, choices=["float","float16"], default="float") in the
parser.add_argument call, and then map that string to the actual torch dtype
before use (e.g., create a mapping like accum_dtype_map = {"float": T.float,
"float16": T.float16} and set accum_dtype = accum_dtype_map[args.accum_dtype]
where the script uses accum_dtype).
🧹 Nitpick comments (3)
examples/minference/example_vertical_slash_sparse_attn.py (2)
566-566: Remove redundant self-assignment.Line 566
vertical_size, slash_size = vertical_size, slash_sizeis a no-op that appears to be leftover from the refactoring. This was likely meant to handle the original argparse-based assignment but is now unnecessary since these are function parameters.Proposed fix
def main(batch=1, heads=1, seq_len=4096, head_dim=64, vertical_size=1000, slash_size=200): BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim - vertical_size, slash_size = vertical_size, slash_size - torch.manual_seed(0)
629-630: Duplicate sorting operations inrun_regression_perf.The
v_idxands_idxtensors are sorted twice with identical operations (lines 629-630 and lines 650-651). The second sorting is redundant since the tensors are already sorted.Proposed fix - remove duplicate sorting
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] from torch.utils.cpp_extension import load import os current_dir = os.path.dirname(os.path.abspath(__file__)) sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] ops = load(name="convert", sources=sources, verbose=False) convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes batch_size, num_heads, context_size, head_dim = query.shape pad = (block_size_M - context_size) & (block_size_M - 1) if pad == block_size_M: pad = 0 query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) if head_dim not in [16, 32, 64, 128, 256, 512]: target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device)Also applies to: 650-651
examples/gemm_sp/example_gemm_sp.py (1)
100-101: Consider aligning defaultcfgvalues.The function default is
cfg="h20"(line 100), while the CLI default iscfg="4090"(line 131). This inconsistency means programmatic callers (like tests) will use "h20" config while CLI users get "4090". If this is intentional for hardware-specific testing, consider documenting it; otherwise, align the defaults.Also applies to: 131-131
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/gemm_sp/example_custom_compress.pyexamples/gemm_sp/example_gemm_sp.pyexamples/minference/example_vertical_slash_sparse_attn.pyexamples/minference/test_vs_sparse_attn.py
🧰 Additional context used
🧬 Code graph analysis (4)
examples/minference/test_vs_sparse_attn.py (2)
examples/minference/example_vertical_slash_sparse_attn.py (1)
main(563-602)tilelang/testing/__init__.py (1)
main(27-29)
examples/gemm_sp/example_custom_compress.py (2)
examples/gemm_sp/example_gemm_sp.py (1)
main(100-122)tilelang/utils/sparse.py (1)
randn_semi_sparse(105-124)
examples/gemm_sp/example_gemm_sp.py (1)
tilelang/utils/sparse.py (2)
randn_semi_sparse(105-124)compress(77-102)
examples/minference/example_vertical_slash_sparse_attn.py (1)
examples/elementwise/example_elementwise_add.py (1)
run_regression_perf(56-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (4)
examples/minference/test_vs_sparse_attn.py (1)
7-8: LGTM!The test correctly adapts to the new parameterized API by calling
main()without arguments, which uses the default parameters. This is consistent with the refactored signature inexample_vertical_slash_sparse_attn.pywheremain(batch=1, heads=1, seq_len=4096, ...)uses a smallerseq_len=4096compared to the CLI default of 16384, effectively reducing test time as intended by the PR.examples/gemm_sp/example_custom_compress.py (1)
294-322: LGTM on the parameterizedmain()function.The refactoring correctly parameterizes the function with sensible defaults (M=1024, N=1024, K=1024) that are smaller than the CLI defaults (16384), which aligns with the PR goal of reducing test time for programmatic usage.
examples/minference/example_vertical_slash_sparse_attn.py (1)
563-565: LGTM on the parameterized API refactoring.The conversion from argv-based parsing to parameterized functions with defaults is well-executed. The
main()function now has smaller defaults (seq_len=4096) suitable for quick tests, while the CLI block preserves larger values (seq_len=16384) for manual benchmarking. This aligns with the PR goal of reducing CI test time.Also applies to: 669-685
examples/gemm_sp/example_gemm_sp.py (1)
100-123: LGTM on the parameterizedmain()function.The refactoring correctly converts to a parameterized API with appropriate defaults for programmatic testing (M=1024, N=1024, K=1024) while preserving larger values (16384) for CLI benchmarking.
| parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") | ||
| parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") | ||
| parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse type mismatch for accum_dtype.
The --accum_dtype argument is declared with type=str but choices=[T.float, T.float16] contains type objects, not strings. Additionally, the default is T.float (a type object). This configuration is inconsistent: if a user passes --accum_dtype float, argparse will compare the string "float" against T.float, which will fail validation.
Proposed fix
- parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+ parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")Then map the string to the type in the main call:
+ dtype_map = {"float": T.float, "float16": T.float16}
main(
M=args.m,
N=args.n,
K=args.k,
use_cutlass_layout=args.use_cutlass_layout,
use_torch_compressor=args.use_torch_compressor,
- accum_dtype=args.accum_dtype,
+ accum_dtype=dtype_map[args.accum_dtype],
cfg=args.cfg,
)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In @examples/gemm_sp/example_custom_compress.py around lines 330 - 332, The
argparse entry for accum_dtype is inconsistent: change
parser.add_argument("--accum_dtype", type=str, choices=[T.float, T.float16],
default=T.float, ...) to use string choices and default (e.g., type=str,
choices=["float","float16"], default="float") in the parser.add_argument call,
and then map that string to the actual torch dtype before use (e.g., create a
mapping like accum_dtype_map = {"float": T.float, "float16": T.float16} and set
accum_dtype = accum_dtype_map[args.accum_dtype] where the script uses
accum_dtype).
| parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") | ||
| parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse type mismatch for accum_dtype.
Same issue as in example_custom_compress.py: the --accum_dtype argument uses type=str with choices=[T.float, T.float16] (type objects), which will cause validation failures when CLI arguments are passed.
Proposed fix
- parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+ parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")Then map in the call:
+ dtype_map = {"float": T.float, "float16": T.float16}
- main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg)
+ main(M=args.m, N=args.n, K=args.k, accum_dtype=dtype_map[args.accum_dtype], cfg=args.cfg)Committable suggestion skipped: line range outside the PR's diff.
as title.
Summary by CodeRabbit
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.