Skip to content

Commit 317ff43

Browse files
Enable keyword arguments for liger functional (#400)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR enables the keyword arguments of liger functional #368. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> 1. Warp the Liger Operator Functions (`torch.autograd.Function`) with an extra layer that can take key word arguments. 2. For each of the liger functions, updating its unit test function `test_{operator_name}.py::test_correctness_functional` to reflect that keyword args can be accepted. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: A10G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Signed-off-by: Hongpeng Guo <[email protected]> Co-authored-by: Byron Hsu <[email protected]>
1 parent 998f4e4 commit 317ff43

File tree

10 files changed

+154
-28
lines changed

10 files changed

+154
-28
lines changed

dev/modal/tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
1515

1616

17-
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
17+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
1818
def liger_tests():
1919
import subprocess
2020

src/liger_kernel/transformers/functional.py

+127-12
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,6 @@
1515
from liger_kernel.ops.rope import LigerRopeFunction
1616
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
1717

18-
liger_swiglu = LigerSiLUMulFunction.apply
19-
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
20-
liger_geglu = LigerGELUMulFunction.apply
21-
liger_rms_norm = LigerRMSNormFunction.apply
22-
liger_rope = LigerRopeFunction.apply
23-
liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
24-
liger_layer_norm = LigerLayerNormFunction.apply
25-
liger_kl_div = LigerKLDivLossFunction.apply
26-
liger_jsd = LigerJSDFunction.apply
27-
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
28-
liger_group_norm = LigerGroupNormFunction.apply
29-
3018

3119
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
3220
# `weight` and `size_average` are placeholders and not implemented yet
@@ -56,3 +44,130 @@ def liger_cross_entropy(
5644
if not return_z_loss:
5745
return loss
5846
return loss, z_loss
47+
48+
49+
def liger_fused_linear_cross_entropy(
50+
input,
51+
weight,
52+
target,
53+
bias=None,
54+
ignore_index: int = -100,
55+
lse_square_scale: float = 0.0,
56+
label_smoothing: float = 0.0,
57+
reduction: str = "mean",
58+
softcap: Optional[float] = None,
59+
):
60+
return LigerFusedLinearCrossEntropyFunction.apply(
61+
input,
62+
weight,
63+
target,
64+
bias,
65+
ignore_index,
66+
lse_square_scale,
67+
label_smoothing,
68+
reduction,
69+
softcap,
70+
)
71+
72+
73+
def liger_fused_linear_jsd(
74+
student_input,
75+
student_weight,
76+
teacher_input,
77+
teacher_weight,
78+
shift_labels=None,
79+
jsd_beta: float = 0.5,
80+
ignore_index: int = -100,
81+
temperature: float = 1.0,
82+
):
83+
return LigerFusedLinearJSDFunction.apply(
84+
student_input,
85+
student_weight,
86+
teacher_input,
87+
teacher_weight,
88+
shift_labels,
89+
jsd_beta,
90+
ignore_index,
91+
temperature,
92+
)
93+
94+
95+
def liger_geglu(a, b):
96+
return LigerGELUMulFunction.apply(a, b)
97+
98+
99+
def liger_group_norm(
100+
X,
101+
affine_scaling_weight,
102+
affine_shifting_bias,
103+
num_channels,
104+
num_groups,
105+
eps,
106+
):
107+
return LigerGroupNormFunction.apply(
108+
X,
109+
affine_scaling_weight,
110+
affine_shifting_bias,
111+
num_channels,
112+
num_groups,
113+
eps,
114+
)
115+
116+
117+
def liger_jsd(
118+
input,
119+
target,
120+
shift_labels=None,
121+
beta: float = 0.5,
122+
ignore_index: int = -100,
123+
):
124+
return LigerJSDFunction.apply(
125+
input,
126+
target,
127+
shift_labels,
128+
beta,
129+
ignore_index,
130+
)
131+
132+
133+
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134+
# `size_average` and `mean` are being deprecated in torch API and are placeholders here
135+
def liger_kl_div(
136+
input,
137+
target,
138+
size_average: bool = True,
139+
reduce: bool = True,
140+
reduction: str = "mean",
141+
log_target: bool = False,
142+
eps: float = 1e-10,
143+
):
144+
# Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145+
return LigerKLDivLossFunction.apply(
146+
input,
147+
target,
148+
reduction,
149+
log_target,
150+
eps,
151+
)
152+
153+
154+
def liger_layer_norm(X, W, B, eps):
155+
return LigerLayerNormFunction.apply(X, W, B, eps)
156+
157+
158+
def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160+
161+
162+
def liger_rms_norm(
163+
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164+
):
165+
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166+
167+
168+
def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169+
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170+
171+
172+
def liger_swiglu(a, b):
173+
return LigerSiLUMulFunction.apply(a, b)

test/transformers/test_fused_linear_cross_entropy.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,12 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
244244
weight = torch.randn(V, H, device=device, dtype=dtype)
245245
bias = torch.randn(V, device=device, dtype=dtype) if bias else None
246246

247-
y1 = liger_fused_linear_cross_entropy(x1, weight, target, bias)
247+
y1 = liger_fused_linear_cross_entropy(
248+
input=x1,
249+
weight=weight,
250+
target=target,
251+
bias=bias,
252+
)
248253
y2 = LigerFusedLinearCrossEntropyFunction.apply(x2, weight, target, bias)
249254

250255
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

test/transformers/test_fused_linear_jsd.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,14 @@ def test_correctness_functional(
296296
label[indices_to_assign] = ignore_index
297297

298298
output1 = liger_fused_linear_jsd(
299-
_input1,
300-
_weight1,
301-
teacher_input,
302-
teacher_weight,
303-
label,
304-
beta,
305-
ignore_index,
306-
temperature,
299+
student_input=_input1,
300+
student_weight=_weight1,
301+
teacher_input=teacher_input,
302+
teacher_weight=teacher_weight,
303+
shift_labels=label,
304+
jsd_beta=beta,
305+
ignore_index=ignore_index,
306+
temperature=temperature,
307307
)
308308
output2 = LigerFusedLinearJSDFunction.apply(
309309
_input2,

test/transformers/test_geglu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
130130
b1 = _b.clone().requires_grad_(True)
131131
b2 = _b.clone().requires_grad_(True)
132132

133-
y1 = liger_geglu(x1, b1)
133+
y1 = liger_geglu(a=x1, b=b1)
134134
y2 = LigerGELUMulFunction.apply(x2, b2)
135135

136136
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

test/transformers/test_jsd.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,13 @@ def _test_correctness_functional(
229229
label[indices_to_assign] = ignore_index
230230

231231
output = LigerJSDFunction.apply(x1, target, label, beta, ignore_index)
232-
output2 = liger_jsd(x2, target, label, beta, ignore_index)
232+
output2 = liger_jsd(
233+
input=x2,
234+
target=target,
235+
shift_labels=label,
236+
beta=beta,
237+
ignore_index=ignore_index,
238+
)
233239
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
234240
if (
235241
not is_last_layer

test/transformers/test_layer_norm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_liger_layer_norm_functional(
8383
b1 = b.clone().requires_grad_(True)
8484
b2 = b.clone().requires_grad_(True)
8585

86-
y1 = liger_layer_norm(x1, w1, b1, 1e-6)
86+
y1 = liger_layer_norm(X=x1, W=w1, B=b1, eps=1e-6)
8787
y2 = LigerLayerNormFunction.apply(x2, w2, b2, 1e-6)
8888

8989
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

test/transformers/test_rms_norm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_correctness_functional(
182182

183183
w = torch.randn(hd, device=device, dtype=dtype)
184184

185-
y1 = liger_rms_norm(h1, w, 1e-6, offset, casting_mode)
185+
y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode)
186186
y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode)
187187

188188
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

test/transformers/test_rope.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_functional_correctness(
125125
pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0)
126126
cos, sin = rotary_emb(k1, pos_ids)
127127

128-
functional_q, functional_k = liger_rope(q1, k1, cos, sin)
128+
functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)
129129
class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin)
130130

131131
assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol)

test/transformers/test_swiglu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
202202
b1 = _b.clone().requires_grad_(True)
203203
b2 = _b.clone().requires_grad_(True)
204204

205-
y1 = liger_swiglu(x1, b1)
205+
y1 = liger_swiglu(a=x1, b=b1)
206206
y2 = LigerSiLUMulFunction.apply(x2, b2)
207207

208208
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)