Skip to content

Commit 998f4e4

Browse files
authored
Generalize JSD to FKL/RKL (#393)
1 parent 2a39f0d commit 998f4e4

File tree

7 files changed

+42
-28
lines changed

7 files changed

+42
-28
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
256256
<!-- TODO: verify vocab sizes are accurate -->
257257
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
258258
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
259-
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
260-
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
259+
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
260+
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
261261

262262

263263
### Experimental Kernels

src/liger_kernel/ops/fused_linear_jsd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def forward(
202202
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203203
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204204
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205-
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
205+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
206206
ignore_index (int): the index to ignore. Default: -100
207207
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208208

src/liger_kernel/ops/jsd.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _jsd_kernel(
1818
dX_ptr,
1919
dX_stride,
2020
label_ptr,
21-
beta,
21+
beta: tl.constexpr,
2222
n_non_ignore: int,
2323
ignore_index: tl.constexpr,
2424
n_cols,
@@ -50,17 +50,26 @@ def _jsd_kernel(
5050
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
5151
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
5252

53-
Q = tl.exp(X)
54-
P = tl.exp(Y)
55-
M = beta * P + (1 - beta) * Q
56-
log_M = tl.log(M)
53+
if beta == 0.0: # forward KL
54+
Y_prob = tl.exp(Y)
55+
loss = Y_prob * (Y - X)
56+
dX = -Y_prob
57+
elif beta == 1.0:
58+
X_prob = tl.exp(X)
59+
loss = X_prob * (X - Y)
60+
dX = loss + X_prob
61+
else:
62+
Q = tl.exp(X)
63+
P = tl.exp(Y)
64+
M = beta * P + (1 - beta) * Q
65+
log_M = tl.log(M)
66+
67+
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68+
dX = (1 - beta) * Q * (X - log_M)
5769

58-
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59-
# reduction == "batchmean"
6070
loss = loss / n_non_ignore
71+
dX = dX / n_non_ignore
6172
tl.store(loss_ptr + offsets, loss, mask=mask)
62-
63-
dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
6473
tl.store(dX_ptr + offsets, dX, mask=mask)
6574

6675

@@ -142,7 +151,7 @@ def forward(
142151
_input (torch.Tensor): predict values with shape (BT, V) in logspace
143152
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144153
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145-
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
154+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
146155
ignore_index (int): the index to ignore. Default: -100
147156
148157
Returns:

src/liger_kernel/transformers/fused_linear_jsd.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class LigerFusedLinearJSD(torch.nn.Module):
1212
the materialization of the large logits tensor.
1313
1414
Args:
15-
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
15+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
1616
ignore_index (int): The index to ignore in the target. Default: `-100`
1717
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
1818
@@ -70,9 +70,6 @@ class LigerFusedLinearJSD(torch.nn.Module):
7070

7171
def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
7272
super().__init__()
73-
assert (
74-
jsd_beta > 0 and jsd_beta < 1
75-
), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
7673
assert temperature != 0, "temperature cannot be 0."
7774
self.jsd_beta = jsd_beta
7875
self.temperature = temperature

src/liger_kernel/transformers/jsd.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class LigerJSD(torch.nn.Module):
1818
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
1919
2020
Args:
21-
beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
21+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
2222
ignore_index (int): The index to ignore in the target. Default: `-100`
2323
2424
Shape:
@@ -58,9 +58,6 @@ class LigerJSD(torch.nn.Module):
5858

5959
def __init__(self, beta: float = 0.5, ignore_index: int = -100):
6060
super().__init__()
61-
assert (
62-
beta > 0 and beta < 1
63-
), f"beta must be greater than 0 and less than 1. Got: {beta}"
6461
self.beta = beta
6562
self.ignore_index = ignore_index
6663

test/transformers/test_fused_linear_jsd.py

+4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def forward(self, student_input, teacher_input, label=None):
105105
[
106106
(1.0, 0.5),
107107
(2.0, 0.1),
108+
(1.0, 0.0), # FKL
109+
(1.0, 1.0), # RKL
108110
],
109111
)
110112
def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol):
@@ -177,7 +179,9 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol):
177179
"temperature, beta, ignore_index",
178180
[
179181
(1.0, 0.5, 2),
182+
(1.0, 0.0, 2),
180183
(2.0, 0.1, 42),
184+
(1.0, 1.0, 2),
181185
],
182186
)
183187
def test_correctness_with_ignore_index(

test/transformers/test_jsd.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ def forward(
3030
log_p: torch.Tensor, # target
3131
label: Optional[torch.Tensor] = None,
3232
):
33-
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
34-
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
35-
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
36-
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
37-
1 - self.beta
38-
) * self.kl(torch.log(m), log_q).sum(dim=-1)
33+
if self.beta == 0.0:
34+
loss = self.kl(log_q, log_p).sum(dim=-1)
35+
elif self.beta == 1.0:
36+
loss = self.kl(log_p, log_q).sum(dim=-1)
37+
else:
38+
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
39+
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(
40+
-1, log_q.size(-1)
41+
)
42+
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
43+
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
44+
1 - self.beta
45+
) * self.kl(torch.log(m), log_q).sum(dim=-1)
3946

4047
if label is not None:
4148
loss = torch.where(label != self.ignore_index, loss, 0.0)
@@ -251,7 +258,7 @@ def test_correctness_not_last(B, T, V, dtype, atol, rtol):
251258

252259
@pytest.mark.parametrize(*_SHAPE_PARAMS)
253260
@pytest.mark.parametrize(*_DTYPE_PARAMS)
254-
@pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
261+
@pytest.mark.parametrize("beta", [0.0, 0.1, 0.5, 0.9, 1.0])
255262
def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
256263
liger_jsd = LigerJSD(beta=beta)
257264
_test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)

0 commit comments

Comments
 (0)