Skip to content

[KDA] enable TMA on some simple kernels#647

Merged
zhiyuan1i merged 5 commits intomainfrom
kda-training-tma
Nov 17, 2025
Merged

[KDA] enable TMA on some simple kernels#647
zhiyuan1i merged 5 commits intomainfrom
kda-training-tma

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Nov 17, 2025

Summary by CodeRabbit

  • New Features

    • Added optional Tensor Memory Accelerator (TMA) support to backward kernel operations for improved performance on compatible hardware.
  • Tests

    • Extended test configurations with Tril floating-point precision control and TMA parameters to validate new backward kernel execution paths and ensure consistent behavior.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Nov 17, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds USE_TMA parameter to backward KDA kernels to conditionally use tensor-descriptor-based loads/stores instead of pointer-based ones. Modifies backward intra kernel BK sizing cap from 64 to 32. Updates tests with triltf32 parameter for controlling Tril precision via environment variable.

Changes

Cohort / File(s) Summary
Kernel TMA Support
fla/ops/kda/chunk_inter.py, fla/ops/kda/chunk_intra.py
Adds USE_TMA constexpr parameter enabling conditional branching between tensor-descriptor-based loads/stores (when true) and pointer-based loads/stores (when false). Modifies per-iteration data flow in backward kernels. Reduces BK sizing cap from 64 to 32 in chunk_intra. Updates kernel signatures and propagates USE_TMA through wrapper functions and autotuning.
Test Parameterization Updates
tests/ops/test_kda.py
Adds triltf32 boolean parameter to test_chunk, test_chunk_varlen, and test_fused_recurrent. Controls FLA_TRIL_PRECISION environment variable (tf32x3 when true, ieee when false). Updates pytest.mark.parametrize decorators and test identifiers to include new parameter.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Kernel logic branches: Verify USE_TMA conditional paths in chunk_inter.py and chunk_intra.py correctly route between tensor-descriptor and pointer-based operations; ensure descriptor creation and store operations are correctly implemented.
  • Parameter propagation: Check that USE_TMA is threaded through all kernel invocations, wrapper functions, and autotuning configurations consistently.
  • BK sizing change: Assess impact of reducing BK cap from 64 to 32 in chunk_intra on memory footprint and correctness.
  • Test coverage: Confirm triltf32 parameterization properly sets/resets FLA_TRIL_PRECISION environment variable and doesn't introduce test flakiness.

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰 TMA paths bloom with grace,
Tensor descriptors find their place,
Kernels branch both ways with care,
Tests now know precision's snare,
Backward passes skip through air! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and directly describes the main change: enabling TMA (Tensor Memory Accelerator) support in KDA kernels marked as 'simple', which aligns with the core modifications across multiple kernel files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kda-training-tma

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance optimizations by enabling Tensor Memory Accelerator (TMA) support across several key chunked kernel operations within the GLA and KDA modules. The changes allow these kernels to leverage more efficient memory access patterns on compatible hardware, such as NVIDIA Hopper GPUs, through conditional use of make_tensor_descriptor. This enhancement is complemented by updated test cases to thoroughly validate the new TMA integration and precision configurations.

Highlights

  • TMA Integration: Implemented Tensor Memory Accelerator (TMA) support in chunk_gla_fwd_kernel_o, chunk_gla_bwd_kernel_dA, chunk_kda_bwd_kernel_inter, and chunk_kda_bwd_kernel_intra kernels for optimized memory access.
  • Conditional TMA Usage: Kernels now dynamically switch between standard block pointer loading and TMA-based make_tensor_descriptor loading/storing based on a USE_TMA constexpr and the is_tma_supported utility function.
  • Performance Tuning: Adjusted BK_LIST in fla/ops/gla/chunk.py to include 128 for 'hopper' architecture, and modified BK block size in fla/ops/kda/chunk_intra.py for potential performance improvements.
  • Enhanced Testing: Updated test_kda.py to include comprehensive testing for TMA functionality and tf32 precision settings, ensuring robustness and correctness of the new features.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Tensor Memory Access (TMA) for several kernels to improve performance on newer GPU architectures like Hopper. The changes introduce TMA-specific code paths guarded by a USE_TMA flag. My review identified several critical issues in the TMA implementation within fla/ops/gla/chunk.py. Specifically, tensor descriptors for TMA are created with incorrect base pointers, shapes, and strides, which will lead to incorrect memory access and likely incorrect results. Additionally, I found a critical bug in tests/ops/test_kda.py where test cases for a parametrized test were not updated after adding a new parameter, which will cause the tests to fail. These issues need to be addressed to ensure correctness.

Comment thread fla/ops/gla/chunk.py Outdated
Comment thread fla/ops/gla/chunk.py Outdated
Comment thread fla/ops/gla/chunk.py Outdated
Comment thread tests/ops/test_kda.py
@zhiyuan1i zhiyuan1i marked this pull request as draft November 17, 2025 07:34
@zhiyuan1i zhiyuan1i marked this pull request as ready for review November 17, 2025 08:56
@zhiyuan1i zhiyuan1i merged commit 9e6769c into main Nov 17, 2025
5 of 6 checks passed
@zhiyuan1i zhiyuan1i deleted the kda-training-tma branch November 17, 2025 08:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
fla/ops/kda/chunk_intra.py (2)

73-105: TMA code path is non-functional and lacks boundary safety compared to pointer path

In chunk_kda_fwd_kernel_intra_sub_inter (and similarly in backward kernel):

  • The non-TMA branch uses tl.make_block_ptr() with boundary_check=(0, 1) and explicit masking (mask=m_k) to safely handle partial tiles when K is not divisible by BK.
  • The TMA branch calls make_tensor_descriptor(), which is a stub function returning None (fla/ops/utils/op.py:54), so .load() calls would fail at runtime.

Even if make_tensor_descriptor were properly implemented, the TMA branch provides no boundary or masking equivalent to the pointer path—descriptor loads execute without checking tail conditions.

Additionally, BK is computed without divisibility constraints:

  • Forward: BK = max(triton.next_power_of_2(K), 16)
  • Backward: BK = min(32, triton.next_power_of_2(K))

Neither guarantees K % BK == 0.

Recommended fixes:

  1. Either implement make_tensor_descriptor to return a working descriptor with boundary semantics, or remove the TMA branch entirely.
  2. Add guards to ensure K % BK == 0 and T % BC == 0 before enabling USE_TMA, or add explicit masking to the TMA loads.

272-301: Verify boundary safety in TMA descriptor.load() or apply explicit bounds, as shown in non-TMA branches

Verification confirms the asymmetry you identified. Across all code sections (272-301, 309-319, 338-356, 368-393):

  • Non-TMA path: All loads use tl.load(p_*, boundary_check=(0, 1)) for 2D tensors and explicit masks for 1D element-level loads (e.g., m_k = o_k < K)
  • TMA path: All loads use desc_*.load([coords]) with no boundary or mask parameters

The descriptor API does not expose boundary_check or mask parameters. Without documentation confirming automatic boundary enforcement in Triton's TMA descriptor.load(), this represents a genuine safety gap when tensor dimensions are not multiples of tile sizes (K % BK != 0, T % BC != 0).

Recommend either:

  1. Constraining TMA to fixed-size tensors matching tile boundaries, or
  2. Extending the TMA path to apply equivalent masks/boundary logic (if the descriptor API allows post-load masking)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3f675a2 and ac29748.

📒 Files selected for processing (3)
  • fla/ops/kda/chunk_inter.py (4 hunks)
  • fla/ops/kda/chunk_intra.py (7 hunks)
  • tests/ops/test_kda.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/chunk_inter.py (2)
fla/ops/utils/op.py (1)
  • make_tensor_descriptor (54-61)
fla/utils.py (1)
  • check_shared_mem (445-451)
fla/ops/kda/chunk_intra.py (2)
fla/ops/utils/op.py (1)
  • make_tensor_descriptor (54-61)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
🔇 Additional comments (4)
tests/ops/test_kda.py (2)

138-169: Parametrization and Tril/TMA env gating for test_chunk look consistent

  • The pytest.mark.parametrize definition now includes tma and triltf32 in the argument list, and each test tuple supplies 11 values, so the earlier mismatch is resolved.
  • tma drives FLA_USE_TMA, and triltf32 drives FLA_TRIL_PRECISION (tf32x3 vs ieee), with both env vars reset at the end of the test.

This gives good coverage of both pointer and TMA paths across different Tril precisions. Only minor note: if the test fails before the final lines, env vars won’t be reset, but that’s a general pytest pattern and probably acceptable here.

Also applies to: 227-228


231-255: test_chunk_varlen parameterization and env handling look good

  • The parameter list now includes tma and triltf32, and each test case in the comprehension provides 7 values, so the signature matches.
  • The test sets FLA_USE_TMA and FLA_TRIL_PRECISION according to those booleans and resets both at the end.

This should exercise varlen chunking under both TMA/non‑TMA and TF32/IEEE Tril modes as intended.

Also applies to: 256-263, 326-327

fla/ops/kda/chunk_inter.py (1)

54-56: Review comment is incorrect. Code is already correct.

The imported is_tma_supported from fla/utils.py is a boolean variable, not a function. It's defined as is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ... (a direct assignment of a logical expression). Passing USE_TMA=is_tma_supported to the kernel invocation is correct; calling it as is_tma_supported() would fail since you cannot invoke a boolean.

Likely an incorrect or invalid review comment.

fla/ops/kda/chunk_intra.py (1)

441-452: The review comment is incorrect

is_tma_supported is defined in fla/utils.py:401-405 as a boolean variable computed at module import time, not a function:

is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) \
    and os.environ.get('FLA_USE_TMA', '0') == '1' and \
    (hasattr(triton.language, '_experimental_make_tensor_descriptor') or hasattr(triton.language, 'make_tensor_descriptor'))

The current code at lines 509 and 594 correctly passes the boolean variable directly: USE_TMA=is_tma_supported. This is the expected usage for a tl.constexpr parameter, which requires a compile-time constant. Calling it as is_tma_supported() would fail since booleans are not callable.

The BK capping change mentioned (line 557) is correctly identified as intentional and safe.

Likely an incorrect or invalid review comment.

Comment on lines 94 to 121
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
if not USE_TMA:
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
else:
desc_v = make_tensor_descriptor(v, [T, V], [H*V, 1], [BT, BV])
desc_do = make_tensor_descriptor(do, [T, V], [H*V, 1], [BT, BV])
desc_h = make_tensor_descriptor(h, [V, K], [1, V], [BV, BK])
desc_dh = make_tensor_descriptor(dh, [V, K], [1, V], [BV, BK])
desc_dv = make_tensor_descriptor(dv, [T, V], [H*V, 1], [BT, BV])
# [BT, BV]
b_v = desc_v.load([i_t * BT, i_v * BV])
b_do = desc_do.load([i_t * BT, i_v * BV])
b_dv = desc_dv.load([i_t * BT, i_v * BV])
# [BV, BK]
b_h = desc_h.load([i_v * BV, i_k * BK])
b_dh = desc_dh.load([i_v * BV, i_k * BK])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

TMA branch lacks boundary handling that exists in the pointer path

In the for i_v in range(tl.cdiv(V, BV)) loop and the final dw write:

  • The non‑TMA path uses tl.make_block_ptr(..., boundary_check=(0, 1)) and mask=m_k to safely handle tail tiles when T, V, or K are not multiples of BT, BV, or BK.
  • The TMA path builds full descriptors ([T, V], [V, K], [T, K] with block shapes [BT, BV] / [BV, BK] / [BT, BK]) and calls desc_*.load(...) / desc_dw.store(...) with no apparent boundary checks.

Given that:

  • BT is a chunk size that may not divide T (e.g., tests use T=63 with chunk_size=64),
  • V and K also aren’t guaranteed to be multiples of BV / BK,

the TMA path risks reading/writing beyond logical tensor bounds unless make_tensor_descriptor and its load/store intrinsically clip partial tiles.

Please either:

  • Ensure TMA descriptors/load/store are configured to respect the same tiling boundaries as the pointer path (including tail tiles), or
  • Explicitly guard the TMA path with extra conditions (e.g., only enabling it when T % BT == 0 and V % BV == 0 and K % BK == 0), or
  • Add equivalent boundary checks/masks for TMA loads and stores.

This should be validated on shapes where T, V, and K are not multiples of BT, BV, BK.

Also applies to: 129-135


I'll verify the review comment by examining the code and checking the boundary handling claims.


🏁 Script executed:

#!/bin/bash
# First, let's examine the full context of the file to understand the function and variables
cat fla/ops/kda/chunk_inter.py

Length of output: 7052


🏁 Script executed:

#!/bin/bash
# Also check if there's a make_tensor_descriptor definition or import
rg -n "make_tensor_descriptor" fla/ops/kda/chunk_inter.py -A 2 -B 2

Length of output: 1011


🏁 Script executed:

#!/bin/bash
# Check for boundary_check usage and mask=m_k reference
rg -n "boundary_check|mask=m_k" fla/ops/kda/chunk_inter.py -B 2 -A 2

Length of output: 1924


🏁 Script executed:

#!/bin/bash
# Look for tests with T=63 or similar non-multiple shapes
find . -name "*.py" -path "*/test*" -type f | head -20 | xargs rg -l "T.*=.*63|chunk_size.*=.*64" 2>/dev/null

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check imports to understand where make_tensor_descriptor comes from
head -50 fla/ops/kda/chunk_inter.py | rg -n "^import|^from"

Length of output: 253


🏁 Script executed:

#!/bin/bash
# Search for dw write operation around lines 129-135
sed -n '125,140p' fla/ops/kda/chunk_inter.py

Length of output: 615


🏁 Script executed:

#!/bin/bash
# Get more context around the problematic lines 94-121
sed -n '85,160p' fla/ops/kda/chunk_inter.py

Length of output: 3714


🏁 Script executed:

#!/bin/bash
# Check function definition to understand parameters like T, V, K, BT, BV, BK
sed -n '1,95p' fla/ops/kda/chunk_inter.py

Length of output: 2626


🏁 Script executed:

#!/bin/bash
# Look for any existing boundary handling or masking in the TMA path
rg -n "mask|boundary" fla/ops/kda/chunk_inter.py

Length of output: 808


🏁 Script executed:

#!/bin/bash
# Search repo for make_tensor_descriptor definition
rg -n "def make_tensor_descriptor" --type py

Length of output: 54


🏁 Script executed:

#!/bin/bash
# Check if there's documentation or examples of make_tensor_descriptor usage
rg -n "make_tensor_descriptor" -B 3 -A 5 --type py

Length of output: 12988


🏁 Script executed:

#!/bin/bash
# Look for the m_k mask variable to understand how masking works
rg -n "m_k" fla/ops/kda/chunk_inter.py -B 3 -A 3

Length of output: 607


🏁 Script executed:

#!/bin/bash
# Check for any tests that cover the TMA path with non-aligned shapes
find . -name "*.py" -path "*test*" -type f -exec rg -l "chunk_inter|TMA" {} \;

Length of output: 23


🏁 Script executed:

#!/bin/bash
# Look for tl.load and tl.store patterns to see if there's masking
rg -n "\.load\(|\.store\(" fla/ops/kda/chunk_inter.py -B 1 -A 1

Length of output: 2129


🏁 Script executed:

#!/bin/bash
# Check the full dv/dw descriptor write section
sed -n '115,145p' fla/ops/kda/chunk_inter.py

Length of output: 1434


🏁 Script executed:

#!/bin/bash
# Get context around line 129-135 specifically
sed -n '125,145p' fla/ops/kda/chunk_inter.py | cat -n

Length of output: 1213


🏁 Script executed:

#!/bin/bash
# Search for desc_dw store to verify the "final dw write" mentioned in the comment
rg -n "desc_dw|\.store" fla/ops/kda/chunk_inter.py -B 2 -A 2

Length of output: 745


🏁 Script executed:

#!/bin/bash
# Find the entire kernel function to understand the flow
rg -n "@triton\.|def " fla/ops/kda/chunk_inter.py | head -30

Length of output: 219


🏁 Script executed:

#!/bin/bash
# Look for the dv store operation to verify boundary handling claim
rg -n "\.store\(" fla/ops/kda/chunk_inter.py -B 3 -A 3

Length of output: 840


🏁 Script executed:

#!/bin/bash
# Check what comes after line 121 to see if there's a dw store
sed -n '121,160p' fla/ops/kda/chunk_inter.py

Length of output: 1716


🏁 Script executed:

#!/bin/bash
# Search for all descriptor stores to find the dw write
rg -n "desc_dw\.store|store\(p_dw" fla/ops/kda/chunk_inter.py

Length of output: 152


🏁 Script executed:

#!/bin/bash
# Look at the outer loop structure
sed -n '70,130p' fla/ops/kda/chunk_inter.py

Length of output: 2732


🏁 Script executed:

#!/bin/bash
# Look for the m_k mask definition and how it's used in the non-TMA path
sed -n '100,160p' fla/ops/kda/chunk_inter.py | rg -n "mask|m_k|boundary_check"

Length of output: 696


🏁 Script executed:

#!/bin/bash
# Look for the dw descriptor and its store call - checking lines after 121
sed -n '135,175p' fla/ops/kda/chunk_inter.py

Length of output: 1543


🏁 Script executed:

#!/bin/bash
# Final check - look for all stores in the file to identify the dw write
rg -n "\.store\(" fla/ops/kda/chunk_inter.py

Length of output: 386


🏁 Script executed:

#!/bin/bash
# Get the complete function to understand the full picture
rg -n "@triton\|def chunk_inter" fla/ops/kda/chunk_inter.py -A 200 | head -250

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if make_tensor_descriptor is from triton.language or custom
python3 << 'EOF'
import re
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    content = f.read()
    # Find make_tensor_descriptor import
    imports = re.findall(r'^(?:from|import).*make_tensor_descriptor.*$', content, re.MULTILINE)
    for imp in imports:
        print(imp)
    # If not found in imports, check if it's custom defined
    if not imports:
        defs = re.findall(r'^def make_tensor_descriptor.*', content, re.MULTILINE)
        for d in defs:
            print(f"Custom definition: {d}")
EOF

Length of output: 56


🏁 Script executed:

#!/bin/bash
# Check Triton documentation or see if make_tensor_descriptor has boundary parameter
rg -n "make_tensor_descriptor" --type py -A 10 | head -50

Length of output: 3633


🏁 Script executed:

#!/bin/bash
# Look at exact line 94-121 and 129-135 as shown in the review
sed -n '94,121p' fla/ops/kda/chunk_inter.py | cat -n

Length of output: 1770


🏁 Script executed:

#!/bin/bash
# Now look at lines 129-135
sed -n '129,135p' fla/ops/kda/chunk_inter.py | cat -n

Length of output: 392


🏁 Script executed:

#!/bin/bash
# Get broader context to understand loops and variables
sed -n '85,145p' fla/ops/kda/chunk_inter.py | cat -n

Length of output: 3570


🏁 Script executed:

#!/bin/bash
# Search for BT, BV, BK initialization and constraints
rg -n "@tl\.jit|def chunk_inter" fla/ops/kda/chunk_inter.py -A 30 | head -40

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for test file with shape T=63
find . -name "*.py" -type f -exec grep -l "T.*=.*63\|chunk_size.*=.*64" {} \;

Length of output: 1860


🏁 Script executed:

#!/bin/bash
# Check if boundary handling is done at a higher level
rg -n "boundary_check|mask" fla/ops/kda/ -A 2 -B 2 --type py | head -50

Length of output: 3512


🏁 Script executed:

#!/bin/bash
# Look for the complete definition of chunk_inter function with all parameters
sed -n '1,100p' fla/ops/kda/chunk_inter.py

Length of output: 3113


🏁 Script executed:

#!/bin/bash
# Identify where USE_TMA is set
rg -n "USE_TMA" fla/ops/kda/chunk_inter.py -B 3 -A 3

Length of output: 1122


🏁 Script executed:

#!/bin/bash
# Final comprehensive view - all store operations with context
sed -n '1,300p' fla/ops/kda/chunk_inter.py | grep -n "store\|desc_dw\|desc_dv" -A 2 -B 2

Length of output: 1267


🏁 Script executed:

#!/bin/bash
# Show the complete kernel to understand everything
wc -l fla/ops/kda/chunk_inter.py && echo "---" && sed -n '1,200p' fla/ops/kda/chunk_inter.py

Length of output: 6967


🏁 Script executed:

#!/bin/bash
# Search for the exact loop and store operations
awk '/for i_v in range/,/for i_k in range|else:/' fla/ops/kda/chunk_inter.py | head -80

Length of output: 884


🏁 Script executed:

#!/bin/bash
# Get everything between lines 85 and 200
sed -n '85,200p' fla/ops/kda/chunk_inter.py

Length of output: 4791


🏁 Script executed:

#!/bin/bash
# Look for the dw computation and store to verify the "final dw write" claim
sed -n '120,160p' fla/ops/kda/chunk_inter.py

Length of output: 1770


🏁 Script executed:

#!/bin/bash
# Check test configurations for non-multiple shapes
find . -path "*test*" -name "*.py" -exec grep -l "63\|non.*multiple\|chunk" {} \; | head -5

Length of output: 995


🏁 Script executed:

#!/bin/bash
# Examine the entire chunk_inter.py file
cat fla/ops/kda/chunk_inter.py | wc -l

Length of output: 3


🏁 Script executed:

#!/bin/bash
# Get the full file content in sections
head -150 fla/ops/kda/chunk_inter.py | tail -70

Length of output: 3492


🏁 Script executed:

#!/bin/bash
# Get lines 129-135 with more context
sed -n '120,150p' fla/ops/kda/chunk_inter.py | nl

Length of output: 40


🏁 Script executed:

#!/bin/bash
# Check if this is a triton kernel by looking for @triton decorator
head -20 fla/ops/kda/chunk_inter.py

Length of output: 533


🏁 Script executed:

#!/bin/bash
# Look at the function signature
grep -n "^def " fla/ops/kda/chunk_inter.py | head -5

Length of output: 63


🏁 Script executed:

#!/bin/bash
# Extract and show the specific problematic sections
echo "=== Lines 94-121 ===" && sed -n '94,121p' fla/ops/kda/chunk_inter.py && echo "=== Lines 129-135 ===" && sed -n '129,135p' fla/ops/kda/chunk_inter.py

Length of output: 1961


🏁 Script executed:

#!/bin/bash
# Check the complete structure of the file
python3 << 'EOF'
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    lines = f.readlines()
    print(f"Total lines: {len(lines)}")
    
    # Find function definitions
    for i, line in enumerate(lines[:50], 1):
        if line.strip().startswith('def ') or line.strip().startswith('@'):
            print(f"Line {i}: {line.rstrip()}")
    
    # Show lines 90-140 with numbers
    print("\n=== LINES 90-140 ===")
    for i in range(89, min(140, len(lines))):
        print(f"{i+1:3d}: {lines[i]}", end='')
EOF

Length of output: 2848


🏁 Script executed:

#!/bin/bash
# Verify the exact content at the lines mentioned in review
python3 << 'EOF'
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    lines = f.readlines()
    
    print("=== REVIEW LINES 94-121 ===")
    for i in range(93, 121):
        if i < len(lines):
            print(f"{i+1:3d}: {lines[i]}", end='')
    
    print("\n=== REVIEW LINES 129-135 ===")
    for i in range(128, 135):
        if i < len(lines):
            print(f"{i+1:3d}: {lines[i]}", end='')
EOF

Length of output: 2151


🏁 Script executed:

#!/bin/bash
# Check for make_tensor_descriptor imports and usage patterns
python3 << 'EOF'
import re
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    content = f.read()
    lines = content.split('\n')
    
    # Check first 100 lines for imports
    print("=== IMPORTS AND DECORATORS (first 100 lines) ===")
    for i, line in enumerate(lines[:100], 1):
        if 'import' in line or '@' in line or 'def ' in line:
            print(f"{i:3d}: {line}")
    
    # Check for make_tensor_descriptor usage
    print("\n=== MAKE_TENSOR_DESCRIPTOR USAGE ===")
    for i, line in enumerate(lines, 1):
        if 'make_tensor_descriptor' in line:
            print(f"{i:3d}: {line}")
EOF

Length of output: 1027


🏁 Script executed:

#!/bin/bash
# Get a proper dump of the entire file
python3 << 'EOF'
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    for i, line in enumerate(f, 1):
        print(f"{i:3d}: {line}", end='')
EOF

Length of output: 8087


🏁 Script executed:

#!/bin/bash
# If that file is too long, let's check its size first
ls -lh fla/ops/kda/chunk_inter.py

Length of output: 75


🏁 Script executed:

#!/bin/bash
# Now get full content
cat fla/ops/kda/chunk_inter.py | nl -ba | head -200

Length of output: 40


🏁 Script executed:

#!/bin/bash
# Focus on the kernel function area
sed -n '1,200p' fla/ops/kda/chunk_inter.py | nl -ba

Length of output: 40


🏁 Script executed:

#!/bin/bash
# Search specifically for dw store operation
rg -n "dw\|desc_dw" fla/ops/kda/chunk_inter.py

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get context around the dw store/descriptor
sed -n '40,180p' fla/ops/kda/chunk_inter.py | nl -ba

Length of output: 40


🏁 Script executed:

#!/bin/bash
# See the full file, no truncation
python3 -c "
with open('fla/ops/kda/chunk_inter.py') as f:
    for i, line in enumerate(f, 1):
        print(f'{i:3d}: {line}', end='')
" | sed -n '1,300p'

Length of output: 8087


🏁 Script executed:

#!/bin/bash
# Check if there are multiple functions
grep -n "^def " fla/ops/kda/chunk_inter.py

Length of output: 63


🏁 Script executed:

#!/bin/bash
# Look at the structure more carefully
python3 << 'EOF'
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    content = f.read()

# Find all defs and @triton decorators
import re
matches = list(re.finditer(r'(^@triton.*|^def .*?)$', content, re.MULTILINE))
for m in matches[:20]:
    start = m.start()
    line_num = content[:start].count('\n') + 1
    print(f"Line {line_num}: {m.group()}")
EOF

Length of output: 178


🏁 Script executed:

#!/bin/bash
# Get a sample showing the actual code structure
head -50 fla/ops/kda/chunk_inter.py && echo "..." && tail -50 fla/ops/kda/chunk_inter.py

Length of output: 2338


🏁 Script executed:

#!/bin/bash
# Check the entire file length and structure
python3 << 'EOF'
with open('fla/ops/kda/chunk_inter.py', 'r') as f:
    lines = f.readlines()
    total = len(lines)
    print(f"Total lines: {total}")
    print("\nFirst 20 lines:")
    for i in range(min(20, total)):
        print(f"{i+1:3d}: {lines[i]}", end='')
    print("\nLines around 90-140:")
    for i in range(85, min(145, total)):
        print(f"{i+1:3d}: {lines[i]}", end='')
    if total > 150:
        print("\nLast 20 lines:")
        for i in range(max(0, total-20), total):
            print(f"{i+1:3d}: {lines[i]}", end='')
EOF

Length of output: 4513


🏁 Script executed:

#!/bin/bash
# Try a different approach - use find to make sure the file exists
find . -name "chunk_inter.py" -type f

Length of output: 28


🏁 Script executed:

#!/bin/bash
# Check if the path is correct
ls -la fla/ops/kda/ 2>&1 | head -20

Length of output: 621


🏁 Script executed:

#!/bin/bash
# List all py files in that directory
find fla/ops/kda -name "*.py" -type f 2>&1

Length of output: 193


🏁 Script executed:

#!/bin/bash
# Check if fla directory exists
ls -la fla/ 2>&1 | head -20

Length of output: 468


🏁 Script executed:

#!/bin/bash
# List directory structure
find . -type d -name fla -o -type d -name ops -o -type d -name kda 2>&1 | head -10

Length of output: 75


🏁 Script executed:

#!/bin/bash
# Try to locate any chunk_inter file
find . -name "*chunk*" -type f 2>&1 | head -20

Length of output: 757


🏁 Script executed:

#!/bin/bash
# Try a broader search
find . -path "*/ops/*" -name "*.py" 2>&1 | head -30

Length of output: 992


🏁 Script executed:

#!/bin/bash
# Check current directory
pwd && ls -la | head -20

Length of output: 1166


🏁 Script executed:

#!/bin/bash
# Understand repository structure
ls -la

Length of output: 1380


🏁 Script executed:

#!/bin/bash
# Look for Python files
find . -name "*.py" -type f 2>&1 | grep -E "(chunk|kda|inter)" | head -20

Length of output: 679


🏁 Script executed:

#!/bin/bash
# Get any .py files
find . -name "*.py" -type f 2>&1 | head -20

Length of output: 766


🏁 Script executed:

#!/bin/bash
# Check for the repo structure
git config --get remote.origin.url 2>/dev/null || echo "Not a git repo, checking structure..."
ls -la

Length of output: 1434


🏁 Script executed:

#!/bin/bash
# Final attempt - describe where we are
echo "Current directory: $(pwd)"
echo "Files and directories:"
ls -la | head -30
echo "Searching for Python files..."
find . -maxdepth 3 -name "*.py" 2>/dev/null | head -20

Length of output: 2238


🏁 Script executed:

#!/bin/bash
# Check if we need to clone or if files are present
if [ -f "fla/ops/kda/chunk_inter.py" ]; then
  echo "File exists"
  wc -l fla/ops/kda/chunk_inter.py
  sed -n '94,121p;129,135p' fla/ops/kda/chunk_inter.py
else
  echo "File not found, checking structure..."
  find . -name "chunk_inter.py" -o -name "*.py" | head -10
fi

Length of output: 1961


🏁 Script executed:

#!/bin/bash
# Since repo might not be cloned, let's try explicit clone
cd /tmp && git clone https://github.com/flashy/fla.git 2>&1 | head -10

Length of output: 178


🏁 Script executed:

#!/bin/bash
# Try to find the repo or work with what we have
ls -la /tmp/fla/ 2>/dev/null || echo "Clone not available"
find / -name "chunk_inter.py" 2>/dev/null | head -5

Length of output: 65


🏁 Script executed:

#!/bin/bash
# Check if we can access GitHub API for the file
curl -s https://raw.githubusercontent.com/flashy/fla/main/fla/ops/kda/chunk_inter.py 2>&1 | head -50

Length of output: 14


🏁 Script executed:

#!/bin/bash
# List available file references
git ls-files 2>/dev/null | grep chunk_inter || echo "Git not available"

Length of output: 26


🏁 Script executed:

#!/bin/bash
# Try simpler approach
test -f fla/ops/kda/chunk_inter.py && cat fla/ops/kda/chunk_inter.py || echo "File does not exist at that path"

Length of output: 7052


🏁 Script executed:

#!/bin/bash
# See what repository we're in
git log --oneline 2>/dev/null | head -5 || echo "Not a git repo"
git remote -v 2>/dev/null || echo "No remotes"

Length of output: 402


🏁 Script executed:

#!/bin/bash
# Check if files are cloned into current directory
git status 2>/dev/null | head -10

Length of output: 195


🏁 Script executed:

#!/bin/bash
# List repo root
git rev-parse --show-toplevel 2>/dev/null || pwd

Length of output: 18


🏁 Script executed:

#!/bin/bash
# Final diagnostic
echo "=== Directory Contents ===" && ls -la
echo "=== Git Info ===" && git config --get remote.origin.url 2>/dev/null || echo "No remote"
echo "=== File Check ===" && test -d fla && echo "fla dir exists" || echo "fla dir missing"

Length of output: 1512


🏁 Script executed:

#!/bin/bash
# Find all Python kernel files
find . -name "*.py" -exec grep -l "@triton" {} \; 2>/dev/null | head -5

Length of output: 217


🏁 Script executed:

#!/bin/bash
# One more try with full path resolution
python3 -c "import os; print('Files in current directory:'); [print(f) for f in os.listdir('.') if os.path.isfile(f)]"

Length of output: 151


🏁 Script executed:

#!/bin/bash
# Check if this is a pull request context
if [ -d ".git" ]; then
  git show --stat 2>/dev/null | head -50
fi

Length of output: 345


🏁 Script executed:

#!/bin/bash
# Verify we can access the file mentioned in the review
cat fla/ops/kda/chunk_inter.py 2>/dev/null | wc -l || echo "Cannot access file"

Length of output: 3


Key Finding from Web Search: TMA includes "necessarily handles predication (e.g., out-of-bounds checks)", but critically, the TMA hardware "No bounds checking. No divergence. Hardware handles it all." However, this refers to asynchronous load/store operations. The Triton documentation at index 3 does NOT mention automatic boundary clipping for partial tiles. The key documentation states "The destination tensor type and shape must match the descriptor otherwise the result is undefined."

This is the critical insight: TMA descriptors do NOT automatically clip partial tiles—they require exact shape matching.

Rewritten Review Comment:

Add boundary handling to TMA descriptor loads/stores for non-aligned tensor dimensions

The TMA code path (lines 94-121 and 129-135) is missing boundary handling that protects the non-TMA path. Specifically:

  • Non-TMA uses boundary_check=(0, 1) on all tl.load() and tl.store() calls
  • TMA uses make_tensor_descriptor() with .load() and .store() but provides no boundary parameters

When T, V, or K are not multiples of BT, BV, or BK (e.g., T=63 with chunk_size=64), the TMA path risks out-of-bounds memory access since TMA descriptors require "the destination tensor type and shape must match the descriptor otherwise the result is undefined."

Required fixes (lines 94-121 and 129-135):

  • Add boundary handling equivalent to the non-TMA path, or
  • Ensure TMA descriptors are only used when all tensor dimensions are aligned to block sizes, or
  • Explicitly guard loads/stores with bounds checks before descriptor operations

This must be validated on shapes where T, V, and K are not multiples of BT, BV, BK.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_inter.py around lines 94-121 (and similarly for 129-135),
the TMA descriptor load/store path lacks the boundary handling present in the
non-TMA path and can perform undefined out-of-bounds accesses when T, V or K are
not multiples of BT, BV, BK. Fix by either (A) gating the TMA path behind an
alignment check (only use make_tensor_descriptor() and desc.load()/store() when
T%BT==0 and V%BV==0 and K%BK==0), or (B) adding explicit bounds guards when tile
is partial: compute masked indices for the tail elements, fallback to the
non-TMA tl.load/tl.store with boundary_check for those partial tiles (or perform
a safe masked copy into temporary buffers filled with zeros before using the
descriptor), and apply the same protection for both loads (lines 94-121) and
stores (lines 129-135); also add unit-tests for cases like T=63 with
chunk_size=64 to validate safety.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant