Skip to content

Commit a80ed95

Browse files
authored
[trainer] fix: batch size mismatch with n>1 when gen_max for ReMax (volcengine#3779)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Resolves volcengine#3408 We should not repeat directly on top of `gen_batch`, as in `gen_max`, we need the original `gen_batch` so that we can disable `do_sample` for rollout to calculate the reward baseline. ```log File "verl/trainer/main_ppo.py", line 317, in run trainer.fit() File "verl/trainer/ppo/ray_trainer.py", line 1065, in fit batch = batch.union(gen_baseline_output) File "verl/protocol.py", line 802, in union self.batch = union_tensor_dict(self.batch, other.batch) File "verl/protocol.py", line 110, in union_tensor_dict assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( AssertionError: Two tensor dict must have identical batch size. Got torch.Size([4096]) and torch.Size([16384]) ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
1 parent 9078a53 commit a80ed95

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

recipe/dapo/dapo_ray_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,16 @@ def fit(self):
132132
batch_keys=["input_ids", "attention_mask", "position_ids"],
133133
non_tensor_batch_keys=["raw_prompt_ids"],
134134
)
135-
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
135+
gen_batch_output = gen_batch.repeat(
136+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
137+
)
136138

137139
is_last_step = self.global_steps >= self.total_training_steps
138140

139141
with marked_timer("step", timing_raw):
140142
# generate a batch
141143
with marked_timer("gen", timing_raw, "red"):
142-
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
144+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
143145
timing_raw.update(gen_batch_output.meta_info["timing"])
144146
gen_batch_output.meta_info.pop("timing", None)
145147

recipe/entropy/entropy_ray_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,19 @@ def fit(self):
108108
batch_keys=["input_ids", "attention_mask", "position_ids"],
109109
non_tensor_batch_keys=["raw_prompt_ids"],
110110
)
111-
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
111+
gen_batch_output = gen_batch.repeat(
112+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
113+
)
112114

113115
is_last_step = self.global_steps >= self.total_training_steps
114116

115117
with simple_timer("step", timing_raw):
116118
# generate a batch
117-
# with simple_timer("gen", timing_raw):
118-
# gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
119119
with simple_timer("gen", timing_raw):
120120
if not self.async_rollout_mode:
121-
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
121+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
122122
else:
123-
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
123+
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
124124

125125
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
126126
with simple_timer("gen_max", timing_raw):

recipe/prime/prime_ray_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,14 @@ def fit(self):
373373

374374
# pop those keys for generation
375375
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
376-
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
376+
gen_batch_output = gen_batch.repeat(
377+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
378+
)
377379

378380
with simple_timer("step", timing_raw):
379381
# generate a batch
380382
with simple_timer("gen", timing_raw):
381-
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
383+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
382384
timing_raw.update(gen_batch_output.meta_info["timing"])
383385
gen_batch_output.meta_info.pop("timing", None)
384386

recipe/sppo/sppo_ray_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,19 @@ def fit(self):
182182
batch_keys=batch_keys_to_pop,
183183
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
184184
)
185-
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
185+
gen_batch_output = gen_batch.repeat(
186+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
187+
)
186188

187189
is_last_step = self.global_steps >= self.total_training_steps
188190

189191
with simple_timer("step", timing_raw):
190192
# generate a batch
191193
with simple_timer("gen", timing_raw):
192194
if not self.async_rollout_mode:
193-
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
195+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
194196
else:
195-
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
197+
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
196198
timing_raw.update(gen_batch_output.meta_info["timing"])
197199
gen_batch_output.meta_info.pop("timing", None)
198200

verl/trainer/ppo/ray_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,16 +1037,18 @@ def fit(self):
10371037

10381038
# pass global_steps to trace
10391039
gen_batch.meta_info["global_steps"] = self.global_steps
1040-
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
1040+
gen_batch_output = gen_batch.repeat(
1041+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
1042+
)
10411043

10421044
is_last_step = self.global_steps >= self.total_training_steps
10431045
with marked_timer("step", timing_raw):
10441046
# generate a batch
10451047
with marked_timer("gen", timing_raw, color="red"):
10461048
if not self.async_rollout_mode:
1047-
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
1049+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
10481050
else:
1049-
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
1051+
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
10501052

10511053
timing_raw.update(gen_batch_output.meta_info["timing"])
10521054
gen_batch_output.meta_info.pop("timing", None)

0 commit comments

Comments
 (0)