Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@
import unittest

import numpy as np
from mpi4py import MPI
from onnx import TensorProto, helper

import onnxruntime

np.random.seed(3)
try:
from mpi4py import MPI

comm = MPI.COMM_WORLD
except (ImportError, RuntimeError):
comm = None

comm = MPI.COMM_WORLD
has_mpi = comm is not None

np.random.seed(3)


def get_rank():
return comm.Get_rank()
return comm.Get_rank() if comm else 0


def get_size():
return comm.Get_size()
return comm.Get_size() if comm else 0


def print_out(*args):
Expand Down Expand Up @@ -254,7 +260,7 @@ def run_ort_with_parity_check(
)


def test_moe_with_tensor_parallelism(
def run_moe_with_tensor_parallelism(
hidden_size,
inter_size,
num_experts,
Expand Down Expand Up @@ -327,7 +333,7 @@ def get_fc2_tensor_shards(expert_weights):
)


def test_moe_with_expert_parallelism(
def run_moe_with_expert_parallelism(
hidden_size,
inter_size,
num_experts,
Expand Down Expand Up @@ -390,19 +396,22 @@ def test_moe_with_expert_parallelism(

class TestMoE(unittest.TestCase):
def test_moe_parallelism(self):
if not has_mpi:
self.skipTest("No MPI support")

for hidden_size in [128, 1024]:
for inter_size in [512, 2048]:
for num_experts in [64]:
for num_rows in [1024]:
print_out("EP")
test_moe_with_expert_parallelism(
run_moe_with_expert_parallelism(
hidden_size,
inter_size,
num_experts,
num_rows,
)
print_out("TP")
test_moe_with_tensor_parallelism(
run_moe_with_tensor_parallelism(
hidden_size,
inter_size,
num_experts,
Expand Down
Loading