@@ -35,40 +35,15 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r
35
35
}
36
36
}
37
37
38
- inline int num_splits_heuristic_ck (int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
38
+ inline int num_splits_heuristic_ck (int batch_nheads_mblocks, int num_SMs, [[maybe_unused]] int num_n_blocks, int max_splits) {
39
39
// If we have enough to almost fill the SMs, then just use 1 split
40
- if (batch_nheads_mblocks >= 0 .8f * num_SMs) { return 1 ; }
41
- max_splits = std::min ({max_splits, num_SMs, num_n_blocks});
42
- float max_efficiency = 0 .f ;
43
- std::vector<float > efficiency;
44
- efficiency.reserve (max_splits);
45
- auto ceildiv = [](int a, int b) { return (a + b - 1 ) / b; };
46
- // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
47
- // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
48
- // (i.e. it's 11 splits anyway).
49
- // So we check if the number of blocks per split is the same as the previous num_splits.
50
- auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
51
- return num_splits == 1 || ceildiv (num_n_blocks, num_splits) != ceildiv (num_n_blocks, num_splits - 1 );
52
- };
53
- for (int num_splits = 1 ; num_splits <= max_splits; num_splits++) {
54
- if (!is_split_eligible (num_splits)) {
55
- efficiency.push_back (0 .f );
56
- } else {
57
- float n_waves = float (batch_nheads_mblocks * num_splits) / num_SMs;
58
- float eff = n_waves / ceil (n_waves);
59
- // printf("num_splits = %d, eff = %f\n", num_splits, eff);
60
- if (eff > max_efficiency) { max_efficiency = eff; }
61
- efficiency.push_back (eff);
62
- }
63
- }
64
- for (int num_splits = 1 ; num_splits <= max_splits; num_splits++) {
65
- if (!is_split_eligible (num_splits)) { continue ; }
66
- if (efficiency[num_splits - 1 ] >= 0.85 * max_efficiency) {
67
- // printf("num_splits chosen = %d\n", num_splits);
40
+ for (int num_splits = 1 ; num_splits <= max_splits; num_splits *= 2 ) {
41
+ if (num_SMs < batch_nheads_mblocks * (num_splits * 2 )) {
68
42
return num_splits;
69
43
}
70
44
}
71
- return 1 ;
45
+
46
+ return max_splits;
72
47
}
73
48
74
49
int override_num_splits_if_necessary (int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
0 commit comments