Skip to content

Commit b3ae496

Browse files
rocking5566poyenc
andauthored
[AMD ROCm] Fix intrinsic for ROCm7 (Dao-AILab#1729)
* Use more reasonable splitkv heuristic * update CK * Pass logits soft-capping arguments * Revert "Merge pull request #147 from ROCm/poyenc/fix-ck-tile-splitkv-heuristic" This reverts commit 12857ce, reversing changes made to e64b970. --------- Co-authored-by: Po Yen Chen <[email protected]>
1 parent 3ba6f82 commit b3ae496

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

csrc/composable_kernel

Submodule composable_kernel updated 837 files

csrc/flash_attn_ck/mha_fwd.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
1919
dtype,
2020
false, // is_group_mode
2121
true, // is_v_rowmajor
22+
false, // has_logits_soft_cap
2223
mask.type,
2324
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
2425
has_lse,
@@ -111,6 +112,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
111112
softmax_scale, // scale_s
112113
1, // scale_p
113114
1, // scale_o
115+
0.0f, // logits_soft_cap
114116
stride_q,
115117
stride_k,
116118
stride_v,
@@ -134,6 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
134136
mask.left,
135137
mask.right,
136138
static_cast<ck_tile::index_t>(mask.type),
139+
0, // min_seqlen_q
137140
p_dropout,
138141
has_dropout_randval,
139142
drop_seed_offset};

csrc/flash_attn_ck/mha_fwd_kvcache.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,
3333
head_size,
3434
dtype,
3535
false, // is_group_mode
36-
true, // is_v_rowmajor
36+
true, // is_v_rowmajor
37+
false, // has_logits_soft_cap
3738
mask.type,
3839
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
3940
has_lse,

csrc/flash_attn_ck/mha_varlen_fwd.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
1717
return fmha_fwd_traits{head_size,
1818
head_size,
1919
dtype,
20-
true, // is_group_mode
21-
true, // is_v_rowmajor
20+
true, // is_group_mode
21+
true, // is_v_rowmajor
22+
false, // has_logits_soft_cap
2223
mask.type,
2324
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
2425
has_lse,
@@ -35,8 +36,9 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m
3536
return fmha_fwd_splitkv_traits{head_size,
3637
head_size,
3738
dtype,
38-
true, // is_group_mode
39-
true, // is_v_rowmajor
39+
true, // is_group_mode
40+
true, // is_v_rowmajor
41+
false, // has_logits_soft_cap
4042
mask.type,
4143
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
4244
has_lse,
@@ -131,6 +133,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
131133
softmax_scale, // scale_s
132134
1, // scale_p
133135
1, // scale_o
136+
0.0f, // logits_soft_cap
134137
stride_q,
135138
stride_k,
136139
stride_v,
@@ -154,6 +157,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
154157
mask.left,
155158
mask.right,
156159
static_cast<ck_tile::index_t>(mask.type),
160+
0, // min_seqlen_q
157161
p_dropout,
158162
has_dropout_randval,
159163
drop_seed_offset};

0 commit comments

Comments
 (0)