Skip to content

Commit d8e7c7d

Browse files
committed
Fix fused_rms_norm
1 parent d0d2759 commit d8e7c7d

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -152,39 +152,46 @@ def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
152152
class RmsNormSeparateWeightGrad(torch.autograd.Function):
153153
@staticmethod
154154
def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd):
155-
ctx.save_for_backward(input, rstd)
155+
ctx.save_for_backward(input, weight, rstd)
156+
ctx.normalized_shape = normalized_shape
156157
return real_output
157158

158159
@staticmethod
159160
def backward(ctx, grad_output):
160-
(input, rstd) = ctx.saved_tensors
161-
normalized = input * rstd
162-
# Gradient w.r.t. weight: sum over batch dimension
163-
grad_weight = (grad_output * normalized).sum(
164-
dim=tuple(range(grad_output.ndim - 1))
161+
input, weight, rstd = ctx.saved_tensors
162+
# Call _fused_rms_norm_backward with output_mask=[False, True]
163+
# We only want gradient w.r.t. weight (index 1)
164+
_, grad_weight = torch._fused_rms_norm_backward(
165+
grad_output,
166+
input,
167+
ctx.normalized_shape,
168+
rstd,
169+
weight,
170+
output_mask=[False, True],
165171
)
166-
return None, None, grad_weight, None, None
172+
return None, None, grad_weight, None, None, None
167173

168174
class RmsNormSeparateInputGrad(torch.autograd.Function):
169175
@staticmethod
170176
def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd):
171-
ctx.save_for_backward(weight, rstd)
177+
ctx.save_for_backward(input, weight, rstd)
178+
ctx.normalized_shape = normalized_shape
172179
return real_output
173180

174181
@staticmethod
175182
def backward(ctx, grad_output):
176-
weight, rstd = ctx.saved_tensors
177-
178-
# Gradient w.r.t. input
179-
if weight is not None:
180-
grad_input_unnorm = grad_output * weight
181-
else:
182-
grad_input_unnorm = grad_output
183-
184-
mean = (grad_input_unnorm * input).mean(-1, keepdim=True)
185-
grad_input = (grad_input_unnorm - input * mean * rstd.pow(2)) * rstd
186-
187-
return grad_input, None, None, None, None
183+
input, weight, rstd = ctx.saved_tensors
184+
# Call _fused_rms_norm_backward with output_mask=[True, False]
185+
# We only want gradient w.r.t. input (index 0)
186+
grad_input, _ = torch._fused_rms_norm_backward(
187+
grad_output,
188+
input,
189+
ctx.normalized_shape,
190+
rstd,
191+
weight,
192+
output_mask=[True, False],
193+
)
194+
return grad_input, None, None, None, None, None
188195

189196
class RmsNormPassThrough(torch.autograd.Function):
190197
@staticmethod
@@ -198,16 +205,16 @@ def backward(ctx, gO):
198205
return None, gO, gO
199206

200207
def split_rms_norm(input, normalized_shape, weight=None, eps=None):
201-
# Compute the actual output first
208+
# Compute the actual output using _fused_rms_norm which returns (output, rstd)
202209
with torch._C._AutoDispatchBelowAutograd():
203-
real_output = torch.rms_norm(
210+
real_output, rstd = torch._fused_rms_norm(
204211
input.detach(),
205212
normalized_shape,
206213
weight.detach() if weight is not None else None,
207214
eps,
208-
).detach()
209-
variance = input.pow(2).mean(-1, keepdim=True)
210-
rstd = torch.rsqrt(variance + eps).detach()
215+
)
216+
real_output = real_output.detach()
217+
rstd = rstd.detach()
211218
rstd2 = rstd.clone().detach()
212219

213220
weight_1 = RmsNormSeparateWeightGrad.apply(
@@ -219,7 +226,7 @@ def split_rms_norm(input, normalized_shape, weight=None, eps=None):
219226
weight.detach() if weight is not None else None,
220227
eps,
221228
real_output,
222-
rstd2
229+
rstd2,
223230
)
224231
return RmsNormPassThrough.apply(real_output, weight_1, input_1)
225232

@@ -280,10 +287,11 @@ def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
280287

281288
lib.impl("mm", split_mm, "Autograd")
282289
lib.impl("addmm", split_addmm, "Autograd")
283-
lib.impl("rms_norm", split_rms_norm, "Autograd")
290+
# lib.impl("_fused_rms_norm", split_rms_norm, "Autograd")
284291
lib.impl("_grouped_mm", split_grouped_mm, "Autograd")
285292
torch.autograd.set_detect_anomaly(True, check_nan=False)
286293

294+
287295
def pipeline_llm(
288296
model: nn.Module,
289297
parallel_dims: ParallelDims,

0 commit comments

Comments
 (0)