@@ -152,39 +152,46 @@ def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
152152 class RmsNormSeparateWeightGrad (torch .autograd .Function ):
153153 @staticmethod
154154 def forward (ctx , input , normalized_shape , weight , eps , real_output , rstd ):
155- ctx .save_for_backward (input , rstd )
155+ ctx .save_for_backward (input , weight , rstd )
156+ ctx .normalized_shape = normalized_shape
156157 return real_output
157158
158159 @staticmethod
159160 def backward (ctx , grad_output ):
160- (input , rstd ) = ctx .saved_tensors
161- normalized = input * rstd
162- # Gradient w.r.t. weight: sum over batch dimension
163- grad_weight = (grad_output * normalized ).sum (
164- dim = tuple (range (grad_output .ndim - 1 ))
161+ input , weight , rstd = ctx .saved_tensors
162+ # Call _fused_rms_norm_backward with output_mask=[False, True]
163+ # We only want gradient w.r.t. weight (index 1)
164+ _ , grad_weight = torch ._fused_rms_norm_backward (
165+ grad_output ,
166+ input ,
167+ ctx .normalized_shape ,
168+ rstd ,
169+ weight ,
170+ output_mask = [False , True ],
165171 )
166- return None , None , grad_weight , None , None
172+ return None , None , grad_weight , None , None , None
167173
168174 class RmsNormSeparateInputGrad (torch .autograd .Function ):
169175 @staticmethod
170176 def forward (ctx , input , normalized_shape , weight , eps , real_output , rstd ):
171- ctx .save_for_backward (weight , rstd )
177+ ctx .save_for_backward (input , weight , rstd )
178+ ctx .normalized_shape = normalized_shape
172179 return real_output
173180
174181 @staticmethod
175182 def backward (ctx , grad_output ):
176- weight , rstd = ctx .saved_tensors
177-
178- # Gradient w.r.t. input
179- if weight is not None :
180- grad_input_unnorm = grad_output * weight
181- else :
182- grad_input_unnorm = grad_output
183-
184- mean = ( grad_input_unnorm * input ). mean ( - 1 , keepdim = True )
185- grad_input = ( grad_input_unnorm - input * mean * rstd . pow ( 2 )) * rstd
186-
187- return grad_input , None , None , None , None
183+ input , weight , rstd = ctx .saved_tensors
184+ # Call _fused_rms_norm_backward with output_mask=[True, False]
185+ # We only want gradient w.r.t. input (index 0)
186+ grad_input , _ = torch . _fused_rms_norm_backward (
187+ grad_output ,
188+ input ,
189+ ctx . normalized_shape ,
190+ rstd ,
191+ weight ,
192+ output_mask = [ True , False ],
193+ )
194+ return grad_input , None , None , None , None , None
188195
189196 class RmsNormPassThrough (torch .autograd .Function ):
190197 @staticmethod
@@ -198,16 +205,16 @@ def backward(ctx, gO):
198205 return None , gO , gO
199206
200207 def split_rms_norm (input , normalized_shape , weight = None , eps = None ):
201- # Compute the actual output first
208+ # Compute the actual output using _fused_rms_norm which returns (output, rstd)
202209 with torch ._C ._AutoDispatchBelowAutograd ():
203- real_output = torch .rms_norm (
210+ real_output , rstd = torch ._fused_rms_norm (
204211 input .detach (),
205212 normalized_shape ,
206213 weight .detach () if weight is not None else None ,
207214 eps ,
208- ). detach ()
209- variance = input . pow ( 2 ). mean ( - 1 , keepdim = True )
210- rstd = torch . rsqrt ( variance + eps ) .detach ()
215+ )
216+ real_output = real_output . detach ( )
217+ rstd = rstd .detach ()
211218 rstd2 = rstd .clone ().detach ()
212219
213220 weight_1 = RmsNormSeparateWeightGrad .apply (
@@ -219,7 +226,7 @@ def split_rms_norm(input, normalized_shape, weight=None, eps=None):
219226 weight .detach () if weight is not None else None ,
220227 eps ,
221228 real_output ,
222- rstd2
229+ rstd2 ,
223230 )
224231 return RmsNormPassThrough .apply (real_output , weight_1 , input_1 )
225232
@@ -280,10 +287,11 @@ def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
280287
281288 lib .impl ("mm" , split_mm , "Autograd" )
282289 lib .impl ("addmm" , split_addmm , "Autograd" )
283- lib .impl ("rms_norm " , split_rms_norm , "Autograd" )
290+ # lib.impl("_fused_rms_norm ", split_rms_norm, "Autograd")
284291 lib .impl ("_grouped_mm" , split_grouped_mm , "Autograd" )
285292 torch .autograd .set_detect_anomaly (True , check_nan = False )
286293
294+
287295def pipeline_llm (
288296 model : nn .Module ,
289297 parallel_dims : ParallelDims ,
0 commit comments