Skip to content

Commit 90ec006

Browse files
zyongyeLiuXiaoxuanPKUsimon-moheheda12345WoosukKwon
authored
[gpt-oss] flashinfer attention sink init (vllm-project#22330)
Signed-off-by: simon-mo <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: Minseok Lee <[email protected]>
1 parent a47e6ff commit 90ec006

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ def __init__(
611611
logits_soft_cap: Optional[float] = None,
612612
attn_type: AttentionType = AttentionType.DECODER,
613613
kv_sharing_target_layer_name: Optional[int] = None,
614+
sinks: Optional[torch.Tensor] = None,
614615
) -> None:
615616
self.num_heads = num_heads
616617
self.head_size = head_size
@@ -635,6 +636,15 @@ def __init__(
635636
"are not implemented for "
636637
"FlashInferImpl")
637638

639+
self.sinks: Optional[torch.Tensor] = None
640+
if sinks is not None:
641+
assert sinks.shape[0] == num_heads, (
642+
"Sinks must have the same number of heads "
643+
"as the number of heads in the layer"
644+
)
645+
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
646+
self.sinks = sinks
647+
638648
def forward(
639649
self,
640650
layer: torch.nn.Module,

0 commit comments

Comments
 (0)