Skip to content

[GDN] Remove safe_exp & add gate 1 tests#463

Merged
yzhangcs merged 1 commit intomainfrom
safe_exp
Jun 18, 2025
Merged

[GDN] Remove safe_exp & add gate 1 tests#463
yzhangcs merged 1 commit intomainfrom
safe_exp

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Jun 18, 2025

Summary by CodeRabbit

  • Bug Fixes
    • Improved handling of gating and masking to ensure correct processing of boundary conditions in chunked operations.
  • Tests
    • Enhanced test coverage by introducing randomized gating sparsity, allowing for more robust validation of chunked gated delta rule implementations with varying levels of masked elements.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jun 18, 2025

Walkthrough

The changes systematically remove all usage of the safe_exp function across several kernel implementations and tests, replacing it with direct use of exp combined with explicit masking via index calculations and tl.where. This update standardizes boundary handling and masking logic for gating and exponential computations in both forward and backward kernels, and augments tests with stochastic gating sparsity.

Changes

File(s) Change Summary
fla/ops/common/chunk_delta_h.py Replaced safe_exp with exp and explicit masking using tl.where in forward and backward kernels for gating exponential computations.
fla/ops/common/chunk_o.py Removed safe_exp, used exp and explicit index-based masks for boundary checking and masking in all kernels.
fla/ops/common/chunk_scaled_dot_kkt.py Switched from safe_exp to exp, updated offset and masking logic to combine chunk and sequence bounds in kernel.
fla/ops/gated_delta_rule/wy_fast.py Replaced safe_exp with exp, refined masking logic for b_dA using global index and sequence length checks in the backward kernel.
tests/ops/test_gated_delta.py Added mask_p parameter to tests, introducing stochastic gating sparsity by randomly masking elements in the gating tensor.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Function
    participant Kernel as Kernel (e.g., chunk_delta_h, chunk_o)
    participant Exp as exp
    participant Mask as Masking Logic

    Test->>Kernel: Call forward/backward kernel
    Kernel->>Mask: Compute valid indices (e.g., o_t < T)
    Kernel->>Exp: Compute exp(gating_difference)
    Mask->>Kernel: Provide mask (valid/invalid positions)
    Kernel->>Kernel: Apply mask using tl.where (zero out invalid)
    Kernel-->>Test: Return result (with correct masking)
Loading

Possibly related PRs

  • fla-org/flash-linear-attention#433: Focuses on related gating exponential computations in chunk_delta_h.py, but retains safe_exp and differs in masking logic and kernel structure.

Poem

In kernels where exponents once slept,
The rabbit hopped in, and safe_exp was swept.
With masks and indices, so clever and neat,
Now boundaries are handled, no more defeat!
Tests with some randomness, gating anew—
A hop, a skip, and code that’s true.
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate Unit Tests
  • Create PR with Unit Tests
  • Post Copyable Unit Tests in Comment
  • Commit Unit Tests in branch safe_exp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
fla/ops/common/chunk_scaled_dot_kkt.py (1)

51-53: Minor: avoid extra work on tail chunks

m_t is already computed; wrapping the subsequent loads / dots with it (e.g. early‐exit when tl.all(~m_t)) would save registers on the very last chunk when T % BT != 0.

fla/ops/common/chunk_delta_h.py (1)

147-154: Broadcast clarity

tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] works but is easy to mis-read. Wrapping the mask:

mask = m_t[:, None]
b_v_new *= mask * exp(b_g_last - b_g)[:, None]

improves readability without extra cost.

tests/ops/test_gated_delta.py (1)

229-230: Flaky test danger

torch.rand_like without a fixed seed inside the test makes the outcome non-deterministic across runs. Capture the RNG state or feed a local torch.manual_seed before masking to keep CI deterministic.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2002703 and 5f41d58.

📒 Files selected for processing (5)
  • fla/ops/common/chunk_delta_h.py (3 hunks)
  • fla/ops/common/chunk_o.py (8 hunks)
  • fla/ops/common/chunk_scaled_dot_kkt.py (3 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (2 hunks)
  • tests/ops/test_gated_delta.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (3)
fla/ops/common/chunk_scaled_dot_kkt.py (1)

70-71: Mask is fine — just confirm diagonal exclusion

m_A uses the strict “>`” test, so the diagonal is zeroed. If self-interaction should be kept, switch to “>=”.

fla/ops/common/chunk_o.py (2)

98-104: Triangular mask direction changed — please double-check

Forward path now keeps (row >= col) whereas backward kernels use (row <= col). This is probably intentional due to transposed A usage, but a mismatch would silently corrupt gradients. Verify with a quick unit test on a tiny T.


242-255: Potential off-by-one in gradient accumulation

b_dg_last is added only to the last element via

b_dg = tl.where(o_t < last_idx, b_dg, b_dg + b_dg_last)

If T is an exact multiple of BT, last_idx equals BT-1; the last row is updated, but earlier rows in the same chunk should also include the suffix-sum of b_dg_last. Please cross-validate against the reference implementation.


from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import safe_exp
from fla.ops.utils.op import exp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Loss of overflow-protection after dropping safe_exp

safe_exp intentionally clipped extreme inputs; exp will overflow for large positive g differences, yielding inf and NaNs that propagate through the kernel. Please either (1) keep the numerically-stable helper, or (2) clamp the argument, e.g.

- b_A = b_A * exp(b_g_diff)
+ b_A = b_A * exp(tl.clip(b_g_diff, -20, 20))

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In fla/ops/common/chunk_scaled_dot_kkt.py at line 11, the import changed from a
numerically stable `safe_exp` to a plain `exp`, which can cause overflow for
large inputs. To fix this, either revert to importing and using the original
`safe_exp` function that clips extreme inputs or add explicit clamping of the
argument before calling `exp` to prevent overflow and NaNs in the kernel
calculations.


from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp, safe_exp
from fla.ops.utils.op import exp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Same overflow concern as above

Replacing safe_exp with bare exp risks inf. Consider clamping the argument or re-introducing the safe wrapper.

🤖 Prompt for AI Agents
In fla/ops/common/chunk_delta_h.py at line 11, the code imports and uses the
bare exp function instead of the previously used safe_exp, which can cause
overflow issues resulting in infinity values. To fix this, either revert to
using the safe_exp wrapper that handles large inputs safely or implement
argument clamping before calling exp to prevent overflow and ensure numerical
stability.

Comment on lines +99 to 106
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)

b_dA *= safe_exp(b_g[:, None] - b_g[None, :])
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
b_dA = b_dA.to(k.dtype.element_ty)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Numerical-stability & work done on invalid blocks

  1. Same overflow note: exp(b_g[:,None]-b_g[None,:]) can blow up.
  2. b_dA accumulates contributions for out-of-range rows/cols and is zeroed only afterwards. Guarding the dot updates with m_A would save compute & SRAM on tail chunks.
- b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
+ valid = tl.where(m_A, 1., 0.)
+ b_dA += tl.dot(b_dw*valid, tl.trans(b_k_beta_g))

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In fla/ops/gated_delta_rule/wy_fast.py around lines 99 to 106, the expression
exp(b_g[:, None] - b_g[None, :]) can cause numerical overflow and b_dA
accumulates values for out-of-range indices before being zeroed, wasting compute
and memory. To fix this, apply a numerical stability technique such as
subtracting the max value from b_g before exponentiation to prevent overflow.
Also, modify the dot product computations to only include valid indices guarded
by m_A, avoiding unnecessary calculations on invalid blocks and reducing
resource usage.

Comment thread fla/ops/common/chunk_o.py
Comment on lines +351 to 358
b_dv *= tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]

m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
b_A = tl.where(m_A, b_A * exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
else:
b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Again, raw exp may overflow

Every instance of exp(...) that used to be safe_exp needs the same clamp or stable formulation.

🤖 Prompt for AI Agents
In fla/ops/common/chunk_o.py around lines 351 to 358, the use of raw exp(...)
can cause overflow errors. Replace all exp calls with a stable version by
clamping the input values to a safe range before applying exp, similar to the
previous safe_exp usage. This prevents overflow and ensures numerical stability
in the calculations.

@yzhangcs yzhangcs merged commit 1970ef7 into main Jun 18, 2025
3 of 6 checks passed
@yzhangcs yzhangcs deleted the safe_exp branch June 18, 2025 09:37
yzhangcs added a commit that referenced this pull request Jun 18, 2025
@yzhangcs yzhangcs linked an issue Jun 18, 2025 that may be closed by this pull request
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] safe_exp causes incorrect results when log-decay is zero

1 participant