feat: add Quasar Attention and standalone model implementation#805
feat: add Quasar Attention and standalone model implementation#805troy12x wants to merge 4 commits intofla-org:mainfrom
Conversation
This PR introduces Quasar Attention, featuring significant kernel optimizations for modern GPU architectures. Key changes: - Standalone Quasar Attention Triton kernels (fla/ops/quasar) - QuasarAttention layer integration (fla/layers/quasar.py) - Full Quasar model suite (fla/models/quasar) - Benchmarks show Quasar outperforming KDA/GDA in throughput and memory efficiency.
WalkthroughAdds a full Quasar Attention feature: Triton forward/backward kernels (chunk, intra-chunk, fused recurrent, gate, WY recompute, forward-substitution), a QuasarAttention layer, HuggingFace-style Quasar model/config/for-causal-lm, a distributed-compat shim, and registers Quasar symbols in package exports. Changes
Sequence DiagramsequenceDiagram
participant Input as Input (hidden_states / input_ids)
participant Proj as Projections (q/k/v/(g), conv)
participant RoPE as RoPE
participant Norm as L2Norm
participant Gate as Quasar Gate (beta→alpha)
participant Kernel as Triton Kernels (chunk / intra / fused_recurrent)
participant Recomp as WY Recompute / ForwardSubstitution
participant Model as QuasarModel / QuasarBlock
participant Output as Output Projection / LM Head
Input->>Proj: embed & project to q,k,v,(g)
Proj->>RoPE: apply rotary embeddings (optional)
RoPE->>Norm: L2-normalize q,k
Norm->>Gate: compute beta → alpha
Gate->>Kernel: send alpha, q,k,v,g and cache/state
Kernel->>Recomp: request w/u, A, or forward-substitution (if needed)
Recomp->>Kernel: return intermediates (w,u,A,...)
Kernel->>Model: return attention output (padded back if unpadded)
Model->>Output: project, apply gated norm, compute logits
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces Quasar Attention, a high-performance linear attention variant, along with its associated Triton kernels, standalone layers, and HuggingFace-compatible model classes. The review identified several critical issues in the implementation of the custom kernels, including inconsistent alpha recomputation in the backward pass, missing gradients for learnable parameters like A_log and dt_bias, and the omission of state-passing gradients. Additionally, feedback was provided regarding hardcoded execution modes that override user preferences, redundant normalization steps, and duplicated code blocks within the Triton kernels.
| eps = 1e-6 | ||
| k_norm_sq = (k.float() * k.float()).sum(dim=-1) # [B, T, H] | ||
| k_norm_sq = torch.clamp(k_norm_sq, min=0.1, max=10.0) | ||
|
|
||
| if beta.dim() == 1: | ||
| beta_h = beta.view(1, 1, H).to(k_norm_sq.dtype) | ||
| else: | ||
| beta_h = beta.to(k_norm_sq.dtype) | ||
|
|
||
| beta_h = torch.clamp(beta_h, min=0.01, max=10.0) | ||
| # Compute alpha with numerical stability | ||
| exp_term = torch.exp(-beta_h * k_norm_sq) | ||
| alpha = (1.0 - exp_term) / (k_norm_sq + eps) | ||
| beta_tok = alpha.to(dtype=q.dtype) |
There was a problem hiding this comment.
The recomputation of alpha (and beta_tok) in the backward pass is inconsistent with the forward pass. In forward (lines 66-95), alpha is computed using A_log and dt_bias, and beta_tok is the mean of alpha across the key dimension. Here, the recomputation uses a simplified formula that ignores A_log and dt_bias, and misses the mean(dim=-1) reduction. This will lead to incorrect gradients during training.
|
|
||
| batch_size, q_len, _ = hidden_states.shape | ||
| # Force chunk mode to avoid fused_recurrent BT conflict | ||
| mode = "chunk" |
| g=None, | ||
| gk=None, | ||
| h0=initial_state_f32, | ||
| dht=None, |
| # Token-wise gradient doesn't need / T normalization if it's fed to linear layer | ||
| dbeta = torch.clamp(dbeta, min=-1.0, max=1.0) | ||
|
|
||
| return dq, dk, dv, dbeta, None, None, None, None, None |
There was a problem hiding this comment.
The backward pass is missing gradients for A_log and dt_bias. Additionally, the gradient for initial_state (dh0, computed at line 274) is not being returned, which will break state-passing training.
| return dq, dk, dv, dbeta, None, None, None, None, None | |
| return dq, dk, dv, dbeta, None, None, dh0, None, None |
| if torch.isnan(q).any(): pass # print("!!! NAN IN Q (before RoPE)", flush=True) | ||
| if torch.isnan(v).any(): pass # print("!!! NAN IN V", flush=True) |
| if USE_QK_L2NORM_IN_KERNEL: | ||
| b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) | ||
| b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) |
| if STORE_FINAL_STATE: | ||
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | ||
|
|
||
| if STORE_FINAL_STATE: | ||
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | ||
|
|
There was a problem hiding this comment.
This block of code is redundant as it is duplicated immediately after.
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | |
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | |
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) |
|
|
||
| # Backward pass: dL = -A^T @ dA @ A^T | ||
| # Simplified implementation for now | ||
| dL = torch.zeros_like(L) |
There was a problem hiding this comment.
Actionable comments posted: 10
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🟡 Minor comments (15)
fla/distributed_compat.py-8-8 (1)
8-8:⚠️ Potential issue | 🟡 MinorRemove unused
torchimport.The
torchmodule is imported but never used in this file. All imports are fromtorch.distributedsubmodules, which don't require the parent module to be explicitly imported.🧹 Proposed fix
-import torch - # DeviceMesh🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/distributed_compat.py` at line 8, Remove the unused top-level import of the torch module by deleting the solitary "import torch" statement; keep the existing torch.distributed submodule imports intact (they do not require the parent torch import) and run a quick search for any references to the plain "torch" symbol in this file to ensure no other code relies on it before committing.fla/ops/quasar/fused_recurrent.py-212-214 (1)
212-214:⚠️ Potential issue | 🟡 MinorBackward pass not implemented.
The backward pass raises
NotImplementedError. This meansfused_recurrent_quasarcannot be used during training with gradient computation. Ensure this limitation is documented, and consider adding a warning when the mode is used in a training context.Would you like me to help document this limitation or open an issue to track the backward implementation?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/fused_recurrent.py` around lines 212 - 214, The backward staticmethod in fused_recurrent_quasar currently raises NotImplementedError preventing training — add clear documentation and a runtime warning instead of just the exception: update the fused_recurrent_quasar class/docstring to state that the backward pass is unimplemented and the op is inference-only, and in the forward (or apply) entrypoint detect gradient-required contexts (e.g., torch.is_grad_enabled() or any input.requires_grad) and emit a warning via warnings.warn or the project logger explaining gradients are not supported and that training will fail; ensure the backward method still raises if actually called, and add a TODO/issue reference comment to track implementing fused_recurrent_quasar.backward.fla/ops/quasar/chunk_bwd.py-9-9 (1)
9-9:⚠️ Potential issue | 🟡 MinorRemove unused import.
IS_NVIDIA_BLACKWELLis imported but never used.🧹 Proposed fix
-from fla.utils import IS_NVIDIA_HOPPER, IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem +from fla.utils import IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` at line 9, The import line in chunk_bwd.py includes an unused symbol IS_NVIDIA_BLACKWELL; remove IS_NVIDIA_BLACKWELL from the from-import (leaving IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem) so the module no longer imports an unused name and linter warnings are resolved.fla/models/quasar/modeling_quasar.py-206-208 (1)
206-208:⚠️ Potential issue | 🟡 MinorAdd
stacklevel=2to warning for correct caller attribution.Without explicit stacklevel, the warning will point to this line rather than the caller's location.
🔧 Proposed fix
- warnings.warn("`QuasarModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`QuasarModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/models/quasar/modeling_quasar.py` around lines 206 - 208, The warning in QuasarModel's handling of output_attentions uses warnings.warn without a stacklevel, so update the call in modeling_quasar.py (the block that checks output_attentions in the QuasarModel code) to pass stacklevel=2 to warnings.warn so the warning points to the caller rather than this function; keep the existing message and behavior but add the stacklevel argument to the warnings.warn invocation.fla/ops/quasar/chunk_bwd.py-269-269 (1)
269-269:⚠️ Potential issue | 🟡 MinorUse explicit
| Noneinstead of implicitOptional.PEP 484 prohibits implicit
Optionalforscale: float = None.🔧 Proposed fix
- scale: float = None, + scale: float | None = None,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` at line 269, Update the function parameter annotation that currently reads "scale: float = None" to use an explicit nullable type by changing it to "scale: float | None = None" (or "Optional[float]" if you prefer typing imports) wherever the parameter appears in chunk_bwd.py; locate the signature that contains the symbol "scale" in the function/method that handles chunk backward logic and replace the implicit Optional pattern with the explicit "| None" union type so it complies with PEP 484.fla/layers/quasar.py-83-83 (1)
83-83:⚠️ Potential issue | 🟡 MinorRemove invalid return type annotation from
__init__.
__init__methods should not have return type annotations other than-> None. The current-> QuasarAttentionis invalid.🔧 Proposed fix
- ) -> QuasarAttention: + ) -> None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/quasar.py` at line 83, The __init__ of class QuasarAttention has an invalid return annotation (-> QuasarAttention); change the signature of QuasarAttention.__init__ to have no return type or explicitly -> None so it conforms to Python constructor typing (remove or replace the existing -> QuasarAttention annotation).fla/models/quasar/modeling_quasar.py-295-305 (1)
295-305:⚠️ Potential issue | 🟡 MinorUse exception chaining for proper traceback.
When re-raising a modified exception, use
raise ... from exceptionto preserve the original traceback context.🔧 Proposed fix
if "past_key_values" in str(exception): raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies", - ) + ) from exception else: raise exception🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/models/quasar/modeling_quasar.py` around lines 295 - 305, In the except AttributeError as exception block in modeling_quasar.py (the handler that checks if "past_key_values" is in str(exception)), change the re-raise to use exception chaining by raising the new AttributeError from the caught exception (i.e., use "raise AttributeError(... ) from exception") so the original traceback is preserved; keep the same message that references self.__class__.__name__ and the generation strategies doc URL.fla/layers/quasar.py-236-237 (1)
236-237:⚠️ Potential issue | 🟡 MinorRemove debug code that causes pipeline failures.
These debug statements with
passon the same line violate E701 and serve no purpose since they're commented out. The pipeline is failing on these lines.🧹 Proposed fix - remove debug lines
- if torch.isnan(q).any(): pass # print("!!! NAN IN Q (before RoPE)", flush=True) - if torch.isnan(v).any(): pass # print("!!! NAN IN V", flush=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/quasar.py` around lines 236 - 237, Remove the two debug if-statements that combine a conditional and a pass/comment on one line (the torch.isnan(q).any() and torch.isnan(v).any() checks) in quasar.py; these violate E701 and are unnecessary — delete those lines (or replace with a proper multiline check/logging if you need runtime validation) so only valid, non-commented statements remain around the q and v NaN checks.fla/layers/quasar.py-80-80 (1)
80-80:⚠️ Potential issue | 🟡 MinorUse explicit
| Noneinstead of implicitOptional.PEP 484 prohibits implicit
Optional. Thelayer_idxparameter should use explicit union syntax.🔧 Proposed fix
- layer_idx: int = None, + layer_idx: int | None = None,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/quasar.py` at line 80, The parameter annotation uses an implicit Optional by writing "layer_idx: int = None"; update the function/method signature to use explicit union syntax by changing the annotation to "layer_idx: int | None = None" (i.e., replace the implicit Optional with int | None for the layer_idx parameter) so it complies with PEP 484; locate the declaration of layer_idx in the function or constructor where it appears and apply this change.fla/layers/quasar.py-6-17 (1)
6-17:⚠️ Potential issue | 🟡 MinorRemove unused imports to fix pipeline failures.
The following imports are flagged as unused by Flake8 and are causing pipeline failures:
math(line 6)repeatfrom einops (line 11)RMSNormfrom fla.modules (line 15)fused_quasar_gatefrom fla.ops.quasar.gate (line 17)🧹 Proposed fix
-import math from typing import TYPE_CHECKING import torch import torch.nn as nn -from einops import rearrange, repeat +from einops import rearrange from torch.nn import functional as F from fla.layers.utils import get_unpad_data, index_first_axis, pad_input -from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules import FusedRMSNormGated, ShortConvolution from fla.ops.quasar import chunk_quasar, fused_recurrent_quasar -from fla.ops.quasar.gate import fused_quasar_gate🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/quasar.py` around lines 6 - 17, Remove the unused imports causing Flake8 failures: delete the top-level imports "math", "repeat" (from einops), "RMSNorm" (from fla.modules) and "fused_quasar_gate" (from fla.ops.quasar.gate) from the import block so only actually used symbols like torch, nn, rearrange, get_unpad_data, index_first_axis, pad_input, FusedRMSNormGated, ShortConvolution, chunk_quasar and fused_recurrent_quasar remain; ensure no other code references those removed names and run linter to confirm the pipeline passes.fla/ops/quasar/chunk_bwd.py-240-243 (1)
240-243:⚠️ Potential issue | 🟡 MinorRemove unused local variables flagged by pipeline.
Variables
m_k(line 190),b_q(line 241), andb_kdk(line 242) are assigned but never used, causing pipeline failures.🧹 Proposed fix
for i_k in range(tl.cdiv(K, BK)): - o_k = i_k * BK + tl.arange(0, BK) - m_k = o_k < K + # o_k and m_k removed - not used ... - p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_kdk = b_k * b_dk b_dk = b_dk + b_dkgb * b_beta[:, None]Note: Review if
o_kis actually needed elsewhere in the loop before removing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` around lines 240 - 243, Remove the unused local variables to fix the pipeline errors: delete or rename the unused assignment to m_k, and remove or replace the unused b_q and b_kdk assignments in the loop (the tl.load into b_q and the temporary b_kdk = b_k * b_dk); if the tl.load call must remain for side-effects, assign its result to _ instead of b_q. Also review use of o_k in the same loop to ensure it’s still needed before removing any related code.fla/ops/quasar/chunk_intra.py-392-392 (1)
392-392:⚠️ Potential issue | 🟡 MinorAvoid shadowing Python builtin
all.The variable name
allshadows the Python builtin function. Use a different name liketotal_tokensornum_elements.🔧 Proposed fix
- all = B * T + total_tokens = B * T if IS_VARLEN: ... ... - db += (i_k * all + bos) * H + i_h + db += (i_k * total_tokens + bos) * H + i_h🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_intra.py` at line 392, The variable name "all" in chunk_intra.py (the assignment "all = B * T") shadows the built-in all(); rename it to a non-conflicting identifier (e.g., total_tokens or num_elements) and update every reference in the same scope (and any return/formatting/logging that uses it) to the new name so behavior is unchanged but the builtin is no longer shadowed; ensure imports or other functions are not affected and run tests/static checks after renaming.fla/ops/quasar/chunk.py-5-13 (1)
5-13:⚠️ Potential issue | 🟡 MinorRemove unused imports.
Multiple imports are unused according to static analysis and cause linting failures:
triton(line 5)chunk_gla_fwd_o_gk(line 9)fused_quasar_gate,fast_quasar_alpha(line 11)autotune_cache_kwargs(line 12)chunk_bwd_dv_local,chunk_bwd_dqkwg(line 13)🧹 Proposed fix
import torch -import triton from fla.ops.utils.index import prepare_chunk_indices from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h -from fla.ops.gla.chunk import chunk_gla_fwd_o_gk from fla.ops.quasar.chunk_intra import chunk_quasar_fwd_intra -from fla.ops.quasar.gate import fused_quasar_gate, fast_quasar_alpha -from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard -from fla.ops.common.chunk_o import chunk_fwd_o, chunk_bwd_dv_local, chunk_bwd_dqkwg +from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard +from fla.ops.common.chunk_o import chunk_fwd_o🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk.py` around lines 5 - 13, Remove the unused imports causing lint errors: delete the triton import and the unused symbols chunk_gla_fwd_o_gk, fused_quasar_gate, fast_quasar_alpha, autotune_cache_kwargs, chunk_bwd_dv_local, and chunk_bwd_dqkwg from the import lists in this file; specifically update the top imports that currently reference triton, fla.ops.gla.chunk (chunk_gla_fwd_o_gk), fla.ops.quasar.gate (fused_quasar_gate, fast_quasar_alpha), fla.utils (autotune_cache_kwargs), and fla.ops.common.chunk_o (chunk_bwd_dv_local, chunk_bwd_dqkwg) so only actually used names remain, and run tests/linter to confirm no other references to those symbols exist.fla/ops/quasar/gate.py-5-5 (1)
5-5:⚠️ Potential issue | 🟡 MinorRemove the unused
torch.nn.functionalimport.Flake8 is already flagging Line 5, so
Fkeeps the lint job red until it is dropped.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/gate.py` at line 5, The import torch.nn.functional as F in gate.py is unused and causes a Flake8 lint failure; remove the unused import (the "F" symbol) from the top of the file so only necessary torch imports remain, e.g., delete the line containing "import torch.nn.functional as F" or replace it with needed imports used elsewhere in the module.fla/ops/quasar/forward_substitution.py-8-8 (1)
8-8:⚠️ Potential issue | 🟡 MinorRemove the unused
check_shared_memimport.Nothing in this file references it, so Flake8 will keep reporting F401 here.
✂️ Proposed fix
-from fla.utils import IS_AMD, autotune_cache_kwargs, check_shared_mem, input_guard +from fla.utils import IS_AMD, autotune_cache_kwargs, input_guard🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/forward_substitution.py` at line 8, Remove the unused import check_shared_mem from the module import list; update the from fla.utils import line in forward_substitution.py to only import the actually used symbols (e.g., IS_AMD, autotune_cache_kwargs, input_guard) so Flake8 F401 is resolved and no other references to check_shared_mem remain.
🧹 Nitpick comments (7)
fla/distributed_compat.py (1)
48-57: Consider sorting__all__alphabetically.Static analysis suggests applying alphabetical sorting to
__all__for consistency.📝 Proposed sort order
__all__ = [ 'DeviceMesh', 'DTensor', + 'HAS_DISTRIBUTED', + 'ParallelStyle', 'Placement', 'Replicate', 'Shard', 'distribute_module', - 'ParallelStyle', - 'HAS_DISTRIBUTED', ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/distributed_compat.py` around lines 48 - 57, The __all__ export list is unsorted; please reorder the entries in the __all__ list alphabetically for consistency—locate the __all__ definition and change the sequence of the symbols ('DeviceMesh', 'DTensor', 'Placement', 'Replicate', 'Shard', 'distribute_module', 'ParallelStyle', 'HAS_DISTRIBUTED') so they appear in ascending alphabetical order (by name) while preserving the exact identifier names and quotes.fla/models/quasar/configuration_quasar.py (1)
1-5: Minor: Remove leading blank lines.The file starts with two empty lines before the import statement. Consider removing them for consistency.
✨ Proposed fix
- - from transformers.configuration_utils import PretrainedConfig🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/models/quasar/configuration_quasar.py` around lines 1 - 5, The file begins with unnecessary leading blank lines before the import; remove the empty lines so the first non-blank line is the import statement (from transformers.configuration_utils import PretrainedConfig) to keep the file tidy and consistent.fla/ops/quasar/chunk.py (1)
104-117: Unused unpacked variables should be prefixed with underscore.Variables
qg,Aqk, andAkk(line 104) are unpacked but never used. The same applies in backward (line 226). Prefix with_to indicate intentionally unused.♻️ Proposed fix for forward
- w, u, qg, kg, Aqk, Akk = chunk_quasar_fwd_intra( + w, u, _qg, kg, _Aqk, _Akk = chunk_quasar_fwd_intra(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk.py` around lines 104 - 117, The unpacked outputs qg, Aqk, and Akk returned from chunk_quasar_fwd_intra (and the corresponding unused outputs in the backward function, e.g., from chunk_quasar_bwd_intra) are never used; rename them to _qg, _Aqk, and _Akk (and the backward equivalents) to mark them as intentionally unused and avoid linter warnings—update the unpack targets where chunk_quasar_fwd_intra and chunk_quasar_bwd_intra are called to use the underscore-prefixed names.fla/layers/quasar.py (1)
318-319: Usecontextlib.suppressinstead of baretry-except-pass.The pipeline is flagging this pattern. Using
contextlib.suppress(TypeError)is more Pythonic.♻️ Proposed refactor
+import contextlib ... - try: - past_key_values.update( - recurrent_state=recurrent_state, - conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, - layer_idx=self.layer_idx, - offset=q_len, - ) - except TypeError: - pass + with contextlib.suppress(TypeError): + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/quasar.py` around lines 318 - 319, Replace the bare try/except TypeError: pass block with contextlib.suppress(TypeError) to be more Pythonic: import contextlib at the top (if missing) and wrap the statement(s) that previously lived inside the try block with a with contextlib.suppress(TypeError): ... block; locate the existing try/except in the quasar-related function or method where the except TypeError: pass appears and convert that block to use the context manager (preserving the same inner statements and indentation).fla/ops/quasar/chunk_intra.py (1)
253-272: Forward substitution loops have hardcoded BC=16 assumption.The loop ranges
range(2, min(BC, ...)),range(BC + 2, min(2*BC, ...)), etc., assumeBC=16based on the 4 sub-chunks pattern. However,BCis a compile-time constant passed to the kernel. IfBCwere changed, the indexing math (i - BC,i - 2*BC,i - 3*BC) would break.Consider adding an assertion or documentation that
BCmust be 16 for this kernel.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_intra.py` around lines 253 - 272, The forward-substitution loops in chunk_intra.py assume a 4-subchunk layout by hardcoding offsets like i-BC, i-2*BC, i-3*BC and loop starts (e.g., range(2, min(BC, ...))) which will break if the compile-time constant BC is not 16; update the kernel to either enforce BC==16 or generalize the loops: add a runtime assertion at kernel entry (reference BC) or compute the number of subchunks = ceil(H*BC/BC) (or use a loop over subchunk_idx and replace the four explicit blocks (b_a00/b_Ai00, b_a11/b_Ai11, b_a22/b_Ai22, b_a33/b_Ai33) with a single parametric block that uses offset = subchunk_idx*BC and compares o_i < i - offset, and updates the corresponding b_Ai array by index; ensure any naming (b_a00, b_Ai00, etc.) is replaced by indexed containers or mapped by subchunk_idx to avoid hardcoded 4-way logic.fla/ops/quasar/gate.py (1)
16-39: Make the reference path match the kernel's compute precision.
quasar_gate_fwd_kernelpromotesbetaandlambda_tto fp32 beforeexp, butnaive_quasar_gate()does the whole computation in the incoming dtype and only casts at the end. On bf16/fp16 inputs that means the "reference" path can disagree with the Triton path for dtype reasons alone.♻️ Proposed alignment
def naive_quasar_gate( beta: torch.Tensor, lambda_t: torch.Tensor, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Torch reference implementation for QuasarAttention gate computation. @@ """ eps = 1e-8 - alpha = (1 - torch.exp(-beta.view(-1, 1) * lambda_t)) / (lambda_t + eps) + beta_f = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1).to(torch.float32) + lambda_f = lambda_t.to(torch.float32) + alpha = (1 - torch.exp(-beta_f * lambda_f)) / (lambda_f + eps) return alpha.to(output_dtype)Also applies to: 65-75
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/gate.py` around lines 16 - 39, The reference implementation naive_quasar_gate runs the math in the incoming dtype which diverges from quasar_gate_fwd_kernel that promotes inputs to fp32; to fix, cast beta and lambda_t to torch.float32 (e.g., beta.view(-1,1).to(torch.float32) and lambda_t.to(torch.float32)), perform the exp and division in fp32 (use an fp32 eps constant), then cast the resulting alpha back to output_dtype before returning; apply the same change to the corresponding reference backward function used around lines 65-75 so both forward and backward reference paths match the kernel's compute precision.fla/ops/quasar/forward_substitution.py (1)
10-10: The autotune configs bypassNUM_WARPS.Line 10 defines the architecture-specific warp candidates, but Lines 15-18 rebuild that list as
[2, 4, 8], so the AMD/CUDA split never actually applies and CUDA never tries the 16/32-warp cases. Please useNUM_WARPShere, or delete it, so this tuning surface has one source of truth.Based on learnings: Align threshold constants used by check_shared_mem to a single source of truth to avoid semantic drift, rather than duplicating literals.⚙️ Proposed fix
`@triton.autotune`( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] + for num_warps in NUM_WARPS for num_stages in [2, 3, 4] ],Also applies to: 13-20
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/forward_substitution.py` at line 10, The autotune candidate list currently defined by NUM_WARPS is being overwritten later in forward_substitution.py (the autotune configs that rebuild the list as [2,4,8]), so the AMD/CUDA split in NUM_WARPS is never used; change the autotune config code to reference NUM_WARPS instead of rebuilding a hardcoded list (or remove the duplicate list entirely) so there is a single source of truth, and likewise consolidate any threshold literals used by check_shared_mem into the same shared constant so check_shared_mem and the autotuner use the same value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/layers/quasar.py`:
- Around line 54-66: QuasarBlock is passing expand_v, num_v_heads, and
allow_neg_eigval into QuasarAttention.__init__ but QuasarAttention's signature
lacks these parameters and they are swallowed by **kwargs; either add explicit
parameters expand_v, num_v_heads, and allow_neg_eigval to
QuasarAttention.__init__ (and use them inside the class where appropriate) or
stop passing them from QuasarBlock and ensure the intended behavior (e.g.,
implement the missing logic or rename/forward the correct config fields). Update
the QuasarAttention class signature and internal usage (or the QuasarBlock
invocation) so the three symbols are consistently defined and handled rather
than silently ignored.
In `@fla/models/quasar/modeling_quasar.py`:
- Around line 54-66: QuasarBlock is passing expand_v, num_v_heads, and
allow_neg_eigval into QuasarAttention but those args are not in
QuasarAttention.__init__, so they get swallowed by **kwargs; fix by either (A)
adding explicit parameters expand_v, num_v_heads, allow_neg_eigval to
QuasarAttention.__init__ in fla/layers/quasar.py and wiring them into the
attention behavior, or (B) removing those three arguments from the QuasarBlock
instantiation (the self.attn = QuasarAttention(...) call) so only the supported
parameters (hidden_size, head_dim, num_heads, mode, use_short_conv, conv_size,
conv_bias, layer_idx, norm_eps) are passed; choose the option consistent with
intended functionality and update any internal uses accordingly.
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 278-283: The elif currently checks the function object instead of
calling it (elif check_shared_mem:) causing CONST_TILING to always become 64;
change that branch to call check_shared_mem with the device index (e.g., elif
check_shared_mem(k.device.index):) so it evaluates available shared memory
correctly and sets CONST_TILING to 64 only when the call returns true; keep the
first branch using check_shared_mem('hopper', k.device.index) unchanged.
- Around line 11-20: The safe_dot function currently uses NVIDIA-only PTX via
tl.inline_asm_elementwise (asm="mov.f32 ...") which will break on AMD; either
guard its use with the existing NVIDIA detection flags (e.g., check
IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL before calling safe_dot or raising a
clear error) or replace the inline asm with a device-agnostic implementation
(for example, return tl.dot(a, b) or an equivalent pure Triton expression) so
that safe_dot and callers (safe_dot) work on non‑NVIDIA devices; update
references to safe_dot and any call sites accordingly.
In `@fla/ops/quasar/chunk_intra_token_parallel.py`:
- Around line 109-110: The code unconditionally creates p_beta_out and calls
tl.store into beta_out even when beta_out can be None; wrap the creation of
p_beta_out and the tl.store call in the same USE_QUASAR_ALPHA guard (or check
beta_out is not None) so the pointer is only made and written when beta_out
exists; specifically, guard the block that references p_beta_out and b_alpha
(the tl.make_block_ptr and tl.store lines) with the USE_QUASAR_ALPHA condition
used earlier so no write occurs when beta_out is None.
In `@fla/ops/quasar/chunk.py`:
- Around line 203-217: The backward pass must reconstruct the exact forward
alpha formula using the saved A_log and dt_bias and compute gradients via the
chain rule; replace the simplified exp(-beta_h * k_norm_sq) path in the backward
method with the forward-equivalent steps: recover A = exp(A_log) and expand beta
to beta_expanded and dt_bias to dt_bias_full to match k_norm_sq shape, compute
g_quasar = -A * softplus(beta_expanded + dt_bias_full), set decay =
exp(g_quasar) and alpha = (1.0 - decay) / (k_norm_sq + eps) (same as forward),
then compute dalpha_dbeta using dalpha/dg * dg/dbeta = ( -decay /
(k_norm_sq+eps) ) * ( -A * sigmoid(beta_expanded + dt_bias_full) ) => A * decay
* sigmoid(...) / (k_norm_sq+eps) and use that in the existing beta gradient
calculations (replace the current dalpha_dbeta that uses k_norm_sq * exp_term /
(k_norm_sq + eps)); ensure you use the saved tensors A_log and dt_bias and the
same eps, softplus, and shapes as in forward (variables: A_log, dt_bias,
beta_h/beta_expanded, g_quasar, decay, alpha).
In `@fla/ops/quasar/forward_substitution.py`:
- Around line 107-119: The backward implementation for
quasar_forward_substitution currently returns zeros_like(L) which silently
blocks gradients; either implement the correct VJP using the saved L and A
(compute dL = -A.transpose(-2,-1) @ dA @ A.transpose(-2,-1) with appropriate
tensor shapes and returns) or, if this op is inference/recompute-only, replace
the zero return with an explicit error (e.g. raise RuntimeError) inside the
static backward(ctx, dA) of the custom autograd Function to prevent silent
gradient masking; locate the backward method where ctx.saved_tensors yields L
and A and update that method accordingly.
In `@fla/ops/quasar/fused_recurrent.py`:
- Around line 120-126: The code contains a duplicated STORE_FINAL_STATE block
that computes p_ht and calls tl.store twice; remove the redundant second block
to ensure final_state is stored only once. Locate the repeated conditional using
STORE_FINAL_STATE that computes p_ht from final_state, i_b, H, i_h, BK, o_k, o_v
and calls tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)), and delete the
duplicate so only a single STORE_FINAL_STATE branch performs the tl.store of
b_h.
- Around line 76-78: The kernel is L2-normalizing b_q and b_k (controlled by
USE_QK_L2NORM_IN_KERNEL) while QuasarAttention.forward already normalizes q and
k with F.normalize; to avoid double normalization, update
QuasarAttention.forward so that when calling fused_recurrent_quasar it either
does not pre-normalize (remove the F.normalize(q, p=2, dim=-1)/F.normalize(k,
p=2, dim=-1) calls) if you want the kernel to handle normalization, or pass
use_qk_l2norm_in_kernel=False to fused_recurrent_quasar so the kernel skips its
own normalization; adjust the call site in QuasarAttention.forward and ensure
consistency with the USE_QK_L2NORM_IN_KERNEL flag and any related parameters.
In `@fla/ops/quasar/gate.py`:
- Around line 125-145: QuasarGateFunction.backward currently returns four
gradients despite forward taking three inputs and computes dbeta incorrectly and
with a wrong reduction; update backward in QuasarGateFunction to (1) return
exactly three gradients to match forward's inputs, (2) compute dbeta =
exp(-beta.view(...)*lambda_t) * dalpha * (lambda_t / (lambda_t + eps)) (i.e.,
multiply the existing dbeta term by lambda_t/(lambda_t+eps)), and (3) replace
the hardcoded .sum(dim=(0,1)) with a reduction that sums over all dimensions of
dalpha except the beta dimension so the resulting dbeta matches beta.shape
(similarly sum dlambda over the non-lambda dimensions so dlambda matches
lambda_t.shape); use ctx.saved_tensors (lambda_t, beta) to determine which dims
to reduce.
---
Minor comments:
In `@fla/distributed_compat.py`:
- Line 8: Remove the unused top-level import of the torch module by deleting the
solitary "import torch" statement; keep the existing torch.distributed submodule
imports intact (they do not require the parent torch import) and run a quick
search for any references to the plain "torch" symbol in this file to ensure no
other code relies on it before committing.
In `@fla/layers/quasar.py`:
- Line 83: The __init__ of class QuasarAttention has an invalid return
annotation (-> QuasarAttention); change the signature of
QuasarAttention.__init__ to have no return type or explicitly -> None so it
conforms to Python constructor typing (remove or replace the existing ->
QuasarAttention annotation).
- Around line 236-237: Remove the two debug if-statements that combine a
conditional and a pass/comment on one line (the torch.isnan(q).any() and
torch.isnan(v).any() checks) in quasar.py; these violate E701 and are
unnecessary — delete those lines (or replace with a proper multiline
check/logging if you need runtime validation) so only valid, non-commented
statements remain around the q and v NaN checks.
- Line 80: The parameter annotation uses an implicit Optional by writing
"layer_idx: int = None"; update the function/method signature to use explicit
union syntax by changing the annotation to "layer_idx: int | None = None" (i.e.,
replace the implicit Optional with int | None for the layer_idx parameter) so it
complies with PEP 484; locate the declaration of layer_idx in the function or
constructor where it appears and apply this change.
- Around line 6-17: Remove the unused imports causing Flake8 failures: delete
the top-level imports "math", "repeat" (from einops), "RMSNorm" (from
fla.modules) and "fused_quasar_gate" (from fla.ops.quasar.gate) from the import
block so only actually used symbols like torch, nn, rearrange, get_unpad_data,
index_first_axis, pad_input, FusedRMSNormGated, ShortConvolution, chunk_quasar
and fused_recurrent_quasar remain; ensure no other code references those removed
names and run linter to confirm the pipeline passes.
In `@fla/models/quasar/modeling_quasar.py`:
- Around line 206-208: The warning in QuasarModel's handling of
output_attentions uses warnings.warn without a stacklevel, so update the call in
modeling_quasar.py (the block that checks output_attentions in the QuasarModel
code) to pass stacklevel=2 to warnings.warn so the warning points to the caller
rather than this function; keep the existing message and behavior but add the
stacklevel argument to the warnings.warn invocation.
- Around line 295-305: In the except AttributeError as exception block in
modeling_quasar.py (the handler that checks if "past_key_values" is in
str(exception)), change the re-raise to use exception chaining by raising the
new AttributeError from the caught exception (i.e., use "raise
AttributeError(... ) from exception") so the original traceback is preserved;
keep the same message that references self.__class__.__name__ and the generation
strategies doc URL.
In `@fla/ops/quasar/chunk_bwd.py`:
- Line 9: The import line in chunk_bwd.py includes an unused symbol
IS_NVIDIA_BLACKWELL; remove IS_NVIDIA_BLACKWELL from the from-import (leaving
IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem) so the module no
longer imports an unused name and linter warnings are resolved.
- Line 269: Update the function parameter annotation that currently reads
"scale: float = None" to use an explicit nullable type by changing it to "scale:
float | None = None" (or "Optional[float]" if you prefer typing imports)
wherever the parameter appears in chunk_bwd.py; locate the signature that
contains the symbol "scale" in the function/method that handles chunk backward
logic and replace the implicit Optional pattern with the explicit "| None" union
type so it complies with PEP 484.
- Around line 240-243: Remove the unused local variables to fix the pipeline
errors: delete or rename the unused assignment to m_k, and remove or replace the
unused b_q and b_kdk assignments in the loop (the tl.load into b_q and the
temporary b_kdk = b_k * b_dk); if the tl.load call must remain for side-effects,
assign its result to _ instead of b_q. Also review use of o_k in the same loop
to ensure it’s still needed before removing any related code.
In `@fla/ops/quasar/chunk_intra.py`:
- Line 392: The variable name "all" in chunk_intra.py (the assignment "all = B *
T") shadows the built-in all(); rename it to a non-conflicting identifier (e.g.,
total_tokens or num_elements) and update every reference in the same scope (and
any return/formatting/logging that uses it) to the new name so behavior is
unchanged but the builtin is no longer shadowed; ensure imports or other
functions are not affected and run tests/static checks after renaming.
In `@fla/ops/quasar/chunk.py`:
- Around line 5-13: Remove the unused imports causing lint errors: delete the
triton import and the unused symbols chunk_gla_fwd_o_gk, fused_quasar_gate,
fast_quasar_alpha, autotune_cache_kwargs, chunk_bwd_dv_local, and
chunk_bwd_dqkwg from the import lists in this file; specifically update the top
imports that currently reference triton, fla.ops.gla.chunk (chunk_gla_fwd_o_gk),
fla.ops.quasar.gate (fused_quasar_gate, fast_quasar_alpha), fla.utils
(autotune_cache_kwargs), and fla.ops.common.chunk_o (chunk_bwd_dv_local,
chunk_bwd_dqkwg) so only actually used names remain, and run tests/linter to
confirm no other references to those symbols exist.
In `@fla/ops/quasar/forward_substitution.py`:
- Line 8: Remove the unused import check_shared_mem from the module import list;
update the from fla.utils import line in forward_substitution.py to only import
the actually used symbols (e.g., IS_AMD, autotune_cache_kwargs, input_guard) so
Flake8 F401 is resolved and no other references to check_shared_mem remain.
In `@fla/ops/quasar/fused_recurrent.py`:
- Around line 212-214: The backward staticmethod in fused_recurrent_quasar
currently raises NotImplementedError preventing training — add clear
documentation and a runtime warning instead of just the exception: update the
fused_recurrent_quasar class/docstring to state that the backward pass is
unimplemented and the op is inference-only, and in the forward (or apply)
entrypoint detect gradient-required contexts (e.g., torch.is_grad_enabled() or
any input.requires_grad) and emit a warning via warnings.warn or the project
logger explaining gradients are not supported and that training will fail;
ensure the backward method still raises if actually called, and add a TODO/issue
reference comment to track implementing fused_recurrent_quasar.backward.
In `@fla/ops/quasar/gate.py`:
- Line 5: The import torch.nn.functional as F in gate.py is unused and causes a
Flake8 lint failure; remove the unused import (the "F" symbol) from the top of
the file so only necessary torch imports remain, e.g., delete the line
containing "import torch.nn.functional as F" or replace it with needed imports
used elsewhere in the module.
---
Nitpick comments:
In `@fla/distributed_compat.py`:
- Around line 48-57: The __all__ export list is unsorted; please reorder the
entries in the __all__ list alphabetically for consistency—locate the __all__
definition and change the sequence of the symbols ('DeviceMesh', 'DTensor',
'Placement', 'Replicate', 'Shard', 'distribute_module', 'ParallelStyle',
'HAS_DISTRIBUTED') so they appear in ascending alphabetical order (by name)
while preserving the exact identifier names and quotes.
In `@fla/layers/quasar.py`:
- Around line 318-319: Replace the bare try/except TypeError: pass block with
contextlib.suppress(TypeError) to be more Pythonic: import contextlib at the top
(if missing) and wrap the statement(s) that previously lived inside the try
block with a with contextlib.suppress(TypeError): ... block; locate the existing
try/except in the quasar-related function or method where the except TypeError:
pass appears and convert that block to use the context manager (preserving the
same inner statements and indentation).
In `@fla/models/quasar/configuration_quasar.py`:
- Around line 1-5: The file begins with unnecessary leading blank lines before
the import; remove the empty lines so the first non-blank line is the import
statement (from transformers.configuration_utils import PretrainedConfig) to
keep the file tidy and consistent.
In `@fla/ops/quasar/chunk_intra.py`:
- Around line 253-272: The forward-substitution loops in chunk_intra.py assume a
4-subchunk layout by hardcoding offsets like i-BC, i-2*BC, i-3*BC and loop
starts (e.g., range(2, min(BC, ...))) which will break if the compile-time
constant BC is not 16; update the kernel to either enforce BC==16 or generalize
the loops: add a runtime assertion at kernel entry (reference BC) or compute the
number of subchunks = ceil(H*BC/BC) (or use a loop over subchunk_idx and replace
the four explicit blocks (b_a00/b_Ai00, b_a11/b_Ai11, b_a22/b_Ai22,
b_a33/b_Ai33) with a single parametric block that uses offset = subchunk_idx*BC
and compares o_i < i - offset, and updates the corresponding b_Ai array by
index; ensure any naming (b_a00, b_Ai00, etc.) is replaced by indexed containers
or mapped by subchunk_idx to avoid hardcoded 4-way logic.
In `@fla/ops/quasar/chunk.py`:
- Around line 104-117: The unpacked outputs qg, Aqk, and Akk returned from
chunk_quasar_fwd_intra (and the corresponding unused outputs in the backward
function, e.g., from chunk_quasar_bwd_intra) are never used; rename them to _qg,
_Aqk, and _Akk (and the backward equivalents) to mark them as intentionally
unused and avoid linter warnings—update the unpack targets where
chunk_quasar_fwd_intra and chunk_quasar_bwd_intra are called to use the
underscore-prefixed names.
In `@fla/ops/quasar/forward_substitution.py`:
- Line 10: The autotune candidate list currently defined by NUM_WARPS is being
overwritten later in forward_substitution.py (the autotune configs that rebuild
the list as [2,4,8]), so the AMD/CUDA split in NUM_WARPS is never used; change
the autotune config code to reference NUM_WARPS instead of rebuilding a
hardcoded list (or remove the duplicate list entirely) so there is a single
source of truth, and likewise consolidate any threshold literals used by
check_shared_mem into the same shared constant so check_shared_mem and the
autotuner use the same value.
In `@fla/ops/quasar/gate.py`:
- Around line 16-39: The reference implementation naive_quasar_gate runs the
math in the incoming dtype which diverges from quasar_gate_fwd_kernel that
promotes inputs to fp32; to fix, cast beta and lambda_t to torch.float32 (e.g.,
beta.view(-1,1).to(torch.float32) and lambda_t.to(torch.float32)), perform the
exp and division in fp32 (use an fp32 eps constant), then cast the resulting
alpha back to output_dtype before returning; apply the same change to the
corresponding reference backward function used around lines 65-75 so both
forward and backward reference paths match the kernel's compute precision.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 3030bcab-3065-48c1-b724-442195645a70
📒 Files selected for processing (17)
PR_DESCRIPTION.mdfla/distributed_compat.pyfla/layers/__init__.pyfla/layers/quasar.pyfla/models/__init__.pyfla/models/quasar/__init__.pyfla/models/quasar/configuration_quasar.pyfla/models/quasar/modeling_quasar.pyfla/ops/quasar/__init__.pyfla/ops/quasar/chunk.pyfla/ops/quasar/chunk_bwd.pyfla/ops/quasar/chunk_intra.pyfla/ops/quasar/chunk_intra_token_parallel.pyfla/ops/quasar/forward_substitution.pyfla/ops/quasar/fused_recurrent.pyfla/ops/quasar/gate.pyfla/ops/quasar/wy_fast.py
| The number of heads. Default: 16. | ||
| mode (str, Optional): | ||
| Which QuasarAttention kernel to use. | ||
| Currently available: `chunk` and `fused_recurrent`. | ||
| Default: `chunk`. | ||
| use_short_conv (bool, Optional): | ||
| Whether to use short convolutions. Default: `True`. | ||
| conv_size (int, Optional): | ||
| The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. | ||
| conv_bias (bool, Optional): | ||
| Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. | ||
| layer_idx (int, Optional): | ||
| The index of the layer. Default: None. |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check QuasarAttention init signature
echo "=== QuasarAttention __init__ parameters ==="
rg -n "def __init__" -A 20 fla/layers/quasar.py | head -30
echo ""
echo "=== QuasarBlock instantiation of QuasarAttention ==="
rg -n "QuasarAttention\(" -A 15 fla/models/quasar/modeling_quasar.pyRepository: fla-org/flash-linear-attention
Length of output: 1585
🏁 Script executed:
# Check if expand_v, num_v_heads, allow_neg_eigval are used in QuasarAttention
echo "=== Checking usage of expand_v in QuasarAttention ==="
rg "expand_v" fla/layers/quasar.py
echo ""
echo "=== Checking usage of num_v_heads in QuasarAttention ==="
rg "num_v_heads" fla/layers/quasar.py
echo ""
echo "=== Checking usage of allow_neg_eigval in QuasarAttention ==="
rg "allow_neg_eigval" fla/layers/quasar.py
echo ""
echo "=== Full QuasarAttention class definition ==="
rg -n "class QuasarAttention" -A 150 fla/layers/quasar.py | head -160Repository: fla-org/flash-linear-attention
Length of output: 7280
Parameter mismatch between QuasarBlock and QuasarAttention.
QuasarBlock passes expand_v, num_v_heads, and allow_neg_eigval to QuasarAttention.__init__, but these parameters are not defined in the signature and never used within the class. They are silently swallowed by **kwargs, indicating either incomplete implementation or incorrect configuration propagation from the model to the layer.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/quasar.py` around lines 54 - 66, QuasarBlock is passing expand_v,
num_v_heads, and allow_neg_eigval into QuasarAttention.__init__ but
QuasarAttention's signature lacks these parameters and they are swallowed by
**kwargs; either add explicit parameters expand_v, num_v_heads, and
allow_neg_eigval to QuasarAttention.__init__ (and use them inside the class
where appropriate) or stop passing them from QuasarBlock and ensure the intended
behavior (e.g., implement the missing logic or rename/forward the correct config
fields). Update the QuasarAttention class signature and internal usage (or the
QuasarBlock invocation) so the three symbols are consistently defined and
handled rather than silently ignored.
| self.attn = QuasarAttention( | ||
| mode=config.attn_mode, | ||
| hidden_size=config.hidden_size, | ||
| expand_v=config.expand_v, | ||
| head_dim=config.head_dim, | ||
| num_heads=config.num_heads, | ||
| num_v_heads=config.num_v_heads, | ||
| use_short_conv=config.use_short_conv, | ||
| allow_neg_eigval=config.allow_neg_eigval, | ||
| conv_size=config.conv_size, | ||
| norm_eps=config.norm_eps, | ||
| layer_idx=layer_idx, | ||
| ) |
There was a problem hiding this comment.
QuasarAttention instantiated with undefined parameters.
QuasarBlock passes expand_v, num_v_heads, and allow_neg_eigval to QuasarAttention, but checking fla/layers/quasar.py, these parameters are not in the __init__ signature. They'll be captured by **kwargs and silently ignored.
Either add these parameters to QuasarAttention.__init__ or remove them from the instantiation here.
🔍 Expected QuasarAttention parameters vs. passed
QuasarAttention accepts:
hidden_size,head_dim,num_heads,mode,use_short_conv,conv_size,conv_bias,layer_idx,norm_eps
QuasarBlock passes (extra):
expand_v(line 57)num_v_heads(line 60)allow_neg_eigval(line 62)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/models/quasar/modeling_quasar.py` around lines 54 - 66, QuasarBlock is
passing expand_v, num_v_heads, and allow_neg_eigval into QuasarAttention but
those args are not in QuasarAttention.__init__, so they get swallowed by
**kwargs; fix by either (A) adding explicit parameters expand_v, num_v_heads,
allow_neg_eigval to QuasarAttention.__init__ in fla/layers/quasar.py and wiring
them into the attention behavior, or (B) removing those three arguments from the
QuasarBlock instantiation (the self.attn = QuasarAttention(...) call) so only
the supported parameters (hidden_size, head_dim, num_heads, mode,
use_short_conv, conv_size, conv_bias, layer_idx, norm_eps) are passed; choose
the option consistent with intended functionality and update any internal uses
accordingly.
| @triton.jit | ||
| def safe_dot(a, b): | ||
| return tl.inline_asm_elementwise( | ||
| asm="mov.f32 $0, $1;", | ||
| constraints="=r,r", | ||
| args=[tl.dot(a, b)], | ||
| dtype=tl.float32, | ||
| is_pure=True, | ||
| pack=1, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if IS_AMD is used to guard this function
rg -n "IS_AMD|safe_dot" fla/ops/quasar/chunk_bwd.pyRepository: fla-org/flash-linear-attention
Length of output: 188
🏁 Script executed:
# Read the imports and first 50 lines to see if IS_AMD is imported
head -50 fla/ops/quasar/chunk_bwd.pyRepository: fla-org/flash-linear-attention
Length of output: 1276
🏁 Script executed:
# Check if IS_AMD is used anywhere in the module, including call sites
rg -n "IS_AMD" fla/ops/quasar/chunk_bwd.pyRepository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
# Look for any conditional logic around safe_dot calls
rg -B5 -A5 "safe_dot" fla/ops/quasar/chunk_bwd.py | head -60Repository: fla-org/flash-linear-attention
Length of output: 1107
safe_dot uses unguarded NVIDIA-specific inline assembly.
The tl.inline_asm_elementwise with mov.f32 PTX instruction will fail on AMD GPUs. The function has no guards despite the module importing other NVIDIA-specific checks (IS_NVIDIA_HOPPER, IS_NVIDIA_BLACKWELL). Either guard this function against AMD usage or use a device-agnostic approach.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, The safe_dot function
currently uses NVIDIA-only PTX via tl.inline_asm_elementwise (asm="mov.f32 ...")
which will break on AMD; either guard its use with the existing NVIDIA detection
flags (e.g., check IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL before calling
safe_dot or raising a clear error) or replace the inline asm with a
device-agnostic implementation (for example, return tl.dot(a, b) or an
equivalent pure Triton expression) so that safe_dot and callers (safe_dot) work
on non‑NVIDIA devices; update references to safe_dot and any call sites
accordingly.
| if check_shared_mem('hopper', k.device.index): | ||
| CONST_TILING = 128 | ||
| elif check_shared_mem: | ||
| CONST_TILING = 64 | ||
| else: | ||
| CONST_TILING = 32 |
There was a problem hiding this comment.
Bug: Missing parentheses in check_shared_mem call.
Line 280 uses elif check_shared_mem: which checks if the function object is truthy (always true), not the result of calling it. This causes CONST_TILING = 64 to always be selected on non-Hopper hardware, even when shared memory is insufficient.
🐛 Proposed fix
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
- elif check_shared_mem:
+ elif check_shared_mem(device_index=k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if check_shared_mem('hopper', k.device.index): | |
| CONST_TILING = 128 | |
| elif check_shared_mem: | |
| CONST_TILING = 64 | |
| else: | |
| CONST_TILING = 32 | |
| if check_shared_mem('hopper', k.device.index): | |
| CONST_TILING = 128 | |
| elif check_shared_mem(device_index=k.device.index): | |
| CONST_TILING = 64 | |
| else: | |
| CONST_TILING = 32 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/chunk_bwd.py` around lines 278 - 283, The elif currently
checks the function object instead of calling it (elif check_shared_mem:)
causing CONST_TILING to always become 64; change that branch to call
check_shared_mem with the device index (e.g., elif
check_shared_mem(k.device.index):) so it evaluates available shared memory
correctly and sets CONST_TILING to 64 only when the call returns true; keep the
first branch using check_shared_mem('hopper', k.device.index) unchanged.
| p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) | ||
| tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,)) |
There was a problem hiding this comment.
Unconditional write to beta_out when it may be None.
Lines 109-110 create a block pointer to beta_out and store to it unconditionally, but the USE_QUASAR_ALPHA guard is only applied to the pointer offset adjustment at lines 84-85. When beta_out is None, this will cause a runtime error.
🐛 Proposed fix: Guard the store with USE_QUASAR_ALPHA
- p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
- tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))
+ if USE_QUASAR_ALPHA:
+ p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,))
+ tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) | |
| tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,)) | |
| if USE_QUASAR_ALPHA: | |
| p_beta_out = tl.make_block_ptr(beta_out + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) | |
| tl.store(p_beta_out, b_alpha.to(beta_out.dtype.element_ty), boundary_check=(0,)) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/chunk_intra_token_parallel.py` around lines 109 - 110, The
code unconditionally creates p_beta_out and calls tl.store into beta_out even
when beta_out can be None; wrap the creation of p_beta_out and the tl.store call
in the same USE_QUASAR_ALPHA guard (or check beta_out is not None) so the
pointer is only made and written when beta_out exists; specifically, guard the
block that references p_beta_out and b_alpha (the tl.make_block_ptr and tl.store
lines) with the USE_QUASAR_ALPHA condition used earlier so no write occurs when
beta_out is None.
| # Recompute alpha | ||
| eps = 1e-6 | ||
| k_norm_sq = (k.float() * k.float()).sum(dim=-1) # [B, T, H] | ||
| k_norm_sq = torch.clamp(k_norm_sq, min=0.1, max=10.0) | ||
|
|
||
| if beta.dim() == 1: | ||
| beta_h = beta.view(1, 1, H).to(k_norm_sq.dtype) | ||
| else: | ||
| beta_h = beta.to(k_norm_sq.dtype) | ||
|
|
||
| beta_h = torch.clamp(beta_h, min=0.01, max=10.0) | ||
| # Compute alpha with numerical stability | ||
| exp_term = torch.exp(-beta_h * k_norm_sq) | ||
| alpha = (1.0 - exp_term) / (k_norm_sq + eps) | ||
| beta_tok = alpha.to(dtype=q.dtype) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Compare forward and backward alpha formulas
echo "=== Forward alpha computation ==="
rg -n "alpha|decay|softplus" fla/ops/quasar/chunk.py | head -20
echo ""
echo "=== Backward alpha computation ==="
rg -n "exp_term|alpha" fla/ops/quasar/chunk.py | tail -20Repository: fla-org/flash-linear-attention
Length of output: 2600
🏁 Script executed:
# Read the forward pass around lines 84-95
sed -n '80,100p' fla/ops/quasar/chunk.pyRepository: fla-org/flash-linear-attention
Length of output: 998
🏁 Script executed:
# Read lines around 187 to see what is saved
sed -n '180,195p' fla/ops/quasar/chunk.pyRepository: fla-org/flash-linear-attention
Length of output: 586
🏁 Script executed:
# Read the full backward alpha recomputation section
sed -n '200,240p' fla/ops/quasar/chunk.pyRepository: fla-org/flash-linear-attention
Length of output: 1421
🏁 Script executed:
# Search for usage of A_log and dt_bias in the entire file
rg -n "A_log|dt_bias" fla/ops/quasar/chunk.pyRepository: fla-org/flash-linear-attention
Length of output: 1514
🏁 Script executed:
# Get the full backward method to see if A_log and dt_bias are used
sed -n '195,350p' fla/ops/quasar/chunk.pyRepository: fla-org/flash-linear-attention
Length of output: 5276
🏁 Script executed:
# Search for any usage of A_log or dt_bias after line 197 in the backward section
sed -n '197,400p' fla/ops/quasar/chunk.py | grep -n "A_log\|dt_bias"Repository: fla-org/flash-linear-attention
Length of output: 532
Backward alpha recomputation uses different formula than forward, and saved parameters are unused.
The forward pass (lines 84-91) computes alpha using:
g_quasar = -A * softplus(beta_expanded + dt_bias_full)
decay = exp(g_quasar)
alpha = (1.0 - decay) / (k_norm_sq_expanded + 1e-6)But the backward pass (lines 215-216) uses a simplified formula that omits A_log and dt_bias:
exp_term = exp(-beta_h * k_norm_sq)
alpha = (1.0 - exp_term) / (k_norm_sq + eps)Additionally, A_log and dt_bias are saved at line 187 and unpacked at line 197, but never used anywhere in the backward method. The gradient computation for beta (lines 326, 333) uses dalpha_dbeta = k_norm_sq * exp_term / (k_norm_sq + eps), which derives from the simplified formula, not the actual forward formula. This causes the computed gradients to be based on a different alpha formula than what was actually used in the forward pass.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/chunk.py` around lines 203 - 217, The backward pass must
reconstruct the exact forward alpha formula using the saved A_log and dt_bias
and compute gradients via the chain rule; replace the simplified exp(-beta_h *
k_norm_sq) path in the backward method with the forward-equivalent steps:
recover A = exp(A_log) and expand beta to beta_expanded and dt_bias to
dt_bias_full to match k_norm_sq shape, compute g_quasar = -A *
softplus(beta_expanded + dt_bias_full), set decay = exp(g_quasar) and alpha =
(1.0 - decay) / (k_norm_sq + eps) (same as forward), then compute dalpha_dbeta
using dalpha/dg * dg/dbeta = ( -decay / (k_norm_sq+eps) ) * ( -A *
sigmoid(beta_expanded + dt_bias_full) ) => A * decay * sigmoid(...) /
(k_norm_sq+eps) and use that in the existing beta gradient calculations (replace
the current dalpha_dbeta that uses k_norm_sq * exp_term / (k_norm_sq + eps));
ensure you use the saved tensors A_log and dt_bias and the same eps, softplus,
and shapes as in forward (variables: A_log, dt_bias, beta_h/beta_expanded,
g_quasar, decay, alpha).
| ctx.save_for_backward(L, A) | ||
| return A | ||
|
|
||
| @staticmethod | ||
| @input_guard | ||
| def backward(ctx, dA): | ||
| L, A = ctx.saved_tensors | ||
|
|
||
| # Backward pass: dL = -A^T @ dA @ A^T | ||
| # Simplified implementation for now | ||
| dL = torch.zeros_like(L) | ||
|
|
||
| return dL |
There was a problem hiding this comment.
Please don't ship a zero gradient from this custom backward.
The comment on Line 115 already states the right VJP form, but the implementation returns zeros_like(L). That makes quasar_forward_substitution() look differentiable while silently blocking every upstream update that flows through L. If this helper is inference/recompute-only for now, raising from backward() is much safer than returning a valid-looking zero tensor.
🧮 Proposed fix
def forward(
ctx,
L: torch.Tensor,
):
A = forward_substitution(L)
- ctx.save_for_backward(L, A)
+ ctx.save_for_backward(A)
return A
@@
def backward(ctx, dA):
- L, A = ctx.saved_tensors
-
- # Backward pass: dL = -A^T @ dA @ A^T
- # Simplified implementation for now
- dL = torch.zeros_like(L)
-
+ (A,) = ctx.saved_tensors
+ gA = torch.tril(dA, diagonal=-1)
+ dL = -(A.transpose(-1, -2) @ gA @ A.transpose(-1, -2))
+ dL = torch.tril(dL, diagonal=-1)
return dL🧰 Tools
🪛 Ruff (0.15.7)
[warning] 113-113: Unpacked variable A is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/forward_substitution.py` around lines 107 - 119, The backward
implementation for quasar_forward_substitution currently returns zeros_like(L)
which silently blocks gradients; either implement the correct VJP using the
saved L and A (compute dL = -A.transpose(-2,-1) @ dA @ A.transpose(-2,-1) with
appropriate tensor shapes and returns) or, if this op is
inference/recompute-only, replace the zero return with an explicit error (e.g.
raise RuntimeError) inside the static backward(ctx, dA) of the custom autograd
Function to prevent silent gradient masking; locate the backward method where
ctx.saved_tensors yields L and A and update that method accordingly.
| if USE_QK_L2NORM_IN_KERNEL: | ||
| b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) | ||
| b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check if QuasarAttention normalizes q/k before calling fused_recurrent_quasar
# Search for F.normalize calls followed by fused_recurrent_quasar with use_qk_l2norm_in_kernel=True
rg -n -A 30 'F\.normalize.*p=2' --type py | rg -A 20 'fused_recurrent_quasar'Repository: fla-org/flash-linear-attention
Length of output: 590
🏁 Script executed:
# Get line count and read the relevant section in quasar.py
wc -l fla/layers/quasar.pyRepository: fla-org/flash-linear-attention
Length of output: 98
🏁 Script executed:
# Read the QuasarAttention.forward method around lines 265-295
sed -n '260,300p' fla/layers/quasar.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 1869
🏁 Script executed:
# Read the kernel normalization code in fused_recurrent.py
sed -n '70,85p' fla/ops/quasar/fused_recurrent.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 753
🏁 Script executed:
# Search for the use_qk_l2norm_in_kernel parameter being passed
rg -n 'use_qk_l2norm_in_kernel' fla/layers/quasar.pyRepository: fla-org/flash-linear-attention
Length of output: 173
Remove redundant L2-normalization in kernel when vectors are already normalized.
The code normalizes q and k with F.normalize(q, p=2, dim=-1) and F.normalize(k, p=2, dim=-1) (lines 264-266), then passes them to fused_recurrent_quasar with use_qk_l2norm_in_kernel=True. The kernel then applies L2-normalization again:
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
This results in double normalization. Either remove the pre-normalization in QuasarAttention.forward when using kernel normalization, or pass use_qk_l2norm_in_kernel=False to avoid the redundant kernel normalization.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/fused_recurrent.py` around lines 76 - 78, The kernel is
L2-normalizing b_q and b_k (controlled by USE_QK_L2NORM_IN_KERNEL) while
QuasarAttention.forward already normalizes q and k with F.normalize; to avoid
double normalization, update QuasarAttention.forward so that when calling
fused_recurrent_quasar it either does not pre-normalize (remove the
F.normalize(q, p=2, dim=-1)/F.normalize(k, p=2, dim=-1) calls) if you want the
kernel to handle normalization, or pass use_qk_l2norm_in_kernel=False to
fused_recurrent_quasar so the kernel skips its own normalization; adjust the
call site in QuasarAttention.forward and ensure consistency with the
USE_QK_L2NORM_IN_KERNEL flag and any related parameters.
| if STORE_FINAL_STATE: | ||
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | ||
|
|
||
| if STORE_FINAL_STATE: | ||
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) |
There was a problem hiding this comment.
Duplicate STORE_FINAL_STATE block.
The final state storage logic is duplicated at lines 120-122 and 124-126. This causes the final state to be written twice unnecessarily.
🐛 Proposed fix: Remove duplicate block
if STORE_FINAL_STATE:
p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))
-
- if STORE_FINAL_STATE:
- p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :]
- tl.store(p_ht, b_h.to(p_ht.dtype.element_ty))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | |
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) | |
| if STORE_FINAL_STATE: | |
| p_ht = final_state + (i_b * H + i_h) * BK * BK + o_k[:, None] * BK + o_v[None, :] | |
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty)) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/fused_recurrent.py` around lines 120 - 126, The code contains
a duplicated STORE_FINAL_STATE block that computes p_ht and calls tl.store
twice; remove the redundant second block to ensure final_state is stored only
once. Locate the repeated conditional using STORE_FINAL_STATE that computes p_ht
from final_state, i_b, H, i_h, BK, o_k, o_v and calls tl.store(p_ht,
b_h.to(p_ht.dtype.element_ty)), and delete the duplicate so only a single
STORE_FINAL_STATE branch performs the tl.store of b_h.
| def backward(ctx, dalpha: torch.Tensor): | ||
| lambda_t, beta = ctx.saved_tensors | ||
| eps = 1e-8 | ||
|
|
||
| # dalpha/dlambda and dalpha/dbeta derivatives | ||
| # alpha = (1 - exp(-beta * lambda)) / (lambda + eps) | ||
| # dalpha/dbeta = exp(-beta * lambda) | ||
| beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t) | ||
| lambda_plus_eps = lambda_t + eps | ||
|
|
||
| # dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2 | ||
| dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2) | ||
|
|
||
| # dalpha/dbeta = exp(-beta * lambda) | ||
| dbeta = beta_exp | ||
|
|
||
| dlambda = dlambda * dalpha | ||
| # Sum over sequence and dimensions, but preserve head dimension | ||
| dbeta = (dbeta * dalpha).sum(dim=(0, 1)) | ||
|
|
||
| return dlambda, dbeta, None, None |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the custom autograd forward/backward contract in fla/ops/quasar/gate.py.
python - <<'PY'
import ast
from pathlib import Path
path = Path("fla/ops/quasar/gate.py")
tree = ast.parse(path.read_text())
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == "QuasarGateFunction":
methods = {m.name: m for m in node.body if isinstance(m, ast.FunctionDef)}
fwd = methods["forward"]
bwd = methods["backward"]
forward_inputs = [arg.arg for arg in fwd.args.args[1:]] # exclude ctx
print("forward inputs:", forward_inputs)
for stmt in bwd.body:
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Tuple):
print("backward return count:", len(stmt.value.elts))
print("backward return exprs:", [ast.unparse(e) for e in stmt.value.elts])
PYRepository: fla-org/flash-linear-attention
Length of output: 211
🏁 Script executed:
cat -n fla/ops/quasar/gate.py | head -150Repository: fla-org/flash-linear-attention
Length of output: 5423
Fix QuasarGateFunction.backward() before using in training.
forward() accepts three inputs, but backward() returns four gradients on line 145, violating torch.autograd.Function's contract. Additionally, dbeta omits the required lambda_t / (lambda_t + eps) factor in its derivative, and the hardcoded .sum(dim=(0, 1)) reduction incorrectly assumes 3D tensors and leaves a singleton dimension instead of matching beta.shape.
🐛 Proposed fix
def backward(ctx, dalpha: torch.Tensor):
lambda_t, beta = ctx.saved_tensors
eps = 1e-8
+ beta_view = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1)
- # dalpha/dlambda and dalpha/dbeta derivatives
- # alpha = (1 - exp(-beta * lambda)) / (lambda + eps)
- # dalpha/dbeta = exp(-beta * lambda)
- beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t)
+ beta_exp = torch.exp(-beta_view * lambda_t)
lambda_plus_eps = lambda_t + eps
- # dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2
- dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
-
- # dalpha/dbeta = exp(-beta * lambda)
- dbeta = beta_exp
-
+ dlambda = (beta_view * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2)
dlambda = dlambda * dalpha
- # Sum over sequence and dimensions, but preserve head dimension
- dbeta = (dbeta * dalpha).sum(dim=(0, 1))
-
- return dlambda, dbeta, None, None
+ dbeta = (lambda_t * beta_exp / lambda_plus_eps) * dalpha
+ reduce_dims = tuple(i for i in range(dbeta.ndim) if i != dbeta.ndim - 2)
+ dbeta = dbeta.sum(dim=reduce_dims)
+
+ return dlambda, dbeta, None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def backward(ctx, dalpha: torch.Tensor): | |
| lambda_t, beta = ctx.saved_tensors | |
| eps = 1e-8 | |
| # dalpha/dlambda and dalpha/dbeta derivatives | |
| # alpha = (1 - exp(-beta * lambda)) / (lambda + eps) | |
| # dalpha/dbeta = exp(-beta * lambda) | |
| beta_exp = torch.exp(-beta.view(-1, 1) * lambda_t) | |
| lambda_plus_eps = lambda_t + eps | |
| # dalpha/dlambda = (beta * exp(-beta * lambda) * lambda - (1 - exp(-beta * lambda))) / lambda^2 | |
| dlambda = (beta.view(-1, 1) * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2) | |
| # dalpha/dbeta = exp(-beta * lambda) | |
| dbeta = beta_exp | |
| dlambda = dlambda * dalpha | |
| # Sum over sequence and dimensions, but preserve head dimension | |
| dbeta = (dbeta * dalpha).sum(dim=(0, 1)) | |
| return dlambda, dbeta, None, None | |
| def backward(ctx, dalpha: torch.Tensor): | |
| lambda_t, beta = ctx.saved_tensors | |
| eps = 1e-8 | |
| beta_view = beta.reshape(*([1] * (lambda_t.ndim - 2)), -1, 1) | |
| beta_exp = torch.exp(-beta_view * lambda_t) | |
| lambda_plus_eps = lambda_t + eps | |
| dlambda = (beta_view * beta_exp * lambda_plus_eps - (1 - beta_exp)) / (lambda_plus_eps ** 2) | |
| dlambda = dlambda * dalpha | |
| dbeta = (lambda_t * beta_exp / lambda_plus_eps) * dalpha | |
| reduce_dims = tuple(i for i in range(dbeta.ndim) if i != dbeta.ndim - 2) | |
| dbeta = dbeta.sum(dim=reduce_dims) | |
| return dlambda, dbeta, None |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/gate.py` around lines 125 - 145, QuasarGateFunction.backward
currently returns four gradients despite forward taking three inputs and
computes dbeta incorrectly and with a wrong reduction; update backward in
QuasarGateFunction to (1) return exactly three gradients to match forward's
inputs, (2) compute dbeta = exp(-beta.view(...)*lambda_t) * dalpha * (lambda_t /
(lambda_t + eps)) (i.e., multiply the existing dbeta term by
lambda_t/(lambda_t+eps)), and (3) replace the hardcoded .sum(dim=(0,1)) with a
reduction that sums over all dimensions of dalpha except the beta dimension so
the resulting dbeta matches beta.shape (similarly sum dlambda over the
non-lambda dimensions so dlambda matches lambda_t.shape); use ctx.saved_tensors
(lambda_t, beta) to determine which dims to reduce.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
fla/ops/quasar/chunk_bwd.py (2)
11-20:⚠️ Potential issue | 🟠 MajorGuard
safe_dotfor non-NVIDIA backends.Line 13 uses NVIDIA PTX inline asm without a backend fallback. This can break AMD execution paths.
🐛 Proposed fix
`@triton.jit` def safe_dot(a, b): - return tl.inline_asm_elementwise( - asm="mov.f32 $0, $1;", - constraints="=r,r", - args=[tl.dot(a, b)], - dtype=tl.float32, - is_pure=True, - pack=1, - ) + if IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL: + return tl.inline_asm_elementwise( + asm="mov.f32 $0, $1;", + constraints="=r,r", + args=[tl.dot(a, b)], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + return tl.dot(a, b)#!/bin/bash # Verify NVIDIA-specific inline PTX is present and currently unguarded in safe_dot. rg -n -C2 'def safe_dot|inline_asm_elementwise|mov\.f32|IS_NVIDIA_HOPPER|IS_NVIDIA_BLACKWELL|IS_AMD' fla/ops/quasar/chunk_bwd.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, The safe_dot function uses NVIDIA PTX via tl.inline_asm_elementwise (asm="mov.f32 ...") with no backend guard, which will break non‑NVIDIA backends; update safe_dot to detect the backend (e.g., use existing flags like IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL or a runtime check) and only call tl.inline_asm_elementwise when on an NVIDIA backend, otherwise fall back to a pure Triton implementation (e.g., return tl.dot(a, b) or equivalent) so that tl.inline_asm_elementwise is never invoked on AMD/other backends.
275-279:⚠️ Potential issue | 🔴 CriticalFix callable misuse in shared-memory branch selection.
Line 277 checks the function object instead of calling it, so this branch is always truthy when Line 275 is false.
🐛 Proposed fix
if check_shared_mem('hopper', k.device.index): CONST_TILING = 128 - elif check_shared_mem: + elif check_shared_mem(device_index=k.device.index): CONST_TILING = 64 else: CONST_TILING = 32🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` around lines 275 - 279, The elif branch is testing the function object instead of calling it, so replace the incorrect check with a proper call to check_shared_mem using the same arguments as the first branch (i.e., call check_shared_mem('hopper', k.device.index)); update the branch that sets CONST_TILING (and ensure CONST_TILING remains assigned 64 when the call returns True) to use check_shared_mem('hopper', k.device.index) instead of check_shared_mem so the shared-memory detection logic using check_shared_mem and k.device.index works as intended.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 188-190: The loop currently defines unused locals m_k and o_k
inside the inner K-loop (variables m_k and o_k in the for i_k in
range(tl.cdiv(K, BK)) block), causing a lint F841; remove m_k and o_k and update
any downstream uses to either compute the required index expression inline
(e.g., i_k * BK + tl.arange(0, BK)) or eliminate the computation entirely if
it’s not used; ensure the loop body only keeps the necessary expressions and
that BK, K, and i_k remain correct for any remaining logic in chunk_bwd.py.
---
Duplicate comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 11-20: The safe_dot function uses NVIDIA PTX via
tl.inline_asm_elementwise (asm="mov.f32 ...") with no backend guard, which will
break non‑NVIDIA backends; update safe_dot to detect the backend (e.g., use
existing flags like IS_NVIDIA_HOPPER or IS_NVIDIA_BLACKWELL or a runtime check)
and only call tl.inline_asm_elementwise when on an NVIDIA backend, otherwise
fall back to a pure Triton implementation (e.g., return tl.dot(a, b) or
equivalent) so that tl.inline_asm_elementwise is never invoked on AMD/other
backends.
- Around line 275-279: The elif branch is testing the function object instead of
calling it, so replace the incorrect check with a proper call to
check_shared_mem using the same arguments as the first branch (i.e., call
check_shared_mem('hopper', k.device.index)); update the branch that sets
CONST_TILING (and ensure CONST_TILING remains assigned 64 when the call returns
True) to use check_shared_mem('hopper', k.device.index) instead of
check_shared_mem so the shared-memory detection logic using check_shared_mem and
k.device.index works as intended.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a5eb4037-cb62-4130-b919-e4ed3d0ef7b6
📒 Files selected for processing (1)
fla/ops/quasar/chunk_bwd.py
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (2)
fla/ops/quasar/chunk_bwd.py (2)
272-277:⚠️ Potential issue | 🔴 CriticalCall
check_shared_mem(...)here instead of testing the function object.
elif check_shared_mem:is always truthy, so non-Hopper devices always end up withCONST_TILING = 64. That bypasses the actual shared-memory check and can select an invalid tile size.🛠️ Suggested change
if check_shared_mem('hopper', k.device.index): CONST_TILING = 128 - elif check_shared_mem: + elif check_shared_mem(device_index=k.device.index): CONST_TILING = 64 else: CONST_TILING = 32🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` around lines 272 - 277, The elif is testing the function object instead of calling it, so CONST_TILING is set to 64 for non-Hopper devices regardless of actual shared memory — call check_shared_mem with the same device identifier used in the first branch (e.g., check_shared_mem('hopper', k.device.index) vs check_shared_mem('some_tag', k.device.index) or simply check_shared_mem(k.device.index) depending on the function signature) to perform the real shared-memory test; update the elif to call check_shared_mem(...) (and pass the appropriate 'hopper' or device args consistent with the first call) so CONST_TILING is chosen based on the actual check_shared_mem result, not the function object.
11-20:⚠️ Potential issue | 🟠 MajorGuard the PTX-only
safe_dotpath.
tl.inline_asm_elementwisehere hardcodes a PTXmov.f32, so this kernel only compiles on NVIDIA backends. Without an AMD/ROCm guard or a plain Triton fallback, the new backward path is not portable across the backends FLA supports.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/quasar/chunk_bwd.py` around lines 11 - 20, safe_dot currently hardcodes PTX via tl.inline_asm_elementwise with asm="mov.f32 ..." which only works on NVIDIA; wrap that PTX-specific call in a guarded path and provide a portable fallback (e.g., return tl.dot(a, b) or tl.move-equivalent) when PTX inline asm fails or when the backend is not NVIDIA. Concretely, update safe_dot to attempt the tl.inline_asm_elementwise(asm="mov.f32 ...", constraints="=r,r", args=[tl.dot(a,b)], ...) inside a try/except (or behind a runtime check for NVIDIA) and on exception or non-NVIDIA return the plain Triton expression tl.dot(a, b) so the function works across AMD/ROCm and other backends. Ensure references to safe_dot, tl.inline_asm_elementwise, and the asm string are used so reviewers can find the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/layers/quasar.py`:
- Around line 6-18: The import list in fla.layers.quasar.py includes unused
symbols (math, repeat from einops, RMSNorm from fla.modules, and
fused_quasar_gate from fla.ops.quasar.gate) causing F401; remove these unused
imports from the top of the module so only used names (contextlib, torch, nn,
rearrange, F, get_unpad_data, index_first_axis, pad_input, FusedRMSNormGated,
ShortConvolution, chunk_quasar, fused_recurrent_quasar) remain; locate and edit
the import block at the top of quasar.py to drop the four unused identifiers.
- Around line 199-205: The no-unpadding fast path fails because when
attention_mask.all() is true we set indices=None but still run the
RoPE/unpad-dependent logic and pad_input with None; modify the branches that
call get_unpad_data, index_first_axis, the RoPE handling, and pad_input to only
execute when unpadding actually occurred (i.e., indices is not None or a
did_unpad flag is true). Concretely, in the blocks around attention_mask, the
get_unpad_data call that sets indices, cu_seqlens must gate subsequent uses
(index_first_axis, rearrange-hidden_states handling, RoPE code, and pad_input)
on indices is not None (or set a boolean did_unpad and check it) so dense masks
(attention_mask.all()) skip all unpadding/padding and RoPE branches.
- Around line 171-174: Don't unconditionally overwrite the local mode variable;
use the layer's configured mode (self.mode) instead of forcing mode = "chunk",
so the fused_recurrent branch can run in eval. Set mode = self.mode (or fall
back to "chunk" if absent), keep the training-time guard assert to enforce that
training requires "chunk" (i.e., assert self.mode == "chunk" when
self.training), and leave the fused_recurrent branch to select kernels based on
this mode. Reference: self.mode, mode variable, self.training, and the
fused_recurrent branch in fla/layers/quasar.py.
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 262-263: The wrapper functions (e.g., chunk_quasar_bwd_dAv)
currently declare A: torch.Tensor | None = None and scale: float = None but pass
them directly into kernels that unconditionally read A and multiply by scale;
remove the spurious None defaults or normalize/validate them before any kernel
launch by making A and scale required parameters (no default) or adding an
explicit check at the start of the wrapper (raise a clear error if A is None or
scale is None) or set a valid default value for scale if intended; update any
other wrapper that uses the same pattern (the other launcher referenced around
the same area) to perform the same validation so kernels never receive None for
A or scale.
---
Duplicate comments:
In `@fla/ops/quasar/chunk_bwd.py`:
- Around line 272-277: The elif is testing the function object instead of
calling it, so CONST_TILING is set to 64 for non-Hopper devices regardless of
actual shared memory — call check_shared_mem with the same device identifier
used in the first branch (e.g., check_shared_mem('hopper', k.device.index) vs
check_shared_mem('some_tag', k.device.index) or simply
check_shared_mem(k.device.index) depending on the function signature) to perform
the real shared-memory test; update the elif to call check_shared_mem(...) (and
pass the appropriate 'hopper' or device args consistent with the first call) so
CONST_TILING is chosen based on the actual check_shared_mem result, not the
function object.
- Around line 11-20: safe_dot currently hardcodes PTX via
tl.inline_asm_elementwise with asm="mov.f32 ..." which only works on NVIDIA;
wrap that PTX-specific call in a guarded path and provide a portable fallback
(e.g., return tl.dot(a, b) or tl.move-equivalent) when PTX inline asm fails or
when the backend is not NVIDIA. Concretely, update safe_dot to attempt the
tl.inline_asm_elementwise(asm="mov.f32 ...", constraints="=r,r",
args=[tl.dot(a,b)], ...) inside a try/except (or behind a runtime check for
NVIDIA) and on exception or non-NVIDIA return the plain Triton expression
tl.dot(a, b) so the function works across AMD/ROCm and other backends. Ensure
references to safe_dot, tl.inline_asm_elementwise, and the asm string are used
so reviewers can find the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 28a852dc-bfb2-417f-88ad-77ce8bde2fcc
📒 Files selected for processing (2)
fla/layers/quasar.pyfla/ops/quasar/chunk_bwd.py
| import contextlib | ||
| import math | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from einops import rearrange, repeat | ||
| from torch.nn import functional as F | ||
|
|
||
| from fla.layers.utils import get_unpad_data, index_first_axis, pad_input | ||
| from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution | ||
| from fla.ops.quasar import chunk_quasar, fused_recurrent_quasar | ||
| from fla.ops.quasar.gate import fused_quasar_gate |
There was a problem hiding this comment.
Clean up the unused imports before merge.
math, repeat, RMSNorm, and fused_quasar_gate are still unused in this module, so Flake8 will keep reporting F401 here.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 7-7: 'math' imported but unused
(F401)
[error] 12-12: 'einops.repeat' imported but unused
(F401)
[error] 16-16: 'fla.modules.RMSNorm' imported but unused
(F401)
[error] 18-18: 'fla.ops.quasar.gate.fused_quasar_gate' imported but unused
(F401)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/quasar.py` around lines 6 - 18, The import list in
fla.layers.quasar.py includes unused symbols (math, repeat from einops, RMSNorm
from fla.modules, and fused_quasar_gate from fla.ops.quasar.gate) causing F401;
remove these unused imports from the top of the module so only used names
(contextlib, torch, nn, rearrange, F, get_unpad_data, index_first_axis,
pad_input, FusedRMSNormGated, ShortConvolution, chunk_quasar,
fused_recurrent_quasar) remain; locate and edit the import block at the top of
quasar.py to drop the four unused identifiers.
| # Force chunk mode to avoid fused_recurrent BT conflict | ||
| mode = "chunk" | ||
| if self.training: | ||
| assert mode == "chunk", "Only chunk mode is supported in training." |
There was a problem hiding this comment.
self.mode is ignored here.
mode is overwritten to "chunk" unconditionally, so the fused_recurrent branch below is dead even in eval. That silently ignores the layer config and defeats the decode-path kernel selection.
🛠️ Suggested change
- # Force chunk mode to avoid fused_recurrent BT conflict
- mode = "chunk"
- if self.training:
- assert mode == "chunk", "Only chunk mode is supported in training."
+ mode = "chunk" if self.training else self.mode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Force chunk mode to avoid fused_recurrent BT conflict | |
| mode = "chunk" | |
| if self.training: | |
| assert mode == "chunk", "Only chunk mode is supported in training." | |
| # Force chunk mode to avoid fused_recurrent BT conflict | |
| mode = "chunk" if self.training else self.mode |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/quasar.py` around lines 171 - 174, Don't unconditionally overwrite
the local mode variable; use the layer's configured mode (self.mode) instead of
forcing mode = "chunk", so the fused_recurrent branch can run in eval. Set mode
= self.mode (or fall back to "chunk" if absent), keep the training-time guard
assert to enforce that training requires "chunk" (i.e., assert self.mode ==
"chunk" when self.training), and leave the fused_recurrent branch to select
kernels based on this mode. Reference: self.mode, mode variable, self.training,
and the fused_recurrent branch in fla/layers/quasar.py.
| if attention_mask is not None: | ||
| # Optimization: Skip unpadding if all tokens are valid (common in packed distillation) | ||
| if attention_mask.all(): | ||
| indices, cu_seqlens = None, None | ||
| else: | ||
| indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) | ||
| hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) |
There was a problem hiding this comment.
Dense masks break the no-unpadding fast path.
When attention_mask.all() is true, indices stays None, but the RoPE branch and the final pad_input(...) branch still execute solely because attention_mask is not None. That sends None into index_first_axis(...)/pad_input(...) and breaks the common “all tokens valid” path. Gate those branches on indices is not None (or an explicit did_unpad flag) instead.
Also applies to: 241-257, 333-334
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/quasar.py` around lines 199 - 205, The no-unpadding fast path
fails because when attention_mask.all() is true we set indices=None but still
run the RoPE/unpad-dependent logic and pad_input with None; modify the branches
that call get_unpad_data, index_first_axis, the RoPE handling, and pad_input to
only execute when unpadding actually occurred (i.e., indices is not None or a
did_unpad flag is true). Concretely, in the blocks around attention_mask, the
get_unpad_data call that sets indices, cu_seqlens must gate subsequent uses
(index_first_axis, rearrange-hidden_states handling, RoPE code, and pad_input)
on indices is not None (or set a boolean did_unpad and check it) so dense masks
(attention_mask.all()) skip all unpadding/padding and RoPE branches.
| A: torch.Tensor | None = None, | ||
| scale: float = None, |
There was a problem hiding this comment.
The wrapper defaults advertise unsupported calls.
Both launchers pass scale straight into kernels that multiply by it, and chunk_quasar_bwd_dAv also forwards A to a kernel that unconditionally loads from it. None is not a real default here, so the next direct caller gets a launch-time failure unless these parameters are made required or normalized before launch.
🛠️ Suggested change
def chunk_quasar_bwd_dAv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
- A: torch.Tensor | None = None,
- scale: float = None,
+ A: torch.Tensor,
+ scale: float,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
@@
def chunk_quasar_bwd_wy_dqkb_fused(
@@
- scale: float | None = None,
+ scale: float,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
):Also applies to: 318-318
🧰 Tools
🪛 Ruff (0.15.7)
[warning] 263-263: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/quasar/chunk_bwd.py` around lines 262 - 263, The wrapper functions
(e.g., chunk_quasar_bwd_dAv) currently declare A: torch.Tensor | None = None and
scale: float = None but pass them directly into kernels that unconditionally
read A and multiply by scale; remove the spurious None defaults or
normalize/validate them before any kernel launch by making A and scale required
parameters (no default) or adding an explicit check at the start of the wrapper
(raise a clear error if A is None or scale is None) or set a valid default value
for scale if intended; update any other wrapper that uses the same pattern (the
other launcher referenced around the same area) to perform the same validation
so kernels never receive None for A or scale.
Pull Request: Add Quasar Attention and Standalone Model Implementation
Summary
This PR introduces Quasar Attention, a highly optimized linear attention variant derived from Kimi Delta Attention (KDA) but featuring significant architectural optimizations and kernel refinements. Quasar achieves superior throughput and memory efficiency, particularly at long context lengths.
This PR includes:
fla/ops/quasar.fla/layers/quasar.py.fla/models/quasar, includingQuasarConfig,QuasarModel, andQuasarForCausalLM.flalibrary root interfaces.Benchmarks
Quasar demonstrates superior hardware efficiency compared to baseline linear attention architectures.
High-Throughput Performance
Setup: 8x NVIDIA B200, 2B Model, 64k Context Length
Scaling and Memory Efficiency
Setup: Single NVIDIA B200, 1B Model
References
Implementation Details
quasarnomenclature to prevent symbol collisions with upstream KDA implementations.Summary by CodeRabbit
New Features
Documentation