Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ supported_kv_types = MInferenceConfig.get_available_kv_types()

Currently, we support the following long-context methods:

- **[① KV Cache Generation]:** [MInference](https://arxiv.org/abs/2407.02490), [FlexPrefill](https://openreview.net/forum?id=OfjIlbelrT), [A-shape](https://arxiv.org/abs/2309.17453), [Tri-shape](https://arxiv.org/abs/2412.10319), [MInference w/ static](https://arxiv.org/abs/2407.02490), [Dilated](https://arxiv.org/abs/2004.05150), [Strided](https://arxiv.org/abs/1904.10509)
- **[① KV Cache Generation]:** [MInference](https://arxiv.org/abs/2407.02490), [xAttention](https://arxiv.org/abs/2503.16428), [FlexPrefill](https://arxiv.org/abs/2502.20766), [A-shape](https://arxiv.org/abs/2309.17453), [Tri-shape](https://arxiv.org/abs/2412.10319), [MInference w/ static](https://arxiv.org/abs/2407.02490), [Dilated](https://arxiv.org/abs/2004.05150), [Strided](https://arxiv.org/abs/1904.10509)
- **[② KV Cache Compression]:** [StreamingLLM](https://arxiv.org/abs/2309.17453), [SnapKV](https://arxiv.org/abs/2404.14469), [PyramidKV](https://arxiv.org/abs/2406.02069), [KIVI](https://arxiv.org/abs/2402.02750)
- **[③ KV Cache Retrieval]:** [CacheBlend](https://arxiv.org/abs/2405.16444)
- **[④ KV Cache Loading]:** [Quest](https://arxiv.org/abs/2406.10774), [RetrievalAttention](https://arxiv.org/abs/2409.10516)
Expand Down
7 changes: 3 additions & 4 deletions minference/minference_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) 2024 Microsoft
# Copyright (c) 2024-2025 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import os

from .configs.model2path import MODEL2PATH


Expand All @@ -27,6 +25,7 @@ class MInferenceConfig:
"inf_llm",
"flexprefill",
"vllm_flexprefill",
"xattention",
]
KV_TYPES = [
"dense",
Expand Down Expand Up @@ -72,7 +71,7 @@ def __init__(
self.kv_type = kv_type
self.attn_kwargs = attn_kwargs

def update_config_path(self, config_path: str, model_name: str):
def update_config_path(self, config_path: str = None, model_name: str = None):
if self.attn_type in self.OTHER_ATTENTION_TYPES:
return ""
if config_path is not None:
Expand Down
4 changes: 2 additions & 2 deletions minference/models_patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft
# Copyright (c) 2024-2025 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import json
Expand Down Expand Up @@ -97,7 +97,7 @@ def patch_model(self, model):
self.config.attn_kwargs.setdefault("n_last", 100)
model = new_patch(model, self.config)

elif self.config.attn_type in ["flexprefill", "dense"]:
elif self.config.attn_type in ["flexprefill", "dense", "xattention"]:
model = new_patch(model, self.config)

elif self.config.attn_type == "dilated1":
Expand Down
2 changes: 2 additions & 0 deletions minference/modules/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..modules.minference_forward import minference_prefill_forward
from ..modules.quest import quest_decode_kernel
from ..modules.retr_attn import retr_attn
from ..modules.xattention import xattention_forward
from ..ops.streaming_kernel import a_shape_kernel, tri_shape_kernel


Expand Down Expand Up @@ -187,6 +188,7 @@ def attn_forward(
"tri_shape": tri_shape_kernel,
"minference": minference_prefill_forward,
"flexprefill": flexprefill_forward,
"xattention": xattention_forward,
}

decoding_forwards = {
Expand Down
Loading