Skip to content

Commit 6a8793d

Browse files
maleksan85Aleksandr Malyshev
andauthored
Adding UNREACHABLE_CODE macro for non MI300 and MI250 cards (vllm-project#138)
* Adding UNREACHABLE_CODE macro * clang format fixes * clang formatting fix * minor updates in syntax * clang format update * clang format fix one more try * clang format one more try * clang format fix one more try --------- Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 5945822 commit 6a8793d

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

csrc/custom/custom_kernels.cu

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,18 @@
44
#include <stdexcept>
55
#include <algorithm>
66

7-
#if defined(__HIPCC__) && \
8-
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
9-
#define __HIP__MI300__
7+
#if defined(__HIPCC__) && (defined(__gfx90a) || defined(__gfx940__) || \
8+
defined(__gfx941__) || defined(__gfx942__))
9+
#define __HIP__MI300_MI250__
10+
#endif
11+
12+
#if defined(NDEBUG)
13+
#undef NDEBUG
14+
#include <assert.h>
15+
#define UNREACHABLE_CODE assert(false);
16+
#define NDEBUG
17+
#else
18+
#define UNREACHABLE_CODE assert(false);
1019
#endif
1120

1221
constexpr int WARP_SIZE = 64;
@@ -334,7 +343,7 @@ __device__ __forceinline__ T loadnt(T* addr) {
334343
#define M 1
335344
#define DTYPE half
336345

337-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
346+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
338347

339348
__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
340349
const DTYPE* __restrict__ A, DTYPE* C,
@@ -463,17 +472,15 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
463472
}
464473
}
465474

466-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
475+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
467476

468477
__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
469478
const DTYPE* __restrict__ A, DTYPE* C,
470-
const int CuCount) {
471-
assert(false);
472-
}
479+
const int CuCount){UNREACHABLE_CODE}
473480

474-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
481+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
475482

476-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
483+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
477484

478485
__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
479486
const DTYPE* __restrict__ A, DTYPE* C,
@@ -820,15 +827,13 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
820827
}
821828
}
822829

823-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
830+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
824831

825832
__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
826833
const DTYPE* __restrict__ A, DTYPE* C,
827-
const int CuCount) {
828-
assert(false);
829-
}
834+
const int CuCount){UNREACHABLE_CODE}
830835

831-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
836+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
832837

833838
#undef YTILE
834839
#undef UNRL
@@ -838,7 +843,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
838843
#define UNRL 2
839844
#define M 2
840845

841-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
846+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
842847

843848
__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
844849
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1185,15 +1190,13 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
11851190
}
11861191
}
11871192

1188-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1193+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
11891194

11901195
__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
11911196
const DTYPE* __restrict__ A, DTYPE* C,
1192-
const int CuCount) {
1193-
assert(false);
1194-
}
1197+
const int CuCount){UNREACHABLE_CODE}
11951198

1196-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1199+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
11971200

11981201
#undef YTILE
11991202
#undef UNRL
@@ -1203,7 +1206,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
12031206
#define UNRL 2
12041207
#define M 3
12051208

1206-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1209+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
12071210

12081211
__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
12091212
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1550,15 +1553,13 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
15501553
}
15511554
}
15521555

1553-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1556+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
15541557

15551558
__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
15561559
const DTYPE* __restrict__ A, DTYPE* C,
1557-
const int CuCount) {
1558-
assert(false);
1559-
}
1560+
const int CuCount){UNREACHABLE_CODE}
15601561

1561-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1562+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
15621563

15631564
#undef YTILE
15641565
#undef UNRL
@@ -1568,7 +1569,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
15681569
#define UNRL 1
15691570
#define M 4
15701571

1571-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
1572+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
15721573

15731574
__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
15741575
const DTYPE* __restrict__ A, DTYPE* C,
@@ -1915,15 +1916,15 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
19151916
}
19161917
}
19171918

1918-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
1919+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
19191920

19201921
__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
19211922
const DTYPE* __restrict__ A, DTYPE* C,
19221923
const int CuCount) {
1923-
assert(false);
1924+
UNREACHABLE_CODE
19241925
}
19251926

1926-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
1927+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
19271928

19281929
void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in,
19291930
const int K_in, const int N_in, cudaStream_t stream,

csrc/custom/paged_attention/attention_ll4mi.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,26 @@
66

77
#include <algorithm>
88

9-
#if defined(__HIPCC__) && \
10-
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
11-
#define __HIP__MI300__
9+
#if defined(__HIPCC__) && (defined(__gfx90a) || defined(__gfx940__) || \
10+
defined(__gfx941__) || defined(__gfx942__))
11+
#define __HIP__MI300_MI250__
12+
#endif
13+
14+
#if defined(NDEBUG)
15+
#undef NDEBUG
16+
#include <assert.h>
17+
#define UNREACHABLE_CODE assert(false);
18+
#define NDEBUG
19+
#else
20+
#define UNREACHABLE_CODE assert(false);
1221
#endif
1322

1423
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1524
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1625
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
1726
#define WARP_SIZE 64
1827

19-
#if defined(__HIP__MI300__) // TODO: Add NAVI support
28+
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
2029

2130
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
2231
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
@@ -863,7 +872,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
863872
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
864873
}
865874

866-
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
875+
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
867876

868877
template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
869878
int GQA_RATIO>
@@ -889,7 +898,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
889898
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
890899
#endif
891900
int max_ctx_blocks) {
892-
assert(false);
901+
UNREACHABLE_CODE
893902
}
894903

895904
// Grid: (num_heads, num_seqs).
@@ -905,11 +914,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
905914
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
906915
// max_num_partitions, head_size]
907916
const int* __restrict__ context_lens, // [num_seqs]
908-
const int max_num_partitions) {
909-
assert(false);
910-
}
917+
const int max_num_partitions){UNREACHABLE_CODE}
911918

912-
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
919+
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
913920

914921
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
915922
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \

0 commit comments

Comments
 (0)