Skip to content

Commit 2a39f0d

Browse files
authored
add nn.module support for chunked loss function (#402)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Same as title <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 81d98ea commit 2a39f0d

12 files changed

+686
-38
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401

src/liger_kernel/chunked_loss/cpo_loss.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.nn.functional as F
23

34
from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -46,10 +47,10 @@ def forward(
4647
target,
4748
bias,
4849
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
49-
compute_nll_loss=compute_nll_loss,
5050
ignore_index=ignore_index,
5151
alpha=alpha,
5252
beta=beta,
53+
compute_nll_loss=compute_nll_loss,
5354
compiled=compiled,
5455
)
5556

@@ -59,3 +60,42 @@ def backward(ctx, grad_output):
5960
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
6061
# Return these gradients, followed by None for the remaining inputs
6162
return *grads, None, None, None, None, None
63+
64+
65+
class LigerFusedLinearCPOLoss(torch.nn.Module):
66+
"""
67+
Fused linear layer with CPO loss.
68+
"""
69+
70+
def __init__(
71+
self,
72+
ignore_index: int = -100,
73+
beta: float = 0.1,
74+
alpha: float = 1.0,
75+
compute_nll_loss: bool = True,
76+
compiled: bool = True,
77+
):
78+
"""
79+
Args:
80+
ignore_index (int): Index to ignore in the loss.
81+
beta (float): Weight for the odds ratio loss.
82+
"""
83+
super().__init__()
84+
self.ignore_index = ignore_index
85+
self.beta = beta
86+
self.alpha = alpha
87+
self.compute_nll_loss = compute_nll_loss
88+
self.compiled = compiled
89+
90+
def forward(self, lin_weight, _input, target, bias=None):
91+
return LigerFusedLinearCPOFunction.apply(
92+
_input,
93+
lin_weight,
94+
target,
95+
bias,
96+
self.ignore_index,
97+
self.beta,
98+
self.alpha,
99+
self.compute_nll_loss,
100+
self.compiled,
101+
)

src/liger_kernel/chunked_loss/dpo_loss.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.nn.functional as F
23

34
from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -43,9 +44,9 @@ def forward(
4344
target=target,
4445
bias=bias,
4546
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46-
compute_nll_loss=compute_nll_loss,
4747
ignore_index=ignore_index,
4848
beta=beta,
49+
compute_nll_loss=compute_nll_loss,
4950
compiled=compiled,
5051
)
5152

@@ -55,3 +56,39 @@ def backward(ctx, grad_output):
5556
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
5657
# Return these gradients, followed by None for the remaining inputs
5758
return *grads, None, None, None, None
59+
60+
61+
class LigerFusedLinearDPOLoss(torch.nn.Module):
62+
"""
63+
Fused linear layer with DPO loss.
64+
"""
65+
66+
def __init__(
67+
self,
68+
ignore_index: int = -100,
69+
beta: float = 0.1,
70+
compute_nll_loss: bool = True,
71+
compiled: bool = True,
72+
):
73+
"""
74+
Args:
75+
ignore_index (int): Index to ignore in the loss.
76+
beta (float): Weight for the odds ratio loss.
77+
"""
78+
super().__init__()
79+
self.ignore_index = ignore_index
80+
self.beta = beta
81+
self.compute_nll_loss = compute_nll_loss
82+
self.compiled = compiled
83+
84+
def forward(self, lin_weight, _input, target, bias=None):
85+
return LigerFusedLinearDPOFunction.apply(
86+
_input,
87+
lin_weight,
88+
target,
89+
bias,
90+
self.ignore_index,
91+
self.beta,
92+
self.compute_nll_loss,
93+
self.compiled,
94+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5+
6+
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7+
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
8+
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9+
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply

src/liger_kernel/chunked_loss/fused_linear_preference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def forward(
2727
bias=None,
2828
loss_fn=None,
2929
chunk_size=1,
30-
compute_nll_loss=True,
3130
ignore_index=-100,
3231
alpha=1.0,
3332
beta=0.1,
33+
compute_nll_loss=True,
3434
compiled=True,
3535
**loss_kwargs,
3636
):

src/liger_kernel/chunked_loss/orpo_loss.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def forward(
3434
ignore_index=-100,
3535
beta=0.1,
3636
compute_nll_loss=True,
37-
compiled=False,
37+
compiled=True,
3838
):
3939
"""
4040
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
@@ -49,9 +49,9 @@ def forward(
4949
target=target,
5050
bias=bias,
5151
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
52-
compute_nll_loss=compute_nll_loss,
5352
ignore_index=ignore_index,
5453
beta=beta,
54+
compute_nll_loss=compute_nll_loss,
5555
compiled=compiled,
5656
)
5757

@@ -61,3 +61,39 @@ def backward(ctx, grad_output):
6161
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
6262
# Return these gradients, followed by None for the remaining inputs
6363
return *grads, None, None, None, None
64+
65+
66+
class LigerFusedLinearORPOLoss(torch.nn.Module):
67+
"""
68+
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
69+
"""
70+
71+
def __init__(
72+
self,
73+
ignore_index: int = -100,
74+
beta: float = 0.1,
75+
compute_nll_loss: bool = True,
76+
compiled: bool = True,
77+
):
78+
"""
79+
Args:
80+
ignore_index (int): Index to ignore in the loss.
81+
beta (float): Weight for the odds ratio loss.
82+
"""
83+
super().__init__()
84+
self.ignore_index = ignore_index
85+
self.beta = beta
86+
self.compute_nll_loss = compute_nll_loss
87+
self.compiled = compiled
88+
89+
def forward(self, lin_weight, _input, target, bias=None):
90+
return LigerFusedLinearORPOFunction.apply(
91+
_input,
92+
lin_weight,
93+
target,
94+
bias,
95+
self.ignore_index,
96+
self.beta,
97+
self.compute_nll_loss,
98+
self.compiled,
99+
)

src/liger_kernel/chunked_loss/simpo_loss.py

+43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.nn.functional as F
23

34
from liger_kernel.chunked_loss.fused_linear_preference import (
@@ -62,3 +63,45 @@ def backward(ctx, grad_output):
6263
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
6364
# Return these gradients, followed by None for the remaining inputs
6465
return *grads, None, None, None, None, None, None
66+
67+
68+
class LigerFusedLinearSimPOLoss(torch.nn.Module):
69+
"""
70+
Fused linear layer with SimPO loss.
71+
"""
72+
73+
def __init__(
74+
self,
75+
ignore_index: int = -100,
76+
beta: float = 0.1,
77+
alpha: float = 1.0,
78+
compute_nll_loss: bool = True,
79+
compiled: bool = True,
80+
gamma: float = 0.5,
81+
):
82+
"""
83+
Args:
84+
ignore_index (int): Index to ignore in the loss.
85+
beta (float): Weight for the odds ratio loss.
86+
"""
87+
super().__init__()
88+
self.ignore_index = ignore_index
89+
self.beta = beta
90+
self.alpha = alpha
91+
self.compute_nll_loss = compute_nll_loss
92+
self.compiled = compiled
93+
self.gamma = gamma
94+
95+
def forward(self, lin_weight, _input, target, bias=None):
96+
return LigerFusedLinearSimPOFunction.apply(
97+
_input,
98+
lin_weight,
99+
target,
100+
bias,
101+
self.ignore_index,
102+
self.beta,
103+
self.alpha,
104+
self.compute_nll_loss,
105+
self.compiled,
106+
self.gamma,
107+
)

0 commit comments

Comments
 (0)