@@ -279,13 +279,33 @@ __global__ __launch_bounds__(
279279 rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
280280 slot_idx * num_bytes_per_msg;
281281 if (dst_rank != rank) {
282- nvshmemi_ibgda_put_nbi_warp (dst_ptr,
283- src_ptr,
284- num_bytes_per_msg,
285- dst_rank,
286- dst_expert_local_idx,
287- lane_id,
288- slot_idx);
282+ void * peer_base_addr = reinterpret_cast <void *>(
283+ __ldg (reinterpret_cast <const uint64_t *>(
284+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
285+ dst_rank));
286+ if (peer_base_addr) {
287+ char * req_rptr_actual =
288+ reinterpret_cast <char *>(peer_base_addr) +
289+ (reinterpret_cast <char *>(dst_ptr) -
290+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base ));
291+ const auto * src_int4_ptr = reinterpret_cast <const int4 *>(src_ptr);
292+ const auto * dst_int4_ptr = reinterpret_cast <int4 *>(req_rptr_actual);
293+ UNROLLED_WARP_COPY (8 ,
294+ lane_id,
295+ num_int4_per_msg,
296+ dst_int4_ptr,
297+ src_int4_ptr,
298+ ld_nc_global,
299+ st_na_global);
300+ } else {
301+ nvshmemi_ibgda_put_nbi_warp (dst_ptr,
302+ src_ptr,
303+ num_bytes_per_msg,
304+ dst_rank,
305+ dst_expert_local_idx,
306+ lane_id,
307+ slot_idx);
308+ }
289309 } else {
290310 // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
291311 const auto * src_int4_ptr = reinterpret_cast <const int4 *>(src_ptr);
@@ -367,11 +387,24 @@ __global__ __launch_bounds__(
367387 responsible_expert_idx) != FINISHED_SUM_TAG * 2 ) {
368388 }
369389 if (dst_rank != rank) {
370- nvshmemi_ibgda_amo_nonfetch_add (
371- rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
372- -num_tokens_sent - 1 ,
373- dst_rank,
374- dst_expert_local_idx);
390+ void * peer_base_addr = reinterpret_cast <void *>(
391+ __ldg (reinterpret_cast <const uint64_t *>(
392+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
393+ dst_rank));
394+ if (peer_base_addr) { // P2P enabled
395+ int * rptr_actual = reinterpret_cast <int *>(
396+ reinterpret_cast <char *>(peer_base_addr) +
397+ (reinterpret_cast <char *>(rdma_recv_count +
398+ dst_expert_local_idx * num_ranks + rank) -
399+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base )));
400+ st_na_release (rptr_actual, -num_tokens_sent - 1 );
401+ } else {
402+ nvshmemi_ibgda_amo_nonfetch_add (
403+ rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
404+ -num_tokens_sent - 1 ,
405+ dst_rank,
406+ dst_expert_local_idx);
407+ }
375408 } else {
376409 st_na_release (rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
377410 -num_tokens_sent - 1 );
@@ -691,13 +724,32 @@ __global__ __launch_bounds__(
691724 x_int4,
692725 ld_nc_global,
693726 st_na_global);
694- nvshmemi_ibgda_put_nbi_warp (dst_ptr,
695- buf_ptr,
696- hidden * sizeof (nv_bfloat16),
697- dst_rank,
698- local_expert_idx,
699- lane_id,
700- token_idx - offset);
727+ void * peer_base_addr = reinterpret_cast <void *>(
728+ __ldg (reinterpret_cast <const uint64_t *>(
729+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
730+ dst_rank));
731+ if (peer_base_addr) {
732+ char * req_rptr_actual =
733+ reinterpret_cast <char *>(peer_base_addr) +
734+ (reinterpret_cast <char *>(dst_ptr) -
735+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base ));
736+ const auto dst_int4_ptr = reinterpret_cast <int4 *>(req_rptr_actual);
737+ UNROLLED_WARP_COPY (7 ,
738+ lane_id,
739+ hidden_bf16_int4,
740+ dst_int4_ptr,
741+ x_int4,
742+ ld_nc_global,
743+ st_na_global);
744+ } else {
745+ nvshmemi_ibgda_put_nbi_warp (dst_ptr,
746+ buf_ptr,
747+ hidden * sizeof (nv_bfloat16),
748+ dst_rank,
749+ local_expert_idx,
750+ lane_id,
751+ token_idx - offset);
752+ }
701753 }
702754 }
703755
@@ -710,8 +762,22 @@ __global__ __launch_bounds__(
710762 while (ld_acquire_global (atomic_clean_flag) == 0 ) {
711763 }
712764 if (dst_rank != rank) {
713- nvshmemi_ibgda_amo_nonfetch_add (
714- rdma_recv_flag + global_expert_idx, 1 , dst_rank, local_expert_idx);
765+ void * peer_base_addr = reinterpret_cast <void *>(
766+ __ldg (reinterpret_cast <const uint64_t *>(
767+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
768+ dst_rank));
769+ if (peer_base_addr) {
770+ int * req_rptr_actual = reinterpret_cast <int *>(
771+ reinterpret_cast <char *>(peer_base_addr) +
772+ (reinterpret_cast <char *>(rdma_recv_flag + global_expert_idx) -
773+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base )));
774+ st_na_release (req_rptr_actual, 1 );
775+ } else {
776+ nvshmemi_ibgda_amo_nonfetch_add (rdma_recv_flag + global_expert_idx,
777+ 1 ,
778+ dst_rank,
779+ local_expert_idx);
780+ }
715781 } else {
716782 st_na_release (rdma_recv_flag + global_expert_idx, 1 );
717783 }
0 commit comments