Skip to content

Commit a16fbf8

Browse files
committed
Use more reasonable splitkv heuristic
1 parent e9e96d3 commit a16fbf8

File tree

2 files changed

+7
-32
lines changed

2 files changed

+7
-32
lines changed

csrc/flash_attn_ck/flash_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int
1818
return num_splits;
1919

2020
// TODO - tile size should match the TileFmhaShape, hardcode for now
21-
const int kM0 = 128;
21+
const int kM0 = 64;
2222
const int kN1 = hdim_v;
2323

2424
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
2525
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
2626

2727
if(num_splits < 1 && p_drop == 0.0f)
2828
return num_splits_heuristic_ck(
29-
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
29+
batch * nhead * num_m_blocks, props.multiProcessorCount, num_n_blocks, 16);
3030

3131
return num_splits;
3232
}

csrc/flash_attn_ck/flash_common.hpp

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,40 +35,15 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r
3535
}
3636
}
3737

38-
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
38+
inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, [[maybe_unused]] int num_n_blocks, int max_splits) {
3939
// If we have enough to almost fill the SMs, then just use 1 split
40-
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
41-
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
42-
float max_efficiency = 0.f;
43-
std::vector<float> efficiency;
44-
efficiency.reserve(max_splits);
45-
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
46-
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
47-
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
48-
// (i.e. it's 11 splits anyway).
49-
// So we check if the number of blocks per split is the same as the previous num_splits.
50-
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
51-
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
52-
};
53-
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
54-
if (!is_split_eligible(num_splits)) {
55-
efficiency.push_back(0.f);
56-
} else {
57-
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
58-
float eff = n_waves / ceil(n_waves);
59-
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
60-
if (eff > max_efficiency) { max_efficiency = eff; }
61-
efficiency.push_back(eff);
62-
}
63-
}
64-
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
65-
if (!is_split_eligible(num_splits)) { continue; }
66-
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
67-
// printf("num_splits chosen = %d\n", num_splits);
40+
for (int num_splits = 1; num_splits <= max_splits; num_splits *= 2) {
41+
if (num_SMs < batch_nheads_mblocks * (num_splits * 2)) {
6842
return num_splits;
6943
}
7044
}
71-
return 1;
45+
46+
return max_splits;
7247
}
7348

7449
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

0 commit comments

Comments
 (0)