4
4
#include < stdexcept>
5
5
#include < algorithm>
6
6
7
- #if defined(__HIPCC__) && \
8
- (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
9
- #define __HIP__MI300__
7
+ #if defined(__HIPCC__) && (defined(__gfx90a) || defined(__gfx940__) || \
8
+ defined (__gfx941__) || defined(__gfx942__))
9
+ #define __HIP__MI300_MI250__
10
+ #endif
11
+
12
+ #if defined(NDEBUG)
13
+ #undef NDEBUG
14
+ #include < assert.h>
15
+ #define UNREACHABLE_CODE assert (false );
16
+ #define NDEBUG
17
+ #else
18
+ #define UNREACHABLE_CODE assert (false );
10
19
#endif
11
20
12
21
constexpr int WARP_SIZE = 64 ;
@@ -334,7 +343,7 @@ __device__ __forceinline__ T loadnt(T* addr) {
334
343
#define M 1
335
344
#define DTYPE half
336
345
337
- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
346
+ #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
338
347
339
348
__global__ void wvSpltK_hf_m1_sml_ (const int K, const int N, const DTYPE* B,
340
349
const DTYPE* __restrict__ A, DTYPE* C,
@@ -463,17 +472,15 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
463
472
}
464
473
}
465
474
466
- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
475
+ #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
467
476
468
477
__global__ void wvSpltK_hf_m1_sml_ (const int K, const int N, const DTYPE* B,
469
478
const DTYPE* __restrict__ A, DTYPE* C,
470
- const int CuCount) {
471
- assert (false );
472
- }
479
+ const int CuCount){UNREACHABLE_CODE}
473
480
474
- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
481
+ #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
475
482
476
- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
483
+ #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
477
484
478
485
__global__ void wvSpltK_hf_m1_ (const int K, const int N, const DTYPE* B,
479
486
const DTYPE* __restrict__ A, DTYPE* C,
@@ -820,15 +827,13 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
820
827
}
821
828
}
822
829
823
- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
830
+ #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
824
831
825
832
__global__ void wvSpltK_hf_m1_ (const int K, const int N, const DTYPE* B,
826
833
const DTYPE* __restrict__ A, DTYPE* C,
827
- const int CuCount) {
828
- assert (false );
829
- }
834
+ const int CuCount){UNREACHABLE_CODE}
830
835
831
- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
836
+ #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
832
837
833
838
#undef YTILE
834
839
#undef UNRL
@@ -838,7 +843,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
838
843
#define UNRL 2
839
844
#define M 2
840
845
841
- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
846
+ #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
842
847
843
848
__global__ void wvSpltK_hf_m2_ (const int K, const int N, const DTYPE* B,
844
849
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1185,15 +1190,13 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
1185
1190
}
1186
1191
}
1187
1192
1188
- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1193
+ #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1189
1194
1190
1195
__global__ void wvSpltK_hf_m2_ (const int K, const int N, const DTYPE* B,
1191
1196
const DTYPE* __restrict__ A, DTYPE* C,
1192
- const int CuCount) {
1193
- assert (false );
1194
- }
1197
+ const int CuCount){UNREACHABLE_CODE}
1195
1198
1196
- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1199
+ #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1197
1200
1198
1201
#undef YTILE
1199
1202
#undef UNRL
@@ -1203,7 +1206,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
1203
1206
#define UNRL 2
1204
1207
#define M 3
1205
1208
1206
- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1209
+ #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
1207
1210
1208
1211
__global__ void wvSpltK_hf_m3_ (const int K, const int N, const DTYPE* B,
1209
1212
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1550,15 +1553,13 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
1550
1553
}
1551
1554
}
1552
1555
1553
- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1556
+ #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1554
1557
1555
1558
__global__ void wvSpltK_hf_m3_ (const int K, const int N, const DTYPE* B,
1556
1559
const DTYPE* __restrict__ A, DTYPE* C,
1557
- const int CuCount) {
1558
- assert (false );
1559
- }
1560
+ const int CuCount){UNREACHABLE_CODE}
1560
1561
1561
- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1562
+ #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1562
1563
1563
1564
#undef YTILE
1564
1565
#undef UNRL
@@ -1568,7 +1569,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
1568
1569
#define UNRL 1
1569
1570
#define M 4
1570
1571
1571
- #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
1572
+ #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
1572
1573
1573
1574
__global__ void wvSpltK_hf_m4_ (const int K, const int N, const DTYPE* B,
1574
1575
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1915,15 +1916,15 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
1915
1916
}
1916
1917
}
1917
1918
1918
- #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
1919
+ #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1919
1920
1920
1921
__global__ void wvSpltK_hf_m4_ (const int K, const int N, const DTYPE* B,
1921
1922
const DTYPE* __restrict__ A, DTYPE* C,
1922
1923
const int CuCount) {
1923
- assert ( false );
1924
+ UNREACHABLE_CODE
1924
1925
}
1925
1926
1926
- #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
1927
+ #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
1927
1928
1928
1929
void wvSpltK_ (void * in_a, void * in_b, void * out_c, const int M_in,
1929
1930
const int K_in, const int N_in, cudaStream_t stream,
0 commit comments