Skip to content

Commit 6a6d434

Browse files
authored
Add paranthesis to correct the check. (#3658)
1 parent 79ec242 commit 6a6d434

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trl/trainer/grpo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def __post_init__(self):
558558
if self.generation_batch_size is None:
559559
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
560560

561-
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
561+
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
562562
raise ValueError(
563563
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
564564
f"({self.per_device_train_batch_size * num_processes})."

0 commit comments

Comments
 (0)