Skip to content

Commit 2aec645

Browse files
committed
support amp in pir dy2st mode.
1 parent 9b76eb2 commit 2aec645

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,7 @@ def _wrap_for_auto(self, model, train_dataloader):
130130

131131
def _wrap_amp_model(self, args, model):
132132
logger.info("Using half precision")
133-
if args.to_static:
134-
return
135-
self.enable_autocast_context_manager = True
136-
self.do_grad_scaling = True if self.args.fp16 else False
137-
self.amp_dtype = "float16" if self.args.fp16 else "bfloat16"
138-
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
133+
self.amp_dtype = "float16" if self.args.fp16 else "bfloat16"
139134
if self.args.fp16_opt_level == "O2":
140135
paddle.amp.decorate(
141136
models=model,
@@ -144,6 +139,11 @@ def _wrap_amp_model(self, args, model):
144139
master_grad=self.args.amp_master_grad,
145140
excluded_layers=QuantizationLinear,
146141
)
142+
if args.to_static:
143+
return
144+
self.enable_autocast_context_manager = True
145+
self.do_grad_scaling = True if self.args.fp16 else False
146+
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
147147

148148
def _get_item_from_loss(self, loss):
149149
if isinstance(loss, paddle.Tensor):

0 commit comments

Comments
 (0)