Skip to content

Commit 446d55b

Browse files
authored
Fix (#67894)
1 parent d722a24 commit 446d55b

File tree

10 files changed

+308
-72
lines changed

10 files changed

+308
-72
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,7 @@ op_library(quantize_linear_op DEPS phi common)
8484
op_library(save_combine_op DEPS string_array phi common)
8585
op_library(load_combine_op DEPS string_array)
8686

87-
if (WITH_GPU OR WITH_ROCM)
88-
op_library(activation_op SRCS activation_op.cc soft_relu_op.cu DEPS ${OP_HEADER_DEPS})
89-
elseif (WITH_XPU_KP)
90-
op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS})
91-
else()
92-
op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS})
93-
endif()
87+
op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS})
9488

9589
if (WITH_GPU OR WITH_ROCM)
9690
op_library(sync_batch_norm_op DEPS phi common)

paddle/fluid/operators/activation_op.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -376,18 +376,6 @@ namespace ops = paddle::operators;
376376

377377
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
378378

379-
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name) \
380-
PD_REGISTER_STRUCT_KERNEL( \
381-
act_type, CPU, ALL_LAYOUT, ops::op_name##Kernel, float, double) {} \
382-
PD_REGISTER_STRUCT_KERNEL(act_type##_grad, \
383-
CPU, \
384-
ALL_LAYOUT, \
385-
ops::op_name##GradKernel, \
386-
float, \
387-
double) {}
388-
389-
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
390-
391379
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
392380

393381
/* ========================== register checkpoint ===========================*/

paddle/fluid/operators/soft_relu_op.cu

Lines changed: 0 additions & 50 deletions
This file was deleted.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <glog/logging.h>
16+
17+
#include <algorithm>
18+
#include <cmath>
19+
#include <memory>
20+
#include <string>
21+
#include <unordered_set>
22+
#include <utility>
23+
#include <vector>
24+
#ifndef _USE_MATH_DEFINES
25+
#define _USE_MATH_DEFINES
26+
#endif
27+
28+
#include <type_traits>
29+
30+
#include "paddle/phi/common/float16.h"
31+
#include "paddle/phi/core/kernel_registry.h"
32+
#include "paddle/phi/core/tensor_utils.h"
33+
#include "paddle/phi/kernels/funcs/activation_functor.h"
34+
#include "paddle/phi/kernels/funcs/blas/blas.h"
35+
#include "paddle/phi/kernels/funcs/eigen/common.h"
36+
37+
namespace phi {
38+
39+
template <typename T>
40+
struct SoftReluGradFunctor {
41+
float threshold;
42+
void SetAttrs(float threshold_) { threshold = threshold_; }
43+
44+
template <typename Device,
45+
typename X,
46+
typename Out,
47+
typename dOut,
48+
typename dX>
49+
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) {
50+
auto tmp = static_cast<T>(threshold);
51+
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
52+
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
53+
}
54+
};
55+
56+
template <typename T, typename Context>
57+
void SoftmaxGradKernel(const Context& dev_ctx,
58+
const DenseTensor& x_in,
59+
const DenseTensor& out_in,
60+
const DenseTensor& out_grad,
61+
float threshold,
62+
DenseTensor* x_grad) {
63+
dev_ctx.template Alloc<T>(x_grad);
64+
auto dout = phi::EigenVector<T>::Flatten(out_grad);
65+
auto out = phi::EigenVector<T>::Flatten(out_in);
66+
auto dx = phi::EigenVector<T>::Flatten(*x_grad);
67+
auto x = phi::EigenVector<T>::Flatten(x_in);
68+
auto* eigen_dev = dev_ctx.eigen_device();
69+
SoftReluGradFunctor<T> functor;
70+
functor.SetAttrs(threshold);
71+
// use 32bit index to speed up computation
72+
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
73+
bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
74+
if (use_32bit_index && is_gpu_place) {
75+
functor(*eigen_dev,
76+
To32BitIndex(x),
77+
To32BitIndex(out),
78+
To32BitIndex(dout),
79+
To32BitIndex(dx));
80+
} else {
81+
functor(*eigen_dev, x, out, dout, dx);
82+
}
83+
}
84+
} // namespace phi
85+
86+
PD_REGISTER_KERNEL(
87+
soft_relu_grad, CPU, ALL_LAYOUT, phi::SoftmaxGradKernel, float, double) {}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <glog/logging.h>
16+
17+
#include <algorithm>
18+
#include <cmath>
19+
#include <memory>
20+
#include <string>
21+
#include <unordered_set>
22+
#include <utility>
23+
#include <vector>
24+
#ifndef _USE_MATH_DEFINES
25+
#define _USE_MATH_DEFINES
26+
#endif
27+
28+
#include <type_traits>
29+
30+
#include "paddle/phi/common/float16.h"
31+
#include "paddle/phi/core/kernel_registry.h"
32+
#include "paddle/phi/core/tensor_utils.h"
33+
#include "paddle/phi/kernels/funcs/activation_functor.h"
34+
#include "paddle/phi/kernels/funcs/blas/blas.h"
35+
#include "paddle/phi/kernels/funcs/eigen/common.h"
36+
37+
namespace phi {
38+
39+
template <typename T>
40+
struct SoftReluFunctor {
41+
float threshold;
42+
void SetAttrs(float threshold_) { threshold = threshold_; }
43+
44+
template <typename Device, typename X, typename Out>
45+
void operator()(Device d, X x, Out out) {
46+
auto tmp = static_cast<T>(threshold);
47+
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
48+
out.device(d) = (static_cast<T>(1) + temp.exp()).log();
49+
}
50+
};
51+
52+
template <typename T, typename Context>
53+
void SoftmaxKernel(const Context& dev_ctx,
54+
const DenseTensor& x,
55+
float threshold,
56+
DenseTensor* out) {
57+
dev_ctx.template Alloc<T>(out);
58+
59+
auto x_flatten = phi::EigenVector<T>::Flatten(x);
60+
auto out_flatten = phi::EigenVector<T>::Flatten(*out);
61+
auto* eigen_dev = dev_ctx.eigen_device();
62+
SoftReluFunctor<T> functor;
63+
functor.SetAttrs(threshold);
64+
// use 32bit index to speed up computation
65+
bool use_32bit_index = out_flatten.size() < Eigen::NumTraits<int>::highest();
66+
bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
67+
if (use_32bit_index && is_gpu_place) {
68+
functor(*eigen_dev, To32BitIndex(x_flatten), To32BitIndex(out_flatten));
69+
} else {
70+
functor(*eigen_dev, x_flatten, out_flatten);
71+
}
72+
}
73+
74+
} // namespace phi
75+
76+
PD_REGISTER_KERNEL(
77+
soft_relu, CPU, ALL_LAYOUT, phi::SoftmaxKernel, float, double) {}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_device_function.h"
16+
#include "paddle/phi/common/amp_type_traits.h"
17+
#include "paddle/phi/common/bfloat16.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/funcs/activation_functor.h"
20+
#include "paddle/phi/kernels/funcs/elementwise/elementwise_op_impl.cu.h"
21+
22+
namespace phi {
23+
24+
template <typename T>
25+
struct CudaSoftReluGradFunctor {
26+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
27+
MPType one = static_cast<MPType>(1.0f);
28+
float threshold;
29+
30+
void SetAttrs(float threshold_) { threshold = threshold_; }
31+
32+
// dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
33+
// threshold should not be negative
34+
__device__ __forceinline__ T operator()(const T arg_dout, const T arg_out) {
35+
MPType dout = static_cast<MPType>(arg_dout);
36+
MPType out = static_cast<MPType>(arg_out);
37+
MPType t = static_cast<MPType>(threshold);
38+
return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
39+
: static_cast<T>(0.0f);
40+
}
41+
};
42+
43+
template <typename T, typename Context>
44+
void SoftReluGradCudaKernel(const Context& dev_ctx,
45+
const DenseTensor& x_in UNUSED,
46+
const DenseTensor& out_in,
47+
const DenseTensor& out_grad,
48+
float threshold,
49+
DenseTensor* x_grad) {
50+
dev_ctx.template Alloc<T>(x_grad);
51+
CudaSoftReluGradFunctor<T> functor;
52+
functor.SetAttrs(threshold);
53+
54+
std::vector<const phi::DenseTensor*> ins = {&out_grad};
55+
std::vector<phi::DenseTensor*> outs = {x_grad};
56+
57+
// Only need forward output Out
58+
ins.push_back(&out_in);
59+
phi::funcs::LaunchSameDimsElementwiseCudaKernel<T>(
60+
dev_ctx, ins, &outs, functor);
61+
}
62+
} // namespace phi
63+
64+
PD_REGISTER_KERNEL(soft_relu_grad,
65+
GPU,
66+
ALL_LAYOUT,
67+
phi::SoftReluGradCudaKernel,
68+
float,
69+
double,
70+
phi::dtype::float16,
71+
phi::dtype::bfloat16) {}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_device_function.h"
16+
#include "paddle/phi/common/amp_type_traits.h"
17+
#include "paddle/phi/common/bfloat16.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/funcs/activation_functor.h"
20+
#include "paddle/phi/kernels/funcs/elementwise/elementwise_op_impl.cu.h"
21+
22+
namespace phi {
23+
24+
template <typename T>
25+
struct CudaSoftReluFunctor {
26+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
27+
MPType one = static_cast<MPType>(1.0f);
28+
float threshold;
29+
30+
void SetAttrs(float threshold_) { threshold = threshold_; }
31+
32+
// soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
33+
// threshold should not be negative
34+
__device__ __forceinline__ T operator()(const T arg_x) {
35+
MPType x = static_cast<MPType>(arg_x);
36+
MPType t = static_cast<MPType>(threshold);
37+
MPType temp_min = x < t ? x : t;
38+
MPType temp_max = temp_min > -t ? temp_min : -t;
39+
return static_cast<T>(log(one + exp(temp_max)));
40+
}
41+
};
42+
43+
template <typename T, typename Context>
44+
void SoftReluCudaKernel(const Context& dev_ctx,
45+
const DenseTensor& x,
46+
float threshold,
47+
DenseTensor* out) {
48+
dev_ctx.template Alloc<T>(out);
49+
std::vector<const phi::DenseTensor*> ins = {&x};
50+
std::vector<phi::DenseTensor*> outs = {out};
51+
CudaSoftReluFunctor<T> functor;
52+
functor.SetAttrs(threshold);
53+
phi::funcs::LaunchSameDimsElementwiseCudaKernel<T>(
54+
dev_ctx, ins, &outs, functor);
55+
}
56+
} // namespace phi
57+
58+
PD_REGISTER_KERNEL(soft_relu,
59+
GPU,
60+
ALL_LAYOUT,
61+
phi::SoftReluCudaKernel,
62+
float,
63+
double,
64+
phi::dtype::float16,
65+
phi::dtype::bfloat16) {}

paddle/phi/ops/yaml/inconsistent/onednn_ops_extra.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@
269269

270270
- op : sigmoid_grad
271271

272+
- op : soft_relu
273+
274+
- op : soft_relu_grad
275+
272276
- op : slice
273277
extra_args : str mkldnn_data_type="float32"
274278

0 commit comments

Comments
 (0)