@@ -100,6 +100,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
100100 int num_experts,
101101 const bool * is_token_in_rank,
102102 int num_tokens,
103+ int num_worst_tokens,
103104 int num_channels,
104105 int expert_alignment,
105106 const int rdma_clean_offset,
@@ -236,9 +237,11 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
236237 sum += rdma_recv_num_tokens_mixed.recv_buffer (i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
237238 recv_rdma_rank_prefix_sum[i] = sum;
238239 }
239- while (ld_volatile_global (moe_recv_rdma_counter_mapped) != -1 )
240- ;
241- *moe_recv_rdma_counter_mapped = sum;
240+ if (num_worst_tokens == 0 ) {
241+ while (ld_volatile_global (moe_recv_rdma_counter_mapped) != -1 )
242+ ;
243+ *moe_recv_rdma_counter_mapped = sum;
244+ }
242245 }
243246
244247 // Send numbers of tokens per rank/expert to NVL ranks
@@ -263,19 +266,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
263266 sum += nvl_recv_num_tokens_per_rank.buffer (src_nvl_rank)[src_rdma_rank];
264267 recv_gbl_rank_prefix_sum[i] = sum;
265268 }
266- while (ld_volatile_global (moe_recv_counter_mapped) != -1 )
267- ;
268- *moe_recv_counter_mapped = sum;
269+ if (num_worst_tokens == 0 ) {
270+ while (ld_volatile_global (moe_recv_counter_mapped) != -1 )
271+ ;
272+ *moe_recv_counter_mapped = sum;
273+ }
269274 }
270275 if (thread_id < num_nvl_experts) {
271276 int sum = 0 ;
272277 #pragma unroll
273278 for (int i = 0 ; i < NUM_MAX_NVL_PEERS; ++i)
274279 sum += nvl_recv_num_tokens_per_expert.buffer (i)[thread_id];
275280 sum = (sum + expert_alignment - 1 ) / expert_alignment * expert_alignment;
276- while (ld_volatile_global (moe_recv_expert_counter_mapped + thread_id) != -1 )
277- ;
278- moe_recv_expert_counter_mapped[thread_id] = sum;
281+ if (num_worst_tokens == 0 ) {
282+ while (ld_volatile_global (moe_recv_expert_counter_mapped + thread_id) != -1 )
283+ ;
284+ moe_recv_expert_counter_mapped[thread_id] = sum;
285+ }
279286 }
280287
281288 // Finally barrier
@@ -346,6 +353,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
346353 int num_experts,
347354 const bool * is_token_in_rank,
348355 int num_tokens,
356+ int num_worst_tokens,
349357 int num_channels,
350358 int hidden_int4,
351359 int num_scales,
@@ -380,6 +388,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
380388 num_experts, \
381389 is_token_in_rank, \
382390 num_tokens, \
391+ num_worst_tokens, \
383392 num_channels, \
384393 expert_alignment, \
385394 rdma_clean_meta.first , \
@@ -455,6 +464,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
455464 const int * recv_gbl_rank_prefix_sum,
456465 const bool * is_token_in_rank,
457466 int num_tokens,
467+ int num_worst_tokens,
458468 int hidden_int4,
459469 int num_scales,
460470 int num_topk,
@@ -1179,6 +1189,21 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
11791189 st_relaxed_sys_global (nvl_channel_head.buffer (), cached_channel_head_idx);
11801190 }
11811191 }
1192+
1193+ // Clean unused `recv_topk_idx` as -1
1194+ if (num_worst_tokens > 0 ) {
1195+ if (is_forwarder)
1196+ return ;
1197+ // get the actual number of num_recv_tokens on the current rank
1198+ int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1 ];
1199+ // some ForwarderCoordinator threads exit early, so we only use non-forwarder thread ids
1200+ const auto clean_start = num_recv_tokens * num_topk + (sm_id / 2 ) * num_threads;
1201+ const auto clean_end = num_worst_tokens * num_topk;
1202+ const auto clean_stride = num_sms * num_threads / 2 ;
1203+ #pragma unroll
1204+ for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
1205+ recv_topk_idx[i] = -1 ;
1206+ }
11821207}
11831208
11841209void dispatch (void * recv_x,
@@ -1200,6 +1225,7 @@ void dispatch(void* recv_x,
12001225 const int * recv_gbl_rank_prefix_sum,
12011226 const bool * is_token_in_rank,
12021227 int num_tokens,
1228+ int num_worst_tokens,
12031229 int hidden_int4,
12041230 int num_scales,
12051231 int num_topk,
@@ -1254,6 +1280,7 @@ void dispatch(void* recv_x,
12541280 recv_gbl_rank_prefix_sum, \
12551281 is_token_in_rank, \
12561282 num_tokens, \
1283+ num_worst_tokens, \
12571284 hidden_int4, \
12581285 num_scales, \
12591286 num_topk, \
@@ -1698,6 +1725,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
16981725 const int * rdma_channel_prefix_matrix,
16991726 const int * rdma_rank_prefix_sum,
17001727 const int * gbl_channel_prefix_matrix,
1728+ const int * gbl_rank_prefix_sum,
17011729 int num_tokens,
17021730 int num_combined_tokens,
17031731 int hidden,
@@ -1789,7 +1817,9 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
17891817 if (lane_id < kNumRDMARanks ) {
17901818 int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
17911819 token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
1792- token_end_idx = (prefix_idx == num_channels * num_ranks - 1 ) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1 ];
1820+ // if it is the last channel, set token_end_idx to actual recevied token count
1821+ token_end_idx = (prefix_idx == num_channels * num_ranks - 1 ) ? gbl_rank_prefix_sum[num_ranks - 1 ]
1822+ : gbl_channel_prefix_matrix[prefix_idx + 1 ];
17931823 }
17941824 __syncwarp ();
17951825
@@ -2261,6 +2291,7 @@ void combine(cudaDataType_t type,
22612291 const int * rdma_channel_prefix_matrix,
22622292 const int * rdma_rank_prefix_sum,
22632293 const int * gbl_channel_prefix_matrix,
2294+ const int * gbl_rank_prefix_sum,
22642295 int num_tokens,
22652296 int num_combined_tokens,
22662297 int hidden,
@@ -2312,6 +2343,7 @@ void combine(cudaDataType_t type,
23122343 rdma_channel_prefix_matrix, \
23132344 rdma_rank_prefix_sum, \
23142345 gbl_channel_prefix_matrix, \
2346+ gbl_rank_prefix_sum, \
23152347 num_tokens, \
23162348 num_combined_tokens, \
23172349 hidden, \
0 commit comments