Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle import _legacy_C_ops
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.dygraph import layers
Expand Down Expand Up @@ -455,8 +456,6 @@ def program_id(self):
"""
Return current train or eval program hash id.
"""
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

if self.training:
if _in_amp_guard():
return self._train_amp_program_id
Expand All @@ -474,8 +473,6 @@ def program_id(self):

@property
def train_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
Expand All @@ -485,8 +482,6 @@ def train_program(self):

@property
def infer_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
Expand All @@ -496,8 +491,6 @@ def infer_program(self):

@property
def forward_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
Expand All @@ -511,8 +504,6 @@ def forward_program(self):

@property
def backward_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard

if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
Expand Down Expand Up @@ -708,8 +699,6 @@ def _get_double_grads(self, program):
return self._valid_vars(double_grads)

def _cast_fp16_if_pure_fp16(self, in_vars):
from paddle.amp.auto_cast import _in_pure_fp16_guard

if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
Expand Down