Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
}
}

void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();

const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;

metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};

// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on

std::string hash_name = kname.str();

auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);

compute_encoder.set_bytes(conv_params, 3);

MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);

compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}

void conv_2D_gpu(
const Stream& s,
metal::Device& d,
Expand Down Expand Up @@ -754,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;

if (groups > 1) {
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;

if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}

if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
Expand Down
122 changes: 122 additions & 0 deletions mlx/backend/metal/kernels/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);

///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////

constant int ker_h [[function_constant(00)]];
constant int ker_w [[function_constant(01)]];
constant int str_h [[function_constant(10)]];
constant int str_w [[function_constant(11)]];
constant int tgp_h [[function_constant(100)]];
constant int tgp_w [[function_constant(101)]];
constant bool do_flip [[function_constant(200)]];

constant int span_h = tgp_h * str_h + ker_h - 1;
constant int span_w = tgp_w * str_w + ker_w - 1;
constant int span_hw = span_h * span_w;

template <typename T>
[[kernel]] void depthwise_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int tc = 8;
constexpr int tw = 8;
constexpr int th = 4;

constexpr int c_per_thr = 8;

constexpr int TGH = th * 2 + 6;
constexpr int TGW = tw * 2 + 6;
constexpr int TGC = tc;

threadgroup T ins[TGH * TGW * TGC];

const int n_tgblocks_h = params.oS[0] / th;
const int n = tid.z / n_tgblocks_h;
const int tghid = tid.z % n_tgblocks_h;
const int oh = tghid * th + lid.z;
const int ow = gid.y;
const int c = gid.x;

in += n * params.in_strides[0];

// Load in
{
constexpr int n_threads = th * tw * tc;
const int tg_oh = (tghid * th) * str_h - params.pad[0];
const int tg_ow = (tid.y * tw) * str_w - params.pad[1];
const int tg_c = tid.x * tc;

const int thread_idx = simd_gid * 32 + simd_lid;
constexpr int thr_per_hw = tc / c_per_thr;
constexpr int hw_per_group = n_threads / thr_per_hw;

const int thr_c = thread_idx % thr_per_hw;
const int thr_hw = thread_idx / thr_per_hw;

for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) {
const int h = hw / span_w;
const int w = hw % span_w;

const int ih = tg_oh + h;
const int iw = tg_ow + w;

const int in_s_offset = h * span_w * TGC + w * TGC;

if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const auto in_load =
in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c;

MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] =
in_load[c_per_thr * thr_c + cc];
}
} else {
MLX_MTL_PRAGMA_UNROLL
for (int cc = 0; cc < c_per_thr; ++cc) {
ins[in_s_offset + c_per_thr * thr_c + cc] = T(0);
}
}
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);
wt += c * params.wt_strides[0];

const auto ins_ptr =
&ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x];
float o = 0.;
for (int h = 0; h < ker_h; ++h) {
for (int w = 0; w < ker_w; ++w) {
int wt_h = h;
int wt_w = w;
if (do_flip) {
wt_h = ker_h - h - 1;
wt_w = ker_w - w - 1;
}
auto inv = ins_ptr[h * span_w * TGC + w * TGC];
auto wtv = wt[wt_h * ker_w + wt_w];
o += inv * wtv;
}
}
threadgroup_barrier(mem_flags::mem_none);

out += n * params.out_strides[0] + oh * params.out_strides[1] +
ow * params.out_strides[2];
out[c] = static_cast<T>(o);
}

#define instantiate_depthconv2d(iname, itype) \
instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype)

instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);

///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////
Expand Down
42 changes: 40 additions & 2 deletions python/tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,11 @@ def __conv_general_test(
flip=flip,
np_dtype=np_dtype,
):
np.random.seed(0)
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
scale = min(0.3, scale)
in_np = np.random.normal(0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype)

in_mx, wt_mx = map(mx.array, (in_np, wt_np))

Expand Down Expand Up @@ -1050,6 +1052,42 @@ def test_repeated_conv(self):
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
self.assertTrue(mx.allclose(y1, y2))

@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_depthwise(self):

# fmt: off
shapes = (
# N, H, W, C kH, kW, O, strides, padding, groups
( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32),
( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32),
( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32),
( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32),
( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32),
)
# fmt: on

dtypes = [np.float32]
if mx.default_device() == mx.gpu:
dtypes += [np.float16]

for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
for dtype in dtypes:
for flip in [False, True]:
Cw = C // groups

self.__conv_general_test(
(N, H, W, C),
(O, kH, kW, Cw),
strides,
padding,
kernel_dilation=1,
input_dilation=1,
groups=groups,
flip=flip,
np_dtype=dtype,
atol=2e-5 if dtype == np.float32 else 5e-4,
)


if __name__ == "__main__":
unittest.main()