File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change 1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
2
from typing import List , Union
3
3
4
+ from swift .utils import get_logger
4
5
from ..argument import TrainArguments
5
6
from .sft import SwiftSft
6
7
8
+ logger = get_logger ()
9
+
7
10
8
11
class SwiftPt (SwiftSft ):
9
12
args_class = TrainArguments
10
13
args : args_class
11
14
12
15
def _prepare_template (self ) -> None :
13
16
self .args .use_chat_template = False
17
+ self .args .loss_scale = 'all'
18
+ logger .info ('Setting args.use_chat_template: False' )
19
+ logger .info ("Setting args.loss_scale: 'all'" )
14
20
super ()._prepare_template ()
15
- self .template .loss_scale = 'all'
16
21
17
22
18
23
def pt_main (args : Union [List [str ], TrainArguments , None ] = None ):
Original file line number Diff line number Diff line change 1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
2
from typing import List , Union
3
3
4
+ from swift .utils import get_logger
4
5
from ..argument import MegatronTrainArguments
5
6
from .sft import MegatronSft
6
7
8
+ logger = get_logger ()
9
+
7
10
8
11
class MegatronPt (MegatronSft ):
9
12
args_class = MegatronTrainArguments
10
13
args : args_class
11
14
12
15
def _prepare_template (self ) -> None :
13
16
self .args .use_chat_template = False
17
+ self .args .loss_scale = 'all'
18
+ logger .info ('Setting args.use_chat_template: False' )
19
+ logger .info ("Setting args.loss_scale: 'all'" )
14
20
super ()._prepare_template ()
15
- self .template .loss_scale = 'all'
16
21
17
22
18
23
def megatron_pt_main (args : Union [List [str ], MegatronTrainArguments , None ] = None ):
You can’t perform that action at this time.
0 commit comments