Skip to content

Commit add17f0

Browse files
ISEEKYANETOgaosion
andauthored
[megatron] support megatron expert parallel (#1467)
### Checklist Before Starting ### What does this PR do? support expert parallel in megatron ### High-Level Design introduce EPsize and ETPsize ETPsize is the TPsize for MoE parts, recommended to set 1, meaning that MoE parts not use TP ### Specific Changes 1. mcore model initilize 2. megatron vllm parameter transfer ### API ### Usage Example ```bash LLM=models/Qwen1.5-MoE-A2.7B-Chat NODES=1 PP=2 TP=4 VLLM_TP=4 EP=4 ETP=1 python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=128 \ data.max_prompt_length=1024 \ data.max_response_length=512 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=$LLM \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=32 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ critic.optim.lr=1e-5 \ critic.model.path=$LLM \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=1 \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ trainer.project_name='verl_megatron_gsm8k_examples' \ trainer.experiment_name='qwen_moe_instruct_1node_ep' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$NODES \ trainer.save_freq=-1 \ trainer.test_freq=5 \ actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ critic.megatron.pipeline_model_parallel_size=$PP \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ critic.megatron.tensor_model_parallel_size=$TP \ actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ critic.megatron.expert_model_parallel_size=$EP \ actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ critic.megatron.expert_tensor_parallel_size=$ETP \ actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ critic.megatron.use_dist_checkpointing=True \ actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ actor_rollout_ref.actor.megatron.param_offload=True \ actor_rollout_ref.ref.megatron.param_offload=True \ critic.megatron.param_offload=True \ trainer.total_epochs=100 $@ ``` ### Test ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary. --------- Co-authored-by: gaoziyuan <[email protected]>
1 parent 7b0426a commit add17f0

File tree

6 files changed

+140
-3
lines changed

6 files changed

+140
-3
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
set -x
2+
3+
HF_MODEL_PATH=Qwen/Qwen3-30B-A3B
4+
DIST_CKPT_PATH=${DIST_CKPT_PATH}
5+
6+
python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH
7+
8+
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
9+
# export VLLM_ATTENTION_BACKEND=XFORMERS
10+
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
11+
12+
python3 -m verl.trainer.main_ppo --config-path=config \
13+
--config-name='ppo_megatron_trainer.yaml'\
14+
algorithm.adv_estimator=grpo \
15+
data.train_files=$HOME/data/gsm8k/train.parquet \
16+
data.val_files=$HOME/data/gsm8k/test.parquet \
17+
data.train_batch_size=64 \
18+
data.max_prompt_length=1024 \
19+
data.max_response_length=2048 \
20+
data.filter_overlong_prompts=True \
21+
data.truncation='error' \
22+
actor_rollout_ref.model.path= \
23+
actor_rollout_ref.actor.optim.lr=1e-6 \
24+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
25+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
26+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
27+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
28+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \
29+
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
30+
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
31+
actor_rollout_ref.actor.use_kl_loss=True \
32+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
33+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
34+
actor_rollout_ref.actor.entropy_coeff=0 \
35+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
36+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
37+
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
38+
actor_rollout_ref.rollout.name=vllm \
39+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
40+
actor_rollout_ref.rollout.n=5 \
41+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
42+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
43+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \
44+
actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \
45+
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
46+
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
47+
algorithm.use_kl_in_reward=False \
48+
trainer.critic_warmup=0 \
49+
trainer.logger=['console','wandb'] \
50+
trainer.project_name='verl_grpo_example_gsm8k_math' \
51+
trainer.experiment_name='qwen3_30b_moe_megatron' \
52+
trainer.n_gpus_per_node=8 \
53+
trainer.nnodes=4 \
54+
trainer.save_freq=20 \
55+
trainer.test_freq=5 \
56+
trainer.total_epochs=15 $@

verl/models/mcore/config_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
6565
# Parallel configuration
6666
"tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
6767
"pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
68+
"expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(),
69+
"expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(),
6870
"virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
6971
"context_parallel_size": mpu.get_context_parallel_world_size(),
7072
"overlap_p2p_comm": overlap_p2p_comm,

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ actor_rollout_ref:
7272
grad_offload: False
7373
optimizer_offload: False
7474
tensor_model_parallel_size: 1
75+
expert_model_parallel_size: 1
76+
expert_tensor_parallel_size: null
7577
pipeline_model_parallel_size: 1
7678
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
7779
context_parallel_size: 1
@@ -95,6 +97,8 @@ actor_rollout_ref:
9597
megatron:
9698
param_offload: False
9799
tensor_model_parallel_size: 1
100+
expert_model_parallel_size: 1
101+
expert_tensor_parallel_size: None
98102
pipeline_model_parallel_size: 1
99103
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
100104
context_parallel_size: 1
@@ -187,6 +191,8 @@ critic:
187191
grad_offload: False
188192
optimizer_offload: False
189193
tensor_model_parallel_size: 1
194+
expert_model_parallel_size: 1
195+
expert_tensor_parallel_size: null
190196
pipeline_model_parallel_size: 1
191197
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
192198
context_parallel_size: 1
@@ -218,6 +224,8 @@ reward_model:
218224
grad_offload: False
219225
optimizer_offload: False
220226
tensor_model_parallel_size: 1
227+
expert_model_parallel_size: 1
228+
expert_tensor_parallel_size: null
221229
pipeline_model_parallel_size: 1
222230
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
223231
context_parallel_size: 1

verl/utils/megatron_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,10 @@ def per_tensor_generator(actor_module, model_config, weight_converter, layer_nam
692692

693693
pp_rank = mpu.get_pipeline_model_parallel_rank()
694694
pp_size = mpu.get_pipeline_model_parallel_world_size()
695+
ep_size = mpu.get_expert_model_parallel_world_size()
696+
etp_size = mpu.get_expert_tensor_parallel_world_size()
697+
ep_group = mpu.get_expert_model_parallel_group()
698+
etp_group = mpu.get_expert_tensor_parallel_group()
695699
vpp_size = len(actor_module)
696700
all_gather_group = mpu.get_tensor_model_parallel_group()
697701
all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
@@ -731,6 +735,35 @@ def tensor_generator():
731735
while cur_name.startswith("module."):
732736
cur_name = cur_name[len("module.") :]
733737

738+
# EP
739+
if ".mlp.experts.linear_fc" in cur_name and ep_size > 1:
740+
num_experts = weight_converter.mcore_config.num_moe_experts
741+
num_experts_per_rank = num_experts // ep_size
742+
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)]
743+
torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group)
744+
745+
name_prefix, local_expert_id = cur_name.split(".weight")
746+
local_expert_id = int(local_expert_id)
747+
global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)]
748+
global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids]
749+
750+
for name, param in zip(global_expert_names, infer_params):
751+
if etp_size > 1:
752+
# gather etp
753+
etp_params = [torch.empty_like(param) for _ in range(etp_size)]
754+
torch.distributed.all_gather(etp_params, param, group=etp_group)
755+
params = etp_params
756+
else:
757+
params = [param]
758+
759+
merge_params = default_tp_concat_fn(name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split)
760+
if not isinstance(merge_params, list):
761+
merge_params = [merge_params]
762+
converted_names, converted_params = weight_converter.convert_param(name, merge_params)
763+
764+
yield from zip(converted_names, converted_params)
765+
continue
766+
734767
# tp all gather
735768
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
736769
# allocate a new tensor with proper size

verl/workers/megatron_workers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(self, config: DictConfig, role: str):
9797
pipeline_model_parallel_split_rank=None,
9898
use_sharp=False,
9999
context_parallel_size=self.config.actor.megatron.context_parallel_size,
100-
expert_model_parallel_size=1,
100+
expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,
101+
expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,
101102
nccl_communicator_config_path=None,
102103
)
103104

@@ -524,7 +525,8 @@ def __init__(self, config):
524525
pipeline_model_parallel_split_rank=None,
525526
use_sharp=False,
526527
context_parallel_size=self.config.megatron.context_parallel_size,
527-
expert_model_parallel_size=1,
528+
expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,
529+
expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,
528530
nccl_communicator_config_path=None,
529531
)
530532

@@ -723,7 +725,8 @@ def __init__(self, config):
723725
pipeline_model_parallel_split_rank=None,
724726
use_sharp=False,
725727
context_parallel_size=self.config.megatron.context_parallel_size,
726-
expert_model_parallel_size=1,
728+
expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,
729+
expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,
727730
nccl_communicator_config_path=None,
728731
)
729732

verl/workers/sharding_manager/megatron_vllm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def __init__(
297297
self.train_tp_size = mpu.get_tensor_model_parallel_world_size()
298298
self.train_tp_rank = mpu.get_tensor_model_parallel_rank()
299299
self.train_tp_group = mpu.get_tensor_model_parallel_group()
300+
self.train_ep_size = mpu.get_expert_model_parallel_world_size()
301+
self.train_ep_rank = mpu.get_expert_model_parallel_rank()
302+
self.train_ep_group = mpu.get_expert_model_parallel_group()
303+
self.train_etp_size = mpu.get_expert_tensor_parallel_world_size()
304+
self.train_etp_rank = mpu.get_expert_tensor_parallel_rank()
305+
self.train_etp_group = mpu.get_expert_tensor_parallel_group()
300306
self.need_tp_reshard = self.train_tp_size != self.infer_tp_size
301307
self.train_tp_larger = self.train_tp_size > self.infer_tp_size
302308

@@ -353,6 +359,35 @@ def tensor_generator():
353359
while cur_name.startswith("module."):
354360
cur_name = cur_name[len("module.") :]
355361

362+
# EP
363+
if ".mlp.experts.linear_fc" in cur_name and self.train_ep_size > 1:
364+
num_experts = self.weight_converter.mcore_config.num_moe_experts
365+
num_experts_per_rank = num_experts // self.train_ep_size
366+
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(self.train_ep_size)]
367+
torch.distributed.all_gather(infer_params, broad_pp_tensor, group=self.train_ep_group)
368+
369+
name_prefix, local_expert_id = cur_name.split(".weight")
370+
local_expert_id = int(local_expert_id)
371+
global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(self.train_ep_size)]
372+
global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids]
373+
374+
for name, param in zip(global_expert_names, infer_params):
375+
if self.train_etp_size > 1:
376+
# gather etp
377+
etp_params = [torch.empty_like(param) for _ in range(self.train_etp_size)]
378+
torch.distributed.all_gather(etp_params, param, group=self.train_etp_group)
379+
params = etp_params
380+
else:
381+
params = [param]
382+
383+
merge_params = self.default_tp_concat_fn(name, broad_pp_tensor, params, self.model_config, convert_qkv_gate_up_by_simple_split)
384+
if not isinstance(merge_params, list):
385+
merge_params = [merge_params]
386+
converted_names, converted_params = self.weight_converter.convert_param(name, merge_params)
387+
388+
yield from zip(converted_names, converted_params)
389+
continue
390+
356391
# tp all gather
357392
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
358393
# allocate a new tensor with proper size

0 commit comments

Comments
 (0)