Skip to content

Commit fb6bb2b

Browse files
committed
Got past RMSNorm errors, now hitting a segfault in all2allv coming from
recomputing MoE during backward ``` [rank4]: (Triggered internally at /data/users/whc/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.) [rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank4]:Exception in thread Thread-2 (run_backward): [rank4]:Traceback (most recent call last): [rank4]: File "/data/users/whc/pytorch/torch/distributed/pipelining/_backward.py", line 384, in stage_backward [rank4]: torch.autograd.backward( [rank4]: File "/data/users/whc/pytorch/torch/autograd/__init__.py", line 364, in backward [rank4]: _engine_run_backward( [rank4]: File "/data/users/whc/pytorch/torch/autograd/graph.py", line 865, in _engine_run_backward [rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank4]:RuntimeError: NCCL Error 5: invalid usage (run with NCCL_DEBUG=WARN for details) [rank4]:Exception raised from throw_nccl_error at /data/users/whc/pytorch/torch/csrc/cuda/nccl.cpp:259 (most recent call first): [rank4]:C++ CapturedTraceback: [rank4]:#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function< std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 [rank4]:#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 [rank4]:#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0 [rank4]:#7 torch::cuda::nccl::detail::throw_nccl_error(torch::cuda::nccl::ncclResult) from ??:0 [rank4]:#8 torch::cuda::nccl::detail::NCCL_CHECK_TIMEOUT(torch::cuda::nccl::ncclResult, void*) from nccl.cpp:0 [rank4]:#9 torch::cuda::nccl::all2all_single_unequal_split(void*, unsigned long const*, unsigned long const*, void*, unsigned long const*, unsigned long const*, unsigned long, c10::ScalarType, void* , c10::cuda::CUDAStream&) from ??:0 [rank4]:#10 c10d::ProcessGroupNCCL::alltoall_base(at::Tensor&, at::Tensor&, std::vector<long, std::allocator<long> >&, std::vector<long, std::allocator<long> >&, c10d::AllToAllOptions const&) from ? ?:0 [rank4]:#11 c10d::ops::(anonymous namespace)::alltoall_base_CUDA(at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<long, std::allocator<long> >, std::vector<long, std::allocator<long> >, bool, long) from Ops.cpp:0 [rank4]:#12 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (*)(at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<long, std::allocator<long> >, std::vec tor<long, std::allocator<long> >, bool, long), c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> >, c10::guts::typelist::typelist<at::Tensor&, at::Tensor&, c 10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<long, std::allocator<long> >, std::vector<long, std::allocator<long> > , bool, long> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) from :0 [rank4]:#13 void c10::BoxedKernel::make_boxed_function<&torch::autograd::basicAutogradNotImplementedFallbackImpl>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c 10::IValue, std::allocator<c10::IValue> >*) from autograd_not_implemented_fallback.cpp:0 [rank4]:#14 c10::impl::BoxedKernelWrapper<c10::intrusive_ptr<c10d::Work, c10::detail::intrusive_target_default_null_type<c10d::Work> > (at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGrou p, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<long, std::allocator<long> >, std::vector<long, std::allocator<long> >, bool, long), void>::call(c10::Box edKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor&, at::Tensor&, c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<long, std::allocator<long> >, std::vector<long, std::allocator<long> >, bool, long) from :0 [rank4]:#15 c10d::ProcessGroup::alltoall_base(at::Tensor&, at::Tensor&, std::vector<long, std::allocator<long> >&, std::vector<long, std::allocator<long> >&, c10d::AllToAllOptions const&) from :0 [rank4]:#16 c10d::all_to_all_single(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ? ?:0 [rank4]:#17 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, st d::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, s td::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std:: allocator<c10::IValue> >*) from :0 [rank4]:#18 void c10::BoxedKernel::make_boxed_function<&torch::autograd::basicAutogradNotImplementedFallbackImpl>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c 10::IValue, std::allocator<c10::IValue> >*) from autograd_not_implemented_fallback.cpp:0 [rank4]:#19 c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocat or<char> >), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, std::__cxx11::basic_stri ng<char, std::char_traits<char>, std::allocator<char> >) from :0 [rank4]:#20 std::vector<at::Tensor, std::allocator<at::Tensor> > torch::autograd::CppNode_apply_functional<(anonymous namespace)::AllToAllSingle>(std::vector<at::Tensor, std::allocator<at::Tensor> > &&, torch::autograd::AutogradContext&, std::vector<bool, std::allocator<bool> > const&, std::vector<torch::autograd::VariableInfo, std::allocator<torch::autograd::VariableInfo> > const&, std::__cxx1 1::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from Functional.cpp:0 [rank4]:#21 torch::autograd::CppNode<(anonymous namespace)::AllToAllSingle>::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) from Functional.cpp:0 [rank4]:#22 torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) from :0 [rank4]:#23 torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueu e> const&) from ??:0 [rank4]:#24 torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) from ??:0 [rank4]:#25 torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) from ??:0 [rank4]:#26 torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) from :0 [rank4]:#27 std::error_code::default_error_condition() const from ??:0 [rank4]:#28 start_thread from ??:0 [rank4]:#29 __clone3 from :0 [rank4]: [rank4]: [rank4]:The above exception was the direct cause of the following exception: [rank4]: [rank4]:Traceback (most recent call last): [rank4]: File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/threading.py", line 1016, in _bootstrap_inner [rank4]: self.run() [rank4]: File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/threading.py", line 953, in run [rank4]: self._target(*self._args, **self._kwargs) [rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/dual_pipe_v.py", line 254, in run_backward [rank4]: backward_stage.backward_one_chunk( [rank4]: File "/data/users/whc/pytorch/torch/distributed/pipelining/stage.py", line 799, in backward_one_chunk [rank4]: grads_input, _ = self.backward_maybe_with_nosync( [rank4]: File "/data/users/whc/pytorch/torch/distributed/pipelining/stage.py", line 653, in backward_maybe_with_nosync [rank4]: result = perform_backward(backward_type)() [rank4]: File "/data/users/whc/pytorch/torch/distributed/pipelining/stage.py", line 607, in <lambda> [rank4]: stage_backward( [rank4]: File "/data/users/whc/pytorch/torch/distributed/pipelining/_backward.py", line 425, in stage_backward [rank4]: raise RuntimeError(exc_msg) from e [rank4]:RuntimeError: [rank4]: Failed to run stage backward: [rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',) [rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',) [rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)'] ```
1 parent d8e7c7d commit fb6bb2b

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
149149
return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha)
150150

151151
# rms_norm operator: RMS normalization
152-
class RmsNormSeparateWeightGrad(torch.autograd.Function):
152+
class FusedRmsNormSeparateWeightGrad(torch.autograd.Function):
153153
@staticmethod
154154
def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd):
155155
ctx.save_for_backward(input, weight, rstd)
@@ -161,7 +161,7 @@ def backward(ctx, grad_output):
161161
input, weight, rstd = ctx.saved_tensors
162162
# Call _fused_rms_norm_backward with output_mask=[False, True]
163163
# We only want gradient w.r.t. weight (index 1)
164-
_, grad_weight = torch._fused_rms_norm_backward(
164+
_, grad_weight = torch.ops.aten._fused_rms_norm_backward(
165165
grad_output,
166166
input,
167167
ctx.normalized_shape,
@@ -171,7 +171,7 @@ def backward(ctx, grad_output):
171171
)
172172
return None, None, grad_weight, None, None, None
173173

174-
class RmsNormSeparateInputGrad(torch.autograd.Function):
174+
class FusedRmsNormSeparateInputGrad(torch.autograd.Function):
175175
@staticmethod
176176
def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd):
177177
ctx.save_for_backward(input, weight, rstd)
@@ -183,7 +183,7 @@ def backward(ctx, grad_output):
183183
input, weight, rstd = ctx.saved_tensors
184184
# Call _fused_rms_norm_backward with output_mask=[True, False]
185185
# We only want gradient w.r.t. input (index 0)
186-
grad_input, _ = torch._fused_rms_norm_backward(
186+
grad_input, _ = torch.ops.aten._fused_rms_norm_backward(
187187
grad_output,
188188
input,
189189
ctx.normalized_shape,
@@ -193,18 +193,18 @@ def backward(ctx, grad_output):
193193
)
194194
return grad_input, None, None, None, None, None
195195

196-
class RmsNormPassThrough(torch.autograd.Function):
196+
class FusedRmsNormPassThrough(torch.autograd.Function):
197197
@staticmethod
198-
def forward(ctx, real_output, fake_1, fake_2):
199-
return real_output
198+
def forward(ctx, real_output, real_std, fake_1, fake_2):
199+
return real_output, real_std
200200

201201
@staticmethod
202-
def backward(ctx, gO):
202+
def backward(ctx, gO, gStd):
203203
# Pass gradients to fake_1 and fake_2 to trigger their backward methods
204-
# Return None for real_output since it's already detached
205-
return None, gO, gO
204+
# Return None for real_output/rstd since they are already detached
205+
return None, None, gO, gO
206206

207-
def split_rms_norm(input, normalized_shape, weight=None, eps=None):
207+
def split_fused_rms_norm(input, normalized_shape, weight=None, eps=None):
208208
# Compute the actual output using _fused_rms_norm which returns (output, rstd)
209209
with torch._C._AutoDispatchBelowAutograd():
210210
real_output, rstd = torch._fused_rms_norm(
@@ -217,18 +217,18 @@ def split_rms_norm(input, normalized_shape, weight=None, eps=None):
217217
rstd = rstd.detach()
218218
rstd2 = rstd.clone().detach()
219219

220-
weight_1 = RmsNormSeparateWeightGrad.apply(
220+
weight_1 = FusedRmsNormSeparateWeightGrad.apply(
221221
input.detach(), normalized_shape, weight, eps, real_output, rstd
222222
)
223-
input_1 = RmsNormSeparateInputGrad.apply(
223+
input_1 = FusedRmsNormSeparateInputGrad.apply(
224224
input,
225225
normalized_shape,
226226
weight.detach() if weight is not None else None,
227227
eps,
228228
real_output,
229229
rstd2,
230230
)
231-
return RmsNormPassThrough.apply(real_output, weight_1, input_1)
231+
return FusedRmsNormPassThrough.apply(real_output, rstd, weight_1, input_1)
232232

233233
# _grouped_mm operator: Grouped matrix multiplication for MoE
234234
class GroupedMmSeparateMat2Grad(torch.autograd.Function):
@@ -287,7 +287,7 @@ def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
287287

288288
lib.impl("mm", split_mm, "Autograd")
289289
lib.impl("addmm", split_addmm, "Autograd")
290-
# lib.impl("_fused_rms_norm", split_rms_norm, "Autograd")
290+
lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd")
291291
lib.impl("_grouped_mm", split_grouped_mm, "Autograd")
292292
torch.autograd.set_detect_anomaly(True, check_nan=False)
293293

0 commit comments

Comments
 (0)