Skip to content

Commit ef0ebe3

Browse files
authored
Relaxing requirements for trl (#1342)
1 parent 18b9451 commit ef0ebe3

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

requirements/framework.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ tensorboard
1818
tqdm
1919
transformers>=4.33,<4.43
2020
transformers_stream_generator
21-
trl>=0.9.6
21+
trl>=0.9.4

swift/llm/rlhf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121

2222
def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
23+
if args.rlhf_type == 'simpo':
24+
import trl
25+
from packaging import version
26+
assert version.parse(trl.__version__) <= version.parse('0.9.4'), \
27+
'Please ensure to update `trl` to the latest version using the following command:' \
28+
'pip install trl==0.9.6 --index-url https://pypi.org/simple.'
29+
2330
logger.info(f'args: {args}')
2431
seed_everything(args.seed)
2532
training_args = args.training_args

swift/llm/utils/argument.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,11 @@ def check_loss_type(self):
14951495
'cpo': ['sigmoid', 'hinge', 'ipo', 'simpo'],
14961496
'kto': ['kto', 'bco']
14971497
}
1498+
if self.loss_type == 'kto_pair':
1499+
import trl
1500+
from packaging import version
1501+
if version.parse(trl.__version__) <= version.parse('0.9.4'):
1502+
return
14981503
if self.rlhf_type in supported_loss_types:
14991504
assert self.loss_type in supported_loss_types.get(self.rlhf_type), \
15001505
f"algo {self.rlhf_type} doesn't support loss type {self.loss_type}"

0 commit comments

Comments
 (0)