Skip to content

Commit ae5de6a

Browse files
authored
Adding JIT cache (#5)
The launch overhead of triton kernels is a well known problem (see e.g. [1](triton-lang/triton#3503), [2](triton-lang/triton#2637), [3](triton-lang/triton#6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use. In many scenarios, these checks can be relaxed. This PR adds such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator: ``` @triton_dejavu.jitcache( check_keys=["x", "BLOCK_SIZE", "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"], ) @triton.jit def kernel_paged_attention_.... ``` Details see Readme. --------- Signed-off-by: Burkhard Ringlein <[email protected]>
1 parent 2f7c278 commit ae5de6a

File tree

3 files changed

+396
-5
lines changed

3 files changed

+396
-5
lines changed

README.md

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ Triton Deja-vu
33
Framework to reduce autotune overhead of [triton-lang](https://github.com/triton-lang/triton) to zero for well known deployments.
44

55
This small framework is based on the [Triton autotuner](https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py) and contributes three features to the Triton community:
6-
1. Store and safely restore autotuner states using JSON files.
6+
1. Store and safely restore autotuner states using JSON files.
77
2. `ConfigSpaces` to explore a defined space exhaustively.
88
3. Bayesian Optimization to speed up the autotuning process.
99

10-
Additionally, it allows to use heuristics in combination with the autotuner. Please find more details in the [feature section below](#features).
10+
Additionally, it allows to use heuristics in combination with the autotuner. Please find more details in the [feature section below](#features).
11+
12+
Besides improvements for autotuning, it also contains useful tools in working with triton, specifically a [cache for JIT-artifacts](#jitcache).
1113

1214

1315
Installation
@@ -31,7 +33,7 @@ import triton_dejavu
3133
@triton_dejavu.autotune(
3234
...
3335
```
34-
Second, the environment variable `TRITON_DEJAVU_STORAGE` needs to be set and point to a read and writable directory.
36+
Second, the environment variable `TRITON_DEJAVU_STORAGE` needs to be set and point to a read and writable directory.
3537

3638

3739
To use the `ConfigSpaces` feature, replace the `config` parameter for the triton_dejavu autotuner with `config_space` definition:
@@ -90,7 +92,7 @@ So far, we think that the above listed combination determines the applicability
9092

9193
In addition, users can define a tag to be used by the dejavu storage to be able to differentiate different deployment scenarios (for otherwise identical value combinations).
9294

93-
Please note, the above list does not include features that do not influence the decision of the autotuner, but influence the behaviour of the kernel or the JIT. For example, the precense or details of `pre_hook` or `post_hook` and also other [`specialization_data`](https://github.com/triton-lang/triton/blob/e87f877eb94efeaeb4ad8697f315932121dec5e0/python/triton/runtime/jit.py#L514) used by the JIT cache are not used by triton-dejavu.
95+
Please note, the above list does not include features that do not influence the decision of the autotuner, but influence the behaviour of the kernel or the JIT. For example, the presence or details of `pre_hook` or `post_hook` and also other [`specialization_data`](https://github.com/triton-lang/triton/blob/e87f877eb94efeaeb4ad8697f315932121dec5e0/python/triton/runtime/jit.py#L514) used by the JIT cache are not used by triton-dejavu.
9496

9597

9698
#### Example
@@ -186,6 +188,30 @@ pip install "triton-dejavu[BO] @ file:./triton-dejavu"
186188
Please note that smac depends on [swig](https://www.swig.org), which need to be installed first.
187189

188190

191+
### JITCache
192+
193+
The launch overhead of triton kernels is a well known problem (see e.g. [1](https://github.com/triton-lang/triton/pull/3503), [2](https://github.com/triton-lang/triton/issues/2637), [3](https://github.com/triton-lang/triton/issues/6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.
194+
195+
In many scenarios, these checks can be relaxed. Such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator:
196+
197+
```
198+
@triton_dejavu.jitcache(
199+
check_keys=["x", "BLOCK_SIZE", "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"],
200+
)
201+
@triton.jit
202+
def kernel_paged_attention_...
203+
```
204+
205+
The required `check_keys` argument must provide a list of the kernel parameters marked as `tl.constexpr` that **must be checked** to select the correct kernel binary. Ideally, this is just a subset of all constant kernel parameters.
206+
For example, if we have two constant parameters A and B, but we know that A never will change in a particular application, but B will, then the list should look like `check_keys=["A"]`.
207+
208+
Consequently, *the usage of `triton_dejavu.jitcache` is application specific* (and also *experimental*).
209+
210+
Additionally, a user could provide a lock with e.g. `cache_lock=triton_dejavu.global_cache_lock` to ensure that no re-compilation happens after the cache lock is locked.
211+
212+
The `triton_dejavu.jitcache` reduces the launch overhead of triton kernels to 30-40 micro-seconds.
213+
214+
189215
Compatibility
190216
------------------
191217

triton_dejavu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# *******************************************************************************/
1616
#
1717

18-
__version__ = "0.7.2"
18+
__version__ = "0.7.3"
1919

2020

2121
from .dejavu_storage import global_dejavu_storage
2222
from .autotuner import autotune, ConfigSpace
23+
from .jit_cache import jitcache, global_cache_lock

0 commit comments

Comments
 (0)