Skip to content

Commit 54bd9a0

Browse files
authored
register custom op for flash attn and use from torch.ops (#7536)
1 parent 50b8d08 commit 54bd9a0

File tree

5 files changed

+220
-41
lines changed

5 files changed

+220
-41
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ steps:
163163
- pytest -v -s models/test_oot_registration.py # it needs a clean process
164164
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
165165

166+
- label: torch compile integration test
167+
source_file_dependencies:
168+
- vllm/
169+
commands:
170+
- pytest -v -s ./compile/test_full_graph.py
171+
172+
166173
- label: Vision Language Models Test # 42min
167174
mirror_hardwares: [amd]
168175
source_file_dependencies:

tests/compile/test_full_graph.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
3+
import pytest
4+
5+
6+
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
7+
def test_full_graph(model):
8+
# make sure these models can be captured in full graph mode
9+
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
10+
11+
from vllm import LLM, SamplingParams
12+
prompts = [
13+
"Hello, my name is",
14+
"The president of the United States is",
15+
"The capital of France is",
16+
"The future of AI is",
17+
]
18+
sampling_params = SamplingParams(temperature=0)
19+
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
20+
llm.generate(prompts, sampling_params)

tests/kernels/test_flash_attn.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import pytest
44
import torch
5-
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
65

7-
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
6+
import vllm.attention.backends.flash_attn # noqa: F401
7+
8+
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
89
HEAD_SIZES = [128, 256]
910
BLOCK_SIZES = [16, 32]
1011
DTYPES = [torch.float16, torch.bfloat16]
11-
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
12+
# one value large enough to test overflow in index calculation.
13+
# one value small enough to test the schema op check
14+
NUM_BLOCKS = [32768, 2048]
1215

1316

1417
def ref_paged_attn(
@@ -72,6 +75,7 @@ def ref_paged_attn(
7275
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
7376
@pytest.mark.parametrize("dtype", DTYPES)
7477
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
78+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
7579
@torch.inference_mode()
7680
def test_flash_attn_with_paged_kv(
7781
kv_lens: List[int],
@@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
8084
dtype: torch.dtype,
8185
block_size: int,
8286
soft_cap: Optional[float],
87+
num_blocks: int,
8388
) -> None:
8489
torch.set_default_device("cuda")
8590
torch.cuda.manual_seed_all(0)
@@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
9196
scale = head_size**-0.5
9297

9398
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
94-
key_cache = torch.randn(NUM_BLOCKS,
99+
key_cache = torch.randn(num_blocks,
95100
block_size,
96101
num_kv_heads,
97102
head_size,
@@ -101,21 +106,40 @@ def test_flash_attn_with_paged_kv(
101106

102107
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
103108
block_tables = torch.randint(0,
104-
NUM_BLOCKS,
109+
num_blocks,
105110
(num_seqs, max_num_blocks_per_seq),
106111
dtype=torch.int32)
107112

108-
output = flash_attn_with_kvcache(
109-
q=query.unsqueeze(1),
110-
k_cache=key_cache,
111-
v_cache=value_cache,
113+
output = torch.ops.vllm.flash_attn_with_kvcache(
114+
decode_query=query.unsqueeze(1),
115+
key_cache=key_cache,
116+
value_cache=value_cache,
112117
softmax_scale=scale,
113118
causal=True,
114119
block_table=block_tables,
115120
cache_seqlens=kv_lens_tensor,
116121
softcap=soft_cap if soft_cap is not None else 0,
117122
).squeeze(1)
118123

124+
if num_blocks <= 2048:
125+
test_utils = ["test_faketensor", "test_schema"]
126+
else:
127+
test_utils = ["test_faketensor"]
128+
129+
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
130+
args=tuple(),
131+
kwargs=dict(
132+
decode_query=query.unsqueeze(1),
133+
key_cache=key_cache,
134+
value_cache=value_cache,
135+
softmax_scale=scale,
136+
causal=True,
137+
block_table=block_tables,
138+
cache_seqlens=kv_lens_tensor,
139+
softcap=soft_cap if soft_cap is not None else 0,
140+
),
141+
test_utils=test_utils)
142+
119143
ref_output = ref_paged_attn(
120144
query=query,
121145
key_cache=key_cache,
@@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
137161
@pytest.mark.parametrize("sliding_window", [None])
138162
@pytest.mark.parametrize("dtype", DTYPES)
139163
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
164+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
140165
@torch.inference_mode()
141166
def test_varlen_with_paged_kv(
142167
seq_lens: List[Tuple[int, int]],
@@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
146171
dtype: torch.dtype,
147172
block_size: int,
148173
soft_cap: Optional[float],
174+
num_blocks: int,
149175
) -> None:
150176
torch.set_default_device("cuda")
151177
torch.cuda.manual_seed_all(0)
@@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
166192
num_query_heads,
167193
head_size,
168194
dtype=dtype)
169-
key_cache = torch.randn(NUM_BLOCKS,
195+
key_cache = torch.randn(num_blocks,
170196
block_size,
171197
num_kv_heads,
172198
head_size,
@@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
181207

182208
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
183209
block_tables = torch.randint(0,
184-
NUM_BLOCKS,
210+
num_blocks,
185211
(num_seqs, max_num_blocks_per_seq),
186212
dtype=torch.int32)
187213

188-
output = flash_attn_varlen_func(
214+
output = torch.ops.vllm.flash_attn_varlen_func(
189215
q=query,
190216
k=key_cache,
191217
v=value_cache,
@@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
200226
softcap=soft_cap if soft_cap is not None else 0,
201227
)
202228

229+
if num_blocks <= 2048:
230+
test_utils = ["test_faketensor", "test_schema"]
231+
else:
232+
test_utils = ["test_faketensor"]
233+
234+
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
235+
args=tuple(),
236+
kwargs=dict(
237+
q=query,
238+
k=key_cache,
239+
v=value_cache,
240+
cu_seqlens_q=cu_query_lens,
241+
cu_seqlens_k=cu_kv_lens,
242+
max_seqlen_q=max_query_len,
243+
max_seqlen_k=max_kv_len,
244+
softmax_scale=scale,
245+
causal=True,
246+
window_size=window_size,
247+
block_table=block_tables,
248+
softcap=soft_cap if soft_cap is not None else 0,
249+
),
250+
test_utils=test_utils)
251+
203252
ref_output = ref_paged_attn(
204253
query=query,
205254
key_cache=key_cache,

vllm/attention/backends/flash_attn.py

Lines changed: 129 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
44

55
import torch
6-
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
76

87
from vllm import _custom_ops as ops
98
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -18,6 +17,108 @@
1817
if TYPE_CHECKING:
1918
from vllm.worker.model_runner import ModelInputForGPUBuilder
2019

20+
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
21+
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
22+
23+
24+
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
25+
def flash_attn_varlen_func(
26+
q: torch.Tensor,
27+
k: torch.Tensor,
28+
v: torch.Tensor,
29+
cu_seqlens_q: torch.Tensor,
30+
cu_seqlens_k: torch.Tensor,
31+
max_seqlen_q: int,
32+
max_seqlen_k: int,
33+
softmax_scale: Optional[float] = None,
34+
causal: bool = False,
35+
window_size: Optional[List[int]] = None,
36+
softcap: float = 0.0,
37+
alibi_slopes: Optional[torch.Tensor] = None,
38+
block_table: Optional[torch.Tensor] = None,
39+
) -> torch.Tensor:
40+
# custom op does not support tuple input
41+
real_window_size: Tuple[int, int]
42+
if window_size is None:
43+
real_window_size = (-1, -1)
44+
else:
45+
assert len(window_size) == 2
46+
real_window_size = (window_size[0], window_size[1])
47+
return _flash_attn_varlen_func(
48+
q=q,
49+
k=k,
50+
v=v,
51+
cu_seqlens_q=cu_seqlens_q,
52+
cu_seqlens_k=cu_seqlens_k,
53+
max_seqlen_q=max_seqlen_q,
54+
max_seqlen_k=max_seqlen_k,
55+
softmax_scale=softmax_scale,
56+
causal=causal,
57+
window_size=real_window_size,
58+
softcap=softcap,
59+
alibi_slopes=alibi_slopes,
60+
block_table=block_table,
61+
)
62+
63+
64+
@flash_attn_varlen_func.register_fake # type: ignore
65+
def _(
66+
q: torch.Tensor,
67+
k: torch.Tensor,
68+
v: torch.Tensor,
69+
cu_seqlens_q: torch.Tensor,
70+
cu_seqlens_k: torch.Tensor,
71+
max_seqlen_q: int,
72+
max_seqlen_k: int,
73+
softmax_scale: Optional[float] = None,
74+
causal: bool = False,
75+
window_size: Optional[List[int]] = None,
76+
softcap: float = 0.0,
77+
alibi_slopes: Optional[torch.Tensor] = None,
78+
block_table: Optional[torch.Tensor] = None,
79+
) -> torch.Tensor:
80+
return torch.empty_like(q)
81+
82+
83+
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
84+
def flash_attn_with_kvcache(
85+
decode_query: torch.Tensor,
86+
key_cache: torch.Tensor,
87+
value_cache: torch.Tensor,
88+
cache_seqlens: Optional[torch.Tensor] = None,
89+
block_table: Optional[torch.Tensor] = None,
90+
softmax_scale: Optional[float] = None,
91+
causal: bool = False,
92+
alibi_slopes: Optional[torch.Tensor] = None,
93+
softcap: float = 0.0,
94+
) -> torch.Tensor:
95+
return _flash_attn_with_kvcache(
96+
decode_query,
97+
key_cache,
98+
value_cache,
99+
cache_seqlens=cache_seqlens,
100+
block_table=block_table,
101+
softmax_scale=softmax_scale,
102+
causal=causal,
103+
alibi_slopes=alibi_slopes,
104+
softcap=softcap,
105+
)
106+
107+
108+
@flash_attn_with_kvcache.register_fake # type: ignore
109+
def _(
110+
decode_query: torch.Tensor,
111+
key_cache: torch.Tensor,
112+
value_cache: torch.Tensor,
113+
cache_seqlens: Optional[torch.Tensor] = None,
114+
block_table: Optional[torch.Tensor] = None,
115+
softmax_scale: Optional[float] = None,
116+
causal: bool = False,
117+
alibi_slopes: Optional[torch.Tensor] = None,
118+
softcap: float = 0.0,
119+
) -> torch.Tensor:
120+
return torch.empty_like(decode_query)
121+
21122

22123
class FlashAttentionBackend(AttentionBackend):
23124

@@ -517,7 +618,7 @@ def forward(
517618
# normal attention
518619
# When block_tables are not filled, it means q and k are the
519620
# prompt, and they have the same length.
520-
out = flash_attn_varlen_func(
621+
out = torch.ops.vllm.flash_attn_varlen_func(
521622
q=query,
522623
k=key,
523624
v=value,
@@ -537,34 +638,36 @@ def forward(
537638
# prefix-enabled attention
538639
assert prefill_meta.seq_lens is not None
539640
max_seq_len = max(prefill_meta.seq_lens)
540-
output[:num_prefill_tokens] = flash_attn_varlen_func(
541-
q=query,
542-
k=key_cache,
543-
v=value_cache,
544-
cu_seqlens_q=prefill_meta.query_start_loc,
545-
max_seqlen_q=prefill_meta.max_query_len,
546-
cu_seqlens_k=prefill_meta.seq_start_loc,
547-
max_seqlen_k=max_seq_len,
641+
output[:
642+
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
643+
q=query,
644+
k=key_cache,
645+
v=value_cache,
646+
cu_seqlens_q=prefill_meta.query_start_loc,
647+
max_seqlen_q=prefill_meta.max_query_len,
648+
cu_seqlens_k=prefill_meta.seq_start_loc,
649+
max_seqlen_k=max_seq_len,
650+
softmax_scale=self.scale,
651+
causal=True,
652+
alibi_slopes=self.alibi_slopes,
653+
block_table=prefill_meta.block_tables,
654+
softcap=self.logits_soft_cap,
655+
)
656+
657+
if decode_meta := attn_metadata.decode_metadata:
658+
# Decoding run.
659+
output[
660+
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
661+
decode_query.unsqueeze(1),
662+
key_cache,
663+
value_cache,
664+
block_table=decode_meta.block_tables,
665+
cache_seqlens=decode_meta.seq_lens_tensor,
548666
softmax_scale=self.scale,
549667
causal=True,
550668
alibi_slopes=self.alibi_slopes,
551-
block_table=prefill_meta.block_tables,
552669
softcap=self.logits_soft_cap,
553-
)
554-
555-
if decode_meta := attn_metadata.decode_metadata:
556-
# Decoding run.
557-
output[num_prefill_tokens:] = flash_attn_with_kvcache(
558-
decode_query.unsqueeze(1),
559-
key_cache,
560-
value_cache,
561-
block_table=decode_meta.block_tables,
562-
cache_seqlens=decode_meta.seq_lens_tensor,
563-
softmax_scale=self.scale,
564-
causal=True,
565-
alibi_slopes=self.alibi_slopes,
566-
softcap=self.logits_soft_cap,
567-
).squeeze(1)
670+
).squeeze(1)
568671

569672
# Reshape the output tensor.
570673
return output.view(num_tokens, hidden_size)

0 commit comments

Comments
 (0)