diff --git a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py index 42682d67e94ec..2d29135726839 100644 --- a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py +++ b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py @@ -6,22 +6,28 @@ import unittest import numpy as np -from mpi4py import MPI from onnx import TensorProto, helper import onnxruntime -np.random.seed(3) +try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD +except (ImportError, RuntimeError): + comm = None -comm = MPI.COMM_WORLD +has_mpi = comm is not None + +np.random.seed(3) def get_rank(): - return comm.Get_rank() + return comm.Get_rank() if comm else 0 def get_size(): - return comm.Get_size() + return comm.Get_size() if comm else 0 def print_out(*args): @@ -254,7 +260,7 @@ def run_ort_with_parity_check( ) -def test_moe_with_tensor_parallelism( +def run_moe_with_tensor_parallelism( hidden_size, inter_size, num_experts, @@ -327,7 +333,7 @@ def get_fc2_tensor_shards(expert_weights): ) -def test_moe_with_expert_parallelism( +def run_moe_with_expert_parallelism( hidden_size, inter_size, num_experts, @@ -390,19 +396,22 @@ def test_moe_with_expert_parallelism( class TestMoE(unittest.TestCase): def test_moe_parallelism(self): + if not has_mpi: + self.skipTest("No MPI support") + for hidden_size in [128, 1024]: for inter_size in [512, 2048]: for num_experts in [64]: for num_rows in [1024]: print_out("EP") - test_moe_with_expert_parallelism( + run_moe_with_expert_parallelism( hidden_size, inter_size, num_experts, num_rows, ) print_out("TP") - test_moe_with_tensor_parallelism( + run_moe_with_tensor_parallelism( hidden_size, inter_size, num_experts,