Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions fla/ops/abc/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,16 +1082,16 @@ def chunk_abc(
s: torch.Tensor,
initial_state: Optional[Tuple[torch.Tensor]] = None,
output_final_state: bool = False,
head_first: bool = True
head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
k (torch.Tensor):
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
v (torch.Tensor):
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`
s (torch.Tensor):
slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
Expand All @@ -1100,11 +1100,11 @@ def chunk_abc(
Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format.
Default: `True`.
Default: `False`.

Returns:
o (torch.Tensor):
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`.
"""
Expand Down
2 changes: 1 addition & 1 deletion fla/ops/based/fused_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def fused_chunk_based(
v: torch.Tensor,
scale: Optional[float] = None,
use_norm: bool = True,
head_first: bool = True
head_first: bool = False
):
assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
if scale is None:
Expand Down
2 changes: 1 addition & 1 deletion fla/ops/based/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def parallel_based(
v: torch.Tensor,
scale: Optional[float] = None,
use_norm: bool = True,
head_first: bool = True
head_first: bool = False
):
assert q.shape[-1] <= 128, "only support feature dim up to 128"
if scale is None:
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/common/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True
save_new_value: bool = True,
head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *k.shape, u.shape[-1]
Expand Down Expand Up @@ -490,8 +490,8 @@ def chunk_gated_delta_rule_bwd_dhu(
scale: float,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
chunk_size: int = 64 # SY: remove this argument and force chunk size 64?
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *q.shape, do.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/common/chunk_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def chunk_fwd_h(
h0: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_fwd_h function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

chunk_size: int = 64,
split_size: Optional[int] = None,
states_in_fp32: bool = False
Expand Down Expand Up @@ -364,7 +364,7 @@ def chunk_bwd_dh(
dht: torch.Tensor,
scale: float,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_bwd_dh function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

chunk_size: int = 64,
split_size: Optional[int] = None,
states_in_fp32: bool = False
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/common/chunk_h_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def chunk_fwd_h(
states_in_fp32: bool = False,
offsets: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
Expand Down Expand Up @@ -569,7 +569,7 @@ def chunk_bwd_dh(
states_in_fp32: bool = False,
offsets: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/common/chunk_h_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def chunk_fwd_h(
offsets: Optional[torch.LongTensor] = None,
split_offsets: Optional[torch.LongTensor] = None,
split_indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64,
split_size: int = 256,
states_in_fp32: bool = True
Expand Down Expand Up @@ -590,7 +590,7 @@ def chunk_bwd_dh(
offsets: Optional[torch.Tensor] = None,
split_offsets: Optional[torch.Tensor] = None,
split_indices: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64,
split_size: int = 256,
states_in_fp32: bool = True
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/common/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def chunk_fwd_o(
scale: Optional[float] = None,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
) -> torch.Tensor:
if head_first:
Expand Down Expand Up @@ -506,7 +506,7 @@ def chunk_bwd_dv(
scale: float,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
) -> torch.Tensor:
if head_first:
Expand Down Expand Up @@ -559,7 +559,7 @@ def chunk_bwd_dv_local(
scale: float,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
) -> torch.Tensor:
if head_first:
Expand Down Expand Up @@ -615,7 +615,7 @@ def chunk_bwd_dqkwg(
indices: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
scale: float = 1.0,
head_first: bool = True,
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if head_first:
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/common/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def fused_recurrent_fwd(
output_final_state: bool = False,
reverse: bool = False,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
):
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down Expand Up @@ -393,7 +393,7 @@ def fused_recurrent_bwd(
initial_state: Optional[torch.Tensor] = None,
reverse: bool = False,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
):
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down Expand Up @@ -487,7 +487,7 @@ def forward(
output_final_state: bool = False,
reverse: bool = False,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
):
o, ht = fused_recurrent_fwd(
q=q,
Expand Down Expand Up @@ -555,7 +555,7 @@ def fused_recurrent(
output_final_state: bool = False,
reverse: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
):
if scale is None:
scale = k.shape[-1] ** -0.5
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def chunk_delta_rule_fwd(
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
):
T = q.shape[2] if head_first else q.shape[1]
Expand Down Expand Up @@ -80,7 +80,7 @@ def chunk_delta_rule_bwd(
dht: torch.Tensor,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
):
T = q.shape[2] if head_first else q.shape[1]
Expand Down Expand Up @@ -181,7 +181,7 @@ def forward(
initial_state: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True
):
T = q.shape[2] if head_first else q.shape[1]
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/delta_rule/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def fused_recurrent_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default value changed from True to False

The default value of head_first parameter for fused_recurrent_delta_rule_fwd has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of fused_recurrent_delta_rule_fwd without explicitly specifying head_first
rg "fused_recurrent_delta_rule_fwd\s*\(" --type py | grep -v "head_first\s*="

Length of output: 251


Action: Address breaking change due to updated default for head_first

In the file fla/ops/delta_rule/fused_recurrent.py (line 304), the default value for the head_first parameter in fused_recurrent_delta_rule_fwd was changed from True to False. Our recent grep search indicates that at least one call to this function does not explicitly specify head_first, meaning it will now default to False. This change can break existing behavior in parts of the codebase that expect the previous default of True.

  • Location: fla/ops/delta_rule/fused_recurrent.py
  • Line 304: head_first: bool = False
  • Issue: Non-explicit usage of fused_recurrent_delta_rule_fwd may cause unintended behavior due to this default change.

Please review and either update the call sites to explicitly pass the correct head_first value or revert the default if the change was unintended.

) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down Expand Up @@ -359,7 +359,7 @@ def fused_recurrent_delta_rule_bwd(
scale: float,
initial_state: torch.Tensor,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default value changed from True to False

The default value of head_first parameter for fused_recurrent_delta_rule_bwd has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of fused_recurrent_delta_rule_bwd without explicitly specifying head_first
rg "fused_recurrent_delta_rule_bwd\s*\(" --type py | grep -v "head_first\s*="

Length of output: 253


Action Required: Explicitly Specify head_first in fused_recurrent_delta_rule_bwd Calls

The verification confirms that the call sites for fused_recurrent_delta_rule_bwd (e.g. in fla/ops/delta_rule/fused_recurrent.py) are not providing an explicit value for head_first. Since the default has changed from True to False, this will alter behavior in these cases, potentially breaking backward compatibility.

  • Location: fla/ops/delta_rule/fused_recurrent.py (definition at line 362 and other call sites)
  • Issue: Call sites of fused_recurrent_delta_rule_bwd rely on the old default behavior (True), but now default to False.
  • Recommendation: Update all calls to explicitly provide the head_first parameter to match the intended behavior (e.g., append head_first=True if the previous behavior is desired).

) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down Expand Up @@ -438,7 +438,7 @@ def forward(
initial_state: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False
):
q_orig = q
Expand Down Expand Up @@ -501,7 +501,7 @@ def fused_recurrent_delta_rule(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default value changed from True to False

The default value of head_first parameter for fused_recurrent_delta_rule has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of fused_recurrent_delta_rule without explicitly specifying head_first
rg "fused_recurrent_delta_rule\s*\(" --type py | grep -v "head_first\s*="

Length of output: 830


Critical: Adjust Default Value for head_first in fused_recurrent_delta_rule

The default for the head_first parameter in fused_recurrent_delta_rule has been changed from True to False (line 504 in fla/ops/delta_rule/fused_recurrent.py). Our search shows that multiple calls—specifically in tests/ops/test_delta.py (and one in fla/layers/delta_net.py)—invoke this function without explicitly setting the head_first parameter, which means they unintentionally rely on its default value.

Action Required:

  • Review & Update Tests: Ensure that test cases invoking fused_recurrent_delta_rule explicitly pass the expected head_first value or adjust their expectations based on the new default.
  • Backward Compatibility: If the change was not intended, consider reverting the default or providing a migration plan to mitigate breaking changes for existing users.

use_qk_l2norm_in_kernel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Expand Down
12 changes: 6 additions & 6 deletions fla/ops/delta_rule/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,18 @@ def parallel_delta_rule(
beta: torch.Tensor,
scale: float = None,
output_attentions: bool = False,
head_first: bool = True
head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
beta (torch.Tensor):
betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
Expand All @@ -328,7 +328,7 @@ def parallel_delta_rule(

Returns:
o (torch.Tensor):
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
attn (torch.Tensor):
Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`.
"""
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/gated_delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def chunk_gated_delta_rule_fwd(
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
):
g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
Expand Down Expand Up @@ -85,7 +85,7 @@ def chunk_gated_delta_rule_bwd(
dht: torch.Tensor,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
chunk_size: int = 64
):
T = q.shape[2] if head_first else q.shape[1]
Expand Down Expand Up @@ -193,7 +193,7 @@ def forward(
initial_state: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False
):
chunk_size = 64
Expand Down
2 changes: 1 addition & 1 deletion fla/ops/gated_delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def fwd_prepare_wy_repr(
beta: torch.Tensor,
offsets: Optional[torch.LongTensor],
indices: Optional[torch.LongTensor],
head_first: bool = True,
head_first: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the fwd_prepare_wy_repr function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

chunk_size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if head_first:
Expand Down
Loading