Skip to content

Commit 85e24bc

Browse files
kashifqgallouedec
andauthored
❤️‍🩹 [CI] fix transformers dev CI failure (huggingface#3176)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 5566f1b commit 85e24bc

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

tests/test_online_dpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_training_with_judge(self, config_name):
244244
@require_torch_accelerator
245245
@unittest.skipIf(not is_vllm_available(), "vllm is not available")
246246
def test_training_with_vllm(self, config_name):
247-
model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We neeed a bigger model
247+
model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model
248248
model = AutoModelForCausalLM.from_pretrained(model_id)
249249
tokenizer = AutoTokenizer.from_pretrained(model_id)
250250
tokenizer.pad_token = tokenizer.eos_token
@@ -253,6 +253,7 @@ def test_training_with_vllm(self, config_name):
253253
training_args = OnlineDPOConfig(
254254
output_dir=tmp_dir,
255255
use_vllm=True,
256+
gpu_memory_utilization=0.2,
256257
report_to="none",
257258
)
258259
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

trl/trainer/online_dpo_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class OnlineDPOConfig(TrainingArguments):
6464
Whether to disable dropout in the model and reference model.
6565
use_vllm (`bool`, *optional*, defaults to `False`):
6666
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
67+
gpu_memory_utilization (`float`, *optional*, defaults to `0.55`):
68+
The vLLM memory utilization. The default value is 0.55.
6769
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
6870
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
6971
improving generation speed. However, disabling this option allows training models that exceed the VRAM
@@ -144,6 +146,12 @@ class OnlineDPOConfig(TrainingArguments):
144146
"(`pip install vllm`)."
145147
},
146148
)
149+
gpu_memory_utilization: Optional[float] = field(
150+
default=0.55,
151+
metadata={
152+
"help": "The vLLM memory utilization. The default value is 0.55.",
153+
},
154+
)
147155
ds3_gather_for_generation: bool = field(
148156
default=True,
149157
metadata={

trl/trainer/online_dpo_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def __init__(
269269
# space for them. Setting gpu_memory_utilization to 0.55 seems to work well in practice.
270270
self.llm = LLM(
271271
model=model.name_or_path,
272-
gpu_memory_utilization=0.55,
272+
gpu_memory_utilization=args.gpu_memory_utilization,
273273
dtype=torch.float32,
274274
# When release by vLLM, we would be able to distribute the model on multiple GPUs
275275
# See https://github.com/vllm-project/vllm/pull/12071
@@ -695,7 +695,9 @@ def training_step(
695695

696696
# Same as Trainer._maybe_log_save_evaluate but log our metrics
697697
# start_time defaults to None to allow compatibility with transformers<=4.46
698-
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
698+
def _maybe_log_save_evaluate(
699+
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
700+
):
699701
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
700702
logs: dict[str, float] = {}
701703

@@ -708,7 +710,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
708710
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
709711
if grad_norm is not None:
710712
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
711-
logs["learning_rate"] = self._get_learning_rate()
713+
if learning_rate is not None:
714+
logs["learning_rate"] = learning_rate
715+
else:
716+
logs["learning_rate"] = self._get_learning_rate()
712717

713718
# Add our metrics
714719
for key, val in self.stats.items():

0 commit comments

Comments
 (0)