Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"mean",
"mean_all",
"meshgrid",
"numel",
"one_hot",
"p_norm",
"pow",
Expand All @@ -59,6 +60,7 @@
"sigmoid_cross_entropy_with_logits",
"silu",
"swiglu",
"swish",
"softmax",
"softsign",
"square",
Expand Down Expand Up @@ -101,6 +103,7 @@
"mean",
"mean_all",
"meshgrid",
"numel",
"p_norm",
"pow",
"reciprocal",
Expand All @@ -109,6 +112,7 @@
"sigmoid_cross_entropy_with_logits",
"silu",
"swiglu",
"swish",
"softmax",
"softsign",
"square",
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
'elementwise_pow_grad',
'maximum_grad',
'reduce_as_grad',
'fmax_grad',
'fmin_grad',
'dot_grad',
]

Expand Down Expand Up @@ -150,6 +152,7 @@
'sqrt_grad',
'stack_grad',
'swiglu',
'swish_grad',
] # custom vjp list of composite op

VJP_COMPS = PRIM_VJP + CUSTOM_VJP
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,26 @@ std::vector<Tensor> unstack_decomp(const Tensor& x, int axis, const int num) {
return res;
}

template <typename T>
Tensor numel_decomp(const Tensor& x) {
auto x_shape = x.shape();
if (has_dynamic_shape(x_shape)) {
const Tensor x_shape_tensor = shape<T>(x);
Tensor value = full<T>({1}, 1, x_shape_tensor.dtype());
for (size_t i = 0; i < x_shape.size(); ++i) {
value = value * get_slice<T>(x_shape_tensor, i);
}
return cast<T>(reshape<T>(value, {}), DataType::INT64);
} else {
return full_scalar<T>(x.numel(), DataType::INT64);
}
}

template <typename T>
Tensor swish_decomp(const Tensor& x) {
return x * sigmoid<T>(x);
}

} // namespace details

} // namespace primitive
Expand Down
94 changes: 94 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -2750,6 +2750,100 @@ void atan_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void swish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
const Tensor one = full_scalar<T>(1.0, x.dtype());
const Tensor sig = sigmoid<T>(x);
Tensor res = out_grad * sig * (one + x * (one - sig));
set_output<T>(res, x_grad);
}
}

template <typename T>
void fmax_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* y_grad) {
const Tensor nan_x = isnan<T>(x);
const Tensor nan_y = isnan<T>(y);
Tensor mask_x = backend::logical_or<T>(nan_y, greater_equal<T>(x, y));
Tensor mask_y = backend::logical_not<T>(mask_x);

if (x_grad) {
Tensor dx = cast<T>(mask_x, out_grad.dtype()) * out_grad;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
dx = reduce_as<T>(dx, x);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
Tensor dx_reduce_res =
dx.sum(common::vectorize(reduce_dim), x.dtype(), false);
dx = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
}
set_output<T>(dx, x_grad);
}

if (y_grad) {
Tensor dy = cast<T>(mask_y, out_grad.dtype()) * out_grad;
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
dy = reduce_as<T>(dy, x);
} else {
if (out_grad.dims() != y.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), y.dims());
Tensor dy_reduce_res =
dy.sum(common::vectorize(reduce_dim), y.dtype(), false);
dy = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
}
set_output<T>(dy, y_grad);
}
}

template <typename T>
void fmin_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* y_grad) {
const Tensor nan_x = isnan<T>(x);
const Tensor nan_y = isnan<T>(y);
Tensor mask_x = backend::logical_or<T>(nan_y, less_equal<T>(x, y));
Tensor mask_y = backend::logical_not<T>(mask_x);

if (x_grad) {
Tensor dx = cast<T>(mask_x, out_grad.dtype()) * out_grad;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
dx = reduce_as<T>(dx, x);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
Tensor dx_reduce_res =
dx.sum(common::vectorize(reduce_dim), x.dtype(), false);
dx = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
}
set_output<T>(dx, x_grad);
}

if (y_grad) {
Tensor dy = cast<T>(mask_y, out_grad.dtype()) * out_grad;
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
dy = reduce_as<T>(dy, x);
} else {
if (out_grad.dims() != y.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), y.dims());
Tensor dy_reduce_res =
dy.sum(common::vectorize(reduce_dim), y.dtype(), false);
dy = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
}
set_output<T>(dy, y_grad);
}
}

template <typename T>
void dot_grad(const Tensor& x,
const Tensor& y,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"pd_op.exp",
"pd_op.expand",
"pd_op.floor",
"pd_op.fmax",
"pd_op.fmin",
"pd_op.gather",
"pd_op.gather_nd",
"pd_op.gelu",
Expand Down Expand Up @@ -83,6 +85,7 @@
"pd_op.subtract",
"pd_op.sum",
"pd_op.swiglu",
"pd_op.swish",
"pd_op.tanh",
"pd_op.topk",
"pd_op.unsqueeze",
Expand Down
8 changes: 7 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5224,7 +5224,9 @@ def ref_swish(x):
class TestSwish(TestActivation):
def setUp(self):
self.op_type = "swish"
self.prim_op_type = "comp"
self.python_api = paddle.nn.functional.swish
self.public_python_api = paddle.nn.functional.swish
self.init_dtype()
self.init_shape()

Expand All @@ -5244,7 +5246,11 @@ def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'], 'Out', check_pir=True, check_pir_onednn=self.check_pir_onednn
['X'],
'Out',
check_pir=True,
check_pir_onednn=self.check_pir_onednn,
check_prim_pir=True,
)


Expand Down
18 changes: 14 additions & 4 deletions test/legacy_test/test_fmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ class TestElementwiseFmaxOp(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
self.prim_op_type = "prim"
self.python_api = paddle.fmax
self.public_python_api = paddle.fmax
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -149,7 +151,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)

def test_check_grad_ignore_x(self):
"""test_check_grad_ignore_x"""
Expand Down Expand Up @@ -178,7 +180,9 @@ class TestElementwiseFmax2Op(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
self.prim_op_type = "prim"
self.python_api = paddle.fmax
self.public_python_api = paddle.fmax
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -196,7 +200,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)

def test_check_grad_ignore_x(self):
"""test_check_grad_ignore_x"""
Expand Down Expand Up @@ -225,7 +229,9 @@ class TestElementwiseFmax3Op(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
self.prim_op_type = "prim"
self.python_api = paddle.fmax
self.public_python_api = paddle.fmax
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -242,7 +248,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)


@unittest.skipIf(
Expand All @@ -253,7 +259,9 @@ def test_check_grad_normal(self):
class TestFmaxBF16OP(OpTest):
def setUp(self):
self.op_type = "elementwise_fmax"
self.prim_op_type = "prim"
self.python_api = paddle.fmax
self.public_python_api = paddle.fmax
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [13, 17]).astype("float32")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float32")
Expand All @@ -271,7 +279,9 @@ def test_check_output(self):

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out', check_pir=True)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True
)


if __name__ == "__main__":
Expand Down
18 changes: 14 additions & 4 deletions test/legacy_test/test_fmin_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ class TestElementwiseFminOp(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
self.prim_op_type = "prim"
self.python_api = paddle.fmin
self.public_python_api = paddle.fmin
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -151,7 +153,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)

def test_check_grad_ignore_x(self):
"""test_check_grad_ignore_x"""
Expand Down Expand Up @@ -180,7 +182,9 @@ class TestElementwiseFmin2Op(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
self.prim_op_type = "prim"
self.python_api = paddle.fmin
self.public_python_api = paddle.fmin
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -198,7 +202,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)

def test_check_grad_ignore_x(self):
"""test_check_grad_ignore_x"""
Expand Down Expand Up @@ -227,7 +231,9 @@ class TestElementwiseFmin3Op(OpTest):
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
self.prim_op_type = "prim"
self.python_api = paddle.fmin
self.public_python_api = paddle.fmin
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -244,7 +250,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True)


@unittest.skipIf(
Expand All @@ -255,7 +261,9 @@ def test_check_grad_normal(self):
class TestFminBF16OP(OpTest):
def setUp(self):
self.op_type = "elementwise_fmin"
self.prim_op_type = "prim"
self.python_api = paddle.fmin
self.public_python_api = paddle.fmin
self.dtype = np.uint16
x = np.random.uniform(1, 1, [13, 17]).astype("float32")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float32")
Expand All @@ -273,7 +281,9 @@ def test_check_output(self):

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out', check_pir=True)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True
)


if __name__ == "__main__":
Expand Down
Loading