Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/grpo_trainer/run_qwen3moe-30b_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
set -x

HF_MODEL_PATH=Qwen/Qwen3-30B-A3B
DIST_CKPT_PATH=${DIST_CKPT_PATH}

python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH

# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=64 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path= \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k_math' \
trainer.experiment_name='qwen3_30b_moe_megatron' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=4 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
2 changes: 2 additions & 0 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
# Parallel configuration
"tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
"pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
"expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(),
"expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(),
"virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
"context_parallel_size": mpu.get_context_parallel_world_size(),
"overlap_p2p_comm": overlap_p2p_comm,
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ actor_rollout_ref:
grad_offload: False
optimizer_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: null
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
Expand All @@ -95,6 +97,8 @@ actor_rollout_ref:
megatron:
param_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: None
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
Expand Down Expand Up @@ -187,6 +191,8 @@ critic:
grad_offload: False
optimizer_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: null
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
Expand Down Expand Up @@ -218,6 +224,8 @@ reward_model:
grad_offload: False
optimizer_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: null
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
Expand Down
33 changes: 33 additions & 0 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@ def per_tensor_generator(actor_module, model_config, weight_converter, layer_nam

pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
ep_size = mpu.get_expert_model_parallel_world_size()
etp_size = mpu.get_expert_tensor_parallel_world_size()
ep_group = mpu.get_expert_model_parallel_group()
etp_group = mpu.get_expert_tensor_parallel_group()
vpp_size = len(actor_module)
all_gather_group = mpu.get_tensor_model_parallel_group()
all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
Expand Down Expand Up @@ -731,6 +735,35 @@ def tensor_generator():
while cur_name.startswith("module."):
cur_name = cur_name[len("module.") :]

# EP
if ".mlp.experts.linear_fc" in cur_name and ep_size > 1:
num_experts = weight_converter.mcore_config.num_moe_experts
num_experts_per_rank = num_experts // ep_size
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)]
torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group)

name_prefix, local_expert_id = cur_name.split(".weight")
local_expert_id = int(local_expert_id)
global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)]
global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids]

for name, param in zip(global_expert_names, infer_params):
if etp_size > 1:
# gather etp
etp_params = [torch.empty_like(param) for _ in range(etp_size)]
torch.distributed.all_gather(etp_params, param, group=etp_group)
params = etp_params
else:
params = [param]

merge_params = default_tp_concat_fn(name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split)
if not isinstance(merge_params, list):
merge_params = [merge_params]
converted_names, converted_params = weight_converter.convert_param(name, merge_params)

yield from zip(converted_names, converted_params)
continue

# tp all gather
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
# allocate a new tensor with proper size
Expand Down
9 changes: 6 additions & 3 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(self, config: DictConfig, role: str):
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=self.config.actor.megatron.context_parallel_size,
expert_model_parallel_size=1,
expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,
expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,
nccl_communicator_config_path=None,
)

Expand Down Expand Up @@ -524,7 +525,8 @@ def __init__(self, config):
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=self.config.megatron.context_parallel_size,
expert_model_parallel_size=1,
expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,
expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,
nccl_communicator_config_path=None,
)

Expand Down Expand Up @@ -723,7 +725,8 @@ def __init__(self, config):
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=self.config.megatron.context_parallel_size,
expert_model_parallel_size=1,
expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,
expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,
nccl_communicator_config_path=None,
)

Expand Down
35 changes: 35 additions & 0 deletions verl/workers/sharding_manager/megatron_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def __init__(
self.train_tp_size = mpu.get_tensor_model_parallel_world_size()
self.train_tp_rank = mpu.get_tensor_model_parallel_rank()
self.train_tp_group = mpu.get_tensor_model_parallel_group()
self.train_ep_size = mpu.get_expert_model_parallel_world_size()
self.train_ep_rank = mpu.get_expert_model_parallel_rank()
self.train_ep_group = mpu.get_expert_model_parallel_group()
self.train_etp_size = mpu.get_expert_tensor_parallel_world_size()
self.train_etp_rank = mpu.get_expert_tensor_parallel_rank()
self.train_etp_group = mpu.get_expert_tensor_parallel_group()
self.need_tp_reshard = self.train_tp_size != self.infer_tp_size
self.train_tp_larger = self.train_tp_size > self.infer_tp_size

Expand Down Expand Up @@ -353,6 +359,35 @@ def tensor_generator():
while cur_name.startswith("module."):
cur_name = cur_name[len("module.") :]

# EP
if ".mlp.experts.linear_fc" in cur_name and self.train_ep_size > 1:
num_experts = self.weight_converter.mcore_config.num_moe_experts
num_experts_per_rank = num_experts // self.train_ep_size
infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(self.train_ep_size)]
torch.distributed.all_gather(infer_params, broad_pp_tensor, group=self.train_ep_group)

name_prefix, local_expert_id = cur_name.split(".weight")
local_expert_id = int(local_expert_id)
global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(self.train_ep_size)]
global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids]

for name, param in zip(global_expert_names, infer_params):
if self.train_etp_size > 1:
# gather etp
etp_params = [torch.empty_like(param) for _ in range(self.train_etp_size)]
torch.distributed.all_gather(etp_params, param, group=self.train_etp_group)
params = etp_params
else:
params = [param]

merge_params = self.default_tp_concat_fn(name, broad_pp_tensor, params, self.model_config, convert_qkv_gate_up_by_simple_split)
if not isinstance(merge_params, list):
merge_params = [merge_params]
converted_names, converted_params = self.weight_converter.convert_param(name, merge_params)

yield from zip(converted_names, converted_params)
continue

# tp all gather
if tp_utils.is_tensor_parallel_param(broad_pp_tensor):
# allocate a new tensor with proper size
Expand Down
Loading