36
36
37
37
logger = init_logger (__name__ )
38
38
39
+ # NOTE(woosuk): This is an arbitrary number. Tune it if needed.
40
+ _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
41
+
39
42
40
43
class FlashAttentionBackend (AttentionBackend ):
41
44
@@ -114,6 +117,7 @@ class FlashAttentionMetadata:
114
117
# Optional aot scheduling
115
118
scheduler_metadata : Optional [torch .Tensor ] = None
116
119
prefix_scheduler_metadata : Optional [torch .Tensor ] = None
120
+ max_num_splits : int = 0
117
121
118
122
# for local attention
119
123
@dataclass
@@ -158,15 +162,35 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158
162
self .kv_cache_spec = kv_cache_spec
159
163
self .block_table = block_table
160
164
165
+ self .max_num_splits = 0 # No upper bound on the number of splits.
161
166
self .aot_schedule = (get_flash_attn_version () == 3 )
162
167
self .use_full_cuda_graph = compilation_config .full_cuda_graph
163
168
if self .use_full_cuda_graph :
164
- # NOTE(lucas): AOT scheduling not supported in full cuda graph mode
165
- # yet. This is because the scheduler and kernel need to always use
166
- # the same num_splits (which acts as an upper bound with the
167
- # dynamic split scheduler) which is currently heuristically decided
168
- # by the kernel launching code.
169
- self .aot_schedule = False
169
+ if not self .aot_schedule :
170
+ raise ValueError (
171
+ "AoT scheduling is required for full cuda graph." )
172
+ capture_sizes = compilation_config .cudagraph_capture_sizes
173
+ if not capture_sizes :
174
+ raise ValueError (
175
+ "cudagraph_capture_sizes should not be None when "
176
+ "full_cuda_graph is True." )
177
+ self .max_cudagraph_size = max (capture_sizes )
178
+ if self .max_cudagraph_size > 992 :
179
+ # This condition derives from FA3's internal heuristic.
180
+ # TODO(woosuk): Support larger cudagraph sizes.
181
+ raise ValueError (
182
+ "Capture size larger than 992 is not supported for "
183
+ "full cuda graph." )
184
+
185
+ self .scheduler_metadata = torch .zeros (
186
+ self .runner .max_num_reqs + 1 ,
187
+ dtype = torch .int32 ,
188
+ device = self .runner .device ,
189
+ )
190
+ # When using cuda graph, we need to set the upper bound of the
191
+ # number of splits so that large enough intermediate buffers are
192
+ # pre-allocated during capture.
193
+ self .max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
170
194
171
195
# Sliding window size to be used with the AOT scheduler will be
172
196
# populated on first build() call.
@@ -226,6 +250,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
226
250
cu_seqlens_q = cu_query_lens ,
227
251
causal = causal ,
228
252
window_size = self .aot_sliding_window ,
253
+ num_splits = self .max_num_splits ,
229
254
)
230
255
return None
231
256
@@ -302,6 +327,26 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
302
327
max_seq_len = max_seq_len ,
303
328
causal = True )
304
329
330
+ if self .use_full_cuda_graph :
331
+ assert scheduler_metadata is not None
332
+ n = scheduler_metadata .shape [0 ]
333
+ self .scheduler_metadata [:n ] = scheduler_metadata
334
+ # NOTE(woosuk): We should zero out the rest of the scheduler
335
+ # metadata to guarantee the correctness. Otherwise, some thread
336
+ # blocks may use the invalid scheduler metadata and overwrite the
337
+ # output buffer.
338
+ self .scheduler_metadata [n :] = 0
339
+ scheduler_metadata = self .scheduler_metadata [:n ]
340
+
341
+ max_num_splits = 0
342
+ if (self .use_full_cuda_graph
343
+ and num_actual_tokens <= self .max_cudagraph_size ):
344
+ # NOTE(woosuk): Setting num_splits > 1 may increase the memory
345
+ # usage, because the intermediate buffers of size [num_splits,
346
+ # num_heads, num_tokens, head_size] are allocated. Therefore,
347
+ # we only set num_splits when using cuda graphs.
348
+ max_num_splits = self .max_num_splits
349
+
305
350
attn_metadata = FlashAttentionMetadata (
306
351
num_actual_tokens = num_actual_tokens ,
307
352
max_query_len = max_query_len ,
@@ -318,6 +363,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
318
363
suffix_kv_lens = suffix_kv_lens ,
319
364
local_attn_metadata = local_attn_metadata ,
320
365
prefix_scheduler_metadata = prefix_scheduler_metadata ,
366
+ max_num_splits = max_num_splits ,
321
367
)
322
368
return attn_metadata
323
369
@@ -510,6 +556,7 @@ def forward(
510
556
q_descale = layer ._q_scale .expand (descale_shape ),
511
557
k_descale = layer ._k_scale .expand (descale_shape ),
512
558
v_descale = layer ._v_scale .expand (descale_shape ),
559
+ num_splits = attn_metadata .max_num_splits ,
513
560
)
514
561
return output
515
562
0 commit comments