Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setUp(self):

def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
Expand Down Expand Up @@ -109,7 +109,7 @@ def cal_composite_grad(self, inputs):
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
core._set_prim_forward_enabled(False)
return res

def compare_backward(self):
Expand Down Expand Up @@ -142,12 +142,13 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):

def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float32"]
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]

def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
Expand All @@ -164,6 +165,7 @@ def cal_composite_grad(self, inputs):
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res

def compare_backward(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

if(WITH_CINN)
set_tests_properties(test_prim_flags_case PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties(test_prim_flags_case PROPERTIES TIMEOUT 300)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import platform
import unittest

import paddle
import paddle.nn.functional as F
from paddle.fluid import core


def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)


class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()

def forward(self, x):
out = F.softmax(x)
res = paddle.exp(out)
return res


class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""

def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
self.flag = None

def reset_env_flag(self):
os.environ["FLAGS_prim_backward"] = "False"
os.environ["FLAGS_prim_forward"] = "False"
if os.getenv("FLAGS_prim_all"):
del os.environ["FLAGS_prim_all"]
core.check_and_set_prim_all_enabled()

def train(self, use_cinn):
net = PrimeNet()
net = apply_to_static(net, use_cinn)

out = net(self.x)
loss = paddle.mean(out)
loss.backward()

self.check_prim(net)

return

def check_prim(self, net):
ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
.train_program.block(0)
.ops
]

if self.flag in ["prim_all", "cinn_prim_all"]:
self.assertTrue('softmax' not in ops)
self.assertTrue('exp_grad' not in ops)
elif self.flag in ["prim_forward", "cinn_prim_forward"]:
self.assertTrue('softmax' not in ops)
self.assertTrue('exp_grad' in ops)
elif self.flag in ["prim_backward", "cinn_prim_backward"]:
self.assertTrue('softmax' in ops)
self.assertTrue('exp_grad' not in ops)
elif self.flag == "cinn":
self.assertTrue('softmax' in ops)
self.assertTrue('exp_grad' in ops)
else:
raise TypeError

def test_cinn_prim_all(self):
"""cinn + prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "cinn_prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass

def test_prim_all(self):
"""prim forward + prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_all"] = "True"
self.flag = "prim_all"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass

def test_cinn_prim_forward(self):
"""cinn + prim forward"""

self.reset_env_flag()

os.environ["FLAGS_prim_forward"] = "True"
self.flag = "cinn_prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass

def test_prim_forward(self):
"""only prim forward"""
self.reset_env_flag()
os.environ["FLAGS_prim_forward"] = "True"
self.flag = "prim_forward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass

def test_cinn_prim_backward(self):
"""cinn + prim_backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "cinn_prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass

def test_prim_backward(self):
"""only prim backward"""
self.reset_env_flag()
os.environ["FLAGS_prim_backward"] = "True"
self.flag = "prim_backward"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=False)
else:
pass

def test_cinn(self):
"""only cinn"""
self.reset_env_flag()
self.flag = "cinn"
plat = platform.system()
if plat == "Linux":
_ = self.train(use_cinn=True)
else:
pass


if __name__ == '__main__':
unittest.main()
5 changes: 2 additions & 3 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,8 @@ def _append_backward_desc(self, main_program):
targets.append(program.global_block().var(out.name))

if targets:
if self._build_strategy.build_cinn_pass:
# TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled()
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[])

start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())
Expand Down
5 changes: 2 additions & 3 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,9 +1140,8 @@ def _build_once(self, cache_key):
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim
if enable_prim:
# TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled()
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
try:
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
Expand Down