Skip to content

Commit c8b6b64

Browse files
authored
Re-enable fp8 types for all architectures. (#1470)
* re-enable fp8 and bf8 for all targets * restore the fp8 gemm instances * re-enable conv_3d fp8 on all architectures * diasble several fp8 gemm instances on all architectures except gfx94 * clang format fix
1 parent 79a5d9c commit c8b6b64

File tree

9 files changed

+28
-31
lines changed

9 files changed

+28
-31
lines changed

CMakeLists.txt

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,14 @@ if (DTYPES)
6262
endif()
6363
message("DTYPES macro set to ${DTYPES}")
6464
else()
65-
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
65+
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
6666
set(CK_ENABLE_INT8 "ON")
6767
set(CK_ENABLE_FP16 "ON")
6868
set(CK_ENABLE_FP32 "ON")
6969
set(CK_ENABLE_FP64 "ON")
7070
set(CK_ENABLE_BF16 "ON")
71-
if (GPU_TARGETS MATCHES "gfx94")
72-
add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
73-
set(CK_ENABLE_FP8 "ON")
74-
set(CK_ENABLE_BF8 "ON")
75-
endif()
71+
set(CK_ENABLE_FP8 "ON")
72+
set(CK_ENABLE_BF8 "ON")
7673
endif()
7774

7875
#for f8/bf8_t type

library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,17 @@ list(APPEND GEMM_INSTANCES
100100
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
101101
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
102102

103-
if(GPU_TARGETS MATCHES "gfx94")
104-
list(APPEND GEMM_INSTANCES
105-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp
106-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp
107-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp
108-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp
109-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp
110-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp
111-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
112-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
113-
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
114-
endif()
103+
list(APPEND GEMM_INSTANCES
104+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp
105+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp
106+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp
107+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp
108+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp
109+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp
110+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
111+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
112+
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
113+
115114

116115
list(APPEND GEMM_INSTANCES
117116
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp

library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm
5151
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
5252
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
5353

54-
if(GPU_TARGETS MATCHES "gfx94")
55-
list(APPEND GEMM_UNIVERSAL_INSTANCES
54+
list(APPEND GEMM_UNIVERSAL_INSTANCES
5655
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
5756
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
5857
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
@@ -123,6 +122,6 @@ if(GPU_TARGETS MATCHES "gfx94")
123122
set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
124123
set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
125124
set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
126-
endif()
125+
127126

128127
add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES})

library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
3535

3636
template <GemmSpecialization GemmSpec>
3737
using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
38-
// clang-format off
38+
// clang-format off
3939
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
4040
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
4141
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
4242
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
43-
43+
#ifdef __gfx94__
44+
//Only enable these instances on gfx94x
4445
// Compute friendly
4546
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
4647
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
@@ -55,6 +56,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
5556
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>,
5657
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 4, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
5758
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
59+
#endif
5860
// clang-format on
5961
>;
6062

library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ set(GROUPED_CONV3D_BWD_DATA
1515
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
1616
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp)
1717

18-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
18+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
1919
list(APPEND GROUPED_CONV3D_BWD_DATA
2020
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp)
2121
endif()

library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
3030
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
3131
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp)
3232

33-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
33+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
3434
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
3535
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
3636
endif()

library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
44
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
55
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
66

7-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
7+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
88
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
99
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
1010
endif()

library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE
44
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
55
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
66

7-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
7+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
88
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE
99
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
1010
endif()

library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,22 @@ set(GROUPED_CONV3D_FWD
4343
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
4444
)
4545

46-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
46+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
4747
list(APPEND GROUPED_CONV3D_FWD
4848
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp)
4949
endif()
5050

51-
if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
51+
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
5252
list(APPEND GROUPED_CONV3D_FWD
5353
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp)
5454
endif()
5555

56-
if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
56+
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
5757
list(APPEND GROUPED_CONV3D_FWD
5858
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp)
5959
endif()
6060

61-
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
61+
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
6262
list(APPEND GROUPED_CONV3D_FWD
6363
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp)
6464
list(APPEND GROUPED_CONV3D_FWD

0 commit comments

Comments
 (0)