Skip to content

Commit fbfd140

Browse files
authored
[AutoParallel] fix amp o1 (#46391)
1 parent 5437bd9 commit fbfd140

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

python/paddle/distributed/passes/auto_parallel_amp.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,18 @@ def __init__(self, block):
3838
self._op_fp16_dict = {
3939
} # op_id --> True/False. 'True' means that the current op is in fp16 mode.
4040
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
41+
self.is_train = False
4142

4243
def _is_fp16_op(self, op_id):
4344
return self._op_fp16_dict.get(op_id, None)
4445

45-
def _build_stats(self, amp_lists, dist_context):
46+
def _build_state(self, amp_lists, dist_context):
4647
ops = self._block.ops
4748
dist_op_context = dist_context.dist_op_context
4849
for op in ops:
50+
if int(op.attr('op_role')) == 257:
51+
self.is_train = True
52+
4953
if int(op.attr('op_role')) == int(OpRole.Forward):
5054
self._mark_black_white_ops(amp_lists)
5155
elif int(op.attr('op_role')) == int(OpRole.Backward):
@@ -59,6 +63,8 @@ def _build_stats(self, amp_lists, dist_context):
5963
elif int(op.attr('op_role')) == int(OpRole.Optimize):
6064
break
6165

66+
return self.is_train
67+
6268
def _mark_black_white_ops(self, amp_lists):
6369
"""
6470
this function is modified from paddle.fluid.contrib.mixed_precision
@@ -546,23 +552,25 @@ def _apply_single_impl(self, main_program, startup_program, context):
546552
set(self.get_attr("custom_black_list")),
547553
set(self.get_attr("custom_black_varnames")))
548554

549-
amp_state = AMPState(main_program.global_block())
550-
amp_state._build_stats(amp_lists, self.dist_context)
551-
552555
with paddle.static.program_guard(main_program, startup_program):
556+
amp_state = AMPState(main_program.global_block())
557+
is_train = amp_state._build_state(amp_lists, self.dist_context)
558+
553559
amp_state.cast_forward_program(self.dist_context)
554-
amp_state.cast_backward_program(params_grads, self.dist_context)
555-
# TODO (JZ-LIANG)support cast forward program only when inference
556-
self._init_amp_var()
557-
self._scale_loss()
558-
559-
if self.get_attr("use_dynamic_loss_scaling"
560-
) or self.get_attr("init_loss_scaling") != 1.0:
561-
grads, found_inf = _check_and_update_gradient(
562-
params_grads, self._loss_scaling, self.dist_context)
563-
564-
if self.get_attr("use_dynamic_loss_scaling"):
565-
self._update_loss_scaling(grads, found_inf)
560+
561+
if is_train:
562+
with paddle.static.program_guard(main_program, startup_program):
563+
amp_state.cast_backward_program(params_grads, self.dist_context)
564+
self._init_amp_var()
565+
self._scale_loss()
566+
567+
if self.get_attr("use_dynamic_loss_scaling"
568+
) or self.get_attr("init_loss_scaling") != 1.0:
569+
grads, found_inf = _check_and_update_gradient(
570+
params_grads, self._loss_scaling, self.dist_context)
571+
572+
if self.get_attr("use_dynamic_loss_scaling"):
573+
self._update_loss_scaling(grads, found_inf)
566574

567575
def _init_amp_var(self):
568576
self._loss_scaling = paddle.static.create_global_var(

python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def test_amp_pass(self):
9797
3,
9898
batch_size=self.batch_size)
9999
amp_o1_losses = np.array(amp_o1_losses["loss"])
100+
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
100101
# self.check_results(mp_losses, amp_o1_losses)
101102

102103
# mp2 amp-o2 training
@@ -105,6 +106,7 @@ def test_amp_pass(self):
105106
3,
106107
batch_size=self.batch_size)
107108
amp_o2_losses = np.array(amp_o2_losses["loss"])
109+
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
108110
# self.check_results(mp_losses, amp_o2_losses)
109111

110112
# mp2 amp-o3 training
@@ -113,6 +115,7 @@ def test_amp_pass(self):
113115
3,
114116
batch_size=self.batch_size)
115117
amp_o3_losses = np.array(amp_o3_losses["loss"])
118+
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
116119
# self.check_results(mp_losses, amp_o3_losses)
117120

118121

0 commit comments

Comments
 (0)