Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,6 +1907,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
F_softmax = torch.nn.functional.softmax
F_dropout = nn.functional.dropout
matmul = torch.matmul
def _align_kv_to_mask(key_states, value_states, attention_mask):
# Eager attention does `attn_weights += attention_mask`, which requires the
# key/value length to equal the mask's key dimension. On some
# transformers/torch combinations (e.g. transformers 5.x on torch < 2.11)
# the KV cache hands back more positions than the causal mask covers
# (pre-allocated cache slots), so a full-attention layer can see e.g. 161
# keys against a 128-wide mask and crash. The surplus positions are masked
# out anyway, so attend only the overlap by trimming KV (and the mask) to
# the shorter length. This keeps the path correct and shape-consistent
# across transformers/torch versions.
if attention_mask is None or not hasattr(attention_mask, "shape"):
return key_states, value_states, attention_mask
kvlen = key_states.shape[-2]
masklen = attention_mask.shape[-1]
if masklen < kvlen:
key_states = key_states[:, :, :masklen, :]
value_states = value_states[:, :, :masklen, :]
elif masklen > kvlen:
attention_mask = attention_mask[:, :, :, :kvlen]
return key_states, value_states, attention_mask

def inplace_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand All @@ -1919,6 +1940,9 @@ def inplace_eager_attention_forward(
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
key_states, value_states, attention_mask = _align_kv_to_mask(
key_states, value_states, attention_mask
)

bsz, n_heads, qlen, _ = query.shape
bsz, n_heads, kvlen, _ = key_states.shape
Expand Down Expand Up @@ -1961,6 +1985,9 @@ def eager_attention_forward(
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
key_states, value_states, attention_mask = _align_kv_to_mask(
key_states, value_states, attention_mask
)
attn_weights = matmul(query, key_states.transpose(2, 3))
attn_weights *= scaling
if attention_mask is not None:
Expand Down Expand Up @@ -2108,22 +2135,12 @@ def forward(
# Transformers >= 5.0 dropped `cache_position` from GptOssAttention.forward's
# signature, so the variants above fail the strict signature match and the
# attention patch silently does not apply (leaving stock attention, which
# then mismatches the sliding-window mask the model patch builds). Provide
# cache_position-free variants (it still arrives via **kwargs) so the patch
# installs on transformers 5.x too.
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
**kwargs: KWARGS_TYPE,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
cache_position = kwargs.pop("cache_position", None)
return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)

functions.append(forward)

# then mismatches the sliding-window mask the model patch builds). Add a
# `past_key_values` variant without `cache_position` (it still arrives via
# **kwargs) so the patch installs on transformers 5.x too. Only the
# `past_key_values` spelling is needed: the singular `past_key_value` naming
# predates transformers dropping `cache_position`, so a singular + no
# `cache_position` signature does not exist in any transformers release.
def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading