Skip to content

Commit 880688e

Browse files
committed
add comments
1 parent 7432a21 commit 880688e

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from verl.utils.debug import marked_timer
5757
from verl.utils.metric import reduce_metrics
5858
from verl.utils.rollout_skip import RolloutSkip
59-
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
59+
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
6060
from verl.utils.torch_functional import masked_mean
6161
from verl.utils.tracking import ValidationGenerationsLogger
6262

@@ -905,7 +905,7 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle
905905
attention_mask = batch.batch["attention_mask"]
906906
batch_size = attention_mask.shape[0]
907907
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
908-
global_seqlen_lst = global_seqlen_lst**2 + global_seqlen_lst * 33024
908+
global_seqlen_lst = calculate_workload(global_seqlen_lst)
909909
world_size = self.actor_rollout_wg.world_size
910910
global_partition_lst = get_seqlen_balanced_partitions(
911911
global_seqlen_lst, k_partitions=world_size, equal_size=True

verl/utils/seqlen_balancing.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
from verl.utils.device import get_device_name
2525

2626

27+
def calculate_workload(seqlen_list: list[int]):
28+
"""
29+
Calculate the workload for a dense transformer block based on sequence length.
30+
FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2
31+
Hardcodes the constants by a 7B model (hidden_size=4096),
32+
so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).
33+
"""
34+
return 24576 * seqlen_list + seqlen_list**2
35+
36+
2737
def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):
2838
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
2939
class Set:
@@ -300,8 +310,8 @@ def rearrange_micro_batches(
300310

301311
assert num_micro_batches <= len(seq_len_effective)
302312

303-
# approximate the workload by Attention and MLP FLOPs
304-
workloads = seq_len_effective**2 + seq_len_effective * 33024
313+
# Approximate workload using transformer FLOPs model
314+
workloads = calculate_workload(seq_len_effective)
305315
micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)
306316

307317
if use_dynamic_bsz_balance:
@@ -313,6 +323,7 @@ def rearrange_micro_batches(
313323
),
314324
reverse=True,
315325
)
326+
# Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down.
316327
micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2]
317328

318329
micro_batches = []

0 commit comments

Comments
 (0)