Skip to content

Commit dd8864f

Browse files
authored
[megatron] feat: script of qwen3vl 235b (volcengine#3799)
an example script
1 parent ae5d850 commit dd8864f

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
set -x
2+
ENGINE=${1:-vllm}
3+
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
4+
5+
# VLLM version >= 0.11.0 for qwen3-vl support, recommend to use container docker://iseekyan/verl:nemo.gptoss_vllm0.11.0
6+
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git # for latest mbridge
7+
# pip install -U transformers # for qwen3-vl support
8+
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 # for megatron-lm0.13.1
9+
10+
11+
export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP
12+
13+
14+
HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-235B-A22B-Instruct"}
15+
16+
17+
train_path=$HOME/data/geo3k/train.parquet
18+
test_path=$HOME/data/geo3k/test.parquet
19+
20+
python3 -m verl.trainer.main_ppo --config-path=config \
21+
--config-name='ppo_megatron_trainer.yaml'\
22+
algorithm.adv_estimator=grpo \
23+
data.train_files="$train_path" \
24+
data.val_files="$test_path" \
25+
data.train_batch_size=512 \
26+
data.max_prompt_length=1024 \
27+
data.max_response_length=2048 \
28+
data.filter_overlong_prompts=True \
29+
data.truncation='error' \
30+
actor_rollout_ref.model.path=$HF_MODEL_PATH \
31+
actor_rollout_ref.actor.optim.lr=1e-6 \
32+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
33+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
34+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
35+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=8 \
36+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
37+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \
38+
actor_rollout_ref.actor.use_kl_loss=True \
39+
actor_rollout_ref.actor.kl_loss_coef=0.01 \
40+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
41+
actor_rollout_ref.actor.entropy_coeff=0 \
42+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
43+
actor_rollout_ref.rollout.tensor_model_parallel_size=16 \
44+
actor_rollout_ref.actor.use_dynamic_bsz=True \
45+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \
46+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
47+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=5120 \
48+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
49+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=5120 \
50+
actor_rollout_ref.rollout.name=$ENGINE \
51+
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
52+
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
53+
actor_rollout_ref.rollout.n=5 \
54+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
55+
actor_rollout_ref.actor.megatron.use_mbridge=True \
56+
actor_rollout_ref.actor.megatron.param_offload=True \
57+
actor_rollout_ref.actor.megatron.optimizer_offload=True \
58+
actor_rollout_ref.actor.megatron.grad_offload=True \
59+
actor_rollout_ref.ref.megatron.param_offload=True \
60+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
61+
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
62+
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
63+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
64+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
65+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
66+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
67+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
68+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
69+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
70+
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
71+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
72+
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True \
73+
+actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True \
74+
algorithm.use_kl_in_reward=False \
75+
trainer.critic_warmup=0 \
76+
trainer.logger='["console","wandb"]' \
77+
trainer.project_name='verl_grpo_example_geo3k' \
78+
trainer.experiment_name='qwen3_vl_235b_megatron' \
79+
trainer.n_gpus_per_node=8 \
80+
trainer.nnodes=8 \
81+
trainer.save_freq=20 \
82+
trainer.test_freq=5 \
83+
trainer.total_epochs=15 $@

0 commit comments

Comments
 (0)