diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 51276d447..fecea0edc 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -50,8 +50,9 @@ def selective_log_softmax(logits, index): # Custom compiled GRPO loss - creates 3 Triton kernels def grpo_compute_loss( - old_logits, + ref_logits, new_logits, + old_logits, input_ids, mask, beta, @@ -65,20 +66,31 @@ def grpo_compute_loss( max_completion_length = kwargs.get("max_completion_length", 8192) delta = kwargs.get("delta", None) - old_logits = old_logits.to(torch.float32) + # All Unsloth Zoo code licensed under LGPLv3 new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) # x_i - logsumexp(x_i) - old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) + + with torch.no_grad(): + if beta != 0.0: + assert ref_logits is not None, "ref_logits should not be None when beta != 0.0" + ref_logits = ref_logits.to(torch.float32) + ref_x = torch.gather(ref_logits, dim = -1, index = input_ids).squeeze(-1) + ref = ref_x - torch.logsumexp(ref_logits, dim = -1) + if old_logits is not None: + old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) + old = old_x - torch.logsumexp(old_logits, dim = -1) + + new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) - old = old_x - torch.logsumexp(old_logits, dim = -1) new = new_x - torch.logsumexp(new_logits, dim = -1) # Reverse KL # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper if beta != 0.0: - kl_i = torch.exp(old - new) - (old - new) - 1.0 + kl_i = torch.exp(ref - new) - (ref - new) - 1.0 + else: kl_i = 0.0 # set it to 0 to not effect the downstream computation # Full correct reverse KL divergence?? Missing term maybe? @@ -86,8 +98,10 @@ def grpo_compute_loss( # Below is forward KL (normal KL) # kl_i = torch.exp(old) * (old - new) - - coef_1 = torch.exp(new - old) + if old_logits is not None: + coef_1 = torch.exp(new - old) + else: + coef_1 = torch.exp(new - new.detach()) coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high) if delta is not None: @@ -99,6 +113,7 @@ def grpo_compute_loss( # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_2 = coef_2 * advantages.unsqueeze(1) loss_i = -torch.min(loss_1, loss_2) @@ -126,6 +141,7 @@ def grpo_compute_loss( mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward mean_kl = mean_kl_per_reward.mean() pass + return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss @@ -142,18 +158,30 @@ def grpo_compute_loss( class UnslothEfficientGRPO(torch.autograd.Function): # All Unsloth Zoo code licensed under LGPLv3 @staticmethod - def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None): + def forward(ctx, _new_hidden_states, _old_hidden_states, _ref_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None): if extra_kwargs is None: extra_kwargs = {} - print(f'Extra kwargs: {extra_kwargs}, beta = {beta}') - def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): + def compute_loss(new_hidden_states, old_hidden_states, ref_hidden_states,input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + with torch.no_grad(): + ref_logits = torch.matmul(ref_hidden_states, lm_head.t()) + ref_logits = ref_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = None + if old_hidden_states is not None: + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + else: + old_logits = None + # if old_hidden_states is not None: + # old_logits = torch.matmul(old_hidden_states, lm_head.t()) #last logit already excluded + # old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + # else: + # old_logits = None loss, completion_length, mean_kl = grpo_compute_loss( - old_logits, new_logits, input_ids, mask, beta, advantages, **extra_kwargs + ref_logits, new_logits,old_logits, input_ids, mask, beta, advantages, **extra_kwargs ) + # Scale loss if needed for mixed precision training scaled_loss = loss * scaling # Must add .loss.detach otherwise autograd uses 2x VRAM @@ -166,12 +194,12 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag accumulated_completion_length = torch.zeros(1, device = device) accumulated_mean_kl = torch.zeros(1, device = device) - def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling): + def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling): (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value( compute_loss, argnums = (0,), has_aux = True, - )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) + )(new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) accumulated_loss .add_(unscaled_loss) accumulated_completion_length.add_(chunk_completion_length) accumulated_mean_kl .add_(chunk_mean_kl) @@ -186,7 +214,11 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0) new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0) - old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0) + if _old_hidden_states is not None: + old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0) + else: + old_hidden_states = [None] * n_chunks + ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0) input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0) mask = torch.chunk(_mask, chunks = n_chunks, dim = 0) advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0) @@ -197,17 +229,18 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask # Force torch.compile to use dynamic shapes for seqlen dim mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1) - for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \ - zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages): + for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, ref_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \ + zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, ref_hidden_states, input_ids, mask, advantages): mark_dynamic(new_hidden_states_j) - mark_dynamic(old_hidden_states_j) + mark_dynamic(ref_hidden_states_j) + if old_hidden_states_j is not None: + mark_dynamic(old_hidden_states_j) mark_dynamic(input_ids_j) mark_dynamic(mask_j) - grad_inputs_j.copy_( - accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) - ) + + grad_inputs_j.copy_(accumulate_chunk(new_hidden_states_j, old_hidden_states_j,ref_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)) pass grad_inputs .div_(n_chunks) @@ -215,7 +248,6 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask accumulated_completion_length.div_(n_chunks) accumulated_mean_kl .div_(n_chunks) ctx.save_for_backward(grad_inputs) - return ( accumulated_loss, accumulated_completion_length, @@ -226,7 +258,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask @staticmethod def backward(ctx, grad_output, dcompletion_length, dmean_kl): (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None, None) + return (grad_input, None, None, None, None, None, None, None, None, None, None) pass pass RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO @@ -238,11 +270,13 @@ def grpo_accumulated_loss( logits_to_keep, completion_mask, advantages, + old_hidden_states, n_chunks = -1, **kwargs, ): # All Unsloth Zoo code licensed under LGPLv3 bsz, qlen = input_ids.shape + # Find closest multiple factors = [i for i in range(1, bsz + 1) if bsz % i == 0] if n_chunks == -1: n_chunks = bsz @@ -255,18 +289,20 @@ def grpo_accumulated_loss( lm_head = trainer.model.get_output_embeddings().weight with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): + #breakpoint() with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits + ref_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits pass - + new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, old_hidden_states, lm_head, + new_hidden_states, old_hidden_states ,ref_hidden_states, lm_head, completion_input_ids, completion_mask, advantages, trainer.beta, trainer.accelerator.scaler, n_chunks, kwargs # pass kwargs as a dict ) + return loss, completion_length, mean_kl # Old non efficient code path