Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import inspect
import linecache
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
from unsloth_zoo.rl_replacements import (
RL_REPLACEMENTS,
left_pack_padding,
chunked_selective_log_softmax,
)
from unsloth_zoo.utils import Version
from trl import __version__ as trl_version_raw
from importlib.metadata import version as importlib_version
Expand Down Expand Up @@ -859,6 +863,18 @@ def chunk_optional(tensor, chunks):
:, -(logits_to_keep + max_left_pad + 1) :, :
]
logits_chunk = logits_chunk[:, :-1, :]
logprobs_chunk = (
chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
)
else:
# Essentially, for VLMs we do not go via the optimized path in models/,
# so we don't encounter the Flash Attn left-padding issue.
Expand All @@ -876,17 +892,27 @@ def chunk_optional(tensor, chunks):
completion_input_ids_chunk = input_ids_chunk[
:, -logits_to_keep:
]

logprobs_chunk = chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
# Guard: check if model returned hidden states or logits
if logits_chunk.shape[-1] == lm_head.shape[1]:
logprobs_chunk = (
chunked_hidden_states_selective_log_softmax(
logits_chunk,
lm_head,
completion_input_ids_chunk,
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
)
else:
# Model returned logits directly - scaling/softcapping already applied by model forward
logprobs_chunk = chunked_selective_log_softmax(
logits_chunk,
completion_input_ids_chunk,
temperature,
)
# This is needed to avoid race conditions with GPT OSS offload_embbed=True
# However, it seems that this line does not slow down or disrupt models.
device_synchronize()
Expand Down
Loading