Conversation
|
Caution Review failedThe pull request is closed. WalkthroughThis pull request refactors several functions across multiple modules by removing the Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Suite
participant FRwkv7 as fused_recurrent_rwkv7
participant RWkv7 as rwkv7_fn
Test->>FRwkv7: Initiate forward pass
FRwkv7->>RWkv7: Call without input_precision
RWkv7-->>FRwkv7: Return computed tensor
FRwkv7-->>Test: Pass along output
sequenceDiagram
participant Caller
participant Func as chunk_dplr_delta_rule
Caller->>Func: Invoke function with tensor q (float32)
Func->>Func: Check tensor data type
alt Tensor is float32
Func->>Caller: Emit warning ("ChunkDeltaRuleFunction does not support float32...")
else
Func->>Caller: Proceed with computation
end
Possibly related PRs
Suggested reviewers
Poem
Tip ⚡💬 Agentic Chat (Pro Plan, General Availability)
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (13)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
* [Attn] Remove `head_first` & rename `offsets` to `cu_seqlens` * Delete 256 headdim tests * [DeltaNet] Remove `head_first` & rename `offsets` to `cu_seqlens` * Rename `offsets` to `cu_seqlens` across GLA/GSA/RWKV6 * Fix NSA tests * Fix TTT checks * Rename `offsets` to `cu_seqlens` in DPLR and IPLR implementations, updating related logic and heuristics for variable-length sequences.
There was a problem hiding this comment.
Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (4)
tests/ops/test_rwkv7.py:18
- Consider adding tests for torch.float16 if needed to cover alternative precision paths, or ensure that bfloat16 is the exclusively supported type.
@pytest.mark.parametrize("dtype", [torch.bfloat16])
fla/ops/rwkv7/fused_recurrent.py:22
- Ensure that removing the input_precision parameter does not break downstream calls expecting this parameter; update documentation if necessary.
input_precision: Optional[torch.dtype] = torch.bfloat16,
fla/ops/rwkv7/fused_addcmul.py:55
- Confirm that replacing the explicit conversion to tl.float32 with conversion to DTYPE maintains numerical stability and performance; verify that the DTYPE parameter is set appropriately at all call sites.
b_hiddn = tl.load(hidden_ptr + (xindex), xmask, other=0.).to(DTYPE)
fla/ops/generalized_delta_rule/dplr/chunk.py:333
- [nitpick] The warning message for float32 support could be enhanced by providing explicit guidance on alternative behavior or raising an error for unsupported dtypes.
if q.dtype == torch.float32:
Summary by CodeRabbit
Refactor
Tests