-
Notifications
You must be signed in to change notification settings - Fork 331
[Refactor]: Change the params in pytest to avoid oom error during ci #1170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughReduced test/workload sizes across multiple example tests; parameterized one example Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as pytest test
participant Example as example_mha_inference.main
participant Flop as FlopCalc
Note over Test,Example: Test calls parameterized example with explicit args
Test->>Example: main(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal)
Example->>Flop: compute total_flops = f(BATCH,H,Q_CTX,KV_CTX,D_HEAD)
alt causal == True
Flop-->>Example: total_flops * 0.5
else causal == False
Flop-->>Example: total_flops
end
Example-->>Test: return profiling/metrics
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)examples/flash_attention/test_example_flash_attention.py (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
🔇 Additional comments (4)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/blocksparse_attention/test_example_blocksparse_attention.py (1)
38-47: Consider reducing batch/heads parameters for consistency.The mask variant only reduces
max_cache_seqlen(4096→1024) while keepingbatch=16,heads=16, andheads_kv=8unchanged. In contrast, the indice variant (lines 27-35) reduces all these parameters. If both tests have similar memory profiles, consider reducing the batch and heads parameters here as well for consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/blocksparse_attention/test_example_blocksparse_attention.py(2 hunks)examples/cast/test_example_cast.py(1 hunks)examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(2 hunks)examples/flash_attention/test_example_flash_attention.py(2 hunks)examples/flash_decoding/example_mha_inference.py(1 hunks)examples/flash_decoding/test_example_flash_decoding.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/flash_decoding/example_mha_inference.py (3)
examples/flash_attention/example_mha_fwd_varlen.py (2)
main(94-206)main(211-283)examples/deepseek_mla/example_mla_decode.py (1)
main(283-305)examples/flash_decoding/example_gqa_decode.py (1)
main(442-489)
examples/flash_attention/test_example_flash_attention.py (2)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
main(275-321)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (2)
main(121-164)main(194-238)
examples/flash_decoding/test_example_flash_decoding.py (1)
examples/flash_decoding/example_mha_inference.py (1)
main(305-323)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (2)
examples/deepseek_v32/fp8_lighting_indexer.py (1)
test_fp8_lighting_indexer(260-302)examples/deepseek_v32/sparse_mla_bwd.py (1)
test_sparse_mla_bwd(334-384)
examples/cast/test_example_cast.py (2)
examples/cast/example_per_token_cast_to_fp8.py (1)
main(80-114)examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
main(164-204)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (12)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (3)
16-16: LGTM! Workload reduced to prevent OOM.The parameter reduction (S: 1024→512, SKV: 2048→1024) appropriately lowers memory usage for CI testing while maintaining test coverage.
32-32: LGTM! Workload reduced to prevent OOM.The parameter reduction (SKV: 1024→512) appropriately lowers memory usage for CI testing.
39-39: LGTM! Workload reduced to prevent OOM.The parameter reduction (SKV: 1024→512) appropriately lowers memory usage for CI testing.
examples/blocksparse_attention/test_example_blocksparse_attention.py (1)
27-35: LGTM! Workload reduced to prevent OOM.The parameter reductions (batch: 16→8, heads: 16→8, heads_kv: 8→4, max_cache_seqlen: 4096→2048) appropriately lower memory usage for CI testing.
examples/cast/test_example_cast.py (2)
7-7: Verify M=4196 is intentional.The value M=4196 is unusual—it's not a power of 2 nor exactly half of the presumed original (8192). Most other reductions in this PR are clean halvings. Please confirm this is intentional (e.g., testing non-power-of-2 sizes or specific alignment) rather than a typo for 4096.
11-11: LGTM! Workload reduced to prevent OOM.The parameter reductions (M: 8192→2048, N: 2048→512, blk_m: 16→8) appropriately lower memory usage for CI testing.
examples/flash_attention/test_example_flash_attention.py (4)
36-42: LGTM! Explicit parameterization improves clarity.The change from implicit defaults to explicit parameters (BATCH=1, H=16, N_CTX=512, D_HEAD=64, causal=False) improves test readability and ensures appropriate workload sizing for CI.
47-53: LGTM! Explicit parameterization improves clarity.The change to explicit parameters ensures appropriate workload sizing for CI and improves test clarity.
59-65: LGTM! Explicit parameterization improves clarity.The change to explicit parameters ensures appropriate workload sizing for CI and improves test clarity.
105-105: LGTM! Explicit parameterization improves clarity.The change from no parameters to explicit parameters (batch=4, heads=16, seq_len=512, dim=64) improves test readability and ensures appropriate workload sizing for CI.
examples/flash_decoding/test_example_flash_decoding.py (1)
15-15: LGTM! Explicit parameterization with reduced KV context.The change to explicit parameters with KV_CTX reduced from the default 8192 to 2048 appropriately lowers memory usage for CI testing while maintaining test coverage.
examples/flash_decoding/example_mha_inference.py (1)
305-323: LGTM! Parameterization enables configurable testing.The extension of
main()from no parameters to a parameterized signature with sensible defaults enables external configuration of batch size, dimensions, and causal behavior. The causal flag is correctly incorporated into the FLOPs calculation (line 308-309), and backward compatibility is maintained through default values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/blocksparse_attention/test_example_blocksparse_attention.py(2 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py(2 hunks)examples/cast/test_example_cast.py(1 hunks)examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(2 hunks)examples/flash_attention/test_example_flash_attention.py(2 hunks)examples/flash_decoding/example_mha_inference.py(1 hunks)examples/flash_decoding/test_example_flash_decoding.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/flash_decoding/test_example_flash_decoding.py
- examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
- examples/blocksparse_attention/test_example_blocksparse_attention.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/flash_decoding/example_mha_inference.py (2)
examples/flash_attention/example_mha_fwd_varlen.py (2)
main(94-206)main(211-283)examples/flash_decoding/example_gqa_decode.py (1)
main(442-489)
examples/cast/test_example_cast.py (2)
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
main(164-204)examples/cast/example_per_token_cast_to_fp8.py (1)
main(80-114)
examples/flash_attention/test_example_flash_attention.py (3)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
main(275-321)examples/flash_attention/example_gqa_bwd.py (1)
main(464-522)examples/flash_attention/example_gqa_fwd_bshd.py (2)
main(154-192)main(222-264)
🪛 Ruff (0.14.2)
examples/cast/example_group_per_split_token_cast_to_fp8.py
164-164: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
🔇 Additional comments (7)
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
173-173: Good parameterization for test flexibility.The change from hardcoded values to parameterized
batch_sizesenables tests to pass smaller workloads, aligning well with the PR objective to avoid OOM errors during CI.examples/flash_attention/test_example_flash_attention.py (3)
36-42: LGTM: Reduced workload parameters for CI.The explicit parameters reduce memory requirements while maintaining test coverage. The smaller values (BATCH=1, H=16, N_CTX=512, D_HEAD=64) are appropriate for avoiding OOM errors during CI.
47-53: LGTM: Reduced workload parameters for CI.Consistent parameter reduction pattern applied to the BHSD variant, ensuring both test variants use appropriate memory footprints for CI environments.
99-99: LGTM: Reduced workload parameters for CI.The variable-length attention test now uses smaller dimensions (batch=4, heads=16, seq_len=512, dim=64) suitable for CI resource constraints.
examples/flash_decoding/example_mha_inference.py (1)
305-309: LGTM: Clean parameterization enables flexible test sizing.The function signature now accepts explicit configuration parameters, allowing tests to specify smaller workload sizes for CI. The
causalparameter properly adjuststotal_flopscalculation, and the implementation follows patterns established in related files.examples/cast/test_example_cast.py (2)
7-8: LGTM: Significant workload reduction for CI.The parameter reduction (M: 8192→1024, N: 2048→1024, blk_m: 8→4) substantially decreases memory usage while maintaining test coverage. The
batch_sizes=[128, 896]parameter aligns with the updated function signature.
12-12: LGTM: Reduced parameters for CI.The parameter reduction (M: 8192→2048, N: 2048→512) appropriately reduces memory footprint for CI environments while preserving test validity.
…ile-ai#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
…ile-ai#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (tile-ai#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (tile-ai#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (tile-ai#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (tile-ai#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (tile-ai#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (tile-ai#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (tile-ai#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (tile-ai#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
Summary by CodeRabbit
Tests
Chores