Skip to content

Commit b1cc950

Browse files
Add bfp16 and unit tests for the CK implicit gemm bwd and fwd xdlops solvers (#3521)
* Enable bfp16 for implicit gemm fwd and bwd xdlops solvers * Add unit tests for implicit gemm fwd and bwd xdlops solvers * Add int8 handling to the gtest unit conv solver * Remove custom smoke test for ConvHipImplicitGemmFwdXdlops --------- Co-authored-by: Bibek Ghimire <[email protected]>
1 parent b5867ae commit b1cc950

6 files changed

+338
-13
lines changed

src/solver/conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ void PerformanceConfigHipImplicitGemmBwdXdlops::HeuristicInit(
181181
{
182182
case miopenHalf: Init<ck::half_t>(problem); break;
183183
case miopenFloat: Init<float>(problem); break;
184+
case miopenBFloat16: Init<ck::bhalf_t>(problem); break;
184185
case miopenFloat8_fnuz:
185186
case miopenBFloat8_fnuz:
186187
case miopenInt8:
187188
case miopenInt32:
188189
case miopenInt64:
189-
case miopenBFloat16:
190190
case miopenDouble: break;
191191
}
192192
#endif
@@ -223,12 +223,12 @@ bool PerformanceConfigHipImplicitGemmBwdXdlops::IsValid(
223223
{
224224
case miopenHalf: return CheckIsSupportCKArgs<ck::half_t>(problem);
225225
case miopenFloat: return CheckIsSupportCKArgs<float>(problem);
226+
case miopenBFloat16: return CheckIsSupportCKArgs<ck::bhalf_t>(problem);
226227
case miopenFloat8_fnuz:
227228
case miopenBFloat8_fnuz:
228229
case miopenInt8:
229230
case miopenInt32:
230231
case miopenInt64:
231-
case miopenBFloat16:
232232
case miopenDouble: break;
233233
}
234234
#endif
@@ -304,12 +304,12 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable(
304304
{
305305
case miopenHalf: return CheckCKApplicability<ck::half_t>(problem);
306306
case miopenFloat: return CheckCKApplicability<float>(problem);
307+
case miopenBFloat16: return CheckCKApplicability<ck::bhalf_t>(problem);
307308
case miopenFloat8_fnuz:
308309
case miopenBFloat8_fnuz:
309310
case miopenInt8:
310311
case miopenInt32:
311312
case miopenInt64:
312-
case miopenBFloat16:
313313
case miopenDouble: break;
314314
}
315315
#endif
@@ -334,10 +334,14 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution(
334334
CKArgs,
335335
miopen::conv::DataInvokeParams>(
336336
ctx, problem, config.kernel_id);
337+
case miopenBFloat16:
338+
return InitInvokerFactoryNHWC<DeviceOpBwdPtrs<ck::bhalf_t>,
339+
CKArgs,
340+
miopen::conv::DataInvokeParams>(
341+
ctx, problem, config.kernel_id);
337342
case miopenInt8:
338343
case miopenInt32:
339344
case miopenInt64:
340-
case miopenBFloat16:
341345
case miopenDouble:
342346
case miopenFloat8_fnuz:
343347
case miopenBFloat8_fnuz:

src/solver/conv/conv_hip_implicit_gemm_fwd_xdlops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ void PerformanceConfigHipImplicitGemmFwdXdlops::HeuristicInit(
182182
case miopenInt8: Init<int8_t>(problem); break;
183183
case miopenHalf: Init<ck::half_t>(problem); break;
184184
case miopenFloat: Init<float>(problem); break;
185+
case miopenBFloat16: Init<ck::bhalf_t>(problem); break;
185186
case miopenFloat8_fnuz:
186187
case miopenBFloat8_fnuz:
187188
case miopenInt64:
188189
case miopenInt32:
189-
case miopenBFloat16:
190190
case miopenDouble: break;
191191
}
192192
#endif
@@ -225,11 +225,11 @@ bool PerformanceConfigHipImplicitGemmFwdXdlops::IsValid(
225225
case miopenInt8: return CheckIsSupportCKArgs<int8_t>(problem);
226226
case miopenHalf: return CheckIsSupportCKArgs<ck::half_t>(problem);
227227
case miopenFloat: return CheckIsSupportCKArgs<float>(problem);
228+
case miopenBFloat16: return CheckIsSupportCKArgs<ck::bhalf_t>(problem);
228229
case miopenFloat8_fnuz:
229230
case miopenBFloat8_fnuz:
230231
case miopenInt64:
231232
case miopenInt32:
232-
case miopenBFloat16:
233233
case miopenDouble: break;
234234
}
235235
#endif
@@ -306,11 +306,11 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable(
306306
case miopenInt8: return CheckCKApplicability<int8_t>(problem);
307307
case miopenHalf: return CheckCKApplicability<ck::half_t>(problem);
308308
case miopenFloat: return CheckCKApplicability<float>(problem);
309+
case miopenBFloat16: return CheckCKApplicability<ck::bhalf_t>(problem);
309310
case miopenFloat8_fnuz:
310311
case miopenBFloat8_fnuz:
311312
case miopenInt64:
312313
case miopenInt32:
313-
case miopenBFloat16:
314314
case miopenDouble: break;
315315
}
316316
#endif
@@ -336,9 +336,13 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution(
336336
case miopenFloat:
337337
return InitInvokerFactoryNHWC<DeviceOpPtrs<float>, CKArgs, miopen::conv::DataInvokeParams>(
338338
ctx, problem, config.kernel_id);
339+
case miopenBFloat16:
340+
return InitInvokerFactoryNHWC<DeviceOpPtrs<ck::bhalf_t>,
341+
CKArgs,
342+
miopen::conv::DataInvokeParams>(
343+
ctx, problem, config.kernel_id);
339344
case miopenInt64:
340345
case miopenInt32:
341-
case miopenBFloat16:
342346
case miopenDouble:
343347
case miopenFloat8_fnuz:
344348
case miopenBFloat8_fnuz:

test/CMakeLists.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,6 @@ endif()
836836
# message output to the log, which happens if something is broken in the tuning machinery.
837837
# * Use MIOPEN_DEBUG_TUNING_ITERATIONS_MAX to save testing time.
838838

839-
add_custom_test(smoke_solver_ConvHipImplicitGemmFwdXdlops GFX900_DISABLED GFX906_DISABLED GFX90A_DISABLED GFX94X_ENABLED HALF_ENABLED INT8_ENABLED
840-
ENVIRONMENT MIOPEN_FIND_ENFORCE=SEARCH_DB_UPDATE MIOPEN_DEBUG_TUNING_ITERATIONS_MAX=5 MIOPEN_FIND_MODE=normal MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmFwdXdlops
841-
COMMAND $<TARGET_FILE:test_conv2d> ${TEST_CONV_VERBOSE_F} --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1 ${MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_INT8} --in_layout NHWC --fil_layout NHWC --out_layout NHWC ${MIOPEN_TEST_FLAGS_ARGS}
842-
)
843-
844839
# FP16 ALT attribute is disabled to enable the backward solver on MI200 for HALF.
845840
add_custom_test(smoke_solver_ConvWinograd3x3MultipassWrW_3x2 HALF_ENABLED BF16_ENABLED SKIP_XNACK_ON
846841
ENVIRONMENT MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL=0 MIOPEN_FIND_MODE=normal MIOPEN_DEBUG_FIND_ONLY_SOLVER='ConvWinograd3x3MultipassWrW<3-2>'

test/gtest/unit_conv_solver.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,9 @@ void RunSolver(const miopen::solver::conv::ConvSolverInterface& solver,
737737
case miopenBFloat16:
738738
RunSolver<bfloat16, bfloat16>(solver, params, direction, conv_config, algo);
739739
return;
740+
case miopenInt8:
741+
RunSolver<int8_t, int8_t>(solver, params, direction, conv_config, algo);
742+
return;
740743
default:
741744
throw std::runtime_error("handling of this data type is not yet implemented");
742745
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*******************************************************************************
2+
*
3+
* MIT License
4+
*
5+
* Copyright (c) 2025 Advanced Micro Devices, Inc.
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*
25+
*******************************************************************************/
26+
27+
#include "unit_conv_solver.hpp"
28+
29+
namespace {
30+
31+
auto GetConvSmokeTestCases(miopenDataType_t datatype)
32+
{
33+
using TestCase = miopen::unit_tests::ConvTestCase;
34+
35+
return std::vector{
36+
// clang-format off
37+
TestCase{{datatype, miopenTensorNHWC, {1, 32, 8, 8}},
38+
{datatype, miopenTensorNHWC, {32, 32, 1, 1}},
39+
datatype, {{0, 0}, {1, 1}, {1, 1}}},
40+
// clang-format on
41+
};
42+
}
43+
44+
auto GetConvFullTestCases(miopenDataType_t datatype)
45+
{
46+
using TestCase = miopen::unit_tests::ConvTestCase;
47+
48+
return std::vector{
49+
// clang-format off
50+
TestCase{{datatype, miopenTensorNHWC, {1, 32, 8, 8}},
51+
{datatype, miopenTensorNHWC, {32, 32, 3, 3}},
52+
datatype, {{1, 1}, {1, 1}, {1, 1}}}, // non-zero padding
53+
TestCase{{datatype, miopenTensorNHWC, {1, 64, 24, 48}},
54+
{datatype, miopenTensorNHWC, {96, 64, 1, 1}},
55+
datatype, {{0, 0}, {2, 2}, {1, 1}}}, // stride > 1
56+
TestCase{{datatype, miopenTensorNHWC, {1, 32, 8, 8}},
57+
{datatype, miopenTensorNHWC, {32, 32, 3, 3}},
58+
datatype, {{0, 0}, {1, 1}, {3, 3}}}, // dilation > 1
59+
TestCase{{datatype, miopenTensorNHWC, {1, 64, 24, 48}},
60+
{datatype, miopenTensorNHWC, {96, 64, 1, 1}},
61+
datatype, {{0, 0}, {1, 1}, {1, 1}}}, // some different NCHW and k parameters
62+
// clang-format on
63+
};
64+
}
65+
66+
auto GetTestParams(miopenDataType_t datatype)
67+
{
68+
Gpu supportedDevices = Gpu::gfx908 | Gpu::gfx90A | Gpu::gfx94X;
69+
auto params = miopen::unit_tests::UnitTestConvSolverParams(supportedDevices);
70+
params.Tunable(5);
71+
if(datatype == miopenHalf)
72+
{
73+
// Enable the backward solver on MI200 for fp16 by disabling the alternate implementation
74+
params.SetConvAttrFp16Alt(0);
75+
}
76+
77+
return params;
78+
}
79+
80+
} // namespace
81+
82+
using GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP16 = GPU_UnitTestConvSolverBwd_FP16;
83+
using GPU_UnitTestConvSolverImplicitGemmBwdXdlops_BFP16 = GPU_UnitTestConvSolverBwd_BFP16;
84+
using GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP32 = GPU_UnitTestConvSolverBwd_FP32;
85+
using CPU_UnitTestConvSolverImplicitGemmBwdXdlopsDevApplicability_FP16 =
86+
CPU_UnitTestConvSolverDevApplicabilityBwd_NONE;
87+
88+
TEST_P(GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP16, ConvHipImplicitGemmBwdXdlops)
89+
{
90+
this->RunTest(miopen::solver::conv::ConvHipImplicitGemmBwdXdlops{});
91+
};
92+
93+
TEST_P(GPU_UnitTestConvSolverImplicitGemmBwdXdlops_BFP16, ConvHipImplicitGemmBwdXdlops)
94+
{
95+
this->RunTest(miopen::solver::conv::ConvHipImplicitGemmBwdXdlops{});
96+
};
97+
98+
TEST_P(GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP32, ConvHipImplicitGemmBwdXdlops)
99+
{
100+
this->RunTest(miopen::solver::conv::ConvHipImplicitGemmBwdXdlops{});
101+
};
102+
103+
TEST_P(CPU_UnitTestConvSolverImplicitGemmBwdXdlopsDevApplicability_FP16,
104+
ConvHipImplicitGemmBwdXdlops)
105+
{
106+
this->RunTest(miopen::solver::conv::ConvHipImplicitGemmBwdXdlops{});
107+
};
108+
109+
// Smoke tests
110+
INSTANTIATE_TEST_SUITE_P(Smoke,
111+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP16,
112+
testing::Combine(testing::Values(GetTestParams(miopenHalf)),
113+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
114+
testing::ValuesIn(GetConvSmokeTestCases(miopenHalf))));
115+
116+
INSTANTIATE_TEST_SUITE_P(
117+
Smoke,
118+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_BFP16,
119+
testing::Combine(testing::Values(GetTestParams(miopenBFloat16)),
120+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
121+
testing::ValuesIn(GetConvSmokeTestCases(miopenBFloat16))));
122+
123+
INSTANTIATE_TEST_SUITE_P(Smoke,
124+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP32,
125+
testing::Combine(testing::Values(GetTestParams(miopenFloat)),
126+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
127+
testing::ValuesIn(GetConvSmokeTestCases(miopenFloat))));
128+
129+
// Full tests
130+
INSTANTIATE_TEST_SUITE_P(Full,
131+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP16,
132+
testing::Combine(testing::Values(GetTestParams(miopenHalf)),
133+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
134+
testing::ValuesIn(GetConvFullTestCases(miopenHalf))));
135+
136+
INSTANTIATE_TEST_SUITE_P(Full,
137+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_BFP16,
138+
testing::Combine(testing::Values(GetTestParams(miopenBFloat16)),
139+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
140+
testing::ValuesIn(GetConvFullTestCases(miopenBFloat16))));
141+
142+
INSTANTIATE_TEST_SUITE_P(Full,
143+
GPU_UnitTestConvSolverImplicitGemmBwdXdlops_FP32,
144+
testing::Combine(testing::Values(GetTestParams(miopenFloat)),
145+
testing::Values(miopenConvolutionAlgoImplicitGEMM),
146+
testing::ValuesIn(GetConvFullTestCases(miopenFloat))));
147+
148+
// Device applicability tests
149+
INSTANTIATE_TEST_SUITE_P(Smoke,
150+
CPU_UnitTestConvSolverImplicitGemmBwdXdlopsDevApplicability_FP16,
151+
testing::Combine(testing::Values(GetTestParams(miopenHalf)),
152+
testing::Values(GetConvSmokeTestCases(miopenHalf)[0])));

0 commit comments

Comments
 (0)