From 0ac4e9b643e145032ce46c5948da74ac3a65e895 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:02:24 +0000 Subject: [PATCH 1/6] Initial commit --- benchmark/data/all_benchmark_data.csv | 24 +++ benchmark/scripts/benchmark_cpo_loss.py | 191 ++++++++++++++++++ src/liger_kernel/chunked_loss/cpo_loss.py | 61 ++++++ .../chunked_loss/fused_linear_preference.py | 108 +++++++++- src/liger_kernel/chunked_loss/orpo_loss.py | 106 +++------- test/chunked_loss/test_cpo_loss.py | 132 ++++++++++++ test/chunked_loss/test_orpo_loss.py | 134 ++---------- test/utils.py | 131 +++++++++++- 8 files changed, 682 insertions(+), 205 deletions(-) create mode 100644 benchmark/scripts/benchmark_cpo_loss.py create mode 100644 src/liger_kernel/chunked_loss/cpo_loss.py create mode 100644 test/chunked_loss/test_cpo_loss.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index a5126f1dd..6e5fd4ce0 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -667,3 +667,27 @@ fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.3144 fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,33418.421875,33418.421875,33418.421875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,2,31.536447525024414,31.457439422607422,31.543052673339844,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,4,62.407745361328125,62.407745361328125,62.407745361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,8,123.64259338378906,123.64259338378906,123.64259338378906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,16,245.66575622558594,245.66575622558594,245.66575622558594,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,2,14.516239166259766,14.514080047607422,14.52575969696045,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,4,26.087743759155273,25.943340301513672,26.269376754760742,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,8,51.85932922363281,51.85932922363281,51.85932922363281,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,16,104.99673461914062,104.99673461914062,104.99673461914062,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,2,33.309967041015625,33.21604919433594,33.40388488769531,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,4,63.053470611572266,63.053470611572266,63.053470611572266,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,8,125.53849792480469,125.53849792480469,125.53849792480469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,16,250.22178649902344,250.22178649902344,250.22178649902344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,2,39.45849609375,39.33102798461914,39.58596420288086,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,4,77.00272369384766,77.00272369384766,77.00272369384766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,8,154.28419494628906,154.28419494628906,154.28419494628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,16,309.23162841796875,309.23162841796875,309.23162841796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py new file mode 100644 index 000000000..d10c8da8a --- /dev/null +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -0,0 +1,191 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadCPO(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + from test.chunked_loss.test_cpo_loss import HFCPOLoss + + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.cpo_loss = HFCPOLoss().get_batch_loss_metrics + + def forward(self, x, y): + return self.cpo_loss(x, self.lin.weight, y) + + +class LigerLMHeadCPO(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.cpo_loss = LigerFusedLinearCPOFunction.apply + + def forward(self, x, y): + return self.cpo_loss(x, self.lin.weight, y) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_cpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_cpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_cpo(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_cpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + + torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_cpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_cpo(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_cpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_cpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_cpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py new file mode 100644 index 000000000..cc8bd44ef --- /dev/null +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -0,0 +1,61 @@ +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): + + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + logits = beta * (chosen_logps - rejected_logps) + loss = F.logsigmoid(logits).mean() + return loss + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + alpha=1.0, + compute_nll_loss=True, + compiled=True, + ): + """ + Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss. + Handles both the forward and backward pass of the final linear layer with CPO loss. + Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. + """ + + return LigerFusedLinearPreferenceBase.forward( + ctx, + _input, + weight, + target, + bias, + loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + compiled=compiled, + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None, None diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c95aa40ed..c43caf839 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -1,7 +1,23 @@ +from abc import abstractmethod +from functools import partial + import torch +from torch.nn import functional as F class LigerFusedLinearPreferenceBase(torch.autograd.Function): + + @abstractmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute preference loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + raise NotImplementedError("Preference loss function must be implemented.") + @staticmethod def forward( ctx, @@ -11,6 +27,10 @@ def forward( bias=None, loss_fn=None, chunk_size=1, + compute_nll_loss=True, + ignore_index=-100, + alpha=1.0, + beta=0.1, compiled=True, ): """ @@ -24,6 +44,10 @@ def forward( bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). loss_fn (callable): Loss function to compute the loss on a chunk of input/target. chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + compute_nll_loss (bool): Whether to compute NLL loss. + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. compiled (bool): Whether to use torch compile for chunk accumulation. """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -36,13 +60,24 @@ def forward( loss_acc = torch.zeros((), device=_input.device) chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + loss_func_to_call = partial( + LigerFusedLinearPreferenceBase._compute_loss, + preference_loss_fn=loss_fn, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + compute_nll_loss=compute_nll_loss, + full_target=target, + ) def accumulate_chunk(input_chunk, target_chunk): if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( chunk_loss, (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 3), has_aux=True)( + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 3), has_aux=True + )( input_chunk, weight, target_chunk, bias ) grad_bias.add_(chunk_grad_bias) @@ -50,7 +85,9 @@ def accumulate_chunk(input_chunk, target_chunk): (chunk_grad_input, chunk_grad_weight), ( chunk_loss, (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1), has_aux=True)( + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( input_chunk, weight, target_chunk ) grad_weight.add_(chunk_grad_weight) @@ -105,3 +142,70 @@ def backward(ctx, grad_output): grad_bias = grad_bias * grad_output if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None + + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + len_chosen_chunk = target_chunk.shape[0] // 2 + + logits_chunk = input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + -1 + ) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + + alignment_loss = preference_loss_fn( + chosen_logps, rejected_logps, beta=beta, **loss_kwargs + ) + alignment_loss = alignment_loss / (full_target.shape[0] // 2) + + loss = alpha * chosen_nll_loss - alignment_loss + return loss, (alignment_loss, chosen_logps, rejected_logps) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 1cd6fe21e..dd8409dfb 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -1,5 +1,3 @@ -from functools import partial - import torch import torch.nn.functional as F @@ -8,79 +6,24 @@ ) -def odds_ratio_loss(chosen_logps, rejected_logps, beta=0.1): - """ - Compute odds-ratio loss. - Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - beta (float): Weight for the odds ratio loss. - """ - log_odds = (chosen_logps - rejected_logps) - ( - torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) - ) - ratio = F.logsigmoid(log_odds) - return beta * ratio.sum() - - -def _compute_orpo_loss( - input_chunk, - weight, - target_chunk, - bias=None, - full_target=None, - ignore_index=-100, - beta=0.1, - compute_nll_loss=True, -): - """ - Compute ORPO loss for a chunk of input and target. - Args: - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - beta (float): Weight for the odds ratio loss. - """ - len_chosen_chunk = target_chunk.shape[0] // 2 - - logits_chunk = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) +class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): - chosen_nll_loss = 0.0 - if compute_nll_loss: - chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) + - torch.log1p(-torch.exp(rejected_logps)) ) + ratio = F.logsigmoid(log_odds) + return beta * ratio.sum() - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - - or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta) - or_loss = or_loss / (full_target.shape[0] // 2) - - loss = chosen_nll_loss - or_loss - return loss, (or_loss, chosen_logps, rejected_logps) - - -class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def forward( ctx, @@ -91,22 +34,25 @@ def forward( ignore_index=-100, beta=0.1, compute_nll_loss=True, - compiled=True, + compiled=False, ): """ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. Handles both the forward and backward pass of the final linear layer with ORPO loss. Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. """ - orpo_loss_fn = partial( - _compute_orpo_loss, - full_target=target, + + return LigerFusedLinearPreferenceBase.forward( + ctx, + _input, + weight, + target, + bias, + loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, - compute_nll_loss=compute_nll_loss, - ) - return LigerFusedLinearPreferenceBase.forward( - ctx, _input, weight, target, bias, loss_fn=orpo_loss_fn + compiled=compiled, ) @staticmethod diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py new file mode 100644 index 000000000..9211f98fd --- /dev/null +++ b/test/chunked_loss/test_cpo_loss.py @@ -0,0 +1,132 @@ +from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed +from typing import Tuple + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction + +# set random seed globally +set_seed() + + +class HFCPOLoss(HFAlignmentLoss): + """ + HF's implementation of CPO loss in TRL. https://github.com/huggingface/trl/blob/main/trl/trainer/cpo_trainer.py + """ + + def __init__( + self, + alpha: float = 1.0, + beta: float = 0.1, + ignore_index: int = -100, + label_smoothing: float = 0.0, + ): + super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index) + # Sigmoid defaults to the CPO loss defined in the paper listed above. + self.loss_type = "sigmoid" + self.label_smoothing = label_smoothing + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the CPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the CPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + logits = policy_chosen_logps - policy_rejected_logps + + # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative CPO loss. + if self.loss_type == "sigmoid": + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']" + ) + + return losses + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + # (1, 2, 12, 128), + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] +) +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha +): + B = 2 * B # cpo loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HFCPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + input1, weight1, target, bias1, alpha=alpha + ) + loss2 = LigerFusedLinearCPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, alpha, True + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 8bd960c84..5e532938b 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -1,9 +1,8 @@ -from test.utils import assert_verbose_allclose, set_seed +from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed from typing import Tuple import pytest import torch -import torch.nn as nn import torch.nn.functional as F from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction @@ -12,7 +11,7 @@ set_seed() -class HF_ORPO_Loss: +class HFORPOLoss(HFAlignmentLoss): """ Implementation of the Odds Ratio Preference Optimization (ORPO) loss, adapted from Hugging Face's implementation. @@ -20,46 +19,9 @@ class HF_ORPO_Loss: """ def __init__(self, ignore_index: int = -100, beta: float = 0.1): - self.ignore_index = ignore_index - self.beta = beta + super().__init__(beta=beta, ignore_index=ignore_index) - def get_batch_logps( - self, - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - is_encoder_decoder: Whether the model is an encoder-decoder model. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError( - "Logits (batch and sequence length dim) and labels must have the same shape." - ) - - loss_mask = labels != self.ignore_index - - # dummy token; we'll ignore the losses on these tokens later - labels = torch.where(labels == self.ignore_index, 0, labels) - - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) - ).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def odds_ratio_loss( + def alignment_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, @@ -94,84 +56,6 @@ def odds_ratio_loss( return losses - def concatenated_forward( - self, - _input: torch.FloatTensor, - weight: torch.FloatTensor, - target: torch.LongTensor, - bias: torch.FloatTensor = None, - ) -> Tuple[ - torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor - ]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - len_chosen = _input.shape[0] // 2 - - outputs = _input @ weight.t() - if bias is not None: - outputs = outputs + bias - all_logits = outputs.float() - - def cross_entropy_loss(logits, labels): - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - labels = target - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], labels[:len_chosen] - ) - - all_logps = self.get_batch_logps( - all_logits, - target, - average_log_prob=True, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - return ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, - ) - - def get_batch_loss_metrics( - self, - _input: torch.FloatTensor, - weight: torch.FloatTensor, - target: torch.LongTensor, - bias: torch.FloatTensor = None, - ): - """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - - forward_output = self.concatenated_forward(_input, weight, target, bias) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - - losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) - # full ORPO loss - loss = policy_nll_loss - losses.mean() - return loss - @pytest.mark.parametrize( "B, T, H, V", @@ -219,11 +103,17 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = HF_ORPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + loss1 = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( input1, weight1, target, bias1 ) loss2 = LigerFusedLinearORPOFunction.apply( - input2, weight2, target, bias2, ignore_index, beta, True + input2, + weight2, + target, + bias2, + ignore_index, + beta, + True, ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index ac9a13190..92b15a0da 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,10 +2,12 @@ import json import os import random +from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import torch +import torch.nn as nn from tokenizers import AddedToken, Tokenizer from tokenizers.models import BPE from tokenizers.pre_tokenizers import Whitespace @@ -304,3 +306,130 @@ def revert_liger_kernel_to_phi3(): importlib.reload(modeling_phi3) print("Liger kernel patches have been reverted.") + + +class HFAlignmentLoss: + + def __init__(self, alpha: float = 1.0, beta: float = 0.1, ignore_index: int = -100): + self.alpha = alpha + self.beta = beta + self.ignore_index = ignore_index + + @abstractmethod + def alignment_loss(self): + pass + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError( + "Logits (batch and sequence length dim) and labels must have the same shape." + ) + + loss_mask = labels != self.ignore_index + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == self.ignore_index, 0, labels) + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + len_chosen = _input.shape[0] // 2 + + outputs = _input @ weight.t() + if bias is not None: + outputs = outputs + bias + all_logits = outputs.float() + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) + + all_logps = self.get_batch_logps( + all_logits, + target, + average_log_prob=True, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) + + def get_batch_loss_metrics( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + alpha: float = 1.0, + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + + forward_output = self.concatenated_forward(_input, weight, target, bias) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + + losses = self.alignment_loss(policy_chosen_logps, policy_rejected_logps) + # full ORPO loss + loss = policy_nll_loss * alpha - losses.mean() + return loss From 6357eed39e347dbde504edd185e329e05e05a55d Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:43:16 +0000 Subject: [PATCH 2/6] Add SimPO Loss --- benchmark/data/all_benchmark_data.csv | 24 +++ benchmark/scripts/benchmark_simpo_loss.py | 191 ++++++++++++++++++ .../chunked_loss/fused_linear_preference.py | 3 + src/liger_kernel/chunked_loss/simpo_loss.py | 64 ++++++ test/chunked_loss/test_cpo_loss.py | 12 +- test/chunked_loss/test_simpo_loss.py | 78 +++++++ test/utils.py | 10 +- 7 files changed, 377 insertions(+), 5 deletions(-) create mode 100644 benchmark/scripts/benchmark_simpo_loss.py create mode 100644 src/liger_kernel/chunked_loss/simpo_loss.py create mode 100644 test/chunked_loss/test_simpo_loss.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 6e5fd4ce0..ed25905cd 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -691,3 +691,27 @@ fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.31445 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py new file mode 100644 index 000000000..457f6f2e8 --- /dev/null +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -0,0 +1,191 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadSimPO(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + from test.chunked_loss.test_cpo_loss import HFCPOLoss + + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics + + def forward(self, x, y): + return self.simpo_loss(x, self.lin.weight, y) + + +class LigerLMHeadSimPO(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.simpo_loss = LigerFusedLinearSimPOFunction.apply + + def forward(self, x, y): + return self.simpo_loss(x, self.lin.weight, y) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_simpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_simpo(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + + torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_simpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_simpo(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_simpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_simpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_simpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c43caf839..65183f837 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -32,6 +32,7 @@ def forward( alpha=1.0, beta=0.1, compiled=True, + **loss_kwargs, ): """ Base class for fused linear layer with preference loss. @@ -49,6 +50,7 @@ def forward( alpha (float): Weight for the NLL loss. beta (float): Weight for the odds ratio loss. compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU CHUNK_SIZE = chunk_size @@ -68,6 +70,7 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, full_target=target, + **loss_kwargs, ) def accumulate_chunk(input_chunk, target_chunk): diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py new file mode 100644 index 000000000..eff581406 --- /dev/null +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -0,0 +1,64 @@ +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): + + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + gamma (float): The simpo gamma, margin term. + """ + logits = beta * (chosen_logps - rejected_logps) - gamma + loss = F.logsigmoid(logits).mean() + return loss + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + alpha=1.0, + compute_nll_loss=False, + compiled=True, + gamma=0.5, + ): + """ + Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 + Handles both the forward and backward pass of the final linear layer with SimPO loss. + Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. + """ + + return LigerFusedLinearPreferenceBase.forward( + ctx, + _input, + weight, + target, + bias, + loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + compiled=compiled, + gamma=gamma, + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None, None, None diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 9211f98fd..b8fce9e06 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -22,11 +22,14 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, label_smoothing: float = 0.0, + simpo_gamma: float = 0.5, + loss_type: str = "sigmoid", ): super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index) # Sigmoid defaults to the CPO loss defined in the paper listed above. - self.loss_type = "sigmoid" + self.loss_type = loss_type self.label_smoothing = label_smoothing + self.simpo_gamma = simpo_gamma def alignment_loss( self, @@ -55,6 +58,12 @@ def alignment_loss( F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + F.logsigmoid(-self.beta * logits) * self.label_smoothing ) + elif self.loss_type == "simpo": + logits = logits - (self.simpo_gamma / self.beta) + losses = ( + F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) else: raise ValueError( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']" @@ -66,7 +75,6 @@ def alignment_loss( @pytest.mark.parametrize( "B, T, H, V", [ - # (1, 2, 12, 128), (8, 128, 1024, 4096), (3, 47, 31, 123), # random shape ], diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py new file mode 100644 index 000000000..727aaa56e --- /dev/null +++ b/test/chunked_loss/test_simpo_loss.py @@ -0,0 +1,78 @@ +from test.chunked_loss.test_cpo_loss import HFCPOLoss +from test.utils import assert_verbose_allclose, set_seed + +import pytest +import torch + +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +# set random seed globally +set_seed() + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] +) +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma +): + B = 2 * B # SimPO loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HFCPOLoss( + ignore_index=ignore_index, beta=beta, simpo_gamma=gamma, loss_type="simpo" + ).get_batch_loss_metrics(input1, weight1, target, bias1) + loss2 = LigerFusedLinearSimPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, 1.0, True, True, gamma + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 92b15a0da..5e0bdbffb 100644 --- a/test/utils.py +++ b/test/utils.py @@ -361,6 +361,7 @@ def concatenated_forward( weight: torch.FloatTensor, target: torch.LongTensor, bias: torch.FloatTensor = None, + average_log_prob: bool = True, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor ]: @@ -393,7 +394,7 @@ def cross_entropy_loss(logits, labels): all_logps = self.get_batch_logps( all_logits, target, - average_log_prob=True, + average_log_prob=average_log_prob, ) chosen_logps = all_logps[:len_chosen] @@ -417,10 +418,13 @@ def get_batch_loss_metrics( target: torch.LongTensor, bias: torch.FloatTensor = None, alpha: float = 1.0, + average_log_prob: bool = True, ): """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward(_input, weight, target, bias) + forward_output = self.concatenated_forward( + _input, weight, target, bias, average_log_prob + ) ( policy_chosen_logps, policy_rejected_logps, @@ -430,6 +434,6 @@ def get_batch_loss_metrics( ) = forward_output[:5] losses = self.alignment_loss(policy_chosen_logps, policy_rejected_logps) - # full ORPO loss + # full loss loss = policy_nll_loss * alpha - losses.mean() return loss From 965ee5585d879908226c36de83b994713a6f0faf Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:59:22 +0000 Subject: [PATCH 3/6] Fix merge --- src/liger_kernel/chunked_loss/fused_linear_preference.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 898d1609c..65183f837 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -71,9 +71,6 @@ def forward( compute_nll_loss=compute_nll_loss, full_target=target, **loss_kwargs, - beta=beta, - compute_nll_loss=compute_nll_loss, - full_target=target, ) def accumulate_chunk(input_chunk, target_chunk): From 3ef9bade6008a9e4943bca085f43a63e8da42fda Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:04:07 +0000 Subject: [PATCH 4/6] Fix checkstyle --- test/chunked_loss/test_cpo_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 7d2a08586..b8fce9e06 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -31,7 +31,6 @@ def __init__( self.label_smoothing = label_smoothing self.simpo_gamma = simpo_gamma - def alignment_loss( self, policy_chosen_logps: torch.FloatTensor, From 0534bb344e0d0954d04b8b982955728dba3189f2 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:05:13 +0000 Subject: [PATCH 5/6] compile just once. --- src/liger_kernel/chunked_loss/fused_linear_preference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 65183f837..93cc38c8a 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -96,6 +96,9 @@ def accumulate_chunk(input_chunk, target_chunk): grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) len_chosen = target.shape[0] // 2 _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) @@ -119,8 +122,6 @@ def accumulate_chunk(input_chunk, target_chunk): [chosen_target_chunk, rejected_target_chunk], dim=0 ) - if compiled: - accumulate_chunk = torch.compile(accumulate_chunk) grad_input = accumulate_chunk(input_chunk, target_chunk) grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) From 7422b7ea7c449d189e088f6e88ac3e1b66543363 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:13:45 +0000 Subject: [PATCH 6/6] fix checkstyle --- src/liger_kernel/chunked_loss/fused_linear_preference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 93cc38c8a..73981dff4 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -96,7 +96,7 @@ def accumulate_chunk(input_chunk, target_chunk): grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) return chunk_grad_input - + if compiled: accumulate_chunk = torch.compile(accumulate_chunk)