You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The launch overhead of triton kernels is a well known problem (see e.g. [1](triton-lang/triton#3503), [2](triton-lang/triton#2637), [3](triton-lang/triton#6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.
In many scenarios, these checks can be relaxed.
This PR adds such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator:
```
@triton_dejavu.jitcache(
check_keys=["x", "BLOCK_SIZE", "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"],
)
@triton.jit
def kernel_paged_attention_....
```
Details see Readme.
---------
Signed-off-by: Burkhard Ringlein <[email protected]>
Copy file name to clipboardExpand all lines: README.md
+30-4Lines changed: 30 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,11 +3,13 @@ Triton Deja-vu
3
3
Framework to reduce autotune overhead of [triton-lang](https://github.com/triton-lang/triton) to zero for well known deployments.
4
4
5
5
This small framework is based on the [Triton autotuner](https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py) and contributes three features to the Triton community:
6
-
1. Store and safely restore autotuner states using JSON files.
6
+
1. Store and safely restore autotuner states using JSON files.
7
7
2.`ConfigSpaces` to explore a defined space exhaustively.
8
8
3. Bayesian Optimization to speed up the autotuning process.
9
9
10
-
Additionally, it allows to use heuristics in combination with the autotuner. Please find more details in the [feature section below](#features).
10
+
Additionally, it allows to use heuristics in combination with the autotuner. Please find more details in the [feature section below](#features).
11
+
12
+
Besides improvements for autotuning, it also contains useful tools in working with triton, specifically a [cache for JIT-artifacts](#jitcache).
11
13
12
14
13
15
Installation
@@ -31,7 +33,7 @@ import triton_dejavu
31
33
@triton_dejavu.autotune(
32
34
...
33
35
```
34
-
Second, the environment variable `TRITON_DEJAVU_STORAGE` needs to be set and point to a read and writable directory.
36
+
Second, the environment variable `TRITON_DEJAVU_STORAGE` needs to be set and point to a read and writable directory.
35
37
36
38
37
39
To use the `ConfigSpaces` feature, replace the `config` parameter for the triton_dejavu autotuner with `config_space` definition:
@@ -90,7 +92,7 @@ So far, we think that the above listed combination determines the applicability
90
92
91
93
In addition, users can define a tag to be used by the dejavu storage to be able to differentiate different deployment scenarios (for otherwise identical value combinations).
92
94
93
-
Please note, the above list does not include features that do not influence the decision of the autotuner, but influence the behaviour of the kernel or the JIT. For example, the precense or details of `pre_hook` or `post_hook` and also other [`specialization_data`](https://github.com/triton-lang/triton/blob/e87f877eb94efeaeb4ad8697f315932121dec5e0/python/triton/runtime/jit.py#L514) used by the JIT cache are not used by triton-dejavu.
95
+
Please note, the above list does not include features that do not influence the decision of the autotuner, but influence the behaviour of the kernel or the JIT. For example, the presence or details of `pre_hook` or `post_hook` and also other [`specialization_data`](https://github.com/triton-lang/triton/blob/e87f877eb94efeaeb4ad8697f315932121dec5e0/python/triton/runtime/jit.py#L514) used by the JIT cache are not used by triton-dejavu.
Please note that smac depends on [swig](https://www.swig.org), which need to be installed first.
187
189
188
190
191
+
### JITCache
192
+
193
+
The launch overhead of triton kernels is a well known problem (see e.g. [1](https://github.com/triton-lang/triton/pull/3503), [2](https://github.com/triton-lang/triton/issues/2637), [3](https://github.com/triton-lang/triton/issues/6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use.
194
+
195
+
In many scenarios, these checks can be relaxed. Such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator:
The required `check_keys` argument must provide a list of the kernel parameters marked as `tl.constexpr` that **must be checked** to select the correct kernel binary. Ideally, this is just a subset of all constant kernel parameters.
206
+
For example, if we have two constant parameters A and B, but we know that A never will change in a particular application, but B will, then the list should look like `check_keys=["A"]`.
207
+
208
+
Consequently, *the usage of `triton_dejavu.jitcache` is application specific* (and also *experimental*).
209
+
210
+
Additionally, a user could provide a lock with e.g. `cache_lock=triton_dejavu.global_cache_lock` to ensure that no re-compilation happens after the cache lock is locked.
211
+
212
+
The `triton_dejavu.jitcache` reduces the launch overhead of triton kernels to 30-40 micro-seconds.
0 commit comments