Skip to content

[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling #20301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
59 changes: 53 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

logger = init_logger(__name__)

# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -114,6 +117,7 @@ class FlashAttentionMetadata:
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0

# for local attention
@dataclass
Expand Down Expand Up @@ -158,15 +162,35 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
# yet. This is because the scheduler and kernel need to always use
# the same num_splits (which acts as an upper bound with the
# dynamic split scheduler) which is currently heuristically decided
# by the kernel launching code.
self.aot_schedule = False
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
capture_sizes = compilation_config.cudagraph_capture_sizes
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
if self.max_cudagraph_size > 992:
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were you hitting an assert/IMA here? if the batch is >992 FA3 "should" silently fall back to no dynamic split (AoT scheduling) granted if we fallback on no dynamic split; setting the split manually may be high since it will try to split to that instead of using it as an upper bound. (so perf may be bad so I think it still makes sense to limit this)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I didn't test bs > 992. I just wanted to add a safety check just because I'm not sure what will happen in the case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense; I agree with this approach. We can try to add support in a later PR

# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")

self.scheduler_metadata = torch.zeros(
self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
Expand Down Expand Up @@ -226,6 +250,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
)
return None

Expand Down Expand Up @@ -302,6 +327,26 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len=max_seq_len,
causal=True)

if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]

max_num_splits = 0
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits

attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand All @@ -318,6 +363,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
return attn_metadata

Expand Down Expand Up @@ -510,6 +556,7 @@ def forward(
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
)
return output

Expand Down