Skip to content

Commit 2b6e58b

Browse files
WoosukKwonCSWYF3634076
authored andcommitted
[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (vllm-project#20301)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent ad741d6 commit 2b6e58b

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
41+
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

vllm/v1/attention/backends/flash_attn.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
logger = init_logger(__name__)
3838

39+
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
40+
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
41+
3942

4043
class FlashAttentionBackend(AttentionBackend):
4144

@@ -114,6 +117,7 @@ class FlashAttentionMetadata:
114117
# Optional aot scheduling
115118
scheduler_metadata: Optional[torch.Tensor] = None
116119
prefix_scheduler_metadata: Optional[torch.Tensor] = None
120+
max_num_splits: int = 0
117121

118122
# for local attention
119123
@dataclass
@@ -158,15 +162,35 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158162
self.kv_cache_spec = kv_cache_spec
159163
self.block_table = block_table
160164

165+
self.max_num_splits = 0 # No upper bound on the number of splits.
161166
self.aot_schedule = (get_flash_attn_version() == 3)
162167
self.use_full_cuda_graph = compilation_config.full_cuda_graph
163168
if self.use_full_cuda_graph:
164-
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
165-
# yet. This is because the scheduler and kernel need to always use
166-
# the same num_splits (which acts as an upper bound with the
167-
# dynamic split scheduler) which is currently heuristically decided
168-
# by the kernel launching code.
169-
self.aot_schedule = False
169+
if not self.aot_schedule:
170+
raise ValueError(
171+
"AoT scheduling is required for full cuda graph.")
172+
capture_sizes = compilation_config.cudagraph_capture_sizes
173+
if not capture_sizes:
174+
raise ValueError(
175+
"cudagraph_capture_sizes should not be None when "
176+
"full_cuda_graph is True.")
177+
self.max_cudagraph_size = max(capture_sizes)
178+
if self.max_cudagraph_size > 992:
179+
# This condition derives from FA3's internal heuristic.
180+
# TODO(woosuk): Support larger cudagraph sizes.
181+
raise ValueError(
182+
"Capture size larger than 992 is not supported for "
183+
"full cuda graph.")
184+
185+
self.scheduler_metadata = torch.zeros(
186+
self.runner.max_num_reqs + 1,
187+
dtype=torch.int32,
188+
device=self.runner.device,
189+
)
190+
# When using cuda graph, we need to set the upper bound of the
191+
# number of splits so that large enough intermediate buffers are
192+
# pre-allocated during capture.
193+
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
170194

171195
# Sliding window size to be used with the AOT scheduler will be
172196
# populated on first build() call.
@@ -226,6 +250,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
226250
cu_seqlens_q=cu_query_lens,
227251
causal=causal,
228252
window_size=self.aot_sliding_window,
253+
num_splits=self.max_num_splits,
229254
)
230255
return None
231256

@@ -302,6 +327,26 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
302327
max_seq_len=max_seq_len,
303328
causal=True)
304329

330+
if self.use_full_cuda_graph:
331+
assert scheduler_metadata is not None
332+
n = scheduler_metadata.shape[0]
333+
self.scheduler_metadata[:n] = scheduler_metadata
334+
# NOTE(woosuk): We should zero out the rest of the scheduler
335+
# metadata to guarantee the correctness. Otherwise, some thread
336+
# blocks may use the invalid scheduler metadata and overwrite the
337+
# output buffer.
338+
self.scheduler_metadata[n:] = 0
339+
scheduler_metadata = self.scheduler_metadata[:n]
340+
341+
max_num_splits = 0
342+
if (self.use_full_cuda_graph
343+
and num_actual_tokens <= self.max_cudagraph_size):
344+
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
345+
# usage, because the intermediate buffers of size [num_splits,
346+
# num_heads, num_tokens, head_size] are allocated. Therefore,
347+
# we only set num_splits when using cuda graphs.
348+
max_num_splits = self.max_num_splits
349+
305350
attn_metadata = FlashAttentionMetadata(
306351
num_actual_tokens=num_actual_tokens,
307352
max_query_len=max_query_len,
@@ -318,6 +363,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
318363
suffix_kv_lens=suffix_kv_lens,
319364
local_attn_metadata=local_attn_metadata,
320365
prefix_scheduler_metadata=prefix_scheduler_metadata,
366+
max_num_splits=max_num_splits,
321367
)
322368
return attn_metadata
323369

@@ -510,6 +556,7 @@ def forward(
510556
q_descale=layer._q_scale.expand(descale_shape),
511557
k_descale=layer._k_scale.expand(descale_shape),
512558
v_descale=layer._v_scale.expand(descale_shape),
559+
num_splits=attn_metadata.max_num_splits,
513560
)
514561
return output
515562

0 commit comments

Comments
 (0)