diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 1f759396..26c8b924 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -62,6 +62,7 @@ cfg.dynamicSmemBytes = smem_size; #define SWITCH_RDMA_RANKS(case_macro) \ switch (num_ranks / NUM_MAX_NVL_PEERS) { \ case 2: case_macro(2); \ + case 3: case_macro(3); \ case 4: case_macro(4); \ case 6: case_macro(6); \ case 8: case_macro(8); \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 7bae5827..4a70d604 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -234,7 +234,7 @@ def get_dispatch_config(num_ranks: int) -> Config: 4: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 6, 256, 6, 128), 16: Config(Buffer.num_sms, 36, 288, 20, 128), - 24: Config(Buffer.num_sms, 8, 288, 32, 128), + 24: Config(Buffer.num_sms, 32, 288, 8, 128), 32: Config(Buffer.num_sms, 32, 288, 8, 128), 48: Config(Buffer.num_sms, 32, 288, 8, 128), 64: Config(Buffer.num_sms, 32, 288, 8, 128), diff --git a/tests/test_internode.py b/tests/test_internode.py index 384446b6..c46cff5f 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -92,7 +92,7 @@ def test_main(args: argparse.Namespace, num_sms: int, time.sleep(1) # Config - rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (48, 96, 144, 160) else 512) + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (24, 48, 96, 144, 160) else 512) config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) # Test dispatch diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index b076f341..f94d596f 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -114,7 +114,8 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, if do_check: diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) assert torch.isnan(combined_x).sum().item() == 0 - assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + if not round_scale: + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' hash_value ^= hash_tensor(combined_x) # noinspection PyShadowingNames