Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 33 additions & 31 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def selective_log_softmax(logits, index):
# More memory efficient by chunking on (bsz+qlen) dimension
# Exactly equivalent to the above
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
def chunked_selective_log_softmax(logits, index, temperature: float = 1.0):
# Split into 4 chunks only
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
def chunked_selective_log_softmax(
logits,
index,
temperature: float = 1.0,
chunks: int = 4,
):
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = chunks, dim = 0)
chunked_index = torch.chunk(index.reshape(-1), chunks = chunks, dim = 0)
all_per_token_logps = []
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
Expand Down Expand Up @@ -928,6 +932,29 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,
logit_softcapping, temperature
)

def compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk):
# Hidden states -> lm_head matmul path; raw logits -> skip matmul and
# skip scale/softcap (model forward already applied them).
chunks = input_ids_chunk.shape[0] * multiplier
if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]:
return efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks = chunks,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
batch_size = B,
)
return chunked_selective_log_softmax(
new_hidden_states_chunk,
completion_ids,
temperature = temperature,
chunks = chunks,
)


for (
input_ids_chunk,
Expand Down Expand Up @@ -959,17 +986,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,

new_hidden_states_chunk = new_hidden_states_chunk[:, -(logits_to_keep + max_left_pad + 1): , :]
new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]
logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
logprobs_chunk = compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk)
else:
new_hidden_states_chunk = unwrapped_model(
input_ids = input_ids_chunk,
Expand All @@ -983,22 +1000,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,
).logits

new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]
# Guard: check if model returned hidden states or logits
if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]:
logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
else:
# Model returned logits directly - scaling/softcapping already applied by model forward
logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature)
logprobs_chunk = compute_logprobs_chunk(new_hidden_states_chunk, completion_ids, input_ids_chunk)
#This is needed to avoid race conditions with GPT OSS offload_embbed=True
#However, it seems that this line does not slow down or disrupt models.
device_synchronize()
Expand Down