Skip to content

Commit 54cbdf2

Browse files
authored
fix loss_scale (#4229)
1 parent ebcc2a2 commit 54cbdf2

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

swift/llm/train/pt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from typing import List, Union
33

4+
from swift.utils import get_logger
45
from ..argument import TrainArguments
56
from .sft import SwiftSft
67

8+
logger = get_logger()
9+
710

811
class SwiftPt(SwiftSft):
912
args_class = TrainArguments
1013
args: args_class
1114

1215
def _prepare_template(self) -> None:
1316
self.args.use_chat_template = False
17+
self.args.loss_scale = 'all'
18+
logger.info('Setting args.use_chat_template: False')
19+
logger.info("Setting args.loss_scale: 'all'")
1420
super()._prepare_template()
15-
self.template.loss_scale = 'all'
1621

1722

1823
def pt_main(args: Union[List[str], TrainArguments, None] = None):

swift/megatron/train/pt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from typing import List, Union
33

4+
from swift.utils import get_logger
45
from ..argument import MegatronTrainArguments
56
from .sft import MegatronSft
67

8+
logger = get_logger()
9+
710

811
class MegatronPt(MegatronSft):
912
args_class = MegatronTrainArguments
1013
args: args_class
1114

1215
def _prepare_template(self) -> None:
1316
self.args.use_chat_template = False
17+
self.args.loss_scale = 'all'
18+
logger.info('Setting args.use_chat_template: False')
19+
logger.info("Setting args.loss_scale: 'all'")
1420
super()._prepare_template()
15-
self.template.loss_scale = 'all'
1621

1722

1823
def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):

0 commit comments

Comments
 (0)