diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 314feb5d2a..805086324e 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -26,7 +26,11 @@ import inspect import linecache from collections import defaultdict -from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding +from unsloth_zoo.rl_replacements import ( + RL_REPLACEMENTS, + left_pack_padding, + chunked_selective_log_softmax, +) from unsloth_zoo.utils import Version from trl import __version__ as trl_version_raw from importlib.metadata import version as importlib_version @@ -859,6 +863,18 @@ def chunk_optional(tensor, chunks): :, -(logits_to_keep + max_left_pad + 1) :, : ] logits_chunk = logits_chunk[:, :-1, :] + logprobs_chunk = ( + chunked_hidden_states_selective_log_softmax( + logits_chunk, + lm_head, + completion_input_ids_chunk, + chunks = input_ids_chunk.shape[0] * multiplier, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + logit_softcapping = logit_softcapping, + temperature = temperature, + ) + ) else: # Essentially, for VLMs we do not go via the optimized path in models/, # so we don't encounter the Flash Attn left-padding issue. @@ -876,17 +892,27 @@ def chunk_optional(tensor, chunks): completion_input_ids_chunk = input_ids_chunk[ :, -logits_to_keep: ] - - logprobs_chunk = chunked_hidden_states_selective_log_softmax( - logits_chunk, - lm_head, - completion_input_ids_chunk, - chunks = input_ids_chunk.shape[0] * multiplier, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - logit_softcapping = logit_softcapping, - temperature = temperature, - ) + # Guard: check if model returned hidden states or logits + if logits_chunk.shape[-1] == lm_head.shape[1]: + logprobs_chunk = ( + chunked_hidden_states_selective_log_softmax( + logits_chunk, + lm_head, + completion_input_ids_chunk, + chunks = input_ids_chunk.shape[0] * multiplier, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + logit_softcapping = logit_softcapping, + temperature = temperature, + ) + ) + else: + # Model returned logits directly - scaling/softcapping already applied by model forward + logprobs_chunk = chunked_selective_log_softmax( + logits_chunk, + completion_input_ids_chunk, + temperature, + ) # This is needed to avoid race conditions with GPT OSS offload_embbed=True # However, it seems that this line does not slow down or disrupt models. device_synchronize()