Skip to content

Commit 91a0506

Browse files
Feature(MInference): add xAttention (#149)
Co-authored-by: Guangxuan Xiao <[email protected]>
1 parent 9d76f96 commit 91a0506

File tree

6 files changed

+975
-7
lines changed

6 files changed

+975
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ supported_kv_types = MInferenceConfig.get_available_kv_types()
8787

8888
Currently, we support the following long-context methods:
8989

90-
- **[① KV Cache Generation]:** [MInference](https://arxiv.org/abs/2407.02490), [FlexPrefill](https://openreview.net/forum?id=OfjIlbelrT), [A-shape](https://arxiv.org/abs/2309.17453), [Tri-shape](https://arxiv.org/abs/2412.10319), [MInference w/ static](https://arxiv.org/abs/2407.02490), [Dilated](https://arxiv.org/abs/2004.05150), [Strided](https://arxiv.org/abs/1904.10509)
90+
- **[① KV Cache Generation]:** [MInference](https://arxiv.org/abs/2407.02490), [xAttention](https://arxiv.org/abs/2503.16428), [FlexPrefill](https://arxiv.org/abs/2502.20766), [A-shape](https://arxiv.org/abs/2309.17453), [Tri-shape](https://arxiv.org/abs/2412.10319), [MInference w/ static](https://arxiv.org/abs/2407.02490), [Dilated](https://arxiv.org/abs/2004.05150), [Strided](https://arxiv.org/abs/1904.10509)
9191
- **[② KV Cache Compression]:** [StreamingLLM](https://arxiv.org/abs/2309.17453), [SnapKV](https://arxiv.org/abs/2404.14469), [PyramidKV](https://arxiv.org/abs/2406.02069), [KIVI](https://arxiv.org/abs/2402.02750)
9292
- **[③ KV Cache Retrieval]:** [CacheBlend](https://arxiv.org/abs/2405.16444)
9393
- **[④ KV Cache Loading]:** [Quest](https://arxiv.org/abs/2406.10774), [RetrievalAttention](https://arxiv.org/abs/2409.10516)

minference/minference_configuration.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
# Copyright (c) 2024 Microsoft
1+
# Copyright (c) 2024-2025 Microsoft
22
# Licensed under The MIT License [see LICENSE for details]
33

4-
import os
5-
64
from .configs.model2path import MODEL2PATH
75

86

@@ -27,6 +25,7 @@ class MInferenceConfig:
2725
"inf_llm",
2826
"flexprefill",
2927
"vllm_flexprefill",
28+
"xattention",
3029
]
3130
KV_TYPES = [
3231
"dense",
@@ -72,7 +71,7 @@ def __init__(
7271
self.kv_type = kv_type
7372
self.attn_kwargs = attn_kwargs
7473

75-
def update_config_path(self, config_path: str, model_name: str):
74+
def update_config_path(self, config_path: str = None, model_name: str = None):
7675
if self.attn_type in self.OTHER_ATTENTION_TYPES:
7776
return ""
7877
if config_path is not None:

minference/models_patch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024 Microsoft
1+
# Copyright (c) 2024-2025 Microsoft
22
# Licensed under The MIT License [see LICENSE for details]
33

44
import json
@@ -97,7 +97,7 @@ def patch_model(self, model):
9797
self.config.attn_kwargs.setdefault("n_last", 100)
9898
model = new_patch(model, self.config)
9999

100-
elif self.config.attn_type in ["flexprefill", "dense"]:
100+
elif self.config.attn_type in ["flexprefill", "dense", "xattention"]:
101101
model = new_patch(model, self.config)
102102

103103
elif self.config.attn_type == "dilated1":

minference/modules/forward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..modules.minference_forward import minference_prefill_forward
1414
from ..modules.quest import quest_decode_kernel
1515
from ..modules.retr_attn import retr_attn
16+
from ..modules.xattention import xattention_forward
1617
from ..ops.streaming_kernel import a_shape_kernel, tri_shape_kernel
1718

1819

@@ -187,6 +188,7 @@ def attn_forward(
187188
"tri_shape": tri_shape_kernel,
188189
"minference": minference_prefill_forward,
189190
"flexprefill": flexprefill_forward,
191+
"xattention": xattention_forward,
190192
}
191193

192194
decoding_forwards = {

0 commit comments

Comments
 (0)