[PERFORMANCE] Implement Linear KV Cache optimization for inference acceleration #143
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements a linear KV cache optimization for inference that reduces memory complexity from O(N) to O(window_size) and computation complexity from O(N²) to O(N × window_size), where N is the sequence length.
Problem
During inference with dynamic mask attention, the current implementation:
For long sequences (N >> window_size), this becomes increasingly inefficient and memory-intensive.
Solution
The optimization leverages the mathematical insight that attention scores are static during inference:
Key Mathematical Property:
S = f(V)
be the attention scores (static/deterministic)M_N = TopK(S_{1:N})
be the selected indices for N tokensM_N = TopK(TopK(S_{1:N-1}), S_N) = TopK(M_{N-1}, S_N)
Proof of Optimality:
window_size
tokens instead of the full historyImplementation
Core Components
keep_window_size
key-value pairs with importance-based evictionUsage Example
Performance Results
Memory Usage Reduction
Speed Improvement
Key Benefits
Files Added
flash_dmattn/kv_cache_optimizer.py
- Core optimization implementationflash_dmattn/optimized_inference.py
- Integration with existing interfacedocs/linear_kv_cache_optimization.md
- Comprehensive documentationsimple_test_kv_cache.py
- Basic functionality testsexample_linear_kv_cache.py
- Performance demonstrationThe optimization is particularly valuable for production inference scenarios where memory efficiency, predictable performance, and cost optimization are critical considerations.
Fixes #141.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.