We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
As the title suggests, torch.distributed.all_reduce is not being converted to stableHLO.
I run the following test:
import os import torch from torch import nn import torch from torch.export import export from torch_xla.stablehlo import exported_program_to_stablehlo def test(): class Basic(nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM) model = Basic() prog = export(model, (torch.rand(20, 10), )) shlo = exported_program_to_stablehlo(prog) print(shlo.get_stablehlo_text()) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" torch.distributed.init_process_group(world_size=1, rank=0) if __name__ == "__main__": test()
I would expect a stableHLO module with all_reduce, however I get the following error:
WARNING:root:Defaulting to PJRT_DEVICE=CPU loc("all-reduce.10"): error: failed to legalize operation 'mhlo.all_reduce' that was explicitly marked illegal [rank0]: Traceback (most recent call last): [rank0]: File "/localdev/aknezevic/xt/test_mp.py", line 28, in <module> [rank0]: test() [rank0]: File "/localdev/aknezevic/xt/test_mp.py", line 19, in test [rank0]: shlo = exported_program_to_stablehlo(prog) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 626, in exported_program_to_stablehlo [rank0]: bundle = _exported_program_to_stablehlo_bundle(exported_model, options) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 405, in _exported_program_to_stablehlo_bundle [rank0]: stablehlo_content = xm.get_stablehlo_bytecode(res) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 1103, in get_stablehlo_bytecode [rank0]: return torch_xla._XLAC._get_stablehlo( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:109 : Check failed: status.ok() [rank0]: *** Begin stack trace *** [rank0]: tsl::CurrentStackTrace() [rank0]: torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*) [rank0]: torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool) [rank0]: torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode) [rank0]: torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode) [rank0]: [rank0]: [rank0]: [rank0]: [rank0]: _PyObject_MakeTpCall [rank0]: _PyEval_EvalFrameDefault [rank0]: [rank0]: PyEval_EvalCode [rank0]: [rank0]: [rank0]: [rank0]: _PyRun_SimpleFileObject [rank0]: _PyRun_AnyFileObject [rank0]: Py_RunMain [rank0]: Py_BytesMain [rank0]: [rank0]: __libc_start_main [rank0]: _start [rank0]: *** End stack trace *** [rank0]: MHLO -> StableHLO conversion failed. [rank0]: StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA. [rank0]: Original HLO dump: [rank0]: HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[20,10]{1,0})->(f32[20,10]{1,0}, f32[20,10]{1,0})} [rank0]: %AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] { [rank0]: %x.7 = f32[] parameter(0) [rank0]: %y.8 = f32[] parameter(1) [rank0]: ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8) [rank0]: } [rank0]: ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[20,10]) -> (f32[20,10], f32[20,10]) { [rank0]: %p1.2 = f32[20,10]{1,0} parameter(1) [rank0]: %p0.1 = f32[] parameter(0) [rank0]: %tuple.3 = (f32[20,10]{1,0}, f32[]) tuple(f32[20,10]{1,0} %p1.2, f32[] %p0.1) [rank0]: %get-tuple-element.4 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=0 [rank0]: %get-tuple-element.5 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=1 [rank0]: %all-reduce.10 = (f32[20,10]{1,0}, f32[]) all-reduce(f32[20,10]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6 [rank0]: %get-tuple-element.12 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=1 [rank0]: %get-tuple-element.11 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=0 [rank0]: ROOT %tuple.13 = (f32[20,10]{1,0}, f32[20,10]{1,0}) tuple(f32[20,10]{1,0} %get-tuple-element.11, f32[20,10]{1,0} %get-tuple-element.11) [rank0]: }
The text was updated successfully, but these errors were encountered:
Thank you for filing this issue. I can confirm this still happens on 76b0ce5. cc @tengyifei @bhavya01
Sorry, something went wrong.
No branches or pull requests
🐛 Bug
As the title suggests, torch.distributed.all_reduce is not being converted to stableHLO.
To Reproduce
I run the following test:
Expected behavior
I would expect a stableHLO module with all_reduce, however I get the following error:
Environment
The text was updated successfully, but these errors were encountered: