You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-**FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
258
258
-**KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
259
-
-**JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
260
-
-**FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
259
+
-**JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.**NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
260
+
-**FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.**NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
Copy file name to clipboardExpand all lines: src/liger_kernel/ops/fused_linear_jsd.py
+1-1
Original file line number
Diff line number
Diff line change
@@ -202,7 +202,7 @@ def forward(
202
202
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
203
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
204
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
-
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
205
+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
206
206
ignore_index (int): the index to ignore. Default: -100
207
207
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
_input (torch.Tensor): predict values with shape (BT, V) in logspace
143
152
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
153
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
-
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
154
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
146
155
ignore_index (int): the index to ignore. Default: -100
Copy file name to clipboardExpand all lines: src/liger_kernel/transformers/fused_linear_jsd.py
+1-4
Original file line number
Diff line number
Diff line change
@@ -12,7 +12,7 @@ class LigerFusedLinearJSD(torch.nn.Module):
12
12
the materialization of the large logits tensor.
13
13
14
14
Args:
15
-
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
15
+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
16
16
ignore_index (int): The index to ignore in the target. Default: `-100`
17
17
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18
18
@@ -70,9 +70,6 @@ class LigerFusedLinearJSD(torch.nn.Module):
Copy file name to clipboardExpand all lines: src/liger_kernel/transformers/jsd.py
+1-4
Original file line number
Diff line number
Diff line change
@@ -18,7 +18,7 @@ class LigerJSD(torch.nn.Module):
18
18
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19
19
20
20
Args:
21
-
beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
21
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
22
22
ignore_index (int): The index to ignore in the target. Default: `-100`
23
23
24
24
Shape:
@@ -58,9 +58,6 @@ class LigerJSD(torch.nn.Module):
0 commit comments