Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.distributed.all_reduce not converted to stableHLO #8854

Open
AleksKnezevic opened this issue Mar 19, 2025 · 1 comment
Open

torch.distributed.all_reduce not converted to stableHLO #8854

AleksKnezevic opened this issue Mar 19, 2025 · 1 comment
Labels
bug Something isn't working SPMD / Distributed stablehlo StableHLO related work

Comments

@AleksKnezevic
Copy link

🐛 Bug

As the title suggests, torch.distributed.all_reduce is not being converted to stableHLO.

To Reproduce

I run the following test:

import os
import torch
from torch import nn
import torch
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo


def test():
    class Basic(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)

    model = Basic()
    prog = export(model, (torch.rand(20, 10), ))
    shlo = exported_program_to_stablehlo(prog)
    print(shlo.get_stablehlo_text())

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
torch.distributed.init_process_group(world_size=1, rank=0)


if __name__ == "__main__":
    test()

Expected behavior

I would expect a stableHLO module with all_reduce, however I get the following error:

WARNING:root:Defaulting to PJRT_DEVICE=CPU
loc("all-reduce.10"): error: failed to legalize operation 'mhlo.all_reduce' that was explicitly marked illegal
[rank0]: Traceback (most recent call last):
[rank0]:   File "/localdev/aknezevic/xt/test_mp.py", line 28, in <module>
[rank0]:     test()
[rank0]:   File "/localdev/aknezevic/xt/test_mp.py", line 19, in test
[rank0]:     shlo = exported_program_to_stablehlo(prog)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 626, in exported_program_to_stablehlo
[rank0]:     bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/stablehlo.py", line 405, in _exported_program_to_stablehlo_bundle
[rank0]:     stablehlo_content = xm.get_stablehlo_bytecode(res)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/localdev/aknezevic/xt/venv/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 1103, in get_stablehlo_bytecode
[rank0]:     return torch_xla._XLAC._get_stablehlo(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:109 : Check failed: status.ok() 
[rank0]: *** Begin stack trace ***
[rank0]:        tsl::CurrentStackTrace()
[rank0]:        torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
[rank0]:        torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
[rank0]:        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
[rank0]:        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyObject_MakeTpCall
[rank0]:        _PyEval_EvalFrameDefault
[rank0]: 
[rank0]:        PyEval_EvalCode
[rank0]: 
[rank0]: 
[rank0]: 
[rank0]:        _PyRun_SimpleFileObject
[rank0]:        _PyRun_AnyFileObject
[rank0]:        Py_RunMain
[rank0]:        Py_BytesMain
[rank0]: 
[rank0]:        __libc_start_main
[rank0]:        _start
[rank0]: *** End stack trace ***
[rank0]: MHLO -> StableHLO conversion failed.
[rank0]: StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.
[rank0]: Original HLO dump:
[rank0]: HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[20,10]{1,0})->(f32[20,10]{1,0}, f32[20,10]{1,0})}

[rank0]: %AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
[rank0]:   %x.7 = f32[] parameter(0)
[rank0]:   %y.8 = f32[] parameter(1)
[rank0]:   ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
[rank0]: }

[rank0]: ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[20,10]) -> (f32[20,10], f32[20,10]) {
[rank0]:   %p1.2 = f32[20,10]{1,0} parameter(1)
[rank0]:   %p0.1 = f32[] parameter(0)
[rank0]:   %tuple.3 = (f32[20,10]{1,0}, f32[]) tuple(f32[20,10]{1,0} %p1.2, f32[] %p0.1)
[rank0]:   %get-tuple-element.4 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=0
[rank0]:   %get-tuple-element.5 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %tuple.3), index=1
[rank0]:   %all-reduce.10 = (f32[20,10]{1,0}, f32[]) all-reduce(f32[20,10]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6
[rank0]:   %get-tuple-element.12 = f32[] get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=1
[rank0]:   %get-tuple-element.11 = f32[20,10]{1,0} get-tuple-element((f32[20,10]{1,0}, f32[]) %all-reduce.10), index=0
[rank0]:   ROOT %tuple.13 = (f32[20,10]{1,0}, f32[20,10]{1,0}) tuple(f32[20,10]{1,0} %get-tuple-element.11, f32[20,10]{1,0} %get-tuple-element.11)
[rank0]: }

Environment

  • Reproducible on XLA backend CPU:
  • torch_xla version 2.5.0 and 2.6.0 (I tried both):
@ysiraichi ysiraichi added bug Something isn't working SPMD / Distributed stablehlo StableHLO related work labels Mar 19, 2025
@ysiraichi
Copy link
Collaborator

Thank you for filing this issue. I can confirm this still happens on 76b0ce5.
cc @tengyifei @bhavya01

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working SPMD / Distributed stablehlo StableHLO related work
Projects
None yet
Development

No branches or pull requests

2 participants