Skip to content

Commit fe56180

Browse files
authored
[MoE] More balanced expert sharding (vllm-project#21497)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 07d80d7 commit fe56180

File tree

1 file changed

+10
-12
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+10
-12
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -591,22 +591,20 @@ def determine_expert_map(
591591
if ep_size == 1:
592592
return (global_num_experts, None)
593593

594-
local_num_experts = global_num_experts // ep_size
594+
# Distribute experts as evenly as possible to each rank.
595+
base_experts = global_num_experts // ep_size
596+
remainder = global_num_experts % ep_size
597+
if ep_rank < remainder:
598+
local_num_experts = base_experts + 1
599+
else:
600+
local_num_experts = base_experts
595601

596602
# Create a tensor of size num_experts filled with -1
597603
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
598604
# Create a expert map for the local experts
599-
if ep_rank < (ep_size - 1):
600-
# Each non-last rank gets local_num_experts experts.
601-
expert_map[ep_rank * local_num_experts:
602-
(ep_rank + 1) * local_num_experts] = \
603-
torch.arange(0, local_num_experts, dtype=torch.int32)
604-
else:
605-
# All remaining experts are assigned to the last rank.
606-
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
607-
608-
expert_map[-local_num_experts:] = \
609-
torch.arange(0, local_num_experts, dtype=torch.int32)
605+
start_idx = ep_rank * base_experts + min(ep_rank, remainder)
606+
expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
607+
0, local_num_experts, dtype=torch.int32)
610608
return (local_num_experts, expert_map)
611609

612610

0 commit comments

Comments
 (0)