Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 8, 2026

as title.

Summary by CodeRabbit

  • Refactor

    • Example scripts refactored to support programmatic invocation with explicit parameters while maintaining CLI compatibility.
    • Added precision validation checks to verify computation accuracy.
  • Tests

    • Updated test invocation to use refactored example API.

✏️ Tip: You can customize this high-level summary in your review settings.

…pt parameters directly (tile-ai#1630)

* Modify main functions in example_custom_compress.py, example_gemm_sp.py, and example_vertical_slash_sparse_attn.py to accept parameters directly instead of using argparse for improved flexibility.
* Update corresponding calls to main functions in the script execution section.
* Ensure consistency in matrix dimensions and argument handling across examples.
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

📝 Walkthrough

Walkthrough

Refactor example scripts to convert parameterless main() functions into parameterized signatures with default values, enabling programmatic reuse. CLI argument parsing moves to __main__ blocks while maintaining backward compatibility.

Changes

Cohort / File(s) Summary
GEMM-SP Examples
examples/gemm_sp/example_custom_compress.py, examples/gemm_sp/example_gemm_sp.py
Converted main() from CLI-driven to parameterized signatures: main(M, N, K, ...) with defaults. Argument parsing relocated to __main__ blocks. Added precision validation and NaN assertion in example_gemm_sp.py after kernel computation.
Minference Examples
examples/minference/example_vertical_slash_sparse_attn.py
Refactored main(argv=None) and run_regression_perf(argv=None) to explicit parameter-based signatures with defaults. Removed internal argparse; CLI parsing now occurs in __main__ block.
Test Update
examples/minference/test_vs_sparse_attn.py
Updated call to example_vertical_slash_sparse_attn.main() to omit empty argv argument, relying on function defaults.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 A rabbit hops through refactored code,
Where defaults now pave the road,
No more argv in main's dark hall,
Just parameters—clean and small!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title describes refactoring to reduce test time, which aligns with the parameterized main() functions enabling programmatic invocation instead of CLI-only usage. However, the specific mechanism (parameter-driven refactoring) is not explicit.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @examples/gemm_sp/example_custom_compress.py:
- Around line 330-332: The argparse entry for accum_dtype is inconsistent:
change parser.add_argument("--accum_dtype", type=str, choices=[T.float,
T.float16], default=T.float, ...) to use string choices and default (e.g.,
type=str, choices=["float","float16"], default="float") in the
parser.add_argument call, and then map that string to the actual torch dtype
before use (e.g., create a mapping like accum_dtype_map = {"float": T.float,
"float16": T.float16} and set accum_dtype = accum_dtype_map[args.accum_dtype]
where the script uses accum_dtype).
🧹 Nitpick comments (3)
examples/minference/example_vertical_slash_sparse_attn.py (2)

566-566: Remove redundant self-assignment.

Line 566 vertical_size, slash_size = vertical_size, slash_size is a no-op that appears to be leftover from the refactoring. This was likely meant to handle the original argparse-based assignment but is now unnecessary since these are function parameters.

Proposed fix
 def main(batch=1, heads=1, seq_len=4096, head_dim=64, vertical_size=1000, slash_size=200):
     BATCH, N_HEADS, SEQ_LEN, D_HEAD = batch, heads, seq_len, head_dim
 
-    vertical_size, slash_size = vertical_size, slash_size
-
     torch.manual_seed(0)

629-630: Duplicate sorting operations in run_regression_perf.

The v_idx and s_idx tensors are sorted twice with identical operations (lines 629-630 and lines 650-651). The second sorting is redundant since the tensors are already sorted.

Proposed fix - remove duplicate sorting
     v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
     s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
     from torch.utils.cpp_extension import load
     import os

     current_dir = os.path.dirname(os.path.abspath(__file__))
     sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")]
     ops = load(name="convert", sources=sources, verbose=False)
     convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes
     batch_size, num_heads, context_size, head_dim = query.shape
     pad = (block_size_M - context_size) & (block_size_M - 1)
     if pad == block_size_M:
         pad = 0
     query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
     key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
     value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
     if head_dim not in [16, 32, 64, 128, 256, 512]:
         target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
         query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
         key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
         value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
-    v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
-    s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
     seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device)

Also applies to: 650-651

examples/gemm_sp/example_gemm_sp.py (1)

100-101: Consider aligning default cfg values.

The function default is cfg="h20" (line 100), while the CLI default is cfg="4090" (line 131). This inconsistency means programmatic callers (like tests) will use "h20" config while CLI users get "4090". If this is intentional for hardware-specific testing, consider documenting it; otherwise, align the defaults.

Also applies to: 131-131

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d5503cd and 45b3a3a.

📒 Files selected for processing (4)
  • examples/gemm_sp/example_custom_compress.py
  • examples/gemm_sp/example_gemm_sp.py
  • examples/minference/example_vertical_slash_sparse_attn.py
  • examples/minference/test_vs_sparse_attn.py
🧰 Additional context used
🧬 Code graph analysis (4)
examples/minference/test_vs_sparse_attn.py (2)
examples/minference/example_vertical_slash_sparse_attn.py (1)
  • main (563-602)
tilelang/testing/__init__.py (1)
  • main (27-29)
examples/gemm_sp/example_custom_compress.py (2)
examples/gemm_sp/example_gemm_sp.py (1)
  • main (100-122)
tilelang/utils/sparse.py (1)
  • randn_semi_sparse (105-124)
examples/gemm_sp/example_gemm_sp.py (1)
tilelang/utils/sparse.py (2)
  • randn_semi_sparse (105-124)
  • compress (77-102)
examples/minference/example_vertical_slash_sparse_attn.py (1)
examples/elementwise/example_elementwise_add.py (1)
  • run_regression_perf (56-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (4)
examples/minference/test_vs_sparse_attn.py (1)

7-8: LGTM!

The test correctly adapts to the new parameterized API by calling main() without arguments, which uses the default parameters. This is consistent with the refactored signature in example_vertical_slash_sparse_attn.py where main(batch=1, heads=1, seq_len=4096, ...) uses a smaller seq_len=4096 compared to the CLI default of 16384, effectively reducing test time as intended by the PR.

examples/gemm_sp/example_custom_compress.py (1)

294-322: LGTM on the parameterized main() function.

The refactoring correctly parameterizes the function with sensible defaults (M=1024, N=1024, K=1024) that are smaller than the CLI defaults (16384), which aligns with the PR goal of reducing test time for programmatic usage.

examples/minference/example_vertical_slash_sparse_attn.py (1)

563-565: LGTM on the parameterized API refactoring.

The conversion from argv-based parsing to parameterized functions with defaults is well-executed. The main() function now has smaller defaults (seq_len=4096) suitable for quick tests, while the CLI block preserves larger values (seq_len=16384) for manual benchmarking. This aligns with the PR goal of reducing CI test time.

Also applies to: 669-685

examples/gemm_sp/example_gemm_sp.py (1)

100-123: LGTM on the parameterized main() function.

The refactoring correctly converts to a parameterized API with appropriate defaults for programmatic testing (M=1024, N=1024, K=1024) while preserving larger values (16384) for CLI benchmarking.

Comment on lines +330 to +332
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Argparse type mismatch for accum_dtype.

The --accum_dtype argument is declared with type=str but choices=[T.float, T.float16] contains type objects, not strings. Additionally, the default is T.float (a type object). This configuration is inconsistent: if a user passes --accum_dtype float, argparse will compare the string "float" against T.float, which will fail validation.

Proposed fix
-    parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")

Then map the string to the type in the main call:

+    dtype_map = {"float": T.float, "float16": T.float16}
     main(
         M=args.m,
         N=args.n,
         K=args.k,
         use_cutlass_layout=args.use_cutlass_layout,
         use_torch_compressor=args.use_torch_compressor,
-        accum_dtype=args.accum_dtype,
+        accum_dtype=dtype_map[args.accum_dtype],
         cfg=args.cfg,
     )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @examples/gemm_sp/example_custom_compress.py around lines 330 - 332, The
argparse entry for accum_dtype is inconsistent: change
parser.add_argument("--accum_dtype", type=str, choices=[T.float, T.float16],
default=T.float, ...) to use string choices and default (e.g., type=str,
choices=["float","float16"], default="float") in the parser.add_argument call,
and then map that string to the actual torch dtype before use (e.g., create a
mapping like accum_dtype_map = {"float": T.float, "float16": T.float16} and set
accum_dtype = accum_dtype_map[args.accum_dtype] where the script uses
accum_dtype).

Comment on lines +130 to +131
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Argparse type mismatch for accum_dtype.

Same issue as in example_custom_compress.py: the --accum_dtype argument uses type=str with choices=[T.float, T.float16] (type objects), which will cause validation failures when CLI arguments are passed.

Proposed fix
-    parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
+    parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")

Then map in the call:

+    dtype_map = {"float": T.float, "float16": T.float16}
-    main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg)
+    main(M=args.m, N=args.n, K=args.k, accum_dtype=dtype_map[args.accum_dtype], cfg=args.cfg)

Committable suggestion skipped: line range outside the PR's diff.

@LeiWang1999 LeiWang1999 merged commit ce68b51 into tile-ai:main Jan 8, 2026
8 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant