Skip to content

Commit cf02f9b

Browse files
authored
Add FlexAttention to V1 (#16078)
Signed-off-by: drisspg <[email protected]>
1 parent c4296b1 commit cf02f9b

File tree

5 files changed

+575
-0
lines changed

5 files changed

+575
-0
lines changed

tests/kernels/test_flex_attention.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Integration tests for FlexAttention backend vs default backend"""
3+
4+
import random
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from packaging import version
10+
11+
from vllm import LLM, SamplingParams
12+
13+
TORCH_VERSION = version.parse(torch.__version__)
14+
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
15+
16+
17+
def set_seed(seed):
18+
"""Set seeds for reproducibility"""
19+
random.seed(seed)
20+
np.random.seed(seed)
21+
torch.manual_seed(seed)
22+
if torch.cuda.is_available():
23+
torch.cuda.manual_seed_all(seed)
24+
25+
26+
@pytest.mark.skipif(
27+
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
28+
reason="CUDA not available or PyTorch version < 2.7",
29+
)
30+
def test_flex_attention_vs_default_backend(monkeypatch):
31+
"""Test that FlexAttention produces the same outputs as the default backend.
32+
33+
This test compares the outputs from the FlexAttention backend with
34+
the default backend, ensuring they are identical when using the same seed.
35+
"""
36+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
37+
seed = 42
38+
max_tokens = 32
39+
prompts = [
40+
"Hello, my name is",
41+
"The president of the United States is",
42+
"The capital of France is",
43+
]
44+
45+
sampling_params = SamplingParams(temperature=0.0,
46+
top_p=1.0,
47+
seed=seed,
48+
max_tokens=max_tokens)
49+
50+
# Run with flex attention
51+
with monkeypatch.context() as m:
52+
m.setenv("VLLM_USE_V1", "1")
53+
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
54+
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
55+
56+
set_seed(seed)
57+
58+
llm_flex = LLM(
59+
model_name,
60+
tensor_parallel_size=1,
61+
num_gpu_blocks_override=128,
62+
enforce_eager=True,
63+
)
64+
output_flex = llm_flex.generate(prompts, sampling_params)
65+
66+
# Run with default backend
67+
with monkeypatch.context() as m:
68+
m.setenv("VLLM_USE_V1", "1")
69+
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
70+
set_seed(seed)
71+
llm_default = LLM(
72+
model_name,
73+
tensor_parallel_size=1,
74+
num_gpu_blocks_override=128,
75+
enforce_eager=True,
76+
)
77+
output_default = llm_default.generate(prompts, sampling_params)
78+
79+
# Compare outputs from both backends
80+
for i, (flex_result,
81+
default_result) in enumerate(zip(output_flex, output_default)):
82+
prompt = prompts[i]
83+
flex_text = flex_result.outputs[0].text
84+
default_text = default_result.outputs[0].text
85+
86+
assert flex_text == default_text, (
87+
f"FlexAttention output doesn't match default for: {prompt!r}\n"
88+
f"FlexAttention: {flex_text!r}\n"
89+
f"Default: {default_text!r}")
90+
91+
92+
if __name__ == "__main__":
93+
pytest.main([__file__])

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14091409
"FLASHINFER_VLLM_V1",
14101410
"ROCM_AITER_MLA",
14111411
"TORCH_SDPA_VLLM_V1",
1412+
"FLEX_ATTENTION",
14121413
]
14131414
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14141415
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/platforms/cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
226226
if selected_backend == _Backend.FLASHINFER:
227227
logger.info_once("Using FlashInfer backend on V1 engine.")
228228
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
229+
if selected_backend == _Backend.FLEX_ATTENTION:
230+
logger.info("Using FlexAttenion backend on V1 engine.")
231+
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
229232
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
230233
logger.info_once("Using Triton backend on V1 engine.")
231234
return ("vllm.v1.attention.backends."

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class _Backend(enum.Enum):
6060
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
6161
DUAL_CHUNK_FLASH_ATTN = enum.auto()
6262
NO_ATTENTION = enum.auto()
63+
FLEX_ATTENTION = enum.auto()
6364

6465

6566
class PlatformEnum(enum.Enum):

0 commit comments

Comments
 (0)