-
Notifications
You must be signed in to change notification settings - Fork 622
[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence #2088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 56 commits
75503d3
63cb062
715e6f1
7c7e4e9
7106d77
1b53047
a863507
a702414
5841dc8
3542670
05f2ff2
5255de2
d33cf65
9ded27a
1bcfe57
18adb9d
adb37d0
8c28c2b
55c2138
868aa2f
71bc50b
b118bbd
978f430
c7cc22a
2922d9e
85a70fd
8363a8f
1b4eaf6
a819a33
6fb8ae0
bedb8d3
288edf4
509fe5c
000dbcc
ecd33b1
c23a6bf
31615e5
07e8dd8
41f6a36
fe081b4
1096789
442e26f
9130e58
544c007
6649ad6
f71847a
402c006
d3c188d
54dbf76
cfff17a
1336eb3
a35f812
f8aa32b
01ebd07
57b5378
aa26b19
c4993df
5932033
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,3 +17,4 @@ ray>=2.47.1 | |
| protobuf==4.25.6 | ||
| librosa | ||
| soundfile | ||
| pytest_mock | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| # | ||
|
||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # This file is a part of the vllm-ascend project. | ||
|
|
||
| import importlib | ||
|
|
||
| import pytest | ||
| import torch | ||
| from pytest_mock import MockerFixture | ||
|
|
||
| from tests.ut.base import PytestBase | ||
| from vllm_ascend.distributed.tensor_parallel import ( | ||
| _gather_along_first_dim, _gather_along_last_dim, | ||
| _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, | ||
| all_to_all_hp2sp, all_to_all_sp2hp) | ||
|
|
||
|
|
||
| class TestDistributedCommunication(PytestBase): | ||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def context(self, mocker: MockerFixture): | ||
| mocker.patch("torch.npu.current_device", return_value="cpu") | ||
| mocker.patch("torch.distributed.get_world_size", return_value=4) | ||
|
|
||
| mocker.patch("torch.distributed.get_rank", return_value=0) | ||
|
|
||
| @pytest.mark.parametrize("world_size, test_tensor, expected", | ||
| [(1, torch.randn(8, 16), (8, 16)), | ||
| (4, torch.randn(8, 16), (32, 16))]) | ||
| def test_gather_along_first_dim(self, test_tensor, expected, world_size, | ||
| mocker: MockerFixture): | ||
| """test _gather_along_first_dim""" | ||
| mocker.patch("torch.distributed.get_world_size", | ||
| return_value=world_size) | ||
|
|
||
| result = _gather_along_first_dim(test_tensor, mocker.MagicMock()) | ||
|
|
||
| assert result.shape == expected | ||
|
|
||
| @pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [ | ||
| (torch.randn(8, 16), [5, 10, 15, 2], (32, 16)), | ||
| ]) | ||
| def test_gather_along_first_dim_unequal_split(self, test_tensor, expected, | ||
| output_split_sizes, | ||
| mocker: MockerFixture): | ||
| """test _gather_along_first_dim""" | ||
|
|
||
| result = _gather_along_first_dim(test_tensor, mocker.MagicMock(), | ||
| output_split_sizes) | ||
|
|
||
| assert result.shape == expected | ||
|
|
||
| @pytest.mark.parametrize("world_size, test_tensor, expected", | ||
| [(1, torch.randn(8, 16, 32), (8, 16, 32)), | ||
| (4, torch.randn(8, 16, 32), (8, 16, 32 * 4))]) | ||
| def test_gather_along_last_dim(self, test_tensor, expected, world_size, | ||
| mocker: MockerFixture): | ||
| """test _gather_along_last_dim""" | ||
| mocker.patch("torch.distributed.get_world_size", | ||
| return_value=world_size) | ||
|
|
||
| result = _gather_along_last_dim(test_tensor, mocker.MagicMock()) | ||
|
|
||
| assert result.shape == expected | ||
|
|
||
| @pytest.mark.parametrize("input_shape,expected_shape", [ | ||
| ((32, 16), (8, 16)), | ||
| ((40, 10), (10, 10)), | ||
| ]) | ||
| def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape, | ||
| mocker: MockerFixture): | ||
| input_tensor = torch.randn(*input_shape) | ||
| result = _reduce_scatter_along_first_dim(input_tensor, | ||
| mocker.MagicMock()) | ||
| assert result.shape == expected_shape | ||
|
|
||
| @pytest.mark.parametrize("input_shape,expected_shape", [ | ||
| ((8, 16, 32), (8, 16, 8)), | ||
| ]) | ||
| def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape, | ||
| mocker: MockerFixture): | ||
| input_tensor = torch.randn(*input_shape) | ||
| result = _reduce_scatter_along_last_dim(input_tensor, | ||
| mocker.MagicMock()) | ||
| assert result.shape == expected_shape | ||
|
|
||
| @pytest.mark.parametrize("func,input_shape,expected_shape", [ | ||
| ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), | ||
| (8, 16, 128)), | ||
| ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), | ||
| ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), | ||
| (8, 16, 8)), | ||
| ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), | ||
| ]) | ||
| def test_wrapper_functions(self, func, input_shape, expected_shape, | ||
| mocker: MockerFixture): | ||
| """test wrapper funcs""" | ||
| mod = importlib.import_module( | ||
| 'vllm_ascend.distributed.tensor_parallel') | ||
| globals = mod.__dict__ | ||
| test_func = globals[func] | ||
| input_tensor = torch.randn(*input_shape) | ||
| result = test_func(input_tensor, mocker.MagicMock()) | ||
| assert result.shape == expected_shape | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "input_shape,output_shape", | ||
| [ | ||
| ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] | ||
| ]) | ||
| def test_all_to_all_sp2hp(self, input_shape, output_shape, | ||
| mocker: MockerFixture): | ||
| input_tensor = torch.randn(*input_shape) | ||
| result = all_to_all_sp2hp(input_tensor, mocker.MagicMock()) | ||
| assert result.shape == output_shape | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "input_shape,output_shape", | ||
| [ | ||
| ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] | ||
| ]) | ||
| def test_all_to_all_hp2sp(self, input_shape, output_shape, | ||
| mocker: MockerFixture): | ||
| input_tensor = torch.randn(*input_shape) | ||
| result = all_to_all_hp2sp(input_tensor, mocker.MagicMock()) | ||
| assert result.shape == output_shape | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plz move this to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # This file is a part of the vllm-ascend project. | ||
|
|
||
| import pytest | ||
| from pytest_mock import MockerFixture | ||
|
|
||
| from tests.ut.base import PytestBase | ||
| from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( | ||
| MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) | ||
| from vllm_ascend.utils import adapt_patch # noqa E402 | ||
|
|
||
|
|
||
| class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): | ||
|
|
||
| @pytest.fixture | ||
| def config(self): | ||
| config = MoEDispatcherConfig() | ||
| config.set_num_local_experts(2) | ||
| config.set_num_moe_experts(4) | ||
| config.set_moe_pad_expert_input_to_capacity(False) | ||
| config.set_moe_expert_capacity_factor(None) | ||
| config.set_moe_router_topk(2) | ||
| config.set_moe_grouped_gemm(False) | ||
| config.set_group_topk(0) | ||
| config.set_num_groups(1) | ||
| config.set_is_fused(False) | ||
| return config.build() | ||
|
|
||
| def mock_ep_group(self, mocker): | ||
| mock_group = mocker.MagicMock() | ||
| mock_group.rank_in_group = 0 | ||
| mock_group.world_size = 2 | ||
| mock_group.device_group = "mock_group" | ||
| return mock_group | ||
|
|
||
| @pytest.fixture | ||
| def dispatcher(self, config, mocker: MockerFixture): | ||
| mocker.patch( | ||
| "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", | ||
| return_value=self.mock_ep_group(mocker)) | ||
| mocker.patch("torch.npu.current_device", return_value="cpu") | ||
| mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock) | ||
| return MoEAlltoAllSeqOverLapDispatcher(config) | ||
|
|
||
| def test_initialization(self, dispatcher, config): | ||
| assert dispatcher.num_local_experts == config.num_local_experts | ||
| assert dispatcher.num_experts == config.num_moe_experts | ||
| assert dispatcher.local_expert_indices == [0, 1] | ||
| assert dispatcher.ep_rank == 0 | ||
| assert dispatcher.ep_size == 2 | ||
| assert dispatcher.overlap_stream is not None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is not ran in CI. You should enable it here as well https://github.com/vllm-project/vllm-ascend/blob/main/.github/workflows/vllm_ascend_test.yaml#L277-L281
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed