Skip to content

Commit ce1622b

Browse files
authored
[OPENCL] refine trigonometric opencl impl, test=develop (#4787)
1 parent d3291bc commit ce1622b

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

lite/kernels/opencl/trigonometric_image_compute.cc

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,37 @@ class TrigonometricComputeImage2D
3636
PRECISION(kFP16),
3737
DATALAYOUT(kImageDefault)> {
3838
public:
39-
using param_t = operators::SinParam;
39+
using param_t = operators::TrigonometricParam;
4040

4141
std::string doc() const override { return "Sin using cl::Image2D, kFP16"; }
4242

4343
void PrepareForRun() override {
4444
auto& context = ctx_->As<OpenCLContext>();
45-
context.cl_context()->AddKernel(kernel_func_name_,
45+
46+
context.cl_context()->AddKernel(KernelFunctionName(),
4647
"image/trigonometric_kernel.cl",
4748
build_options_,
4849
time_stamp_);
49-
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
50+
51+
VLOG(1) << "kernel_func_name_:" << KernelFunctionName();
5052

5153
STL::stringstream kernel_key;
52-
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
54+
kernel_key << KernelFunctionName() << build_options_ << time_stamp_;
5355
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
5456
}
5557

5658
void ReInitWhenNeeded() override {
57-
sin_param_ = param_.get_mutable<param_t>();
58-
auto x_dims = sin_param_->X->dims();
59+
trigonometric_param_ = param_.get_mutable<param_t>();
60+
auto x_dims = trigonometric_param_->X->dims();
5961
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
6062
first_epoch_for_reinit_) {
6163
last_x_dims_ = x_dims;
6264
first_epoch_for_reinit_ = false;
6365

6466
// compute image shape
6567
paddle::lite::CLImageConverterDefault default_convertor;
66-
out_img_shape_ =
67-
default_convertor.InitImageDimInfoWith(sin_param_->Out->dims());
68+
out_img_shape_ = default_convertor.InitImageDimInfoWith(
69+
trigonometric_param_->Out->dims());
6870

6971
// compute global work size
7072
GetGlobalWorkSize();
@@ -78,9 +80,10 @@ class TrigonometricComputeImage2D
7880
}
7981

8082
void Run() override {
81-
auto* x_img = sin_param_->X->data<half_t, cl::Image2D>();
82-
auto* out_img = sin_param_->Out->mutable_data<half_t, cl::Image2D>(
83-
out_img_shape_[0], out_img_shape_[1]);
83+
auto* x_img = trigonometric_param_->X->data<half_t, cl::Image2D>();
84+
auto* out_img =
85+
trigonometric_param_->Out->mutable_data<half_t, cl::Image2D>(
86+
out_img_shape_[0], out_img_shape_[1]);
8487

8588
auto& context = ctx_->As<OpenCLContext>();
8689
CHECK(context.cl_context() != nullptr);
@@ -102,20 +105,25 @@ class TrigonometricComputeImage2D
102105
CL_CHECK_FATAL(status);
103106
}
104107

108+
virtual std::string KernelFunctionName() {
109+
CHECK(
110+
"please extend TrigonometricComputeImage2D to support Trigonometric "
111+
"kernels");
112+
return "";
113+
}
105114
#ifdef LITE_WITH_PROFILE
106115
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
107-
ch->kernel_func_name = kernel_func_name_;
116+
ch->kernel_func_name = KernelFunctionName();
108117
ch->cl_event =
109118
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
110119
}
111120
#endif
112121

113122
private:
114-
std::string kernel_func_name_{"trigonometric_sin"};
115123
std::string build_options_{"-DCL_DTYPE_half"};
116124
std::string time_stamp_{GetTimeStamp()};
117125

118-
param_t* sin_param_{nullptr};
126+
param_t* trigonometric_param_{nullptr};
119127
cl::Kernel kernel_;
120128
bool first_epoch_for_reinit_{true};
121129
DDim last_x_dims_;
@@ -125,6 +133,9 @@ class TrigonometricComputeImage2D
125133
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
126134
};
127135

136+
class SinComputeImage2D : public TrigonometricComputeImage2D {
137+
std::string KernelFunctionName() override { return "trigonometric_sin"; }
138+
};
128139
} // namespace opencl
129140
} // namespace kernels
130141
} // namespace lite
@@ -134,7 +145,7 @@ REGISTER_LITE_KERNEL(sin,
134145
kOpenCL,
135146
kFP16,
136147
kImageDefault,
137-
paddle::lite::kernels::opencl::TrigonometricComputeImage2D,
148+
paddle::lite::kernels::opencl::SinComputeImage2D,
138149
image2d)
139150
.BindInput("X",
140151
{LiteType::GetTensorTy(TARGET(kOpenCL),

lite/operators/op_params.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,15 +1977,13 @@ struct OneHotParam : ParamBase {
19771977
bool allow_out_of_range;
19781978
};
19791979

1980-
struct SinParam : ParamBase {
1980+
struct TrigonometricParam : ParamBase {
19811981
lite::Tensor* X{};
19821982
lite::Tensor* Out{};
19831983
};
19841984

1985-
struct CosParam : ParamBase {
1986-
lite::Tensor* X{};
1987-
lite::Tensor* Out{};
1988-
};
1985+
using SinParam = TrigonometricParam;
1986+
using CosParam = TrigonometricParam;
19891987

19901988
struct FlattenContiguousRangeParam : ParamBase {
19911989
lite::Tensor* x{};

0 commit comments

Comments
 (0)