Skip to content

[GDN] Optimize b_dg computation in chunk_bwd_kernel_dqkwg#USE_G#823

Open
MzeroMiko wants to merge 2 commits intofla-org:mainfrom
MzeroMiko:main
Open

[GDN] Optimize b_dg computation in chunk_bwd_kernel_dqkwg#USE_G#823
MzeroMiko wants to merge 2 commits intofla-org:mainfrom
MzeroMiko:main

Conversation

@MzeroMiko
Copy link
Copy Markdown

@MzeroMiko MzeroMiko commented Apr 11, 2026

Math behind modification

before

$$\begin{aligned} \delta \boldsymbol{\gamma}_{[t],\text{part1}} &= \delta \exp \boldsymbol{\gamma}_{[t],\text{part1}} \odot \exp\boldsymbol{\gamma}_{[t]} \&= \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \right) \mathbf{Q}_{[t]}^\top \right) \&- \text{diag}\left( \mathbf{K}_{[t]} \left( \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}} \right)^\top \right) \&+ \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1} \right) \left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top \right)^\top \right) \&- \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1} \right)^\top \left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top \right) \right) \&+ [0,0,...,\delta \exp \boldsymbol{\gamma}_{[t]}^C \exp \boldsymbol{\gamma}_{[t]}^C]^\top \end{aligned}$$

after

$$\begin{aligned} \delta \boldsymbol{\gamma}_{[t],\text{part1}} &= \text{diag}\left( \delta \mathbf{Q}_{[t]} \mathbf{Q}_{[t]}^\top \right) - \text{diag}\left( \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/ } \mathbf{O}_{[t]} \text{w/o } \mathbf{V}_{[t],new} } \mathbf{K}_{[t]}^\top \right) \&+ [0,0,...,\delta \exp \boldsymbol{\gamma}_{[t]}^C \exp \boldsymbol{\gamma}_{[t]}^C]^\top \end{aligned}$$

Corresponding code: fla.ops.common.chunk_o -> chunk_bwd_dqkwg

Benchmark Results

Performance:
      B        T    H      D  origin (Execution Time (ms))  point1 (Execution Time (ms))
0  16.0    256.0  8.0  256.0                      1.240800                      1.170304
1  16.0    512.0  8.0  256.0                      2.102896                      1.962208
2  16.0   1024.0  8.0  256.0                      3.853152                      3.570560
3  16.0   2048.0  8.0  256.0                      7.273824                      6.776016
4  16.0   4096.0  8.0  256.0                     14.233936                     13.206304
5  16.0   8192.0  8.0  256.0                     27.992672                     25.957024
6  16.0  16384.0  8.0  256.0                     55.688545                     51.651104

Tests

  • tested on H100
  • we only test gated delta rule, comba, simple-gla as they are the only ops that uses chunk_bwd_dqkwg with g.

pytest tests/ops/test_gated_delta.py

=================================================== warnings summary ===================================================
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: 14 warnings
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================== 56 passed, 14 warnings in 1487.51s (0:24:47) =====================================

pytest tests/ops/test_comba.py

====================================================== warnings summary ======================================================
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: 14 warnings
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================== 23 passed, 14 warnings in 544.00s (0:09:03) =========================================

pytest tests/ops/test_simple_gla.py

====================================================== warnings summary ======================================================
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: 14 warnings
../../venvs/fla/lib/python3.10/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================== 42 passed, 4 skipped, 14 warnings in 862.48s (0:14:22) ===================================

Benchmark Code

from functools import partial

import torch
import torch.nn.functional as F
import triton

from fla.modules.l2norm import l2norm
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
from gdno import chunk_gated_delta_rule as gdno_chunk_gated_delta_rule
from gdnx import chunk_gated_delta_rule as gdnx_chunk_gated_delta_rule


def run_chunk_gated_delta_rule(
    _chunk_gated_delta_rule,
    q, k, v, g, beta, scale, h0, use_qk_l2norm_in_kernel,
    do, dht,
):
    tri, tri_ht = _chunk_gated_delta_rule(
        q=q,
        k=k,
        v=v,
        g=g,
        beta=beta,
        scale=scale,
        initial_state=h0,
        output_final_state=True,
        use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
    )

    ((tri * do).sum() + (tri_ht * dht).sum()).backward()
    return


@triton.testing.perf_report(
    triton.testing.Benchmark(
        # argument names to use as an x-axis for the plot
        x_names=['B', 'T', 'H', 'D'],
        # different possible values for `x_name`
        x_vals=[(16, 128 * 2 ** i, h, 2048//h) for h in [8,] for i in range(1, 8)],
        # x_vals=[(16, 128 * 2 ** i, h, 2048//h) for h in [16,] for i in range(1, 8)],
        # argument name whose value corresponds to a different line in the plot
        line_arg='provider',
        # possible values for `line_arg``
        line_vals=['origin', 'point1'],
        # label name for the lines
        line_names=['origin', 'point1'],
        # line styles
        styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
                ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')],
        ylabel="Execution Time (ms)",  # label name for the y-axis
        # name for the plot. Used also as a file name for saving the plot.
        plot_name="Performance",
        args={},
    ),
)
def benchmark(B, H, D, T, provider):
    from fla.utils import device
    gate_logit_normalizer = 1.0
    mask_p = 0.0
    use_qk_l2norm_in_kernel = False
    scale = 1.0
    dtype=torch.float16

    torch.manual_seed(42)
    q = torch.rand(B, T, H, D, dtype=dtype)
    k = torch.rand(B, T, H, D, dtype=dtype)
    v = torch.rand(B, T, H, D, dtype=dtype)
    beta = torch.rand(B, T, H, dtype=torch.float).sigmoid()
    g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32))
    g = g / gate_logit_normalizer
    g = g * (torch.rand_like(g) > mask_p)
    h0 = torch.zeros(B, H, D, D, dtype=torch.float32)
    q, k, v, beta, g, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, g, h0))
    q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone()
    k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone()
    do = torch.randn_like(v)
    dht = torch.randn_like(h0)

    def prepare_data():
        q1 = q.detach().clone().requires_grad_(True)
        k1 = k.detach().clone().requires_grad_(True)
        v1 = v.detach().clone().requires_grad_(True)
        g1 = g.detach().clone().requires_grad_(True)
        beta1 = beta.detach().clone().requires_grad_(True)
        h01 = h0.detach().clone().requires_grad_(True)
        do1 = do.detach()
        dht1 = dht.detach()
        return (
            q1, k1, v1, g1, beta1, scale, h01, use_qk_l2norm_in_kernel,
            do1, dht1,
        )

    # if True:
    #     data = prepare_data()
    #     def fn_point1():
    #         run_chunk_gated_delta_rule(
    #             gdnx_chunk_gated_delta_rule, *data,
    #         )
    #     fn_point1()
    #     breakpoint()

    quantiles = [0.5, 0.2, 0.8]
    results = 0, 0, 0
    if provider.startswith('origin'):
        data = prepare_data()
        def fn_origin():
            run_chunk_gated_delta_rule(
                gdno_chunk_gated_delta_rule, *data,
            )
        results = triton.testing.do_bench(fn_origin, quantiles=quantiles)
    if provider.startswith('point1'):
        data = prepare_data()
        def fn_point1():
            run_chunk_gated_delta_rule(
                gdnx_chunk_gated_delta_rule, *data,
            )
        results = triton.testing.do_bench(fn_point1, quantiles=quantiles)
    
    return results


if __name__ == '__main__':
    benchmark.run(print_data=True)

Summary by CodeRabbit

  • Refactor
    • Backward gradient aggregation now recomputes contributions from finalized intermediate values rather than incrementally accumulating partial terms. This shifts computation order to avoid reliance on transient state, clarifies data flow, and improves robustness and maintainability of gradient computations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 11, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5d38ec0f-d3a6-4125-8f54-3632b2f98af9

📥 Commits

Reviewing files that changed from the base of the PR and between 4f120e1 and 71a2bcc.

📒 Files selected for processing (1)
  • fla/ops/common/chunk_o.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/common/chunk_o.py

Walkthrough

Refactors chunk_bwd_kernel_dqkwg (the USE_G branch) to remove incremental b_dg accumulation via b_ds2. Instead, after recomputing final b_dq/b_dk, b_dg is zeroed and computed directly from sums of b_dq * b_q and b_dk * b_k (with subtraction).

Changes

Cohort / File(s) Summary
Backward Gradient Computation
fla/ops/common/chunk_o.py
In chunk_bwd_kernel_dqkwg (the USE_G branch) removed b_ds2-based incremental accumulation into b_dg. After recomputing b_dq/b_dk from final b_ds, b_dg is re-initialized to zeros and computed via tl.sum(b_dq * b_q, axis=1) and tl.sum(b_dk * b_k, axis=1) (subtracted).

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 I hop through kernels, quick and neat,
I stop mid-sum and make things neat.
Finalize dq, finalize dk’s tune —
Then recompute dg beneath the moon. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the specific optimization to the b_dg computation in the chunk_bwd_kernel_dqkwg function with the USE_G flag enabled, matching the primary change in the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the calculation of b_dg within the chunk_bwd_kernel_dqkwg Triton kernel by moving its initialization and consolidating its updates at the end of the block. A review comment suggests further simplifying the b_dg assignment into a single expression to improve code clarity and efficiency.

Comment thread fla/ops/common/chunk_o.py
Comment on lines +305 to +307
b_dg = tl.zeros([BT], dtype=tl.float32)
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dg -= tl.sum(b_dk * b_k, axis=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of b_dg to zero followed by incremental updates can be simplified into a single expression. This improves code clarity and avoids redundant operations.

Suggested change
b_dg = tl.zeros([BT], dtype=tl.float32)
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dg -= tl.sum(b_dk * b_k, axis=1)
b_dg = tl.sum(b_dq * b_q, axis=1) - tl.sum(b_dk * b_k, axis=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants