Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union

import transformers
from packaging import version
from transformers import TrainingArguments


Expand All @@ -34,7 +35,7 @@ class GRPOConfig(TrainingArguments):
Parameters:
> Parameters that control the model and reference model

model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
model_init_kwargs (`str, dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`GRPOTrainer`] is provided as a string.

Expand Down Expand Up @@ -143,8 +144,15 @@ class GRPOConfig(TrainingArguments):
prompts are logged.
"""

if version.parse(transformers.__version__) <= version.parse("4.50.3"):
from transformers.training_args import _VALID_DICT_FIELDS

_VALID_DICT_FIELDS.append("model_init_kwargs")
else:
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]

# Parameters that control the model and reference model
model_init_kwargs: Optional[dict] = field(
model_init_kwargs: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
Expand Down