Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_input_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_run(self):
)
x = paddle.randn([2, 10])
out = net(x)
np.testing.assert_equal(out.shape, [2, 5])
np.testing.assert_equal(net.forward._input_spec, None)


if __name__ == '__main__':
Expand Down
24 changes: 21 additions & 3 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,23 @@ def ignore_module(modules: list[Any]):
add_ignore_module(modules)


def _check_and_set_backend(backend, build_strategy):
if backend not in ['CINN', None]:
raise ValueError(
"The backend of to_static should be 'CINN' or None, but received {}.".format(
backend
)
)
if backend == 'CINN':
build_strategy.build_cinn_pass = True


def to_static(
function=None, input_spec=None, build_strategy=None, property=False
function=None,
input_spec=None,
build_strategy=None,
backend=None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

property 参数从positional/keyword arg 皆可, 改成了 keyword-only arg,是不兼容升级了。
检查过其影响吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

property参数目前只有语音套件会使用,已经和语音侧同步,语音侧使用方式都是通过k=v的方式传入该参数,所以对主框架以及套件不会产生其他影响。不兼容升级的邮件已经发送,评委已经通过

):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
Expand All @@ -228,7 +243,6 @@ def to_static(
Tensor(s) to do imperative training, inference, or other operations. If the
decorated function calls other imperative function, the called one will be
converted into declarative function as well.

Args:
function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
Expand All @@ -238,7 +252,8 @@ def to_static(
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.
property(bool, Optional): whether the fucntion is python property. The default is False.
backend(str, Optional): Specifies compilation backend, which can be `CINN` or None. When backend is `CINN`, CINN compiler will be used to speed up training and inference.
kwargs: Support keys including `property`, set `property` to True if the fucntion is python property.


Returns:
Expand All @@ -263,6 +278,7 @@ def func(x):
print(x_v) # [[2. 2.]]

"""
property = kwargs.get("property", False)

def decorated(python_func):
"""
Expand All @@ -279,6 +295,7 @@ def decorated(python_func):
input_spec=input_spec,
build_strategy=build_strategy,
property=property,
backend=backend,
),
)

Expand All @@ -291,6 +308,7 @@ def decorated(python_func):
type(build_strategy).__name__
)
)
_check_and_set_backend(backend, build_strategy)

# for usage: `to_static(foo, ...)`
if function is not None:
Expand Down
13 changes: 9 additions & 4 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from paddle.optimizer.lr import LRScheduler

from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
_out_grad_names,
_param_grad_names,
dy2st_prim_guard,
)

__all__ = []

Expand Down Expand Up @@ -197,6 +202,7 @@ def __init__(
# program_id -> list(scope)
self._scope_cache = {}
self._hooker = None
self._backend = kwargs.get('backend', None)

def __call__(self, inputs):
"""
Expand Down Expand Up @@ -636,10 +642,9 @@ def _append_backward_desc(self, main_program):

start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
backward.gradients(targets=targets, inputs=[])
with dy2st_prim_guard(self._backend):
backward.gradients(targets=targets, inputs=[])

if self._hooker:
program, start_idx = self._hooker.after_append_backward(
Expand Down
68 changes: 38 additions & 30 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ALREADY_D2S,
ast_to_func,
ast_to_source_code,
dy2st_prim_guard,
func_to_source_code,
input_specs_compatible,
is_paddle_func,
Expand Down Expand Up @@ -333,7 +334,7 @@ def __init__(self, function, input_spec=None, **kwargs):
self._class_instance = None

if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None)
kwargs.get("build_strategy", None), kwargs.get("backend", None)
):
from paddle.static import InputSpec

Expand Down Expand Up @@ -1183,11 +1184,9 @@ def __init__(self):
def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.

# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim
core.check_and_set_prim_all_enabled()
try:
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
Expand Down Expand Up @@ -1215,7 +1214,8 @@ def _build_once(self, cache_key):
else:
raise

if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy']):
backend = cache_key.kwargs['backend']
if prim_or_cinn_is_enabled(cache_key.kwargs['build_strategy'], backend):
for var in concrete_program.main_program.list_vars():
if -1 in var.shape:
warnings.warn(
Expand All @@ -1227,10 +1227,11 @@ def _build_once(self, cache_key):
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
if core._is_fwd_prim_enabled():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
)
with dy2st_prim_guard(backend):
if core._is_fwd_prim_enabled():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program, backend)
)
return concrete_program, partial_program

def __getitem__(self, item):
Expand Down Expand Up @@ -1290,39 +1291,46 @@ def clear(self):


class PrimHooker(PartialProgramLayerHook):
def __init__(self, original_program):
def __init__(self, original_program, backend):
if len(original_program.blocks) > 1:
raise ValueError(
'The primitive mode only support one block currently.'
)
self.backend = backend
self.custom_vjps = set()
if core._is_all_prim_enabled():
self.custom_vjps = {
op.type
for op in original_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
with dy2st_prim_guard(self.backend):
if core._is_all_prim_enabled():
self.custom_vjps = {
op.type
for op in original_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}

def before_append_backward(self, forward_program):
if core._is_fwd_prim_enabled():
_to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program
with dy2st_prim_guard(self.backend):
if core._is_fwd_prim_enabled():
_to_prim(forward_program.blocks, blacklist=self.custom_vjps)
return forward_program

def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
# only process backward part of block
_to_prim(whole_program.blocks, backward_length=backward_length)
new_start_index = len(whole_program.block(0).ops) - backward_length
if backward_length > 0:
# only process forward part of block
_to_prim(whole_program.blocks, start_idx=new_start_index)
return whole_program, new_start_index
with dy2st_prim_guard(self.backend):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
# only process backward part of block
_to_prim(whole_program.blocks, backward_length=backward_length)
new_start_index = len(whole_program.block(0).ops) - backward_length
if backward_length > 0:
# only process forward part of block
_to_prim(whole_program.blocks, start_idx=new_start_index)
return whole_program, new_start_index

def after_infer(self, infer_program):
if core._is_fwd_prim_enabled():
_to_prim(infer_program.block(0))
return infer_program
with dy2st_prim_guard(self.backend):
if core._is_fwd_prim_enabled():
_to_prim(infer_program.block(0))
return infer_program


class ProgramTranslator:
Expand Down
21 changes: 20 additions & 1 deletion python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from paddle.fluid import core, unique_name
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.utils import gast

from .ast_utils import ast_to_source_code
Expand Down Expand Up @@ -1491,7 +1492,10 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
return names


def prim_or_cinn_is_enabled(build_strategy):
def prim_or_cinn_is_enabled(build_strategy, backend):
if backend == 'CINN':
return True

if build_strategy is not None and build_strategy.build_cinn_pass:
return True

Expand Down Expand Up @@ -1527,3 +1531,18 @@ def name_judge():
return True
else:
return False


@signature_safe_contextmanager
def dy2st_prim_guard(backend):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_guard is more self explained name

Copy link
Contributor Author

@0x45f 0x45f Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在这个guard只是在切换prim的状态,现在的dy2st_prim_guard是不是更具体一些?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是你这个部分本意是根据backend修改状态,其实不应该叫dy2st_prim_guard,而应该叫backend_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,感谢~

core.check_and_set_prim_all_enabled()
orign_fwd = core._is_fwd_prim_enabled()
orign_bwd = core._is_bwd_prim_enabled()

if backend == 'CINN':
core._set_prim_all_enabled(True)
try:
yield
finally:
core._set_prim_forward_enabled(orign_fwd)
core._set_prim_backward_enabled(orign_bwd)
15 changes: 15 additions & 0 deletions test/dygraph_to_static/test_cinn_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,20 @@ def test_cinn_prim(self):
)


class TestBackend(unittest.TestCase):
def test_backend(self):
x = paddle.randn([2, 4])
out1 = self.forward(x, 'CINN')
out2 = self.forward(x, None)
np.testing.assert_allclose(out1, out2, rtol=1e-6)

def forward(self, x, beckend=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我下面提个PR再修改,感谢~

paddle.seed(2022)
net = PrimeNet()
net = paddle.jit.to_static(net, backend=beckend)
out = net(x)
return out


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_partial_program_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def f():
f
).get_concrete_program()
self._hook = program_translator.PrimHooker(
concrete_program.main_program
concrete_program.main_program, None
)
self._forward = partial_program.forward_program
self._whole = partial_program._train_program
Expand Down