Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class GRPOConfig(TrainingArguments):
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
tokens.
cache_implementation (`str`, *optional*, defaults to `None`):
Implementation of the cache method for faster generation when use_vllm is set to False.

> Parameters that control generation acceleration powered by vLLM

Expand Down Expand Up @@ -217,6 +219,10 @@ class GRPOConfig(TrainingArguments):
"to repeat tokens."
},
)
cache_implementation: Optional[str] = field(
default=None,
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def new_group_context():
top_k=args.top_k,
min_p=args.min_p,
repetition_penalty=args.repetition_penalty,
cache_implementation=args.cache_implementation,
)

# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
Expand Down
Loading