@@ -591,22 +591,20 @@ def determine_expert_map(
591
591
if ep_size == 1 :
592
592
return (global_num_experts , None )
593
593
594
- local_num_experts = global_num_experts // ep_size
594
+ # Distribute experts as evenly as possible to each rank.
595
+ base_experts = global_num_experts // ep_size
596
+ remainder = global_num_experts % ep_size
597
+ if ep_rank < remainder :
598
+ local_num_experts = base_experts + 1
599
+ else :
600
+ local_num_experts = base_experts
595
601
596
602
# Create a tensor of size num_experts filled with -1
597
603
expert_map = torch .full ((global_num_experts , ), - 1 , dtype = torch .int32 )
598
604
# Create a expert map for the local experts
599
- if ep_rank < (ep_size - 1 ):
600
- # Each non-last rank gets local_num_experts experts.
601
- expert_map [ep_rank * local_num_experts :
602
- (ep_rank + 1 ) * local_num_experts ] = \
603
- torch .arange (0 , local_num_experts , dtype = torch .int32 )
604
- else :
605
- # All remaining experts are assigned to the last rank.
606
- local_num_experts = (global_num_experts - ep_rank * local_num_experts )
607
-
608
- expert_map [- local_num_experts :] = \
609
- torch .arange (0 , local_num_experts , dtype = torch .int32 )
605
+ start_idx = ep_rank * base_experts + min (ep_rank , remainder )
606
+ expert_map [start_idx :start_idx + local_num_experts ] = torch .arange (
607
+ 0 , local_num_experts , dtype = torch .int32 )
610
608
return (local_num_experts , expert_map )
611
609
612
610
0 commit comments