diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml deleted file mode 100644 index 6c8f4545..00000000 --- a/.github/workflows/format.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Code Format Check - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - format-check: - runs-on: ubuntu-latest - - steps: - - name: Checkout source - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Setup environment - run: | - sudo apt-get update - sudo apt-get install -y bash - - - name: Run format.sh - run: | - bash ./format.sh - - # If format.sh return non-zero, GitHub Actions will mark it as failure. \ No newline at end of file diff --git a/csrc/config.hpp b/csrc/config.hpp index 0e4f5b06..ecf12ccb 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -49,7 +49,7 @@ struct Config { EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); } - size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks, bool return_recv_hook) const { // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; @@ -58,7 +58,7 @@ struct Config { EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); - const int num_channels = num_sms / 2; + const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode size_t num_bytes = 0; num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); @@ -73,7 +73,7 @@ struct Config { return num_bytes; } - size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + size_t get_rdma_buffer_size_hint(int num_max_dispatch_tokens_per_rank, int64_t hidden_bytes, int num_ranks, bool decoupled_mode, bool return_recv_hook) const { #ifndef DISABLE_NVSHMEM // Legacy mode if (num_ranks <= NUM_MAX_NVL_PEERS) @@ -86,16 +86,17 @@ struct Config { EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); EP_HOST_ASSERT(num_sms % 2 == 0); const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - const int num_channels = num_sms / 2; + const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode + int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens; size_t num_bytes = 0; num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * hidden_bytes * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(topk_idx_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * sizeof(int4) * 2; num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; #else @@ -192,4 +193,54 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } +uint64_t get_normal_hook_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_nodes, int num_sms, bool return_recv_hook) { + if (num_nodes <= 1) + return 0; + + // Below are some assumptions + // TODO: add assertions + int hidden_bytes = hidden * sizeof(nv_bfloat16); + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode + uint64_t num_slots_per_rdma_chunk = (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels; + + uint64_t num_bytes = 0; + num_bytes += num_channels * num_nodes * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * hidden_bytes * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(int64_t) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; +} + +uint64_t get_normal_hook_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_nodes, int num_sms, bool return_recv_hook) { + if (num_nodes <= 1) + return 0; + + // Below are some assumptions + // TODO: add assertions + int hidden_bytes = hidden * sizeof(nv_bfloat16); + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode + uint64_t num_slots_per_rdma_chunk = (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels; + + uint64_t num_bytes = 0; + num_bytes += num_channels * num_nodes * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * hidden_bytes * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(int64_t) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; +} + } // namespace deep_ep diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 954c9ffb..976f3575 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -143,9 +143,9 @@ Buffer::Buffer(int rank, comm_stream(at::cuda::getStreamFromPool(true)), shared_memory_allocator(use_fabric) { // Metadata memory - int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); - int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); - int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); + uint64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + uint64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + uint64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); @@ -392,7 +392,11 @@ void Buffer::sync(const std::vector& device_ids, std::tuple, torch::Tensor, torch::Tensor, std::optional> Buffer::get_dispatch_layout( - const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream, bool return_recv_hook) { + if (return_recv_hook) { + EP_HOST_ASSERT(not async); + } + EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); @@ -400,16 +404,19 @@ Buffer::get_dispatch_layout( // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); + if(not return_recv_hook) { + if (previous_event.has_value()) { + stream_wait(launch_stream, previous_event.value()); + } else { + stream_wait(launch_stream, compute_stream); + } } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); @@ -429,14 +436,14 @@ Buffer::get_dispatch_layout( num_topk, num_ranks, num_experts, - comm_stream); + launch_stream); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(launch_stream); for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { - t.record_stream(comm_stream); + t.record_stream(launch_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -445,8 +452,8 @@ Buffer::get_dispatch_layout( if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } - } else { - stream_wait(compute_stream, comm_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); } // Switch back compute stream @@ -924,7 +931,9 @@ std::tuple, std::optional, std::optional, - std::optional> + int, + std::optional, + std::optional>> Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, @@ -944,16 +953,24 @@ Buffer::internode_dispatch(const torch::Tensor& x, const Config& config, std::optional& previous_event, bool async, - bool allocate_on_comm_stream) { + bool allocate_on_comm_stream, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank) { #ifndef DISABLE_NVSHMEM // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL // unless we release GIL here. pybind11::gil_scoped_release release; - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); + if (return_recv_hook) { + EP_HOST_ASSERT((not async) and decoupled_mode); + } + + const int num_channels = return_recv_hook ? config.num_sms : config.num_sms / 2; // one SM per channel for hook mode + EP_HOST_ASSERT(return_recv_hook or config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + EP_HOST_ASSERT((not decoupled_mode) or num_max_dispatch_tokens_per_rank > 0); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); if (cached_mode) { @@ -1041,16 +1058,19 @@ Buffer::internode_dispatch(const torch::Tensor& x, // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); - at::cuda::setCurrentCUDAStream(comm_stream); + at::cuda::setCurrentCUDAStream(launch_stream); } // Wait previous tasks to be finished - if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); - } else { - stream_wait(comm_stream, compute_stream); + if(not return_recv_hook) { + if (previous_event.has_value()) { + stream_wait(launch_stream, previous_event.value()); + } else { + stream_wait(launch_stream, compute_stream); + } } // Create handles (only return for non-cached mode) @@ -1088,11 +1108,14 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + launch_stream, + config.get_rdma_buffer_size_hint(num_max_dispatch_tokens_per_rank, hidden_int4 * sizeof(int4), num_ranks, decoupled_mode, return_recv_hook), num_nvl_bytes, true, - low_latency_mode); + low_latency_mode, + decoupled_mode, + return_recv_hook, + num_max_dispatch_tokens_per_rank); } else { rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -1129,10 +1152,12 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + launch_stream, + config.get_rdma_buffer_size_hint(num_max_dispatch_tokens_per_rank, hidden_int4 * sizeof(int4), num_ranks, decoupled_mode, return_recv_hook), num_nvl_bytes, - low_latency_mode); + low_latency_mode, + decoupled_mode, + num_max_dispatch_tokens_per_rank); // Synchronize total received tokens and tokens per expert if (num_worst_tokens > 0) { @@ -1201,49 +1226,53 @@ Buffer::internode_dispatch(const torch::Tensor& x, // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file - internode::dispatch(recv_x.data_ptr(), - recv_x_scales_ptr, - recv_topk_idx_ptr, - recv_topk_weights_ptr, - cached_mode ? nullptr : recv_src_meta->data_ptr(), - x.data_ptr(), - x_scales_ptr, - topk_idx_ptr, - topk_weights_ptr, - cached_mode ? nullptr : send_rdma_head->data_ptr(), - cached_mode ? nullptr : send_nvl_head->data_ptr(), - cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), - cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), - recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), - recv_gbl_rank_prefix_sum.data_ptr(), - is_token_in_rank.data_ptr(), - num_tokens, - num_worst_tokens, + auto launcher = [=](int phases) { + internode::dispatch(recv_x.data_ptr(), + recv_x_scales_ptr, + recv_topk_idx_ptr, + recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta->data_ptr(), + x.data_ptr(), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_ptr(), + cached_mode ? nullptr : send_nvl_head->data_ptr(), + cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), + cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, + num_worst_tokens, hidden_int4, - num_scales, - num_topk, - num_experts, - scale_token_stride, - scale_hidden_stride, - rdma_buffer_ptr, - config.num_max_rdma_chunked_send_tokens, - config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, - config.num_max_nvl_chunked_send_tokens, - config.num_max_nvl_chunked_recv_tokens, - rank, - num_ranks, - cached_mode, - comm_stream, - num_channels, - low_latency_mode); + num_scales, + num_topk, + num_experts, + scale_token_stride, + scale_hidden_stride, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + cached_mode, + launch_stream, + num_channels, + low_latency_mode, decoupled_mode, return_recv_hook, phases, num_max_dispatch_tokens_per_rank); + }; + int phases = return_recv_hook ? NORMAL_DECOUPLED_SEND_PHASE : (NORMAL_DECOUPLED_SEND_PHASE | NORMAL_DECOUPLED_RECV_PHASE); + launcher(phases); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(launch_stream); for (auto& t : {x, is_token_in_rank, recv_x, @@ -1251,7 +1280,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { - t.record_stream(comm_stream); + t.record_stream(launch_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -1277,14 +1306,21 @@ Buffer::internode_dispatch(const torch::Tensor& x, if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } - } else { - stream_wait(compute_stream, comm_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) { + EP_HOST_ASSERT(decoupled_mode); + recv_hook = [=]() { launcher(NORMAL_DECOUPLED_RECV_PHASE); }; + } + // Return values return {recv_x, recv_x_scales, @@ -1300,14 +1336,16 @@ Buffer::internode_dispatch(const torch::Tensor& x, recv_src_meta, send_rdma_head, send_nvl_head, - event}; + num_max_dispatch_tokens_per_rank, + event, + recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; #endif } -std::tuple, std::optional> Buffer::internode_combine( +std::tuple, std::optional, std::optional>> Buffer::internode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, @@ -1322,10 +1360,18 @@ std::tuple, std::optional& previous_event, bool async, - bool allocate_on_comm_stream) { + bool allocate_on_comm_stream, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank) { #ifndef DISABLE_NVSHMEM - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); + if (return_recv_hook) { + EP_HOST_ASSERT((not async) and decoupled_mode); + } + + const int num_channels = return_recv_hook ? config.num_sms : config.num_sms / 2; // one SM per channel for hook mode + EP_HOST_ASSERT(return_recv_hook or config.num_sms % 2 == 0); + EP_HOST_ASSERT((not decoupled_mode) or num_max_dispatch_tokens_per_rank > 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); @@ -1358,16 +1404,19 @@ std::tuple, std::optional, std::optional>({bias_0, bias_1}); void* bias_ptrs[2] = {nullptr, nullptr}; @@ -1427,40 +1479,48 @@ std::tuple, std::optional(), - x.data_ptr(), - topk_weights_ptr, - bias_ptrs[0], - bias_ptrs[1], - combined_rdma_head.data_ptr(), - combined_nvl_head.data_ptr(), - src_meta.data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), - rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), - num_tokens, - num_combined_tokens, - hidden, - num_topk, - rdma_buffer_ptr, - config.num_max_rdma_chunked_send_tokens, - config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, - config.num_max_nvl_chunked_send_tokens, - config.num_max_nvl_chunked_recv_tokens, - rank, - num_ranks, - comm_stream, - num_channels, - low_latency_mode); + auto launcher = [=](int phases) { + internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), + combined_x.data_ptr(), + combined_topk_weights_ptr, + is_combined_token_in_rank.data_ptr(), + x.data_ptr(), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + combined_rdma_head.data_ptr(), + combined_nvl_head.data_ptr(), + src_meta.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + num_tokens, + num_combined_tokens, + hidden, + num_topk, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + launch_stream, + num_channels, + low_latency_mode, + decoupled_mode, + return_recv_hook, + phases, + num_max_dispatch_tokens_per_rank); + }; + int phases = return_recv_hook ? NORMAL_DECOUPLED_SEND_PHASE : (NORMAL_DECOUPLED_SEND_PHASE | NORMAL_DECOUPLED_RECV_PHASE); + launcher(phases); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(launch_stream); for (auto& t : {x, src_meta, is_combined_token_in_rank, @@ -1470,7 +1530,7 @@ std::tuple, std::optional, std::optionalrecord_stream(compute_stream) : void(); } - } else { - stream_wait(compute_stream, comm_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) { + EP_HOST_ASSERT(decoupled_mode); + recv_hook = [=]() { launcher(NORMAL_DECOUPLED_RECV_PHASE); }; + } + // Return values - return {combined_x, combined_topk_weights, event}; + return {combined_x, combined_topk_weights, event, recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; @@ -1856,6 +1923,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); + m.def("get_normal_hook_rdma_size_hint", &deep_ep::get_normal_hook_rdma_size_hint); pybind11::class_(m, "EventHandle") .def(pybind11::init<>()) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 5fb90bff..1038f2f2 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -60,12 +60,12 @@ struct Buffer { bool low_latency_mode = false; // NVLink Buffer - int64_t num_nvl_bytes; + uint64_t num_nvl_bytes; void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer - int64_t num_rdma_bytes; + uint64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; // Shrink mode buffer @@ -115,8 +115,8 @@ struct Buffer { public: Buffer(int rank, int num_ranks, - int64_t num_nvl_bytes, - int64_t num_rdma_bytes, + uint64_t num_nvl_bytes, + uint64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, @@ -155,7 +155,7 @@ struct Buffer { int num_experts, std::optional& previous_event, bool async, - bool allocate_on_comm_stream); + bool allocate_on_comm_stream, bool return_recv_hook); std::tuple, @@ -213,7 +213,7 @@ struct Buffer { std::optional, std::optional, std::optional, - std::optional> + int, std::optional, std::optional>> internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, @@ -233,9 +233,12 @@ struct Buffer { const Config& config, std::optional& previous_event, bool async, - bool allocate_on_comm_stream); + bool allocate_on_comm_stream, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank); - std::tuple, std::optional> internode_combine( + std::tuple, std::optional, std::optional>> internode_combine( const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, @@ -250,7 +253,10 @@ struct Buffer { const Config& config, std::optional& previous_event, bool async, - bool allocate_on_comm_stream); + bool allocate_on_comm_stream, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 9bbe096a..5ced5dfb 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -171,9 +171,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, - bool low_latency_mode); + uint64_t num_rdma_bytes, + uint64_t num_nvl_bytes, + bool low_latency_mode, + bool decoupled_mode, + int num_max_dispatch_tokens_per_rank); void dispatch(void* recv_x, float* recv_x_scales, @@ -212,7 +214,11 @@ void dispatch(void* recv_x, bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode); + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int phases, + int num_max_dispatch_tokens_per_ran); void cached_notify(int hidden_int4, int num_scales, @@ -232,10 +238,13 @@ void cached_notify(int hidden_int4, int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + uint64_t num_rdma_bytes, + uint64_t num_nvl_bytes, bool is_cached_dispatch, - bool low_latency_mode); + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank); void combine(cudaDataType_t type, void* combined_x, @@ -265,7 +274,11 @@ void combine(cudaDataType_t type, int num_ranks, cudaStream_t stream, int num_channels, - bool low_latency_mode); + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int phases, + int num_max_dispatch_tokens_per_rank); } // namespace internode diff --git a/csrc/kernels/buffer.cuh b/csrc/kernels/buffer.cuh index 222f42ac..788ef241 100644 --- a/csrc/kernels/buffer.cuh +++ b/csrc/kernels/buffer.cuh @@ -35,26 +35,26 @@ template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; - int64_t num_bytes; + uint64_t num_bytes; public: - int64_t total_bytes; + uint64_t total_bytes; - __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, uint64_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int64_t per_channel_bytes = num_bytes * num_ranks; + uint64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptr = static_cast(gbl_ptr) + total_bytes; } - __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, uint64_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int64_t per_channel_bytes = num_bytes * num_ranks; + uint64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; for (int i = 0; i < kNumRanks; ++i) { ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; @@ -97,15 +97,15 @@ private: // NOTES: for non-decoupled case, `recv_ptr` is not used uint8_t* send_ptr; uint8_t* recv_ptr; - int64_t num_bytes; + uint64_t num_bytes; public: - int64_t total_bytes; + uint64_t total_bytes; - __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { + __device__ __forceinline__ SymBuffer(void*&gbl_ptr, uint64_t num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); - int64_t per_channel_bytes = num_bytes * num_ranks; + uint64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index b26f298e..4ad55d3e 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -19,6 +19,8 @@ #define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_RECV_PHASE 2 +#define NORMAL_DECOUPLED_SEND_PHASE 1 +#define NORMAL_DECOUPLED_RECV_PHASE 2 // Make CLion CUDA indexing work #ifdef __CLION_IDE__ diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 48c6c001..35d13f15 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -1,6 +1,7 @@ #include #include +#include "configs.cuh" #include "buffer.cuh" #include "configs.cuh" #include "exception.cuh" @@ -45,21 +46,21 @@ __host__ __device__ __forceinline__ int get_num_bytes_per_token(int hidden_int4, sizeof(int4))); } -__host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, - int num_scales, - int num_topk_idx, - int num_topk_weights, - int num_rdma_ranks, - int num_rdma_recv_buffer_tokens, - int num_channels) { - // Return `int32_t` offset and count to clean - return {(get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * - num_rdma_ranks * 2 * num_channels) / - sizeof(int), - (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; +__host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_channels) { + // in case of overflow + uint64_t rdma_meta_first = get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights); + rdma_meta_first *= num_rdma_recv_buffer_tokens; + rdma_meta_first *= num_rdma_ranks; + rdma_meta_first *= 2; + rdma_meta_first *= num_channels; + + return { + rdma_meta_first / sizeof(int), + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels + }; } -__host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, +__host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, @@ -71,10 +72,13 @@ __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int h // Return `int32_t` offset and to clean EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + uint64_t nvl_meta_first = get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights); + nvl_meta_first *= num_nvl_recv_buffer_tokens; + nvl_meta_first *= num_nvl_ranks; + nvl_meta_first *= num_channels; + return { - (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * - num_channels) / - sizeof(int), + nvl_meta_first / sizeof(int), num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, }; } @@ -103,10 +107,10 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int num_worst_tokens, int num_channels, int expert_alignment, - const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, + const uint64_t rdma_clean_offset, + const uint64_t rdma_num_int_clean, + const uint64_t nvl_clean_offset, + const uint64_t nvl_num_int_clean, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, @@ -145,7 +149,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, // Send numbers of tokens per rank/expert to RDMA ranks auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); - auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); + auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, static_cast(NUM_MAX_NVL_PEERS + num_rdma_experts + 1), kNumRDMARanks); // Clean up for later data dispatch EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); @@ -370,9 +374,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, - bool low_latency_mode) { + uint64_t num_rdma_bytes, + uint64_t num_nvl_bytes, + bool low_latency_mode, + bool decoupled_mode, + int num_max_dispatch_tokens_per_rank) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ auto notify_dispatch_func = low_latency_mode ? notify_dispatch : notify_dispatch; \ @@ -411,8 +417,9 @@ void notify_dispatch(const int* num_tokens_per_rank, const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta + int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens; auto rdma_clean_meta = - get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_slots_per_rdma_chunk, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, @@ -424,8 +431,8 @@ void notify_dispatch(const int* num_tokens_per_rank, true); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); // Launch kernel SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); @@ -439,6 +446,7 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { } template (gridDim.x); const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); - const auto num_channels = num_sms / 2, channel_id = sm_id / 2; + const auto num_channels = return_recv_hook ? num_sms : num_sms / 2, channel_id = return_recv_hook ? sm_id : sm_id / 2; const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + EP_DEVICE_ASSERT((not return_recv_hook) or ((phases & NORMAL_DECOUPLED_SEND_PHASE) == 0) or ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0)); EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); - const auto role_meta = [=]() -> std::pair { - if (is_forwarder) { - if (warp_id < NUM_MAX_NVL_PEERS) { - return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; - } else { - return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; - } - } else if (warp_id < kNumDispatchRDMASenderWarps) { - return {WarpRole::kRDMASender, -1}; - } else if (warp_id == kNumDispatchRDMASenderWarps) { - return {WarpRole::kRDMASenderCoordinator, -1}; - } else { - return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; - } - }(); + auto role_meta = get_warp_role_dispatch(is_forwarder, warp_id, channel_id, kNumDispatchRDMASenderWarps, return_recv_hook, phases); auto warp_role = role_meta.first; auto target_rank = role_meta.second; // Not applicable for RDMA senders EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); @@ -519,7 +513,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV auto scale_bytes = num_scales * sizeof(float); auto num_bytes_per_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); auto rdma_channel_data = SymBuffer( - rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); + rdma_buffer_ptr, static_cast(num_max_rdma_chunked_recv_tokens) * static_cast(num_bytes_per_token), kNumRDMARanks, channel_id, num_channels); auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); @@ -529,10 +523,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Receivers" void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; int rs_wr_rank = 0, ws_rr_rank = 0; - if (warp_role == WarpRole::kRDMAAndNVLForwarder) + if (warp_role == WarpRole_Dispatch::kRDMAAndNVLForwarder) rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; - if (warp_role == WarpRole::kNVLReceivers) + if (warp_role == WarpRole_Dispatch::kNVLReceivers) rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; @@ -560,14 +554,17 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV __shared__ int rdma_send_channel_lock[kNumRDMARanks]; __shared__ int rdma_send_channel_tail[kNumRDMARanks]; __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; - auto sync_rdma_sender_smem = []() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + // The real num of 'kRDMASender' warps per block is 'kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS' in decoupled mode + auto sync_rdma_sender_smem = kDecoupledMode ? []() { asm volatile("barrier.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32)); } : + []() { asm volatile("barrier.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + num_bytes_per_token); uint32_t tma_phase = 0; - if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and elect_one_sync()) { + // if ((warp_role == WarpRole_Dispatch::kRDMAAndNVLForwarder or warp_role == WarpRole_Dispatch::kNVLReceivers) and elect_one_sync()) { + if ((warp_role == WarpRole_Dispatch::kRDMAAndNVLForwarder or (warp_role == WarpRole_Dispatch::kNVLReceivers and (not kDecoupledMode))) and lane_id == 0) { mbarrier_init(tma_mbarrier, 1); fence_barrier_init(); EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerWarp); @@ -577,16 +574,17 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Forward warp synchronization __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; - auto sync_forwarder_smem = []() { asm volatile("barrier.sync 1, %0;" ::"r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + auto sync_forwarder_smem = kDecoupledMode ? []() { asm volatile("barrier.sync 1, %0;" :: "r"(NUM_MAX_NVL_PEERS * 32)); } : + []() { asm volatile("barrier.sync 1, %0;" ::"r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; - if (warp_role == WarpRole::kRDMASender) { + if (warp_role == WarpRole_Dispatch::kRDMASender) { // Get tasks int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); - for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += (kDecoupledMode ? kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS : kNumDispatchRDMASenderWarps)) { auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { @@ -635,25 +633,27 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV __syncwarp(); // Skip the token which does not belong to this warp - if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) + if ((token_idx - token_start_idx) % (kDecoupledMode ? kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS : kNumDispatchRDMASenderWarps) != warp_id) continue; auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; // Wait the remote buffer to be released auto start_time = clock64(); - while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { - cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + if (not kDecoupledMode) { + while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - // Timeout check - if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", - channel_id, - rdma_rank, - nvl_rank, - lane_id, - cached_rdma_channel_head, - rdma_tail_idx); - trap(); + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + cached_rdma_channel_head, + rdma_tail_idx); + trap(); + } } } __syncwarp(); @@ -750,9 +750,9 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } __syncwarp(); } - } else if (warp_role == WarpRole::kRDMASenderCoordinator) { + } else if (warp_role == WarpRole_Dispatch::kRDMASenderCoordinator) { // NOTES: in case of splitting, the issued put at the end of the buffer - EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + EP_DEVICE_ASSERT(kDecoupledMode or num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); // Clean shared memory EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); @@ -841,7 +841,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV __syncwarp(); } } - } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + } else if (warp_role == WarpRole_Dispatch::kRDMAAndNVLForwarder) { // RDMA consumers and NVL producers const auto dst_nvl_rank = target_rank; @@ -998,8 +998,10 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } // Sync head index - if (lane_id == src_rdma_rank) - forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + if (lane_id == src_rdma_rank) { + if (kDecoupledMode) cached_rdma_channel_head = src_rdma_tail; + else forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + } // Move tail index __syncwarp(); @@ -1009,9 +1011,11 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV // Retired __syncwarp(); - if (elect_one_sync()) + if ((not kDecoupledMode) and elect_one_sync()) forward_channel_retired[dst_nvl_rank] = true; - } else if (warp_role == WarpRole::kForwarderCoordinator) { + } else if (warp_role == WarpRole_Dispatch::kForwarderCoordinator) { + EP_DEVICE_ASSERT(not kDecoupledMode); + // Extra warps for forwarder coordinator should exit directly if (target_rank > 0) return; @@ -1129,8 +1133,9 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; - bool scale_aligned = (scale_bytes % 16 == 0); - auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0); + if (not kDecoupledMode) { + bool scale_aligned = (scale_bytes % 16 == 0); + auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0); // Copy data if (elect_one_sync()) { @@ -1160,28 +1165,63 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV } shifted += scale_bytes; - // Copy source meta - if (not kCachedMode and elect_one_sync()) - st_na_global(recv_src_meta + recv_token_idx, meta); - shifted += sizeof(SourceMeta); - - // Copy `topk_idx` and `topk_weights` - if (lane_id < num_topk) { - // Read - auto idx_value = static_cast(ld_nc_global(reinterpret_cast(shifted) + lane_id)); - auto weight_value = ld_nc_global(reinterpret_cast(shifted + sizeof(int) * num_topk) + lane_id); - auto recv_idx = recv_token_idx * num_topk + lane_id; - - // Transform and write - idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1; - weight_value = idx_value >= 0 ? weight_value : 0.0f; - st_na_global(recv_topk_idx + recv_idx, idx_value); - st_na_global(recv_topk_weights + recv_idx, weight_value); - } + // Copy source meta + if (not kCachedMode and elect_one_sync()) + st_na_global(recv_src_meta + recv_token_idx, meta); + shifted += sizeof(SourceMeta); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + // Read + auto idx_value = static_cast(ld_nc_global(reinterpret_cast(shifted) + lane_id)); + auto weight_value = ld_nc_global(reinterpret_cast(shifted + sizeof(int) * num_topk) + lane_id); + auto recv_idx = recv_token_idx * num_topk + lane_id; + + // Transform and write + idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1; + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(recv_topk_idx + recv_idx, idx_value); + st_na_global(recv_topk_weights + recv_idx, weight_value); + } - // Wait TMA to be finished - tma_store_wait<0>(); - __syncwarp(); + // Wait TMA to be finished + tma_store_wait<0>(); + __syncwarp(); + } + else { + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + recv_x + recv_token_idx * hidden_int4, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted += hidden_bytes; + + // Copy scales + UNROLLED_WARP_COPY(1, lane_id, num_scales, + recv_x_scales + recv_token_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted += scale_bytes; + + // Copy source meta + if (lane_id == 0 and not kCachedMode) + st_na_global(recv_src_meta + recv_token_idx, meta); + shifted += sizeof(SourceMeta); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + // Read + auto idx_value = static_cast(ld_nc_global(reinterpret_cast(shifted) + lane_id)); + auto weight_value = ld_nc_global(reinterpret_cast(shifted + sizeof(int) * num_topk) + lane_id); + auto recv_idx = recv_token_idx * num_topk + lane_id; + + // Transform and write + idx_value = (idx_value >= local_expert_begin and idx_value < local_expert_end) ? idx_value - local_expert_begin : -1; + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(recv_topk_idx + recv_idx, idx_value); + st_na_global(recv_topk_weights + recv_idx, weight_value); + } + } } // Move queue @@ -1244,74 +1284,60 @@ void dispatch(void* recv_x, bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode) { + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int phases, + int num_max_dispatch_tokens_per_rank) { constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumTMABytesPerWarp = 16384; constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; + auto num_sms = return_recv_hook ? num_channels : num_channels * 2; // one SM per channel for hook mode + // Make sure never OOB EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); -#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ - { \ - auto dispatch_func = low_latency_mode \ - ? (is_cached_dispatch ? dispatch \ - : dispatch) \ - : (is_cached_dispatch ? dispatch \ - : dispatch); \ - SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \ - LAUNCH_KERNEL(&cfg, \ - dispatch_func, \ - reinterpret_cast(recv_x), \ - recv_x_scales, \ - recv_topk_idx, \ - recv_topk_weights, \ - reinterpret_cast(recv_src_meta), \ - reinterpret_cast(x), \ - x_scales, \ - topk_idx, \ - topk_weights, \ - send_rdma_head, \ - send_nvl_head, \ - recv_rdma_channel_prefix_matrix, \ - recv_gbl_channel_prefix_matrix, \ - rdma_channel_prefix_matrix, \ - recv_rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, \ - recv_gbl_rank_prefix_sum, \ - is_token_in_rank, \ - num_tokens, \ - num_worst_tokens, \ - hidden_int4, \ - num_scales, \ - num_topk, \ - num_experts, \ - scale_token_stride, \ - scale_hidden_stride, \ - rdma_buffer_ptr, \ - num_max_rdma_chunked_send_tokens, \ - num_max_rdma_chunked_recv_tokens, \ - buffer_ptrs, \ - num_max_nvl_chunked_send_tokens, \ - num_max_nvl_chunked_recv_tokens, \ - rank, \ - num_ranks); \ - } \ - break + int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens; + +#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ + auto dispatch_func = low_latency_mode ? \ + (decoupled_mode ? ((is_cached_dispatch ? dispatch : \ + dispatch)) : \ + ((is_cached_dispatch ? dispatch : \ + dispatch))) : \ + (decoupled_mode ? ((is_cached_dispatch ? dispatch : \ + dispatch)) : \ + ((is_cached_dispatch ? dispatch : \ + dispatch))); \ + SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \ + LAUNCH_KERNEL(&cfg, dispatch_func, \ + reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast(recv_src_meta), \ + reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ + send_rdma_head, send_nvl_head, \ + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \ + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + is_token_in_rank, \ + num_tokens, hidden_int4, num_scales, num_topk, num_experts, \ + scale_token_stride, scale_hidden_stride, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_slots_per_rdma_chunk, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks, return_recv_hook, phases); } break EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); - SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); + SETUP_LAUNCH_CONFIG(num_sms, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } template -__global__ void cached_notify(const int rdma_clean_offset, - const int rdma_num_int_clean, - const int nvl_clean_offset, - const int nvl_num_int_clean, +__global__ void cached_notify(const uint64_t rdma_clean_offset, + const uint64_t rdma_num_int_clean, + const uint64_t nvl_clean_offset, + const uint64_t nvl_num_int_clean, int* combined_rdma_head, int num_combined_tokens, int num_channels, @@ -1324,7 +1350,9 @@ __global__ void cached_notify(const int rdma_clean_offset, int rank, int num_ranks, bool is_cached_dispatch, + bool return_recv_hook, const nvshmem_team_t rdma_team) { + auto num_sms = return_recv_hook ? num_channels : num_channels * 2; // one SM per channel for hook mode auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); auto num_threads = static_cast(blockDim.x); @@ -1374,34 +1402,37 @@ __global__ void cached_notify(const int rdma_clean_offset, if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); + // EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(num_rdma_ranks <= 32); // Iterate in reverse order - if (lane_id < num_rdma_ranks and warp_id < num_channels) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); - - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { - auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); - if (current_head < 0) { - combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; - } else { - last_head = current_head; + if (lane_id < num_rdma_ranks) { + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } } } + } } else { if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); + // EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); - if (warp_id < num_channels) { + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; @@ -1409,6 +1440,7 @@ __global__ void cached_notify(const int rdma_clean_offset, // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + // auto tma_buffer = smem_tma_buffer + channel_id * kNumTMABytesPerWarp; auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); uint32_t tma_phase = 0; @@ -1418,10 +1450,10 @@ __global__ void cached_notify(const int rdma_clean_offset, } __syncwarp(); - for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { + for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_sms - 2) { // Iterate in reverse order - int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; - int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int token_start_idx = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += shift, token_end_idx += shift; @@ -1485,19 +1517,25 @@ void cached_notify(int hidden_int4, int** barrier_signal_ptrs, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, - int64_t num_nvl_bytes, + uint64_t num_rdma_bytes, + uint64_t num_nvl_bytes, bool is_cached_dispatch, - bool low_latency_mode) { - const int num_threads = std::max(128, 32 * num_channels); + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int num_max_dispatch_tokens_per_rank) { + const int num_threads = std::min(std::max(128, 32 * num_channels), 512); const int num_warps = num_threads / 32; const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int kNumTMABytesPerWarp = 8192; const int smem_size = kNumTMABytesPerWarp * num_warps; + auto num_sms = return_recv_hook ? num_channels : num_channels * 2; // one SM per channel for hook mode + // Get clean meta + int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens; auto rdma_clean_meta = get_rdma_clean_meta( - hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_slots_per_rdma_chunk, num_channels); auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, @@ -1509,13 +1547,13 @@ void cached_notify(int hidden_int4, is_cached_dispatch); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); - EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_channels * 2 > 3); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_sms > 3); // Launch kernel auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; - SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream); SET_SHARED_MEMORY_FOR_TMA(cached_notify_func); LAUNCH_KERNEL(&cfg, cached_notify_func, @@ -1535,9 +1573,11 @@ void cached_notify(int hidden_int4, rank, num_ranks, is_cached_dispatch, + return_recv_hook, cpu_rdma_team); } +template template (&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); + } // Read buffers // TODO: maybe too many registers here int4 recv_value_int4[kMaxNumRanks]; @@ -1660,6 +1715,7 @@ __device__ int combine_token(bool is_token_in_rank, for (int j = 0; j < num_topk_ranks; ++j) recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i)); + // Clean // Reduce bias float values[kDtypePerInt4] = {0}; @@ -1703,48 +1759,33 @@ __device__ int combine_token(bool is_token_in_rank, return topk_ranks[0]; } -template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, - int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, - int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS> -__global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* combined_x, - float* combined_topk_weights, - const bool* is_combined_token_in_rank, - const int4* x, - const float* topk_weights, - const int4* bias_0, - const int4* bias_1, - const int* combined_rdma_head, - const int* combined_nvl_head, - const SourceMeta* src_meta, - const int* rdma_channel_prefix_matrix, - const int* rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, - int num_tokens, - int num_combined_tokens, - int hidden, - int num_topk, - void* rdma_buffer_ptr, - int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, - int rank, - int num_ranks) { - enum class WarpRole { kNVLSender, kNVLAndRDMAForwarder, kRDMAReceiver, kCoordinator }; - +template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, + int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, + int kNumRDMAReceivers = kDecoupledMode ? kNumForwarders + NUM_MAX_NVL_PEERS + 1 : kNumForwarders - NUM_MAX_NVL_PEERS> +__global__ void __launch_bounds__(kDecoupledMode ? (kNumForwarders + NUM_MAX_NVL_PEERS + 1) * 32 : (kNumForwarders + 1) * 32, 1) +combine(int4* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const int4* x, const float* topk_weights, + const int4* bias_0, const int4* bias_1, + const int* combined_rdma_head, const int* combined_nvl_head, + const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool return_recv_hook, int phases) { const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); - const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; - const bool is_forwarder_sm = sm_id % 2 == 1; + const auto num_channels = return_recv_hook ? static_cast(gridDim.x) : static_cast(gridDim.x) / 2, channel_id = return_recv_hook ? sm_id : sm_id / 2; + const bool is_forwarder_sm = return_recv_hook ? true : sm_id % 2 == 1; EP_DEVICE_ASSERT(num_topk <= 32); EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); @@ -1754,34 +1795,19 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - auto role_meta = [=]() -> std::pair { - auto warp_id = thread_id / 32; - if (not is_forwarder_sm) { - if (warp_id < NUM_MAX_NVL_PEERS) { - auto shuffled_warp_id = warp_id; - shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; - return {WarpRole::kNVLSender, shuffled_warp_id}; - } else if (warp_id < kNumForwarders) { - return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; - } else { - return {WarpRole::kCoordinator, 0}; - } - } else { - if (warp_id < kNumForwarders) { - auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; - return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; - } else { - return {WarpRole::kCoordinator, 0}; - } - } - }(); + auto warp_id = thread_id / 32; + auto role_meta = get_warp_role_combine(is_forwarder_sm, warp_id, channel_id, kNumForwarders, return_recv_hook, phases); auto warp_role = role_meta.first; - auto warp_id = role_meta.second; + warp_id = role_meta.second; - EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1); + if (kDecoupledMode) { + EP_DEVICE_ASSERT(num_warps == kNumForwarders + NUM_MAX_NVL_PEERS + 1); + } else { + EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1); + } auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; - if (warp_role == WarpRole::kNVLSender) { + if (warp_role == WarpRole_Combine::kNVLSender) { // NVL producers const auto dst_nvl_rank = warp_id; @@ -1924,9 +1950,9 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co // Combiners and coordinators // RDMA symmetric layout auto rdma_channel_data = SymBuffer( - rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); - auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + rdma_buffer_ptr, static_cast(num_max_rdma_chunked_recv_tokens * num_bytes_per_token), kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, static_cast(1), kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, static_cast(1), kNumRDMARanks, channel_id, num_channels); // NVL layouts void* local_nvl_buffer = buffer_ptrs[nvl_rank]; @@ -1950,9 +1976,15 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; auto sync_forwarder_smem = [=]() { asm volatile("barrier.sync 0, %0;" ::"r"((kNumForwarders + 1) * 32)); }; - auto sync_rdma_receiver_smem = [=]() { asm volatile("barrier.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * 32)); }; + auto sync_rdma_receiver_smem = [=]() { + if (kDecoupledMode) { + asm volatile("barrier.sync 1, %0;" :: "r"(kNumRDMAReceivers * 32)); + } else { + asm volatile("barrier.sync 1, %0;" ::"r"((kNumRDMAReceivers + 1) * 32)); + } + }; - if (warp_role == WarpRole::kNVLAndRDMAForwarder) { + if (warp_role == WarpRole_Combine::kNVLAndRDMAForwarder) { // Receive from NVL ranks and forward to RDMA ranks // NOTES: this part is using "large warps" for each RDMA ranks const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; @@ -1980,8 +2012,8 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co return reinterpret_cast(smem_ptr + i * kNumTMABufferBytesPerStage + kNumTMALoadBytes * (NUM_MAX_NVL_PEERS + 1)); }; uint32_t tma_phase[kNumStages] = {0}; - if (lane_id < kNumStages) { - mbarrier_init(tma_mbarrier(lane_id), 32); + if (not kDecoupledMode and lane_id < kNumStages) { + mbarrier_init(tma_mbarrier(lane_id), 32); // in hook mode, 'kNVLAndRDMAForwarder' and 'kNVLSender' warps are on the same SM, in case the init here makes the data dirty fence_barrier_init(); } __syncwarp(); @@ -2011,26 +2043,28 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); auto num_chunked_tokens = token_end_idx - token_start_idx; auto start_time = clock64(); - while (sub_warp_id == 0 and lane_id == 0) { - // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` - // Here, `token_start_idx` is the actual tail - int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); - if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) - break; + if (not kDecoupledMode) { + while (sub_warp_id == 0 and lane_id == 0) { + // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` + // Here, `token_start_idx` is the actual tail + int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); + if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) + break; - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf( + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf( "DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: " "%d, chunked: %d\n", - channel_id, + channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); - trap(); + trap(); + } } } sync_large_warp(); @@ -2074,6 +2108,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co // Combine current token auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token; + auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx; }; auto get_addr_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4* { return reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * num_bytes_per_token) + hidden_int4_idx; @@ -2083,21 +2118,32 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co hidden_bytes + sizeof(SourceMeta)) + topk_idx); }; - combine_token( - expected_head >= 0, + if(not kDecoupledMode) { + combine_token( + expected_head >= 0, expected_head, - lane_id, + lane_id, hidden_int4, num_topk, - static_cast(shifted), - reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), - nullptr, + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, - get_addr_fn, + get_addr_fn, recv_tw_fn, smem_ptr, tma_phase); + } + else { + combine_token(expected_head >= 0, + expected_head, lane_id, + hidden_int4, num_topk, + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, get_addr_fn, recv_tw_fn, + nullptr, tma_phase); + } // Update head if (lane_id < NUM_MAX_NVL_PEERS) @@ -2142,12 +2188,14 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co __syncwarp(); if (elect_one_sync()) forwarder_retired[warp_id] = true; - } else if (warp_role == WarpRole::kRDMAReceiver) { + } else if (warp_role == WarpRole_Combine::kRDMAReceiver) { // Receive from RDMA ranks and write to the output tensor // Clean shared memory and sync EP_DEVICE_ASSERT(kNumRDMARanks <= 32); - lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; - lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + if (not kDecoupledMode) { + lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + } sync_rdma_receiver_smem(); // The same tokens as the dispatch process @@ -2162,8 +2210,10 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co int expected_head = -1; if (lane_id < kNumRDMARanks) { expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); - (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) - : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + if (not kDecoupledMode) { + (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) + : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + } } // Wait lanes to be ready @@ -2218,7 +2268,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co // Retired __syncwarp(); - if (elect_one_sync()) + if ((not kDecoupledMode) and elect_one_sync()) rdma_receiver_retired[warp_id] = true; } else { // Coordinator @@ -2303,61 +2353,43 @@ void combine(cudaDataType_t type, int num_ranks, cudaStream_t stream, int num_channels, - bool low_latency_mode) { - constexpr int kNumCombineForwarderWarps = 24; + bool low_latency_mode, + bool decoupled_mode, + bool return_recv_hook, + int phases, + int num_max_dispatch_tokens_per_rank) { + const int kNumCombineForwarderWarps = decoupled_mode ? 16 : 24; constexpr int kNumTMABytesPerSenderWarp = 16384; constexpr int kNumTMABytesPerForwarderWarp = 9248; - constexpr int smem_size = + const int smem_size = std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS, kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps); -#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ - { \ - auto combine_func = low_latency_mode ? combine \ - : combine; \ - SET_SHARED_MEMORY_FOR_TMA(combine_func); \ - LAUNCH_KERNEL(&cfg, \ - combine_func, \ - reinterpret_cast(combined_x), \ - combined_topk_weights, \ - is_combined_token_in_rank, \ - reinterpret_cast(x), \ - topk_weights, \ - reinterpret_cast(bias_0), \ - reinterpret_cast(bias_1), \ - combined_rdma_head, \ - combined_nvl_head, \ - reinterpret_cast(src_meta), \ - rdma_channel_prefix_matrix, \ - rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, \ - num_tokens, \ - num_combined_tokens, \ - hidden, \ - num_topk, \ - rdma_buffer_ptr, \ - num_max_rdma_chunked_send_tokens, \ - num_max_rdma_chunked_recv_tokens, \ - buffer_ptrs, \ - num_max_nvl_chunked_send_tokens, \ - num_max_nvl_chunked_recv_tokens, \ - rank, \ - num_ranks); \ - } \ - break + auto num_sms = return_recv_hook ? num_channels : num_channels * 2; // one SM per channel for hook mode + + int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens; + +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ + auto combine_func = low_latency_mode ? \ + (decoupled_mode ? combine : \ + combine ) : \ + (decoupled_mode ? combine : \ + combine ); \ + SET_SHARED_MEMORY_FOR_TMA(combine_func); \ + LAUNCH_KERNEL(&cfg, combine_func, \ + reinterpret_cast(combined_x), combined_topk_weights, is_combined_token_in_rank, \ + reinterpret_cast(x), topk_weights, \ + reinterpret_cast(bias_0), reinterpret_cast(bias_1), \ + combined_rdma_head, combined_nvl_head, \ + reinterpret_cast(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ + num_tokens, num_combined_tokens, hidden, num_topk, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_slots_per_rdma_chunk, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks, return_recv_hook, phases); } break int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; + auto num_warps_per_block = decoupled_mode ? num_forwarder_warps + NUM_MAX_NVL_PEERS + 1 : num_forwarder_warps + 1; // 25 warps per block for both cases EP_HOST_ASSERT(num_rdma_ranks <= kNumCombineForwarderWarps); EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS and num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); @@ -2367,7 +2399,7 @@ void combine(cudaDataType_t type, EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens >= num_warps_per_forwarder); EP_HOST_ASSERT(type == CUDA_R_16BF); - SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream); + SETUP_LAUNCH_CONFIG(num_sms, num_warps_per_block * 32, stream); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 0c2eec02..6eae667c 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -637,4 +637,283 @@ __forceinline__ __device__ T warp_reduce_or(T value) { return warp_reduce(value, ReduceOr{}); } +enum class WarpRole_Dispatch { + kRDMASender, + kRDMASenderCoordinator, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers, + kInvalidWarpRole +}; + +template +__forceinline__ __device__ std::pair get_warp_role_dispatch(bool is_forwarder, int warp_id, int channel_id, int kNumDispatchRDMASenderWarps, bool return_recv_hook, int phases) { + if (not kDecoupledMode) { // native mode + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole_Dispatch::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } + + if (return_recv_hook) { // hook mode + EP_DEVICE_ASSERT(phases != 0); + if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase + if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } + } + else if ((phases & NORMAL_DECOUPLED_SEND_PHASE) == 0) { // recv phase + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } + else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } + else { + return {WarpRole_Dispatch::kInvalidWarpRole, -1}; + } + } + + // decoupled mode, but no hook + if (not is_forwarder) { // send warps + if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } + } + else { // recv warps + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } + else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } +} + +enum class WarpRole_Combine { + kNVLSender, + kNVLAndRDMAForwarder, + kRDMAReceiver, + kCoordinator, + kInvalidWarpRole +}; + +template +__forceinline__ __device__ std::pair get_warp_role_combine(bool is_forwarder_sm, int warp_id, int channel_id, int kNumForwarders, bool return_recv_hook, int phases) { + if (not kDecoupledMode) { // native mode + if (not is_forwarder_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < kNumForwarders) { + return {WarpRole_Combine::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } else { + if (warp_id < kNumForwarders) { + auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } + } + + if (return_recv_hook) { // hook mode + EP_DEVICE_ASSERT(phases != 0); + if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < kNumForwarders + NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } + else if ((phases & NORMAL_DECOUPLED_SEND_PHASE) == 0) { // recv phase + return {WarpRole_Combine::kRDMAReceiver, warp_id}; + } + else { + return {WarpRole_Combine::kInvalidWarpRole, -1}; + } + } + + // decoupled mode, but no hook + if (is_forwarder_sm) { // send warps + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } else { // recv warps + return {WarpRole_Combine::kRDMAReceiver, warp_id}; + } +} + +enum class WarpRole_Dispatch { + kRDMASender, + kRDMASenderCoordinator, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers, + kInvalidWarpRole +}; + +template +__forceinline__ __device__ std::pair get_warp_role(bool is_forwarder, int warp_id, int channel_id, int kNumDispatchRDMASenderWarps, bool return_recv_hook, int phases) { + if (not kDecoupledMode) { // native mode + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole_Dispatch::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } + + else if (return_recv_hook) { // hook mode + EP_DEVICE_ASSERT(phases != 0); + if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase + if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } + } + else if ((phases & NORMAL_DECOUPLED_SEND_PHASE) == 0) { // recv phase + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } + else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } + else { + return {WarpRole_Dispatch::kInvalidWarpRole, -1}; + } + } + + else { // decoupled mode, but no hook + if (not is_forwarder) { // send warps + if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMASender, -1}; + } else { + return {WarpRole_Dispatch::kRDMASenderCoordinator, -1}; + } + } + + else { // recv warps + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } + else { + return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + } + } +} + +enum class WarpRole_Combine { + kNVLSender, + kNVLAndRDMAForwarder, + kRDMAReceiver, + kCoordinator, + kInvalidWarpRole +}; + +template +__forceinline__ __device__ std::pair get_warp_role_combine(bool is_forwarder_sm, int warp_id, int channel_id, int kNumForwarders, bool return_recv_hook, int phases) { + if (not kDecoupledMode) { // native mode + if (not is_forwarder_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < kNumForwarders) { + return {WarpRole_Combine::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } else { + if (warp_id < kNumForwarders) { + auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } + } + else if (return_recv_hook) { // hook mode + EP_DEVICE_ASSERT(phases != 0); + if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } + else if ((phases & NORMAL_DECOUPLED_SEND_PHASE) == 0) { // recv phase + return {WarpRole_Combine::kRDMAReceiver, warp_id}; + } + else { + return {WarpRole_Combine::kInvalidWarpRole, -1}; + } + } + else { // decoupled mode, but no hook + if (is_forwarder_sm) { // send warps + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole_Combine::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole_Combine::kCoordinator, 0}; + } + } else { // recv warps + return {WarpRole_Combine::kRDMAReceiver, warp_id}; + } + } +} + } // namespace deep_ep diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 37512ee9..68be73a4 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -188,6 +188,22 @@ def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden """ return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) + @staticmethod + def get_normal_hook_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_rdma_ranks: int, num_sms: int, return_recv_hook: bool) -> int: + """ + Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. + + Arguments: + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. + hidden: the hidden dimension of each token. + num_ranks: the number of EP group ranks. + num_experts: the number of all experts. + + Returns: + size: the RDMA buffer size recommended. + """ + return deep_ep_cpp.get_normal_hook_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_rdma_ranks, num_sms, return_recv_hook) + def get_comm_stream(self) -> torch.Stream: """ Get the communication stream. @@ -292,7 +308,7 @@ def get_combine_config(num_ranks: int) -> Config: # noinspection PyTypeChecker def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, - allocate_on_comm_stream: bool = False) -> \ + allocate_on_comm_stream: bool = False, return_recv_hook: bool = False) -> \ Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]: """ Calculate the layout required for later communication. @@ -315,7 +331,7 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int, """ num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \ self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None), - async_finish, allocate_on_comm_stream) + async_finish, allocate_on_comm_stream, return_recv_hook) return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event) # noinspection PyTypeChecker @@ -327,9 +343,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, - allocate_on_comm_stream: bool = False) -> \ + allocate_on_comm_stream: bool = False, return_recv_hook: bool = False, num_max_dispatch_tokens_per_rank: int = -1) -> \ Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], List[int], Tuple, EventOverlap]: + Optional[torch.Tensor], List[int], Tuple, EventOverlap, Callable]: """ Dispatch tokens to different ranks, both intranode and internode settings are supported. Intranode kernels require all the ranks should be visible via NVLink. @@ -357,6 +373,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + decoupled_mode: whether to use large network buffer. + return_recv_hook: whether to return recv hook. if set, 'decoupled_mode' should also be set to True. + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value, used in non-cached and decoupled mode. Returns: recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the @@ -369,6 +388,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_finish` is set). """ + decoupled_mode = return_recv_hook # This mode (decoupled_mode=True, return_recv_hook=False) is implemented to support large buffers without hooks, but offers no practical performance benefit and is not exposed to user for use. # Default config config = self.get_dispatch_config(self.group_size) if config is None else config @@ -376,7 +396,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], if self.runtime.get_num_rdma_ranks() > 1: return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, num_worst_tokens, config, - previous_event, async_finish, allocate_on_comm_stream) + previous_event, async_finish, allocate_on_comm_stream, decoupled_mode, return_recv_hook, num_max_dispatch_tokens_per_rank) # Launch the kernel with cached or non-cached mode x, x_scales = x if isinstance(x, tuple) else (x, None) @@ -387,7 +407,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event), None else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \ @@ -399,7 +419,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], return ( recv_x, recv_x_scales ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( - event) + event), None # noinspection PyTypeChecker def combine(self, x: torch.Tensor, handle: Tuple, @@ -407,8 +427,8 @@ def combine(self, x: torch.Tensor, handle: Tuple, bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, - allocate_on_comm_stream: bool = False) -> \ - Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: + allocate_on_comm_stream: bool = False, return_recv_hook: bool = False) -> \ + Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap, Callable]: """ Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode settings are supported. @@ -425,18 +445,22 @@ def combine(self, x: torch.Tensor, handle: Tuple, previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + decoupled_mode: whether to use large network buffer. + return_recv_hook: whether to return recv hook. if set, 'decoupled_mode' should also be set to True. Returns: recv_x: the reduced token from its dispatched ranks. recv_topk_weights: the reduced top-k weights from its dispatch ranks. event: the event after executing the kernel (valid only if `async_finish` is set). """ + decoupled_mode = return_recv_hook # This mode (decoupled_mode=True, return_recv_hook=False) is implemented to support large buffers without hooks, but offers no practical performance benefit and is not exposed to user for use. + # Default config config = self.get_combine_config(self.group_size) if config is None else config # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream) + return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream, decoupled_mode, return_recv_hook) # NOTES: the second `_` is for the sending side, so we should use the third one rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle @@ -447,7 +471,7 @@ def combine(self, x: torch.Tensor, handle: Tuple, channel_prefix_matrix, send_head, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return recv_x, recv_topk_weights, EventOverlap(event) + return recv_x, recv_topk_weights, EventOverlap(event), None # noinspection PyTypeChecker def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], @@ -457,9 +481,9 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, - allocate_on_comm_stream: bool = False) -> \ + allocate_on_comm_stream: bool = False, decoupled_mode: bool = False, return_recv_hook: bool = False, num_max_dispatch_tokens_per_rank: int = -1) -> \ Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], List[int], Tuple, EventOverlap]: + Optional[torch.Tensor], List[int], Tuple, EventOverlap, Callable]: """ Internode dispatch implementation, for more details, please refer to the `dispatch` docs. Normally, you should not directly call this function. @@ -473,32 +497,33 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te is_token_in_rank, \ rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \ recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ - recv_src_meta, send_rdma_head, send_nvl_head = handle + recv_src_meta, send_rdma_head, send_nvl_head, num_max_dispatch_tokens_per_rank = handle num_recv_tokens = recv_src_meta.size(0) num_rdma_recv_tokens = send_nvl_head.size(0) - recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch( + recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, _, event, hook = self.runtime.internode_dispatch( x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, - expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) + expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream, decoupled_mode, return_recv_hook, num_max_dispatch_tokens_per_rank) + return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event), hook else: + assert (not decoupled_mode) or num_max_dispatch_tokens_per_rank > 0 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \ rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \ recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ - recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch( + recv_src_meta, send_rdma_head, send_nvl_head, num_max_dispatch_tokens_per_rank, event, hook = self.runtime.internode_dispatch( x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, 0, 0, None, None, None, None, - expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream, decoupled_mode, return_recv_hook, num_max_dispatch_tokens_per_rank) handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, - send_nvl_head) + send_nvl_head, num_max_dispatch_tokens_per_rank) return ( recv_x, recv_x_scales ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( - event) + event), hook # noinspection PyTypeChecker def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], @@ -506,8 +531,8 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, - allocate_on_comm_stream: bool = False) -> \ - Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: + allocate_on_comm_stream: bool = False, decoupled_mode: bool = False, return_recv_hook: bool = False) -> \ + Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap, Callable]: """ Internode combine implementation, for more details, please refer to the `combine` docs. Normally, you should not directly call this function. @@ -518,17 +543,17 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], is_combined_token_in_rank, \ _, _, \ rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \ - src_meta, send_rdma_head, send_nvl_head = handle + src_meta, send_rdma_head, send_nvl_head, num_max_dispatch_tokens_per_rank = handle bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel - combined_x, combined_topk_weights, event = self.runtime.internode_combine(x, topk_weights, bias_0, bias_1, src_meta, + combined_x, combined_topk_weights, event, hook = self.runtime.internode_combine(x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', - None), async_finish, allocate_on_comm_stream) - return combined_x, combined_topk_weights, EventOverlap(event) + None), async_finish, allocate_on_comm_stream, decoupled_mode, return_recv_hook, num_max_dispatch_tokens_per_rank) + return combined_x, combined_topk_weights, EventOverlap(event), hook def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: """ diff --git a/tests/test_internode.py b/tests/test_internode.py index 6530669d..59749410 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -135,7 +135,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if is_rand else topk_weights}) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch( + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, _ = buffer.dispatch( **dispatch_args) event.current_stream_wait() if async_mode else () @@ -148,7 +148,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks - recv_gbl_rank_prefix_sum = handle[-4] + recv_gbl_rank_prefix_sum = handle[-5] assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), \ f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list @@ -191,7 +191,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + recv_x, _, _, _, _, event, _ = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x if not is_rand: @@ -205,7 +205,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: combine_args.update({'previous_event': buffer.capture()}) - combined_x, combined_topk_weights, event = buffer.combine(**combine_args) + combined_x, combined_topk_weights, event, _ = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) ref_x = x_pure_rand if is_rand else x @@ -280,7 +280,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): 'num_tokens_per_expert': num_tokens_per_expert, 'config': dispatch_config if dispatch_config is not None else config } - recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + recv_x, _, _, _, handle, _, _ = buffer.dispatch(**dispatch_args) # Tune combine performance best_time, best_results = 1e10, None diff --git a/tests/test_internode_hook.py b/tests/test_internode_hook.py new file mode 100644 index 00000000..8a1b3d08 --- /dev/null +++ b/tests/test_internode_hook.py @@ -0,0 +1,532 @@ +import os +import time +import torch +import torch.distributed as dist +from functools import partial + +# noinspection PyUnresolvedReferences +import deep_ep +from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back + +# Test compatibility with low latency functions +import test_low_latency + + +def test_main_decoupled(num_sms: int, num_tokens: int, num_max_dispatch_tokens_per_rank: int, hidden: int, num_topk_groups: int, num_topk: int, num_experts: int, + local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): + # Settings + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) + + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + x_e4m3 = per_token_cast_to_fp8(x) + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank + topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + + # RDMA dispatch counts + rdma_idx = topk_idx // (num_experts // num_nodes) + rdma_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rdma_idx, num_nodes) + num_rdma_token_sent = rdma_idx.ne(-1).sum().item() + + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda') + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda') + num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda') + token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda') + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1)[0] + count = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] + tokens[:count] = torch.sort(tokens[:count])[0] + token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda') + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \ + buffer.get_dispatch_layout(topk_idx, num_experts) + assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) + assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) + t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f'[layout] get_dispatch_layout() Kernel performance: {t * 1000:.3f} ms', flush=True) + print('', flush=True) + group.barrier() + time.sleep(1) + + # Config + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + nvl_buffer_size = 64 * num_nodes ## for my testing environment 2/4 nodes, I use this config + config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) + + # noinspection PyShadowingNames + def check_data(check_x, recv_gbl_rank_prefix_sum): + assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) + check_start = 0 + for i in range(num_ranks): + check_end = recv_gbl_rank_prefix_sum[i].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + # Test dispatch: all modes + for previous_mode in (False, True): + for async_mode, return_recv_hook in [(True, False), (False, False), (False, True)]: ## all modes + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + if local_rank == 0: + print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, return_recv_hook={return_recv_hook}, previous={previous_mode}) ...', flush=True, end='') + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if with_topk: + dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + if return_recv_hook: + dispatch_args.update({'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank}) + + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, hook = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + if return_recv_hook: + hook() + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + # Checks + recv_gbl_rank_prefix_sum = handle[-5] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + if with_topk: + # Check `topk_idx` + assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.eq(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + + # Test cached dispatch (must without top-k staffs) + if not with_topk: + dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, _, _, _, _, event, hook = buffer.dispatch(**dispatch_args) + + event.current_stream_wait() if async_mode else () + if return_recv_hook: + hook() + # torch.cuda.synchronize() + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + + # Test combine + bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if with_topk: + combine_args.update({'topk_weights': recv_topk_weights}) + if previous_mode: + combine_args.update({'previous_event': buffer.capture()}) + + combined_x, combined_topk_weights, event, hook = buffer.combine(**combine_args) + + event.current_stream_wait() if async_mode else () + if return_recv_hook: + hook() + + check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) + ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes + + torch.cuda.synchronize() + + if local_rank == 0: + print(' passed', flush=True) + if local_rank == 0: + print('General Tests for All Normal Modes Complete', flush=True) + print('', flush=True) + + + # Border Test for dispatch: hook mode + def calc_pass(): + pass + + def calc_sleep_1ms(): + time.sleep(0.001) + + def calc_sleep_10ms(): + time.sleep(0.01) + + def calc_sleep_100ms(): + time.sleep(0.1) + + def calc_sync(): + torch.cuda.synchronize() + + torch.cuda.synchronize() + + calc_func_list = [calc_pass, calc_sleep_1ms, calc_sleep_10ms, calc_sleep_100ms, calc_sync] + calc_func_name_list = ["calc_pass", "calc_sleep_1ms", "calc_sleep_10ms", "calc_sleep_100ms", "calc_sync"] + + for i in range(len(calc_func_list)): + calc_func = calc_func_list[i] + calc_func_name = calc_func_name_list[i] + for previous_mode in (False, True): + for async_mode, return_recv_hook in [(False, True),]: ## hook mode + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + if local_rank == 0: + print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (calc_func={calc_func_name}, async={async_mode}, return_recv_hook={return_recv_hook}, previous={previous_mode}) ...', flush=True, end='') + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if with_topk: + dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + if return_recv_hook: + dispatch_args.update({'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank}) + + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event, hook = buffer.dispatch(**dispatch_args) + calc_func() + hook() + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + # Checks + recv_gbl_rank_prefix_sum = handle[-5] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + if with_topk: + # Check `topk_idx` + assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.eq(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + + # Test cached dispatch (must without top-k staffs) + if not with_topk: + dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if previous_mode: + dispatch_args.update({'previous_event': buffer.capture()}) + recv_x, _, _, _, _, event, hook = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + + calc_func() + hook() + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + + # Test combine + combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode, 'return_recv_hook': return_recv_hook} + if with_topk: + combine_args.update({'topk_weights': recv_topk_weights}) + if previous_mode: + combine_args.update({'previous_event': buffer.capture()}) + combined_x, combined_topk_weights, event, hook = buffer.combine(**combine_args) + event.current_stream_wait() if async_mode else () + + calc_func() + hook() + + check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) + ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + torch.cuda.synchronize() + if local_rank == 0: + print(' passed', flush=True) + if local_rank == 0: + print('Border Tests for Hook Mode Complete', flush=True) + print('', flush=True) + + + ### Tune decoupled mode dispatch performance + # noinspection PyShadowingNames + def large_gemm_with_hook(hook): + mat_0 = torch.randn((4096, 4096), dtype=torch.float) + mat_1 = torch.randn((4096, 4096), dtype=torch.float) + mat_0 @ mat_1 + hook() + + # noinspection PyShadowingNames + def test_dispatch_hook(x, config, handle, return_recv_hook): + _, _, _, _, _, _, hook = \ + buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + torch.cuda.synchronize() + + def test_combine_hook(x, config, handle, return_recv_hook): + _, _, _, hook = \ + buffer.combine(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + torch.cuda.synchronize() + + def test_dispatch_combine_hook(x, config, handle, return_recv_hook): + recv_x, _, _, _, _, _, hook = \ + buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + _, _, _, hook = \ + buffer.combine(x=recv_x, config=config, handle=handle, async_finish=False, return_recv_hook=return_recv_hook) + large_gemm_with_hook(hook) if return_recv_hook else None + torch.cuda.synchronize() + + + ## Hook mode Dispatch + torch.cuda.synchronize() + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(2, 45, 2): + for rdma_chunk_size in range(4, 33, 4): + for return_recv_hook in (True,): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': False, 'return_recv_hook': True, + 'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank} + _, _, _, _, handle_hook_tuning, _, hook = buffer.dispatch(**dispatch_args) + hook() + torch.cuda.synchronize() + + # trace_path = f'/home/nas/zhiyihu/traces/trace_rank_{rank}_nvl_chunk_size_{nvl_chunk_size}_rdma_chunk_size_{rdma_chunk_size}.json' + # dispatch_t, gemm_t = bench_kineto(partial(test_dispatch_hook, x=current_x, config=config, handle=handle_hook_tuning, return_recv_hook=return_recv_hook), + # kernel_names=('dispatch', 'gemm'), trace_path=trace_path, barrier_comm_profiling=True, suppress_kineto_output=False) + dispatch_t, gemm_t = bench_kineto(partial(test_dispatch_hook, x=current_x, config=config, handle=handle_hook_tuning, return_recv_hook=return_recv_hook), + kernel_names=('dispatch', 'gemm')) + + if dispatch_t < best_time: + best_time, best_results = dispatch_t, (num_sms, nvl_chunk_size, rdma_chunk_size) + if local_rank == 0: + print(f'[Tuning Decoupled Mode Dispatch, With Recv Hook] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: Dispatch send kernel time plus recv kernel time: {dispatch_t * 2 * 1e6:.2f} us, GEMM kernel time: {gemm_t * 1e6:.2f} us ', flush=True) + + if local_rank == 0: + print(f'[Tuning Decoupled Mode Dispatch, With Recv Hook] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}, rdma_send_bytes {rdma_send_bytes}, nvl_recv_bytes {nvl_recv_bytes}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {best_time * 2 * 1e6:.2f} us, {rdma_send_bytes / 1e9 / (best_time * 2):.2f} GB/s (RDMA) ', flush=True) + print('', flush=True) + + ## Hook mode Combine + torch.cuda.synchronize() + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(1, 13, 1): + for rdma_chunk_size in range(8, 33, 4): + for return_recv_hook in (True,): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': False, 'return_recv_hook': True, + 'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank} + recv_x, _, _, _, handle_hook_tuning, _, hook = buffer.dispatch(**dispatch_args) + hook() + torch.cuda.synchronize() + + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x + + combine_t, gemm_t = bench_kineto(partial(test_combine_hook, x=recv_x, config=config, handle=handle_hook_tuning, return_recv_hook=return_recv_hook), + kernel_names=('combine', 'gemm')) + + if combine_t < best_time: + best_time, best_results = combine_t, (num_sms, nvl_chunk_size, rdma_chunk_size) + if local_rank == 0: + print(f'[Tuning Decoupled Mode Combine, With Recv Hook] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: Combine send kernel time plus recv kernel time: {combine_t * 2 * 1e6:.2f} us, GEMM kernel time: {gemm_t * 1e6:.2f} us ', flush=True) + + if local_rank == 0: + print(f'[Tuning Decoupled Mode Combine, With Recv Hook] Best combine (nvl_send_bytes {combine_bf16_nvl_send_bytes}, rdma_recv_bytes {combine_bf16_rdma_recv_bytes}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {best_time * 2 * 1e6:.2f} us, {combine_bf16_rdma_recv_bytes / 1e9 / (best_time * 2):.2f} GB/s (RDMA) ', flush=True) + print('', flush=True) + + ## Hook mode Dispatch + Combine + torch.cuda.synchronize() + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(1, 13, 1): + for rdma_chunk_size in range(8, 33, 4): + for return_recv_hook in (True,): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': False, 'return_recv_hook': True, + 'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank} + _, _, _, _, handle_hook_tuning, _, hook = buffer.dispatch(**dispatch_args) + hook() + torch.cuda.synchronize() + + dispatch_t, combine_t, gemm_t = bench_kineto(partial(test_dispatch_combine_hook, x=current_x, config=config, handle=handle_hook_tuning, return_recv_hook=return_recv_hook), + kernel_names=('dispatch', 'combine', 'gemm')) + + if dispatch_t + combine_t < best_time: + best_time, best_results = dispatch_t + combine_t, (num_sms, nvl_chunk_size, rdma_chunk_size) + if local_rank == 0: + print(f'[Tuning Decoupled Mode Dispatch + Combine, With Recv Hook] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: Dispatch + Combine send kernel time plus recv kernel time: {(dispatch_t + combine_t) * 2 * 1e6:.2f} us, Dispatch send kernel time plus recv kernel time: {dispatch_t * 2 * 1e6:.2f} us, Combine send kernel time plus recv kernel time: {combine_t * 2 * 1e6:.2f} us, GEMM kernel time: {gemm_t * 1e6:.2f} us ', flush=True) + + if local_rank == 0: + print(f'[Tuning Decoupled Mode Dispatch + Combine, With Recv Hook] Best dispatch {"FP8" if isinstance(current_x, tuple) else "BF16"} + combine BF16 (rdma_send_bytes {rdma_send_bytes}, nvl_recv_bytes {nvl_recv_bytes}, nvl_send_bytes {combine_bf16_nvl_send_bytes}, rdma_recv_bytes {combine_bf16_rdma_recv_bytes}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {best_time * 2 * 1e6:.2f} us ', flush=True) + print('', flush=True) + + + ### Tune native (non-decoupled) mode dispatch performance + # noinspection PyShadowingNames + def test_func_native(x, config, handle): + _, _, _, _, _, event, _ = \ + buffer.dispatch(x=x, config=config, handle=handle, async_finish=False, return_recv_hook=False) + + torch.cuda.synchronize() + num_sms = 24 + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes + for nvl_chunk_size in range(4, 45, 4): + for rdma_chunk_size in range(4, 33, 4): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode, + 'num_max_dispatch_tokens_per_rank': num_max_dispatch_tokens_per_rank} + _, _, _, _, handle_native, _, _ = buffer.dispatch(**dispatch_args) + avg_t = bench(partial(test_func_native, x=current_x, config=config, handle=handle_native))[0] + if avg_t < best_time: + best_time, best_results = avg_t, (num_sms, nvl_chunk_size, rdma_chunk_size) + if local_rank == 0: + print(f'[Tuning Native Mode Dispatch] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: Dispatch kernel time: {avg_t * 1e6:.2f} us, Dispatch bandwidth: {rdma_send_bytes / 1e9 / avg_t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / avg_t:.2f} GB/s (NVL) ', flush=True) + if local_rank == 0: + print(f'[Tuning Native Mode Dispatch] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}, rdma_send_bytes {rdma_send_bytes}, nvl_recv_bytes {nvl_recv_bytes}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {best_time * 1e6:.2f} us, {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print('', flush=True) + + if isinstance(current_x, tuple): + # Gather FP8 the best config from rank 0 + best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda') + all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] + dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + + dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) + dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config} + recv_x, _, _, _, handle_native, _, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 13, 1): + for rdma_chunk_size in range(12, 33, 4): + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) + tune_args = {'x': recv_x, 'handle': handle_native, 'config': config} + avg_t = bench(lambda: buffer.combine(**tune_args))[0] + if local_rank == 0: + print(f'[Tuning Native Mode Combine] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: Combine kernel time: {avg_t * 1e6:.2f} us, Combine bandwidth: {combine_bf16_rdma_recv_bytes / 1e9 / avg_t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / avg_t:.2f} GB/s (NVL) ', flush=True) + if avg_t < best_time: + best_time, best_results = avg_t, (num_sms, nvl_chunk_size, rdma_chunk_size) + + if local_rank == 0: + print(f'[Tuning Native Mode Combine] Best combine BF16: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {best_time * 1e6:.2f} us, {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print('', flush=True) + + + +# noinspection PyUnboundLocalVariable +def test_loop_decoupled(local_rank: int, num_local_ranks: int): + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + + num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks + + # num_max_dispatch_tokens_per_rank = num_tokens + 100 + num_max_dispatch_tokens_per_rank = num_tokens + + test_ll_compatibility = True + if test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + + num_sms= 64 + num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0) + + return_recv_hook = True + num_rdma_bytes = deep_ep.Buffer.get_normal_hook_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_nodes, num_sms, return_recv_hook) + if local_rank == 0: + print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) + buffer = deep_ep.Buffer(group, num_nvl_bytes=int(8e9), num_rdma_bytes=num_rdma_bytes, low_latency_mode=test_ll_compatibility, + num_qps_per_rank=num_qps_per_rank) + assert num_local_ranks == 8 and num_ranks > 8 + torch.manual_seed(rank) + + for i in (num_sms, ): + test_main_decoupled(i, num_tokens, num_max_dispatch_tokens_per_rank, hidden, num_topk_groups, num_topk, num_experts, + local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) + if local_rank == 0: + print('', flush=True) + + # Test compatibility with low latency functions + if test_ll_compatibility: + buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) + test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + + # Destroy the communication group + dist.barrier() + dist.destroy_process_group() + +if __name__ == '__main__': + num_processes = 8 + torch.multiprocessing.spawn(test_loop_decoupled, args=(num_processes, ), nprocs=num_processes)