Skip to content

Commit 2bf4847

Browse files
📋 Allow calling trl cli in sft mode with config file (#3380)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent a8cfca6 commit 2bf4847

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

tests/test_cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515

16+
import os
1617
import sys
1718
import tempfile
1819
import unittest
1920
from io import StringIO
2021
from unittest.mock import patch
2122

23+
import yaml
24+
2225

2326
@unittest.skipIf(
2427
sys.version_info < (3, 10),
@@ -67,6 +70,33 @@ def test_sft(self):
6770
with patch("sys.argv", command.split(" ")):
6871
main()
6972

73+
def test_sft_config_file(self):
74+
from trl.cli import main
75+
76+
with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory
77+
output_dir = os.path.join(tmp_dir, "output")
78+
79+
# Create a temporary config file
80+
config_path = os.path.join(tmp_dir, "config.yaml")
81+
config_content = {
82+
"model_name_or_path": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
83+
"dataset_name": "trl-internal-testing/zen",
84+
"dataset_config": "standard_language_modeling",
85+
"report_to": "none",
86+
"output_dir": output_dir,
87+
"lr_scheduler_type": "cosine_with_restarts",
88+
}
89+
with open(config_path, "w") as config_file:
90+
yaml.dump(config_content, config_file)
91+
92+
# Test the CLI with config file
93+
command = f"trl sft --config {config_path}"
94+
with patch("sys.argv", command.split(" ")):
95+
main()
96+
97+
# Verify that output directory was created
98+
self.assertTrue(os.path.exists(output_dir))
99+
70100

71101
if __name__ == "__main__":
72102
unittest.main()

trl/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
make_vllm_serve_parser(subparsers)
4747

4848
# Parse the arguments
49-
args = parser.parse_args()
49+
args = parser.parse_args_and_config()[0]
5050

5151
if args.command == "chat":
5252
(chat_args,) = parser.parse_args_and_config()

trl/scripts/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ScriptArguments:
5151
type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
5252
"""
5353

54-
dataset_name: str = field(metadata={"help": "Dataset name."})
54+
dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."})
5555
dataset_config: Optional[str] = field(
5656
default=None,
5757
metadata={

0 commit comments

Comments
 (0)