-
Notifications
You must be signed in to change notification settings - Fork 995
fix(dcp_alltoall): require MNNVL workspace, drop broken plain-memory path #3210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4e11900
85f1d8b
b87aa37
e6e0a08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -35,7 +35,7 @@ | |||||
|
|
||||||
| from flashinfer.comm import ( | ||||||
| decode_cp_a2a_alltoall, | ||||||
| decode_cp_a2a_allocate_workspace, | ||||||
| decode_cp_a2a_allocate_mnnvl_workspace, | ||||||
| decode_cp_a2a_init_workspace, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new public API
Suggested change
|
||||||
| decode_cp_a2a_workspace_size, | ||||||
| ) | ||||||
|
|
@@ -120,7 +120,7 @@ def _setup_rank(): | |||||
| _rank, _cp_size, _comm = _setup_rank() | ||||||
|
|
||||||
| def _allocate_mnnvl_workspace_once(): | ||||||
| """Allocate MNNVL workspace once at module level. | ||||||
| """Allocate MNNVL workspace once at module level via the public API. | ||||||
|
|
||||||
| MnnvlMemory uses a global bump allocator that doesn't support | ||||||
| individual frees. Allocating per-test causes segfaults when | ||||||
|
|
@@ -137,12 +137,7 @@ def _allocate_mnnvl_workspace_once(): | |||||
| tp_size=1, | ||||||
| pp_size=1, | ||||||
| ) | ||||||
|
|
||||||
| ws_bytes = decode_cp_a2a_workspace_size(_cp_size) | ||||||
| mnnvl_mem = MnnvlMemory(mapping, ws_bytes) | ||||||
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) | ||||||
| workspace._mnnvl_mem = mnnvl_mem # prevent GC | ||||||
| return workspace | ||||||
| return decode_cp_a2a_allocate_mnnvl_workspace(mapping) | ||||||
|
|
||||||
| _mnnvl_workspace = _allocate_mnnvl_workspace_once() | ||||||
| else: | ||||||
|
|
@@ -319,32 +314,5 @@ def test_repeated_alltoall(self): | |||||
| _comm.Barrier() | ||||||
|
|
||||||
|
|
||||||
| class TestMnnvlDcpDeviceMemoryFallback: | ||||||
| """Test that non-MNNVL (device memory) path also works multi-GPU. | ||||||
|
|
||||||
| Uses decode_cp_a2a_allocate_workspace without MNNVL mapping. This only | ||||||
| works when all ranks are on the same GPU (single-GPU simulation) | ||||||
| or with IPC. Included here to verify the workspace API contract. | ||||||
| """ | ||||||
|
|
||||||
| @pytest.fixture(autouse=True) | ||||||
| def setup(self): | ||||||
| torch.manual_seed(0xA2A) | ||||||
| yield | ||||||
|
|
||||||
| def test_device_workspace_shape(self): | ||||||
| """Device workspace has correct shape [cp_size, ws_elems].""" | ||||||
| try: | ||||||
| workspace = decode_cp_a2a_allocate_workspace(_cp_size, cp_rank=_rank) | ||||||
| assert workspace.shape[0] == _cp_size | ||||||
|
|
||||||
| ws_bytes = decode_cp_a2a_workspace_size(_cp_size) | ||||||
| expected_elems = (ws_bytes + 7) // 8 | ||||||
| assert workspace.shape[1] == expected_elems | ||||||
| assert workspace.dtype == torch.int64 | ||||||
| finally: | ||||||
| _comm.Barrier() | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| pytest.main([__file__, "-v", "-s"]) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
cp_sizeandcp_rankarguments are redundant because themappingobject (which is now a required positional argument) already contains this information. As noted in the docstring,mappingcarries the authoritative rank info. Removing these redundant parameters simplifies the API and eliminates the risk of passing inconsistent values.