Skip to content

Commit 7b7ef81

Browse files
committed
Make swiglu and rmsnor svm optional
Modify swiglu and rmsnorm cl implementations to work with both svm allocated tensors and buffers **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipp Signed-off-by: Grzegorz Kisala <[email protected]>
1 parent 3f3bfc5 commit 7b7ef81

File tree

4 files changed

+233
-198
lines changed

4 files changed

+233
-198
lines changed

nntrainer/layers/cl_layers/swiglu_cl.cpp

Lines changed: 51 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -116,91 +116,66 @@ void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2,
116116
}
117117

118118
void SwiGLULayerCl::swiglu_cl(float *matAdata, float *vecXdata, float *vecYdata,
119-
unsigned int dim1, unsigned int dim2, bool svm) {
120-
auto *global_cl_context =
119+
unsigned int dim1, unsigned int dim2,
120+
const bool use_svm) {
121+
auto cl_context =
121122
static_cast<ClContext *>(Engine::Global().getRegisteredContext("gpu"));
122-
auto &clbuffInstance = ClBufferManager::Global();
123123

124-
do {
125-
const auto &kernel_swiglu_ptr = getLayerKernelPtrs()[Kernels::SWIGLU_CL];
126-
int dim = int(dim1 * dim2);
124+
const auto &kernel = getLayerKernelPtrs()[Kernels::SWIGLU_CL];
125+
const size_t dim = dim1 * dim2;
127126

128-
if (!svm) {
129-
bool write_result = true;
130-
131-
write_result &= clbuffInstance.getInBufferA()->WriteDataRegion(
132-
global_cl_context->command_queue_inst_, dim * sizeof(float), matAdata);
133-
write_result &= clbuffInstance.getInBufferB()->WriteDataRegion(
134-
global_cl_context->command_queue_inst_, dim * sizeof(float), vecXdata);
135-
if (!write_result) {
136-
break;
137-
}
138-
139-
auto bufferInA = clbuffInstance.getInBufferA()->GetBuffer();
140-
auto bufferInB = clbuffInstance.getInBufferB()->GetBuffer();
141-
auto bufferOutA = clbuffInstance.getOutBufferA()->GetBuffer();
142-
143-
bool set_result = true;
144-
set_result &=
145-
kernel_swiglu_ptr->SetKernelArguments(0, &bufferInA, sizeof(cl_mem));
146-
set_result &=
147-
kernel_swiglu_ptr->SetKernelArguments(1, &bufferInB, sizeof(cl_mem));
148-
set_result &=
149-
kernel_swiglu_ptr->SetKernelArguments(2, &bufferOutA, sizeof(cl_mem));
150-
if (!set_result) {
151-
break;
152-
}
153-
} else {
154-
bool map_result = true;
155-
map_result &=
156-
global_cl_context->command_queue_inst_.enqueueSVMUnmap(matAdata);
157-
map_result &=
158-
global_cl_context->command_queue_inst_.enqueueSVMUnmap(vecXdata);
159-
if (!map_result) {
160-
ml_loge("Failed to map svm");
161-
break;
162-
}
163-
164-
bool set_svm_result = true;
165-
set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments(0, matAdata);
166-
set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments(1, vecXdata);
167-
set_svm_result &= kernel_swiglu_ptr->SetKernelSVMArguments(2, vecYdata);
168-
if (!set_svm_result) {
169-
ml_loge("Failed to set svm");
170-
break;
171-
}
172-
}
127+
if (!use_svm) {
128+
bool write_result = true;
129+
auto &clbuffInstance = ClBufferManager::Global();
173130

174-
// NOTE(mwlasiuk) : local size can not be larger than global
175-
const int32_t desired_local = 64;
176-
const bool can_use_desired = dim >= desired_local;
177-
const int32_t chosen_local = can_use_desired ? desired_local : dim;
131+
write_result &= clbuffInstance.getInBufferA()->WriteDataRegion(
132+
cl_context->command_queue_inst_, dim * sizeof(float), matAdata);
133+
write_result &= clbuffInstance.getInBufferB()->WriteDataRegion(
134+
cl_context->command_queue_inst_, dim * sizeof(float), vecXdata);
135+
if (!write_result) {
136+
return;
137+
}
178138

179-
const int work_groups_count[3] = {dim, 1, 1};
180-
/// @todo: create a group size by device & input
181-
const int work_group_size[3] = {chosen_local, 1, 1}; // test-value
139+
auto bufferInA = clbuffInstance.getInBufferA()->GetBuffer();
140+
auto bufferInB = clbuffInstance.getInBufferB()->GetBuffer();
141+
auto bufferOutA = clbuffInstance.getOutBufferA()->GetBuffer();
182142

183-
if (!global_cl_context->command_queue_inst_.DispatchCommand(
184-
kernel_swiglu_ptr, work_groups_count, work_group_size)) {
185-
ml_loge("Failed to run");
186-
break;
143+
bool set_result = true;
144+
set_result &= kernel->SetKernelArguments(0, &bufferInA, sizeof(cl_mem));
145+
set_result &= kernel->SetKernelArguments(1, &bufferInB, sizeof(cl_mem));
146+
set_result &= kernel->SetKernelArguments(2, &bufferOutA, sizeof(cl_mem));
147+
if (!set_result) {
148+
return;
187149
}
188-
189-
if (!svm) {
190-
if (!clbuffInstance.getOutBufferA()->ReadDataRegion(
191-
global_cl_context->command_queue_inst_, dim * sizeof(float),
192-
vecYdata)) {
193-
break;
194-
}
195-
} else {
196-
if (!global_cl_context->command_queue_inst_.enqueueSVMMap(
197-
vecYdata, dim * sizeof(float), true)) {
198-
ml_loge("Failed to unmap svm");
199-
break;
200-
}
150+
} else {
151+
bool set_svm_result = true;
152+
set_svm_result &= kernel->SetKernelSVMArguments(0, matAdata);
153+
set_svm_result &= kernel->SetKernelSVMArguments(1, vecXdata);
154+
set_svm_result &= kernel->SetKernelSVMArguments(2, vecYdata);
155+
if (!set_svm_result) {
156+
ml_loge("Failed to set svm");
157+
return;
201158
}
159+
}
202160

203-
} while (false);
161+
std::array<size_t, 3> global_work_size = {dim, 1, 1};
162+
163+
cl_event swiglu_wait;
164+
165+
if (!cl_context->command_queue_inst_.enqueueKernel(
166+
kernel->GetKernel(), global_work_size.size(), global_work_size.data(),
167+
nullptr, 0, nullptr, &swiglu_wait)) {
168+
}
169+
170+
cl_context->command_queue_inst_.waitForEvent(1, &swiglu_wait);
171+
172+
if (!use_svm) {
173+
auto &clbuffInstance = ClBufferManager::Global();
174+
if (!clbuffInstance.getOutBufferA()->ReadDataRegion(
175+
cl_context->command_queue_inst_, dim * sizeof(float), vecYdata)) {
176+
return;
177+
}
178+
}
204179
}
205180

206181
#ifdef ENABLE_FP16

nntrainer/layers/cl_layers/swiglu_cl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class SwiGLULayerCl final : public LayerImplCl {
7777
* @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
7878
*/
7979
void exportTo(Exporter &exporter,
80-
const ml::train::ExportMethods &method) const override {};
80+
const ml::train::ExportMethods &method) const override{};
8181

8282
/**
8383
* @copydoc Layer::getType()
@@ -106,9 +106,11 @@ class SwiGLULayerCl final : public LayerImplCl {
106106
* @param[in] vecYdata float * for Output Vector Y
107107
* @param[in] dim1 number of elements in input vector A
108108
* @param[in] dim1 number of elements in input vector X
109+
* @param[in] use_svm input pointers allocated by OpenCL SVM
109110
*/
110111
void swiglu_cl(float *matAdata, float *vecXdata, float *vecYdata,
111-
unsigned int dim1, unsigned int dim2, bool svm = false);
112+
unsigned int dim1, unsigned int dim2,
113+
const bool use_svm = false);
112114

113115
#ifdef ENABLE_FP16
114116
/**

nntrainer/tensor/cl_operations/blas_kernels_templates.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,20 +398,25 @@ inline static void rmsnorm_cl_internal(ClContext::SharedPtrClKernel kernel,
398398
if (!kernel->SetKernelArguments(5, &width, sizeof(int))) {
399399
return;
400400
}
401+
401402
#ifdef __ANDROID__
402403
constexpr int SUBGROUP_SIZE = 64;
403404
#else
404405
constexpr int SUBGROUP_SIZE = 32;
405406
#endif
406-
const int work_groups_count[3] = {static_cast<int>(height) * SUBGROUP_SIZE, 1,
407-
1};
408407

409-
const int work_group_size[3] = {SUBGROUP_SIZE, 1, 1};
410-
if (!blas_cc->command_queue_inst_.DispatchCommand(kernel, work_groups_count,
411-
work_group_size)) {
412-
return;
408+
std::array<size_t, 3> global_work_size = {height * SUBGROUP_SIZE, 1, 1};
409+
std::array<size_t, 3> local_work_size = {SUBGROUP_SIZE, 1, 1};
410+
411+
cl_event rmsnorm_wait;
412+
413+
if (!blas_cc->command_queue_inst_.enqueueKernel(
414+
kernel->GetKernel(), global_work_size.size(), global_work_size.data(),
415+
local_work_size.data(), 0, nullptr, &rmsnorm_wait)) {
413416
}
414417

418+
blas_cc->command_queue_inst_.waitForEvent(1, &rmsnorm_wait);
419+
415420
if (!use_svm) {
416421
auto &clbuffInstance = ClBufferManager::Global();
417422
if (!clbuffInstance.getOutBufferA()->ReadDataRegion(

0 commit comments

Comments
 (0)