diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 791d4d9c1..9c922a26b 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -44,10 +44,14 @@ def selective_log_softmax(logits, index): # More memory efficient by chunking on (bsz+qlen) dimension # Exactly equivalent to the above @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index, temperature: float = 1.0): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) +def chunked_selective_log_softmax( + logits, + index, + temperature: float = 1.0, + chunks: int = 4, +): + chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = chunks, dim = 0) + chunked_index = torch.chunk(index.reshape(-1), chunks = chunks, dim = 0) all_per_token_logps = [] # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): @@ -928,6 +932,29 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32, logit_softcapping, temperature ) + def compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk): + # Hidden states -> lm_head matmul path; raw logits -> skip matmul and + # skip scale/softcap (model forward already applied them). + chunks = input_ids_chunk.shape[0] * multiplier + if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]: + return efficient_log_softmax( + new_hidden_states_chunk, + lm_head, + completion_ids, + chunks = chunks, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + logit_softcapping = logit_softcapping, + temperature = temperature, + batch_size = B, + ) + return chunked_selective_log_softmax( + new_hidden_states_chunk, + completion_ids, + temperature = temperature, + chunks = chunks, + ) + for ( input_ids_chunk, @@ -959,17 +986,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32, new_hidden_states_chunk = new_hidden_states_chunk[:, -(logits_to_keep + max_left_pad + 1): , :] new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :] - logprobs_chunk = efficient_log_softmax( - new_hidden_states_chunk, - lm_head, - completion_ids, - chunks=input_ids_chunk.shape[0]*multiplier, - logit_scale_multiply=logit_scale_multiply, - logit_scale_divide=logit_scale_divide, - logit_softcapping=logit_softcapping, - temperature=temperature, - batch_size = B - ) + logprobs_chunk = compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk) else: new_hidden_states_chunk = unwrapped_model( input_ids = input_ids_chunk, @@ -983,22 +1000,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32, ).logits new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :] - # Guard: check if model returned hidden states or logits - if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]: - logprobs_chunk = efficient_log_softmax( - new_hidden_states_chunk, - lm_head, - completion_ids, - chunks=input_ids_chunk.shape[0]*multiplier, - logit_scale_multiply=logit_scale_multiply, - logit_scale_divide=logit_scale_divide, - logit_softcapping=logit_softcapping, - temperature=temperature, - batch_size = B - ) - else: - # Model returned logits directly - scaling/softcapping already applied by model forward - logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature) + logprobs_chunk = compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk) #This is needed to avoid race conditions with GPT OSS offload_embbed=True #However, it seems that this line does not slow down or disrupt models. device_synchronize()