Skip to content

Commit fc4dae2

Browse files
kashifqgallouedec
andauthored
🫣 [GRPO] add cache_implementation option in GRPO (#3075)
* add cache_implementation option in GRPO * add cache_implementation to config * Update trl/trainer/grpo_config.py Co-authored-by: Quentin Gallouédec <[email protected]> --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent e4e5671 commit fc4dae2

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

trl/trainer/grpo_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class GRPOConfig(TrainingArguments):
7272
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
7373
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
7474
tokens.
75+
cache_implementation (`str` or `None`, *optional*, defaults to `None`):
76+
Implementation of the cache method for faster generation when use_vllm is set to False.
7577
7678
> Parameters that control generation acceleration powered by vLLM
7779
@@ -217,6 +219,10 @@ class GRPOConfig(TrainingArguments):
217219
"to repeat tokens."
218220
},
219221
)
222+
cache_implementation: Optional[str] = field(
223+
default=None,
224+
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
225+
)
220226

221227
# Parameters that control generation acceleration powered by vLLM
222228
use_vllm: Optional[bool] = field(

trl/trainer/grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def new_group_context():
548548
top_k=args.top_k,
549549
min_p=args.min_p,
550550
repetition_penalty=args.repetition_penalty,
551+
cache_implementation=args.cache_implementation,
551552
)
552553

553554
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the

0 commit comments

Comments
 (0)