Skip to content

Commit 610b076

Browse files
committed
Enable CUDA Graph for internode dispatch
1 parent 2be0d4f commit 610b076

File tree

6 files changed

+107
-38
lines changed

6 files changed

+107
-38
lines changed

csrc/deep_ep.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
825825
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
826826
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
827827
int expert_alignment,
828+
int num_worst_tokens,
828829
const Config& config,
829830
std::optional<EventHandle>& previous_event,
830831
bool async,
@@ -997,6 +998,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
997998
num_experts,
998999
is_token_in_rank.data_ptr<bool>(),
9991000
num_tokens,
1001+
num_worst_tokens,
10001002
num_channels,
10011003
hidden_int4,
10021004
num_scales,
@@ -1018,30 +1020,35 @@ Buffer::internode_dispatch(const torch::Tensor& x,
10181020
low_latency_mode);
10191021

10201022
// Synchronize total received tokens and tokens per expert
1021-
auto start_time = std::chrono::high_resolution_clock::now();
1022-
while (true) {
1023-
// Read total count
1024-
num_recv_tokens = static_cast<int>(*moe_recv_counter);
1025-
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
1026-
1027-
// Read per-expert count
1028-
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
1029-
for (int i = 0; i < num_local_experts and ready; ++i)
1030-
ready &= moe_recv_expert_counter[i] >= 0;
1031-
1032-
if (ready)
1033-
break;
1034-
1035-
// Timeout check
1036-
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
1037-
NUM_CPU_TIMEOUT_SECS) {
1038-
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
1039-
for (int i = 0; i < num_local_experts; ++i)
1040-
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
1041-
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
1023+
if (num_worst_tokens > 0) {
1024+
num_recv_tokens = num_worst_tokens;
1025+
num_rdma_recv_tokens = num_worst_tokens;
1026+
} else {
1027+
auto start_time = std::chrono::high_resolution_clock::now();
1028+
while (true) {
1029+
// Read total count
1030+
num_recv_tokens = static_cast<int>(*moe_recv_counter);
1031+
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
1032+
1033+
// Read per-expert count
1034+
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
1035+
for (int i = 0; i < num_local_experts and ready; ++i)
1036+
ready &= moe_recv_expert_counter[i] >= 0;
1037+
1038+
if (ready)
1039+
break;
1040+
1041+
// Timeout check
1042+
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
1043+
NUM_CPU_TIMEOUT_SECS) {
1044+
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
1045+
for (int i = 0; i < num_local_experts; ++i)
1046+
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
1047+
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
1048+
}
10421049
}
1050+
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
10431051
}
1044-
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
10451052
}
10461053

10471054
// Allocate new tensors
@@ -1098,6 +1105,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
10981105
recv_gbl_rank_prefix_sum.data_ptr<int>(),
10991106
is_token_in_rank.data_ptr<bool>(),
11001107
num_tokens,
1108+
num_worst_tokens,
11011109
hidden_int4,
11021110
num_scales,
11031111
num_topk,
@@ -1194,6 +1202,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
11941202
const torch::Tensor& rdma_channel_prefix_matrix,
11951203
const torch::Tensor& rdma_rank_prefix_sum,
11961204
const torch::Tensor& gbl_channel_prefix_matrix,
1205+
const torch::Tensor& gbl_rank_prefix_sum,
11971206
const torch::Tensor& combined_rdma_head,
11981207
const torch::Tensor& combined_nvl_head,
11991208
const Config& config,
@@ -1228,6 +1237,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
12281237
EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels);
12291238
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
12301239
EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels);
1240+
EP_HOST_ASSERT(gbl_rank_prefix_sum.size(0) == num_ranks);
12311241
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and
12321242
combined_rdma_head.size(1) == num_rdma_ranks);
12331243
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);
@@ -1318,6 +1328,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
13181328
rdma_channel_prefix_matrix.data_ptr<int>(),
13191329
rdma_rank_prefix_sum.data_ptr<int>(),
13201330
gbl_channel_prefix_matrix.data_ptr<int>(),
1331+
gbl_rank_prefix_sum.data_ptr<int>(),
13211332
num_tokens,
13221333
num_combined_tokens,
13231334
hidden,
@@ -1344,6 +1355,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
13441355
rdma_channel_prefix_matrix,
13451356
rdma_rank_prefix_sum,
13461357
gbl_channel_prefix_matrix,
1358+
gbl_rank_prefix_sum,
13471359
combined_x,
13481360
combined_rdma_head,
13491361
combined_nvl_head}) {

csrc/deep_ep.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ struct Buffer {
198198
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
199199
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
200200
int expert_alignment,
201+
int num_worst_tokens,
201202
const Config& config,
202203
std::optional<EventHandle>& previous_event,
203204
bool async,
@@ -213,6 +214,7 @@ struct Buffer {
213214
const torch::Tensor& rdma_channel_prefix_matrix,
214215
const torch::Tensor& rdma_rank_prefix_sum,
215216
const torch::Tensor& gbl_channel_prefix_matrix,
217+
const torch::Tensor& gbl_rank_prefix_sum,
216218
const torch::Tensor& combined_rdma_head,
217219
const torch::Tensor& combined_nvl_head,
218220
const Config& config,

csrc/kernels/api.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
154154
int num_experts,
155155
const bool* is_token_in_rank,
156156
int num_tokens,
157+
int num_worst_tokens,
157158
int num_channels,
158159
int hidden_int4,
159160
int num_scales,
@@ -193,6 +194,7 @@ void dispatch(void* recv_x,
193194
const int* recv_gbl_rank_prefix_sum,
194195
const bool* is_token_in_rank,
195196
int num_tokens,
197+
int num_worst_tokens,
196198
int hidden_int4,
197199
int num_scales,
198200
int num_topk,
@@ -249,6 +251,7 @@ void combine(cudaDataType_t type,
249251
const int* rdma_channel_prefix_matrix,
250252
const int* rdma_rank_prefix_sum,
251253
const int* gbl_channel_prefix_matrix,
254+
const int* gbl_rank_prefix_sum,
252255
int num_tokens,
253256
int num_combined_tokens,
254257
int hidden,

csrc/kernels/internode.cu

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
100100
int num_experts,
101101
const bool* is_token_in_rank,
102102
int num_tokens,
103+
int num_worst_tokens,
103104
int num_channels,
104105
int expert_alignment,
105106
const int rdma_clean_offset,
@@ -236,9 +237,11 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
236237
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
237238
recv_rdma_rank_prefix_sum[i] = sum;
238239
}
239-
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
240-
;
241-
*moe_recv_rdma_counter_mapped = sum;
240+
if (num_worst_tokens == 0) {
241+
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
242+
;
243+
*moe_recv_rdma_counter_mapped = sum;
244+
}
242245
}
243246

244247
// Send numbers of tokens per rank/expert to NVL ranks
@@ -263,19 +266,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
263266
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
264267
recv_gbl_rank_prefix_sum[i] = sum;
265268
}
266-
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
267-
;
268-
*moe_recv_counter_mapped = sum;
269+
if (num_worst_tokens == 0) {
270+
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
271+
;
272+
*moe_recv_counter_mapped = sum;
273+
}
269274
}
270275
if (thread_id < num_nvl_experts) {
271276
int sum = 0;
272277
#pragma unroll
273278
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
274279
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
275280
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
276-
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
277-
;
278-
moe_recv_expert_counter_mapped[thread_id] = sum;
281+
if (num_worst_tokens == 0) {
282+
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
283+
;
284+
moe_recv_expert_counter_mapped[thread_id] = sum;
285+
}
279286
}
280287

281288
// Finally barrier
@@ -346,6 +353,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
346353
int num_experts,
347354
const bool* is_token_in_rank,
348355
int num_tokens,
356+
int num_worst_tokens,
349357
int num_channels,
350358
int hidden_int4,
351359
int num_scales,
@@ -380,6 +388,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
380388
num_experts, \
381389
is_token_in_rank, \
382390
num_tokens, \
391+
num_worst_tokens, \
383392
num_channels, \
384393
expert_alignment, \
385394
rdma_clean_meta.first, \
@@ -455,6 +464,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
455464
const int* recv_gbl_rank_prefix_sum,
456465
const bool* is_token_in_rank,
457466
int num_tokens,
467+
int num_worst_tokens,
458468
int hidden_int4,
459469
int num_scales,
460470
int num_topk,
@@ -1179,6 +1189,21 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
11791189
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
11801190
}
11811191
}
1192+
1193+
// Clean unused `recv_topk_idx` as -1
1194+
if (num_worst_tokens > 0) {
1195+
if (is_forwarder)
1196+
return;
1197+
// get the actual number of num_recv_tokens on the current rank
1198+
int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1];
1199+
// some ForwarderCoordinator threads exit early, so we only use non-forwarder thread ids
1200+
const auto clean_start = num_recv_tokens * num_topk + (sm_id / 2) * num_threads;
1201+
const auto clean_end = num_worst_tokens * num_topk;
1202+
const auto clean_stride = num_sms * num_threads / 2;
1203+
#pragma unroll
1204+
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
1205+
recv_topk_idx[i] = -1;
1206+
}
11821207
}
11831208

11841209
void dispatch(void* recv_x,
@@ -1200,6 +1225,7 @@ void dispatch(void* recv_x,
12001225
const int* recv_gbl_rank_prefix_sum,
12011226
const bool* is_token_in_rank,
12021227
int num_tokens,
1228+
int num_worst_tokens,
12031229
int hidden_int4,
12041230
int num_scales,
12051231
int num_topk,
@@ -1254,6 +1280,7 @@ void dispatch(void* recv_x,
12541280
recv_gbl_rank_prefix_sum, \
12551281
is_token_in_rank, \
12561282
num_tokens, \
1283+
num_worst_tokens, \
12571284
hidden_int4, \
12581285
num_scales, \
12591286
num_topk, \
@@ -1698,6 +1725,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
16981725
const int* rdma_channel_prefix_matrix,
16991726
const int* rdma_rank_prefix_sum,
17001727
const int* gbl_channel_prefix_matrix,
1728+
const int* gbl_rank_prefix_sum,
17011729
int num_tokens,
17021730
int num_combined_tokens,
17031731
int hidden,
@@ -1789,7 +1817,9 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* co
17891817
if (lane_id < kNumRDMARanks) {
17901818
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
17911819
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
1792-
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
1820+
// if it is the last channel, set token_end_idx to actual recevied token count
1821+
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? gbl_rank_prefix_sum[num_ranks - 1]
1822+
: gbl_channel_prefix_matrix[prefix_idx + 1];
17931823
}
17941824
__syncwarp();
17951825

@@ -2261,6 +2291,7 @@ void combine(cudaDataType_t type,
22612291
const int* rdma_channel_prefix_matrix,
22622292
const int* rdma_rank_prefix_sum,
22632293
const int* gbl_channel_prefix_matrix,
2294+
const int* gbl_rank_prefix_sum,
22642295
int num_tokens,
22652296
int num_combined_tokens,
22662297
int hidden,
@@ -2312,6 +2343,7 @@ void combine(cudaDataType_t type,
23122343
rdma_channel_prefix_matrix, \
23132344
rdma_rank_prefix_sum, \
23142345
gbl_channel_prefix_matrix, \
2346+
gbl_rank_prefix_sum, \
23152347
num_tokens, \
23162348
num_combined_tokens, \
23172349
hidden, \

0 commit comments

Comments
 (0)