@@ -116,91 +116,66 @@ void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2,
116
116
}
117
117
118
118
void SwiGLULayerCl::swiglu_cl (float *matAdata, float *vecXdata, float *vecYdata,
119
- unsigned int dim1, unsigned int dim2, bool svm) {
120
- auto *global_cl_context =
119
+ unsigned int dim1, unsigned int dim2,
120
+ const bool use_svm) {
121
+ auto cl_context =
121
122
static_cast <ClContext *>(Engine::Global ().getRegisteredContext (" gpu" ));
122
- auto &clbuffInstance = ClBufferManager::Global ();
123
123
124
- do {
125
- const auto &kernel_swiglu_ptr = getLayerKernelPtrs ()[Kernels::SWIGLU_CL];
126
- int dim = int (dim1 * dim2);
124
+ const auto &kernel = getLayerKernelPtrs ()[Kernels::SWIGLU_CL];
125
+ const size_t dim = dim1 * dim2;
127
126
128
- if (!svm) {
129
- bool write_result = true ;
130
-
131
- write_result &= clbuffInstance.getInBufferA ()->WriteDataRegion (
132
- global_cl_context->command_queue_inst_ , dim * sizeof (float ), matAdata);
133
- write_result &= clbuffInstance.getInBufferB ()->WriteDataRegion (
134
- global_cl_context->command_queue_inst_ , dim * sizeof (float ), vecXdata);
135
- if (!write_result) {
136
- break ;
137
- }
138
-
139
- auto bufferInA = clbuffInstance.getInBufferA ()->GetBuffer ();
140
- auto bufferInB = clbuffInstance.getInBufferB ()->GetBuffer ();
141
- auto bufferOutA = clbuffInstance.getOutBufferA ()->GetBuffer ();
142
-
143
- bool set_result = true ;
144
- set_result &=
145
- kernel_swiglu_ptr->SetKernelArguments (0 , &bufferInA, sizeof (cl_mem));
146
- set_result &=
147
- kernel_swiglu_ptr->SetKernelArguments (1 , &bufferInB, sizeof (cl_mem));
148
- set_result &=
149
- kernel_swiglu_ptr->SetKernelArguments (2 , &bufferOutA, sizeof (cl_mem));
150
- if (!set_result) {
151
- break ;
152
- }
153
- } else {
154
- bool map_result = true ;
155
- map_result &=
156
- global_cl_context->command_queue_inst_ .enqueueSVMUnmap (matAdata);
157
- map_result &=
158
- global_cl_context->command_queue_inst_ .enqueueSVMUnmap (vecXdata);
159
- if (!map_result) {
160
- ml_loge (" Failed to map svm" );
161
- break ;
162
- }
163
-
164
- bool set_svm_result = true ;
165
- set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments (0 , matAdata);
166
- set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments (1 , vecXdata);
167
- set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments (2 , vecYdata);
168
- if (!set_svm_result) {
169
- ml_loge (" Failed to set svm" );
170
- break ;
171
- }
172
- }
127
+ if (!use_svm) {
128
+ bool write_result = true ;
129
+ auto &clbuffInstance = ClBufferManager::Global ();
173
130
174
- // NOTE(mwlasiuk) : local size can not be larger than global
175
- const int32_t desired_local = 64 ;
176
- const bool can_use_desired = dim >= desired_local;
177
- const int32_t chosen_local = can_use_desired ? desired_local : dim;
131
+ write_result &= clbuffInstance.getInBufferA ()->WriteDataRegion (
132
+ cl_context->command_queue_inst_ , dim * sizeof (float ), matAdata);
133
+ write_result &= clbuffInstance.getInBufferB ()->WriteDataRegion (
134
+ cl_context->command_queue_inst_ , dim * sizeof (float ), vecXdata);
135
+ if (!write_result) {
136
+ return ;
137
+ }
178
138
179
- const int work_groups_count[ 3 ] = {dim, 1 , 1 } ;
180
- // / @todo: create a group size by device & input
181
- const int work_group_size[ 3 ] = {chosen_local, 1 , 1 }; // test-value
139
+ auto bufferInA = clbuffInstance. getInBufferA ()-> GetBuffer () ;
140
+ auto bufferInB = clbuffInstance. getInBufferB ()-> GetBuffer ();
141
+ auto bufferOutA = clbuffInstance. getOutBufferA ()-> GetBuffer ();
182
142
183
- if (!global_cl_context->command_queue_inst_ .DispatchCommand (
184
- kernel_swiglu_ptr, work_groups_count, work_group_size)) {
185
- ml_loge (" Failed to run" );
186
- break ;
143
+ bool set_result = true ;
144
+ set_result &= kernel->SetKernelArguments (0 , &bufferInA, sizeof (cl_mem));
145
+ set_result &= kernel->SetKernelArguments (1 , &bufferInB, sizeof (cl_mem));
146
+ set_result &= kernel->SetKernelArguments (2 , &bufferOutA, sizeof (cl_mem));
147
+ if (!set_result) {
148
+ return ;
187
149
}
188
-
189
- if (!svm) {
190
- if (!clbuffInstance.getOutBufferA ()->ReadDataRegion (
191
- global_cl_context->command_queue_inst_ , dim * sizeof (float ),
192
- vecYdata)) {
193
- break ;
194
- }
195
- } else {
196
- if (!global_cl_context->command_queue_inst_ .enqueueSVMMap (
197
- vecYdata, dim * sizeof (float ), true )) {
198
- ml_loge (" Failed to unmap svm" );
199
- break ;
200
- }
150
+ } else {
151
+ bool set_svm_result = true ;
152
+ set_svm_result &= kernel->SetKernelSVMArguments (0 , matAdata);
153
+ set_svm_result &= kernel->SetKernelSVMArguments (1 , vecXdata);
154
+ set_svm_result &= kernel->SetKernelSVMArguments (2 , vecYdata);
155
+ if (!set_svm_result) {
156
+ ml_loge (" Failed to set svm" );
157
+ return ;
201
158
}
159
+ }
202
160
203
- } while (false );
161
+ std::array<size_t , 3 > global_work_size = {dim, 1 , 1 };
162
+
163
+ cl_event swiglu_wait;
164
+
165
+ if (!cl_context->command_queue_inst_ .enqueueKernel (
166
+ kernel->GetKernel (), global_work_size.size (), global_work_size.data (),
167
+ nullptr , 0 , nullptr , &swiglu_wait)) {
168
+ }
169
+
170
+ cl_context->command_queue_inst_ .waitForEvent (1 , &swiglu_wait);
171
+
172
+ if (!use_svm) {
173
+ auto &clbuffInstance = ClBufferManager::Global ();
174
+ if (!clbuffInstance.getOutBufferA ()->ReadDataRegion (
175
+ cl_context->command_queue_inst_ , dim * sizeof (float ), vecYdata)) {
176
+ return ;
177
+ }
178
+ }
204
179
}
205
180
206
181
#ifdef ENABLE_FP16
0 commit comments