@@ -36,35 +36,37 @@ class TrigonometricComputeImage2D
36
36
PRECISION (kFP16 ),
37
37
DATALAYOUT(kImageDefault )> {
38
38
public:
39
- using param_t = operators::SinParam ;
39
+ using param_t = operators::TrigonometricParam ;
40
40
41
41
std::string doc () const override { return " Sin using cl::Image2D, kFP16" ; }
42
42
43
43
void PrepareForRun () override {
44
44
auto & context = ctx_->As <OpenCLContext>();
45
- context.cl_context ()->AddKernel (kernel_func_name_,
45
+
46
+ context.cl_context ()->AddKernel (KernelFunctionName (),
46
47
" image/trigonometric_kernel.cl" ,
47
48
build_options_,
48
49
time_stamp_);
49
- VLOG (1 ) << " kernel_func_name_:" << kernel_func_name_;
50
+
51
+ VLOG (1 ) << " kernel_func_name_:" << KernelFunctionName ();
50
52
51
53
STL::stringstream kernel_key;
52
- kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
54
+ kernel_key << KernelFunctionName () << build_options_ << time_stamp_;
53
55
kernel_ = context.cl_context ()->GetKernel (kernel_key.str ());
54
56
}
55
57
56
58
void ReInitWhenNeeded () override {
57
- sin_param_ = param_.get_mutable <param_t >();
58
- auto x_dims = sin_param_ ->X ->dims ();
59
+ trigonometric_param_ = param_.get_mutable <param_t >();
60
+ auto x_dims = trigonometric_param_ ->X ->dims ();
59
61
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
60
62
first_epoch_for_reinit_) {
61
63
last_x_dims_ = x_dims;
62
64
first_epoch_for_reinit_ = false ;
63
65
64
66
// compute image shape
65
67
paddle::lite::CLImageConverterDefault default_convertor;
66
- out_img_shape_ =
67
- default_convertor. InitImageDimInfoWith (sin_param_ ->Out ->dims ());
68
+ out_img_shape_ = default_convertor. InitImageDimInfoWith (
69
+ trigonometric_param_ ->Out ->dims ());
68
70
69
71
// compute global work size
70
72
GetGlobalWorkSize ();
@@ -78,9 +80,10 @@ class TrigonometricComputeImage2D
78
80
}
79
81
80
82
void Run () override {
81
- auto * x_img = sin_param_->X ->data <half_t , cl::Image2D>();
82
- auto * out_img = sin_param_->Out ->mutable_data <half_t , cl::Image2D>(
83
- out_img_shape_[0 ], out_img_shape_[1 ]);
83
+ auto * x_img = trigonometric_param_->X ->data <half_t , cl::Image2D>();
84
+ auto * out_img =
85
+ trigonometric_param_->Out ->mutable_data <half_t , cl::Image2D>(
86
+ out_img_shape_[0 ], out_img_shape_[1 ]);
84
87
85
88
auto & context = ctx_->As <OpenCLContext>();
86
89
CHECK (context.cl_context () != nullptr );
@@ -102,20 +105,25 @@ class TrigonometricComputeImage2D
102
105
CL_CHECK_FATAL (status);
103
106
}
104
107
108
+ virtual std::string KernelFunctionName () {
109
+ CHECK (
110
+ " please extend TrigonometricComputeImage2D to support Trigonometric "
111
+ " kernels" );
112
+ return " " ;
113
+ }
105
114
#ifdef LITE_WITH_PROFILE
106
115
void SetProfileRuntimeKernelInfo (paddle::lite::profile::OpCharacter* ch) {
107
- ch->kernel_func_name = kernel_func_name_ ;
116
+ ch->kernel_func_name = KernelFunctionName () ;
108
117
ch->cl_event =
109
118
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
110
119
}
111
120
#endif
112
121
113
122
private:
114
- std::string kernel_func_name_{" trigonometric_sin" };
115
123
std::string build_options_{" -DCL_DTYPE_half" };
116
124
std::string time_stamp_{GetTimeStamp ()};
117
125
118
- param_t * sin_param_ {nullptr };
126
+ param_t * trigonometric_param_ {nullptr };
119
127
cl::Kernel kernel_;
120
128
bool first_epoch_for_reinit_{true };
121
129
DDim last_x_dims_;
@@ -125,6 +133,9 @@ class TrigonometricComputeImage2D
125
133
static_cast <size_t >(1 ), static_cast <size_t >(1 ), static_cast <size_t >(1 )};
126
134
};
127
135
136
+ class SinComputeImage2D : public TrigonometricComputeImage2D {
137
+ std::string KernelFunctionName () override { return " trigonometric_sin" ; }
138
+ };
128
139
} // namespace opencl
129
140
} // namespace kernels
130
141
} // namespace lite
@@ -134,7 +145,7 @@ REGISTER_LITE_KERNEL(sin,
134
145
kOpenCL ,
135
146
kFP16 ,
136
147
kImageDefault ,
137
- paddle::lite::kernels::opencl::TrigonometricComputeImage2D ,
148
+ paddle::lite::kernels::opencl::SinComputeImage2D ,
138
149
image2d)
139
150
.BindInput(" X" ,
140
151
{LiteType::GetTensorTy (TARGET (kOpenCL ),
0 commit comments