Skip to content

[Metal] fix Softmax image #8498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 172 additions & 4 deletions lite/backends/metal/metal_kernel/texture/Softmax.metal
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,23 @@ struct SoftmaxParam {
int K;
};

struct SoftmaxParam2 {
int N;
int C;
int H;
int W;
};

kernel void softmax(texture2d_array<ftype, access::read> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant SoftmaxParam& sp[[buffer(0)]],
constant SoftmaxParam2& sp[[buffer(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
int group = sp.K / 4;
int remain = sp.K % 4;
int group = sp.N * sp.C / 4;
int remain = sp.N * sp.C % 4;
#if LITE_WITH_METAL_FULL
ftype max_value = FLT_MIN;
#else
Expand Down Expand Up @@ -73,4 +80,165 @@ kernel void softmax(texture2d_array<ftype, access::read> inTexture[[texture(0)]]
v = exp(v - max_value);
outTexture.write(v / sum, gid.xy, z);
}
}
}

kernel void softmax_c_d3_common(texture2d_array<ftype, access::read> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant SoftmaxParam2& sp[[buffer(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size())
return;

int out_texture_array_size = outTexture.get_array_size();
int left = out_texture_array_size * 4 - sp.N * sp.C;

ftype4 max_vector = inTexture.read(uint2(gid.x, gid.y), 0);

// caculate max
int array_size = inTexture.get_array_size();
for (int z = 0; z < (array_size - 1); z++) {
ftype4 temp_value_vector = inTexture.read(uint2(gid.x, gid.y), z);
max_vector = max(temp_value_vector, max_vector);
}

ftype max_value = max_vector[0];
if (array_size > 1) {
for (int c = 0; c < 4; c++) {
max_value = max(max_vector[c], max_value);
}
}

ftype4 temp_value_vector = inTexture.read(uint2(gid.x, gid.y), array_size - 1);
ftype max_value_left = temp_value_vector[0];
for (int c = 0; c < left; c++) {
max_value_left = max(temp_value_vector[c], max_value_left);
}
max_value = max(max_value, max_value_left);

// caculate sum
ftype4 sum_vector = 0.0;
for (int z = 0; z < array_size - 1; z++) {
ftype4 temp_value_vector = inTexture.read(uint2(gid.x, gid.y), z);
sum_vector += exp(temp_value_vector - max_value);
}
ftype sum_value = 0.0;
if (array_size > 1) {
sum_value = sum_vector[0] + sum_vector[1] + sum_vector[2] + sum_vector[3];
}
ftype4 sum_vector_left = 0.0;
ftype4 temp_value_vector_left = inTexture.read(uint2(gid.x, gid.y), array_size - 1);
sum_vector_left += exp(temp_value_vector_left - max_value);

ftype sum_value_left = 0.0;
for (int i = 0; i < left; i++) {
sum_value_left += sum_vector_left[i];
}
sum_value += sum_value_left;

// calculate output
ftype4 result_vector = inTexture.read(gid.xy, gid.z);
result_vector = exp(result_vector - max_value) / sum_value;
outTexture.write(result_vector, gid.xy, gid.z);
}

kernel void softmax_w_d3_common(texture2d_array<ftype, access::read> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant SoftmaxParam2& sp[[buffer(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size())
return;

// caculate sum
ftype4 max_vector = inTexture.read(uint2(0, gid.y), gid.z);
int w = sp.W;
for (int x = 1; x < w; x++) {
ftype4 temp_value_vector = inTexture.read(uint2(x, gid.y), gid.z);
max_vector = max(temp_value_vector, max_vector);
}

// caculate sum
ftype4 sum_vector = 0.0;
for (int x = 0; x < w; x++) {
ftype4 temp_value_vector = inTexture.read(uint2(x, gid.y), gid.z);
sum_vector += exp(temp_value_vector - max_vector);
}

// calculate output
ftype4 result_vector = inTexture.read(gid.xy, gid.z);
result_vector = exp(result_vector - max_vector) / sum_vector;
outTexture.write(result_vector, gid.xy, gid.z);
}

kernel void softmax_h_d3_common(texture2d_array<ftype, access::read> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant SoftmaxParam2& sp[[buffer(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size())
return;

ftype4 max_vector = inTexture.read(uint2(gid.x, 0), gid.z);

// caculate max
int h = sp.H;
for (int y = 1; y < h; y++) {
ftype4 temp_value_vector = inTexture.read(uint2(gid.x, y), gid.z);
max_vector = max(temp_value_vector, max_vector);
}

// caculate sum
ftype4 sum_vector = 0.0;
for (int y = 0; y < h; y++) {
ftype4 temp_value_vector = inTexture.read(uint2(gid.x, y), gid.z);
sum_vector += exp(temp_value_vector - max_vector);
}

// calculate output
ftype4 result_vector = inTexture.read(gid.xy, gid.z);
result_vector = exp(result_vector - max_vector) / sum_vector;
outTexture.write(result_vector, gid.xy, gid.z);
}

kernel void softmax_dim2_common(texture2d_array<ftype, access::read> inTexture[[texture(0)]],
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
constant SoftmaxParam2& sp[[buffer(0)]],
uint3 gid[[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size())
return;

// caculate sum
ftype4 max_vector = 0;
int w = sp.W % 4 == 0 ? sp.W / 4 : sp.W / 4 + 1;
int res = sp.W % 4 == 0 ? 4 : sp.W % 4;
ftype max_num = 0;
ftype sum_num = 0;
for (int z = 0; z < w; z++) {
ftype4 temp_value_vector = inTexture.read(uint2(gid.x, gid.y), z);
max_vector = max(temp_value_vector, max_vector);
}
for (int z = 0; z < 4; z++) {
max_num = max(max_vector[z], max_num);
}

// caculate sum
ftype4 sum_vector = 0.0;
for (int z = 0; z < w - 1; z++) {
ftype4 tem_value_vector = inTexture.read(uint2(gid.x, gid.y), z);
sum_vector += exp(tem_value_vector - max_num);
}
for (int z = 0; z < 4; z++) {
sum_num += sum_vector[z];
}
ftype4 tem_value_vector = inTexture.read(uint2(gid.x, gid.y), w - 1);
for (int z = 0; z < res; z++) {
sum_num += exp(tem_value_vector[z] - max_num);
}

// calculate output
ftype4 result_vector = inTexture.read(uint2(gid.x, gid.y), gid.z);
result_vector = exp(result_vector - max_num) / sum_num;
outTexture.write(result_vector, uint2(gid.x, gid.y), gid.z);
}
41 changes: 35 additions & 6 deletions lite/kernels/metal/image_op/softmax_image_compute.mm
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
metal_context_ = (MetalContext*)context.context();

const auto& param = this->Param<param_t>();
auto input_dims = param.x->dims();
auto output_dims = param.output->dims();

#ifdef LITE_WITH_METAL_FULL
Expand All @@ -37,13 +38,17 @@
output_buffer_ = param.output->mutable_data<MetalHalf, MetalImage>(metal_context_, output_dims);
#endif

auto axis = param.axis;
if (axis < 0) {
axis += input_dims.size();
}
// whether to use mps
bool should_use_mps = false;
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
if (metal_context_->use_mps()) {
int input_c = static_cast<int>(input_buffer_->dim_[3]);
int output_c = static_cast<int>(output_buffer_->dim_[3]);
if (input_c >= 3 && output_c >= 3) {
if (input_c >= 3 && output_c >= 3 && input_dims.size() == 4 && axis == 1) {
should_use_mps = true;
}
}
Expand Down Expand Up @@ -86,18 +91,42 @@
const auto& param = this->Param<param_t>();
auto input_dims = param.x->dims();

if (input_dims.size() - param.axis != 3 && input_dims.size() != 2) {
LOG(FATAL) << "only support input with rank(dim)=2 or doing softmax in C channel";
if (input_dims.size() != 4 && input_dims.size() != 2) {
LOG(FATAL) << "only support input with rank(dim)=4 and 2";
return;
}

function_name_ = "softmax";
auto axis = param.axis;
if (axis < 0) {
axis += input_dims.size();
}

std::string function_name = "softmax";
if (input_dims.size() == 4) {
if (axis == 1) {
function_name = "softmax";
} else if (axis == 2) {
function_name = "softmax_h_d3_common";
} else if (axis == 3) {
function_name = "softmax_w_d3_common";
}
}
if (input_dims.size() == 2) {
function_name = "softmax_dim2_common";
}

function_name_ = function_name;

// pipline
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
pipline_ = [backend pipline:function_name_];

SoftmaxMetalParam metal_param{
(int)input_buffer_->tensor_dim_[0], (int)input_buffer_->tensor_dim_[1]};
SoftmaxMetalParam2 metal_param{
(int)input_buffer_->pad_to_four_dim_[0],
(int)input_buffer_->pad_to_four_dim_[1],
(int)input_buffer_->pad_to_four_dim_[2],
(int)input_buffer_->pad_to_four_dim_[3],
};
params_buffer_ =
std::make_shared<MetalBuffer>(metal_context_, sizeof(metal_param), &metal_param);
}
Expand Down