3
3
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type
4
4
5
5
import torch
6
- from vllm_flash_attn import flash_attn_varlen_func , flash_attn_with_kvcache
7
6
8
7
from vllm import _custom_ops as ops
9
8
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
18
17
if TYPE_CHECKING :
19
18
from vllm .worker .model_runner import ModelInputForGPUBuilder
20
19
20
+ from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
21
+ from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
22
+
23
+
24
+ @torch .library .custom_op ("vllm::flash_attn_varlen_func" , mutates_args = [])
25
+ def flash_attn_varlen_func (
26
+ q : torch .Tensor ,
27
+ k : torch .Tensor ,
28
+ v : torch .Tensor ,
29
+ cu_seqlens_q : torch .Tensor ,
30
+ cu_seqlens_k : torch .Tensor ,
31
+ max_seqlen_q : int ,
32
+ max_seqlen_k : int ,
33
+ softmax_scale : Optional [float ] = None ,
34
+ causal : bool = False ,
35
+ window_size : Optional [List [int ]] = None ,
36
+ softcap : float = 0.0 ,
37
+ alibi_slopes : Optional [torch .Tensor ] = None ,
38
+ block_table : Optional [torch .Tensor ] = None ,
39
+ ) -> torch .Tensor :
40
+ # custom op does not support tuple input
41
+ real_window_size : Tuple [int , int ]
42
+ if window_size is None :
43
+ real_window_size = (- 1 , - 1 )
44
+ else :
45
+ assert len (window_size ) == 2
46
+ real_window_size = (window_size [0 ], window_size [1 ])
47
+ return _flash_attn_varlen_func (
48
+ q = q ,
49
+ k = k ,
50
+ v = v ,
51
+ cu_seqlens_q = cu_seqlens_q ,
52
+ cu_seqlens_k = cu_seqlens_k ,
53
+ max_seqlen_q = max_seqlen_q ,
54
+ max_seqlen_k = max_seqlen_k ,
55
+ softmax_scale = softmax_scale ,
56
+ causal = causal ,
57
+ window_size = real_window_size ,
58
+ softcap = softcap ,
59
+ alibi_slopes = alibi_slopes ,
60
+ block_table = block_table ,
61
+ )
62
+
63
+
64
+ @flash_attn_varlen_func .register_fake # type: ignore
65
+ def _ (
66
+ q : torch .Tensor ,
67
+ k : torch .Tensor ,
68
+ v : torch .Tensor ,
69
+ cu_seqlens_q : torch .Tensor ,
70
+ cu_seqlens_k : torch .Tensor ,
71
+ max_seqlen_q : int ,
72
+ max_seqlen_k : int ,
73
+ softmax_scale : Optional [float ] = None ,
74
+ causal : bool = False ,
75
+ window_size : Optional [List [int ]] = None ,
76
+ softcap : float = 0.0 ,
77
+ alibi_slopes : Optional [torch .Tensor ] = None ,
78
+ block_table : Optional [torch .Tensor ] = None ,
79
+ ) -> torch .Tensor :
80
+ return torch .empty_like (q )
81
+
82
+
83
+ @torch .library .custom_op ("vllm::flash_attn_with_kvcache" , mutates_args = [])
84
+ def flash_attn_with_kvcache (
85
+ decode_query : torch .Tensor ,
86
+ key_cache : torch .Tensor ,
87
+ value_cache : torch .Tensor ,
88
+ cache_seqlens : Optional [torch .Tensor ] = None ,
89
+ block_table : Optional [torch .Tensor ] = None ,
90
+ softmax_scale : Optional [float ] = None ,
91
+ causal : bool = False ,
92
+ alibi_slopes : Optional [torch .Tensor ] = None ,
93
+ softcap : float = 0.0 ,
94
+ ) -> torch .Tensor :
95
+ return _flash_attn_with_kvcache (
96
+ decode_query ,
97
+ key_cache ,
98
+ value_cache ,
99
+ cache_seqlens = cache_seqlens ,
100
+ block_table = block_table ,
101
+ softmax_scale = softmax_scale ,
102
+ causal = causal ,
103
+ alibi_slopes = alibi_slopes ,
104
+ softcap = softcap ,
105
+ )
106
+
107
+
108
+ @flash_attn_with_kvcache .register_fake # type: ignore
109
+ def _ (
110
+ decode_query : torch .Tensor ,
111
+ key_cache : torch .Tensor ,
112
+ value_cache : torch .Tensor ,
113
+ cache_seqlens : Optional [torch .Tensor ] = None ,
114
+ block_table : Optional [torch .Tensor ] = None ,
115
+ softmax_scale : Optional [float ] = None ,
116
+ causal : bool = False ,
117
+ alibi_slopes : Optional [torch .Tensor ] = None ,
118
+ softcap : float = 0.0 ,
119
+ ) -> torch .Tensor :
120
+ return torch .empty_like (decode_query )
121
+
21
122
22
123
class FlashAttentionBackend (AttentionBackend ):
23
124
@@ -517,7 +618,7 @@ def forward(
517
618
# normal attention
518
619
# When block_tables are not filled, it means q and k are the
519
620
# prompt, and they have the same length.
520
- out = flash_attn_varlen_func (
621
+ out = torch . ops . vllm . flash_attn_varlen_func (
521
622
q = query ,
522
623
k = key ,
523
624
v = value ,
@@ -537,34 +638,36 @@ def forward(
537
638
# prefix-enabled attention
538
639
assert prefill_meta .seq_lens is not None
539
640
max_seq_len = max (prefill_meta .seq_lens )
540
- output [:num_prefill_tokens ] = flash_attn_varlen_func (
541
- q = query ,
542
- k = key_cache ,
543
- v = value_cache ,
544
- cu_seqlens_q = prefill_meta .query_start_loc ,
545
- max_seqlen_q = prefill_meta .max_query_len ,
546
- cu_seqlens_k = prefill_meta .seq_start_loc ,
547
- max_seqlen_k = max_seq_len ,
641
+ output [:
642
+ num_prefill_tokens ] = torch .ops .vllm .flash_attn_varlen_func ( # noqa
643
+ q = query ,
644
+ k = key_cache ,
645
+ v = value_cache ,
646
+ cu_seqlens_q = prefill_meta .query_start_loc ,
647
+ max_seqlen_q = prefill_meta .max_query_len ,
648
+ cu_seqlens_k = prefill_meta .seq_start_loc ,
649
+ max_seqlen_k = max_seq_len ,
650
+ softmax_scale = self .scale ,
651
+ causal = True ,
652
+ alibi_slopes = self .alibi_slopes ,
653
+ block_table = prefill_meta .block_tables ,
654
+ softcap = self .logits_soft_cap ,
655
+ )
656
+
657
+ if decode_meta := attn_metadata .decode_metadata :
658
+ # Decoding run.
659
+ output [
660
+ num_prefill_tokens :] = torch .ops .vllm .flash_attn_with_kvcache (
661
+ decode_query .unsqueeze (1 ),
662
+ key_cache ,
663
+ value_cache ,
664
+ block_table = decode_meta .block_tables ,
665
+ cache_seqlens = decode_meta .seq_lens_tensor ,
548
666
softmax_scale = self .scale ,
549
667
causal = True ,
550
668
alibi_slopes = self .alibi_slopes ,
551
- block_table = prefill_meta .block_tables ,
552
669
softcap = self .logits_soft_cap ,
553
- )
554
-
555
- if decode_meta := attn_metadata .decode_metadata :
556
- # Decoding run.
557
- output [num_prefill_tokens :] = flash_attn_with_kvcache (
558
- decode_query .unsqueeze (1 ),
559
- key_cache ,
560
- value_cache ,
561
- block_table = decode_meta .block_tables ,
562
- cache_seqlens = decode_meta .seq_lens_tensor ,
563
- softmax_scale = self .scale ,
564
- causal = True ,
565
- alibi_slopes = self .alibi_slopes ,
566
- softcap = self .logits_soft_cap ,
567
- ).squeeze (1 )
670
+ ).squeeze (1 )
568
671
569
672
# Reshape the output tensor.
570
673
return output .view (num_tokens , hidden_size )
0 commit comments