Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def selective_log_softmax(logits, index):

# Custom compiled GRPO loss - creates 3 Triton kernels
def grpo_compute_loss(
old_logits,
ref_logits,
new_logits,
old_logits,
input_ids,
mask,
beta,
Expand All @@ -65,29 +66,42 @@ def grpo_compute_loss(
max_completion_length = kwargs.get("max_completion_length", 8192)
delta = kwargs.get("delta", None)

old_logits = old_logits.to(torch.float32)
# All Unsloth Zoo code licensed under LGPLv3
new_logits = new_logits.to(torch.float32)
input_ids = input_ids.unsqueeze(-1)

# x_i - logsumexp(x_i)
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)

with torch.no_grad():
if beta != 0.0:
assert ref_logits is not None, "ref_logits should not be None when beta != 0.0"
ref_logits = ref_logits.to(torch.float32)
ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1)
ref = ref_x - torch.logsumexp(ref_logits, dim = -1)
if old_logits is not None:
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
old = old_x - torch.logsumexp(old_logits, dim = -1)


new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
old = old_x - torch.logsumexp(old_logits, dim = -1)
new = new_x - torch.logsumexp(new_logits, dim = -1)

# Reverse KL
# Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
if beta != 0.0:
kl_i = torch.exp(old - new) - (old - new) - 1.0
kl_i = torch.exp(ref - new) - (ref - new) - 1.0

else:
kl_i = 0.0 # set it to 0 to not effect the downstream computation
# Full correct reverse KL divergence?? Missing term maybe?
# kl_i = torch.exp(new) * kl_i

# Below is forward KL (normal KL)
# kl_i = torch.exp(old) * (old - new)

coef_1 = torch.exp(new - old)
if old_logits is not None:
coef_1 = torch.exp(new - old)
else:
coef_1 = torch.exp(new - new.detach())
coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)

if delta is not None:
Expand All @@ -99,6 +113,7 @@ def grpo_compute_loss(
# Must detach - otherwise gradients are not propagated correctly!
# exp(x - x) == 1
# loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)


loss_2 = coef_2 * advantages.unsqueeze(1)
loss_i = -torch.min(loss_1, loss_2)
Expand Down Expand Up @@ -126,6 +141,7 @@ def grpo_compute_loss(
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
mean_kl = mean_kl_per_reward.mean()
pass

return loss, completion_length, mean_kl
pass
RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss
Expand All @@ -142,18 +158,30 @@ def grpo_compute_loss(
class UnslothEfficientGRPO(torch.autograd.Function):
# All Unsloth Zoo code licensed under LGPLv3
@staticmethod
def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None):
def forward(ctx, _new_hidden_states, _old_hidden_states, _ref_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None):
if extra_kwargs is None:
extra_kwargs = {}
print(f'Extra kwargs: {extra_kwargs}, beta = {beta}')
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
def compute_loss(new_hidden_states, old_hidden_states, ref_hidden_states,input_ids, mask, advantages, scaling):
new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
with torch.no_grad():
ref_logits = torch.matmul(ref_hidden_states, lm_head.t())
ref_logits = ref_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
old_logits = None
if old_hidden_states is not None:
old_logits = torch.matmul(old_hidden_states, lm_head.t())
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
else:
old_logits = None
# if old_hidden_states is not None:
# old_logits = torch.matmul(old_hidden_states, lm_head.t()) #last logit already excluded
# old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
# else:
# old_logits = None
loss, completion_length, mean_kl = grpo_compute_loss(
old_logits, new_logits, input_ids, mask, beta, advantages, **extra_kwargs
ref_logits, new_logits,old_logits, input_ids, mask, beta, advantages, **extra_kwargs
)

# Scale loss if needed for mixed precision training
scaled_loss = loss * scaling
# Must add .loss.detach otherwise autograd uses 2x VRAM
Expand All @@ -166,12 +194,12 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag
accumulated_completion_length = torch.zeros(1, device = device)
accumulated_mean_kl = torch.zeros(1, device = device)

def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
(chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
compute_loss,
argnums = (0,),
has_aux = True,
)(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
)(new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
accumulated_loss .add_(unscaled_loss)
accumulated_completion_length.add_(chunk_completion_length)
accumulated_mean_kl .add_(chunk_mean_kl)
Expand All @@ -186,7 +214,11 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask

grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
if _old_hidden_states is not None:
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
else:
old_hidden_states = [None] * n_chunks
ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0)
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
Expand All @@ -197,25 +229,25 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask
# Force torch.compile to use dynamic shapes for seqlen dim
mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)

for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, ref_hidden_states, input_ids, mask, advantages):

mark_dynamic(new_hidden_states_j)
mark_dynamic(old_hidden_states_j)
mark_dynamic(ref_hidden_states_j)
if old_hidden_states_j is not None:
mark_dynamic(old_hidden_states_j)
mark_dynamic(input_ids_j)
mark_dynamic(mask_j)

grad_inputs_j.copy_(
accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
)

grad_inputs_j.copy_(accumulate_chunk(new_hidden_states_j, old_hidden_states_j,ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling))
pass

grad_inputs .div_(n_chunks)
accumulated_loss .div_(n_chunks)
accumulated_completion_length.div_(n_chunks)
accumulated_mean_kl .div_(n_chunks)
ctx.save_for_backward(grad_inputs)

return (
accumulated_loss,
accumulated_completion_length,
Expand All @@ -226,7 +258,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask
@staticmethod
def backward(ctx, grad_output, dcompletion_length, dmean_kl):
(grad_input,) = ctx.saved_tensors
return (grad_input, None, None, None, None, None, None, None, None, None)
return (grad_input, None, None, None, None, None, None, None, None, None, None)
pass
pass
RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO
Expand All @@ -238,11 +270,13 @@ def grpo_accumulated_loss(
logits_to_keep,
completion_mask,
advantages,
old_hidden_states,
n_chunks = -1,
**kwargs,
):
# All Unsloth Zoo code licensed under LGPLv3
bsz, qlen = input_ids.shape

# Find closest multiple
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
if n_chunks == -1: n_chunks = bsz
Expand All @@ -255,18 +289,20 @@ def grpo_accumulated_loss(
lm_head = trainer.model.get_output_embeddings().weight

with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
#breakpoint()
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
ref_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
pass

new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits

loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
new_hidden_states, old_hidden_states, lm_head,
new_hidden_states, old_hidden_states ,ref_hidden_states, lm_head,
completion_input_ids, completion_mask, advantages, trainer.beta,
trainer.accelerator.scaler,
n_chunks, kwargs # pass kwargs as a dict
)

return loss, completion_length, mean_kl

# Old non efficient code path
Expand Down