diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 4dbb2273b..7d7fc0281 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -60,7 +60,66 @@ def chunked_selective_log_softmax(logits, index): return all_per_token_logps pass +def calculate_pad_tokens_in_prompt( + input_ids: torch.Tensor, + logits_to_keep: int, + pad_token_id: int + ) -> torch.Tensor: + """ + Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens + """ + if logits_to_keep >= input_ids.shape[1]: + raise ValueError("logits_to_keep must be smaller than the sequence length.") + + prompt_section = input_ids[:, :-logits_to_keep] + + padding_mask = (prompt_section == pad_token_id) + + pad_token_counts = padding_mask.sum(dim=1) + + return pad_token_counts + +def create_completion_attention_mask( + completion_input_ids: torch.Tensor, + left_pad_tokens_per_prompt: torch.Tensor, + max_left_pad: int, + pad_token_id: int + ) -> torch.Tensor: + """ + Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] + + Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens + and pad are pad tokens, this function would make a completion mask that would 0 out the pad + and p tokens. so in this example [0,0,0,1,1,1,0,0,0] + """ + batch_size, completion_len = completion_input_ids.shape + device = completion_input_ids.device + + num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt + + indices = torch.arange(completion_len, device=device).unsqueeze(0) + shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) + + non_padding_mask = (completion_input_ids != pad_token_id) + + final_mask = shift_mask & non_padding_mask + + return final_mask + +def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: + """ + Moves all padding tokens in each sequence of a batch to the right. + """ + mask = (tensor != pad_id) + sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) + packed_tensor = torch.gather(tensor, 1, sorted_indices) + + return packed_tensor + RL_REPLACEMENTS["selective_log_softmax"] = chunked_selective_log_softmax +RL_REPLACEMENTS["create_completion_attention_mask"] = create_completion_attention_mask +RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"] = calculate_pad_tokens_in_prompt +RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding # Custom compiled GRPO loss - creates 3 Triton kernels @@ -85,6 +144,7 @@ def grpo_compute_loss( logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0) logit_scale_divide = kwargs.get("logit_scale_divide", 0.0) logit_softcapping = kwargs.get("logit_softcapping", 0.0) + importance_sampling_level = kwargs.get("importance_sampling_level", "token") input_ids = input_ids.unsqueeze(-1) @@ -98,7 +158,6 @@ def grpo_compute_loss( if temperature != 1.0: new_logits = new_logits / temperature new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) new = new_x - torch.logsumexp(new_logits, dim = -1) - # x_i - logsumexp(x_i) with torch.no_grad(): if beta != 0.0: @@ -143,9 +202,23 @@ def grpo_compute_loss( # Below is forward KL (normal KL) # kl_i = torch.exp(old) * (old - new) if old_logits is not None: - coef_1 = torch.exp(new - old) + log_ratio = new - old + else: + log_ratio = new - new.detach() + + if importance_sampling_level == "token": + log_importance_weights = log_ratio + elif importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) else: - coef_1 = torch.exp(new - new.detach()) + raise ValueError( + f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high) if delta is not None: @@ -178,15 +251,19 @@ def grpo_compute_loss( # loss = (loss_i * mask).sum() / mask.sum() - # Get metrics as well which are folded - with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() - n_mask_per_reward = n_mask_per_reward.clamp(min = 1.0) # Counteracts division by 0 - mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() - pass + completion_length = n_mask_per_reward.mean() - return loss, completion_length, mean_kl + # Get metrics as well which are folded + def masked_batch_mean(x): + with torch.inference_mode(): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward + mean_kl = mean_kl_per_reward.mean() + return mean_kl + + return loss, completion_length, masked_batch_mean(kl_i) pass RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss RL_REPLACEMENTS["grpo_compute_loss_slow"] = \ @@ -283,7 +360,10 @@ def accumulate_chunk( old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0) else: old_hidden_states = [None] * n_chunks - ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0) + if _ref_hidden_states is not None: + ref_hidden_states = torch.chunk(_ref_hidden_states, chunks = n_chunks, dim = 0) + else: + ref_hidden_states = [None] * n_chunks input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0) mask = torch.chunk(_mask, chunks = n_chunks, dim = 0) advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0) @@ -347,12 +427,17 @@ def grpo_accumulated_loss( completion_mask, advantages, old_hidden_states, + ref_hidden_states, n_chunks = -1, **kwargs, ): # All Unsloth Zoo code licensed under LGPLv3 bsz, qlen = input_ids.shape + pixel_values = kwargs.get('pixel_values',None) + image_grid_thw = kwargs.get('image_grid_thw',None) + pixel_attention_mask = kwargs.get('pixel_attention_mask',None) + image_sizes = kwargs.get('image_sizes',None) # Find closest multiple factors = [i for i in range(1, bsz + 1) if bsz % i == 0] if n_chunks == -1: n_chunks = bsz @@ -364,41 +449,72 @@ def grpo_accumulated_loss( pass os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight - with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype): - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - ref_hidden_states = trainer.model( + if pixel_values is None: + left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(input_ids, logits_to_keep, trainer.processing_class.pad_token_id) + + max_left_pad = max(left_pad_tokens_per_prompt).item() + + input_ids = left_pack_padding(input_ids, trainer.processing_class.pad_token_id) + + completion_input_ids = input_ids[:, -(logits_to_keep +max_left_pad):] + + completion_mask = create_completion_attention_mask(completion_input_ids, left_pad_tokens_per_prompt, max_left_pad, trainer.processing_class.pad_token_id).to(attention_mask.dtype) + attention_mask = input_ids != trainer.processing_class.pad_token_id + attention_mask = attention_mask.to(attention_mask.dtype) + else: + completion_input_ids = input_ids[:, -logits_to_keep:] + + unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False) + with torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype): + if pixel_values is None: + new_hidden_states = unwrapped_model( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + #logits_to_keep = logits_to_keep + 1, + ).logits + + #keep extra logit as we generated a new token + new_hidden_states = new_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :] + if ref_hidden_states is not None: + ref_hidden_states = ref_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :] + if old_hidden_states is not None: + old_hidden_states = old_hidden_states[:, -(logits_to_keep +max_left_pad+1): , :] + else: + new_hidden_states = unwrapped_model( input_ids = input_ids, attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, logits_to_keep = logits_to_keep + 1, ).logits - pass - new_hidden_states = trainer.model( - input_ids = input_ids, - attention_mask = attention_mask, - logits_to_keep = logits_to_keep + 1, - ).logits - - loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, - old_hidden_states, - ref_hidden_states, - lm_head, - completion_input_ids, - completion_mask, - advantages, - trainer.beta, - trainer.accelerator.scaler, - n_chunks, - kwargs # pass kwargs as a dict - ) + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, + old_hidden_states, + ref_hidden_states, + lm_head, + completion_input_ids, + completion_mask, + advantages, + trainer.beta, + trainer.accelerator.scaler, + n_chunks, + kwargs # pass kwargs as a dict + ) pass + # Must force not returning hidden states but logits otherwise gibberish os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" - return loss, completion_length, mean_kl + return loss, completion_length, mean_kl # Old non efficient code path new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred