-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[FRONTEND][RFC] Low latency kernel launching #3503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
First of all, thank you for the high-quality contribution and thorough, well-written PR description. If only every first-time PR was this good. On the OAI side, we're very much aware of this issue (we were just talking about it yesterday) and we considered doing the C++ approach that you did. Our guess was that C++ couldn't be that much faster than Python, because the C++ still has to call into the Python runtime a lot. Your experiments seem to validate this. 75us is a lot faster than 250us, but it's still very slow. The way we've solved the problem of slow kernel launches on our side is by precompiling the kernels (there's an API) and then launching them directly from Pytorch or whatever, so you don't exercise any of the slow Triton stuff. I feel like this is the only way one is going to get kernel launches that are "fast enough". It's @ptillet's final call, but my feeling is that going from 250us to 75us isn't worth the complexity, because if you care about 250us, you probably still are unhappy at 75us, so either way you probably want to precompile. |
Hi @jlebar, thanks for the quick feedback! I also feel the same -- maintaining some C++ code just to preprocess args can be difficult over time. But on the bright side, with this change, you probably can just launch This PR is mostly optimizing for small inputs, like inferencing with kv-cache (well, without continuous batching or other techniques to optimize GPU util), and for a more general audience that are definitely outside of OAI : ) A few microseconds is not that much, but if a model consists a lot of kernels, then it does make a difference. Although I don't have any data to back my claim. For the speed, please note I benchmarked pytorch I didn't have too much time to run more rigorous benchmarks. But I'll definitely run more benchmarks if the conclusion is to use this PR or use similar methods. |
Ah I see, I should be subtracting To verify this I would be interested in launching some kernels without pytorch in the way, simply by invoking a I am confused by and a little suspicious of the data showing that long_params_optimized is consistently faster than short_params_optimized. Do you have a hypothesis that can explain this? |
So do you think it's worthy of the extra complexity? Will run more benchmarks over the weekend.
I was debugging for a while but couldn't figure out. Then I just found I didn't set grad to zero for long_params_optimized. short and long are 77us and 84us. Updated data and code. |
OK I just ran some more benchmarks. Please see the updated comment for details. In summary, GPU time is 20us, CPU time is 31us and 8us before and after optimization. And that 8us can be completely hidden. |
I think the results you show are really promising. Honestly they almost seem too good to be true, maybe something fundamental is missing. I'll send this to Phil to have a look. |
If you do run more benchmarks, trying something with 50x more arguments of type tensor might be interesting. So like 64 tensor arguments, instead of the 1 or 2 tensor args that the kernels currently have (if I'm reading right). |
I spoke with Phil. We think this is promising and we're in favor of making this change in principle, so long as (a) the full-featured C++ isn't overly complicated, (b) we can delete the relevant Python code so we don't have to maintain two copies of the logic, and (c) the performance gains really are as large as they look. That said, this is going to be an invasive and high-risk change to OAI, so we'll have to do a lot of testing on our end, and we don't have bandwidth to do that immediately. I'll see if Meta has time to review and test this patch before we get to it, but to set expectations, we'll still need to test this carefully ourselves, and so it might be a while before we can merge it into mainline, maybe multiple months. (While we could keep two copies of this logic in tree while OAI tests the new fast version, that opens up a whole different can of worms, where we have to make all changes to the JIT twice, and test both versions, etc etc, so we're not in favor of that either.) I totally understand if you want to abandon this change, because this isn't a smooth path. Another option would be to create a separate project (e.g.
|
cc @peterbell10 |
This is very possible. I make a lot of dumb mistakes. But the general idea is valid: interact with python only if needed.
There are roughly three sources of overhead in parsing params:
To make it fast, we need different treatment for each kind of overhead. And if we could launch kernels directly from C++, it can be even faster. I think the
Totally agreed.
I expect this to be a decent-sized refactoring, so I expect it can be merged by end of this year. It will take me some time to get it ready for review first : ) |
Thanks for your quick response! I know this refactoring is not a priority for most industry users. So I'll do my best to make it easy to test and review. |
@liboyue Excellent work! In the part where you assemble and return a complicated Python object like so:
could you instead serialise all of this information into a single |
@apgoucher Thanks for the suggestion! I tried, but it turned out Python's hashing is not a bottleneck at all. And Python can be really fast. So I abandoned the C++ function. Please read the updated comment. |
@jlebar please see the updated comment. I think the overhead can be significantly reduced by generating a few Python functions at |
Oh excellent; in that case this should be a rather minor code change to JITFunction (no FFI for example!) to avoid the expensive |
This is really cool, and we'd love to have something like it at Meta! I have two questions/comments:
|
I was able to benchmark this approach using https://gist.github.com/bertmaher/e8869ebf5297dfc77e26d51037d21f80 and triton 2.1.0, and I'm seeing 34us launch latency. That is great! |
Hmmm, hold up a second. I did a straight translation of the cpp back to python, and I'm also getting 34 us:
Which actually lines up with @jlebar's intuition that the C++ is mostly assembling python objects anyways. I think what's really going on here is that the key computation that's in main right now does a lot more "stuff". It handles kwargs/default arguments using I do still think there's something to this approach, but it probably needs to have feature-parity with the python implementation to really evaluate. Something I'd be curious about is to actually put the cache itself in C++; that way we could assemble a key entirely with C++ types, which would probably save a good bit of pyobject creation overhead. |
Hi @bertmaher, thanks for the comment. Please read the updated comment and use the latest code. We reached the same conclusion. |
That @bertmaher I'd be interested to see how much this change alone is able to reduce the launch overhead (hopefully with your automated profiling tools it shouldn't be too much effort to determine this?). It looks like the main commonality between the two pure-Python implementations that you and @liboyue have written is that they both avoid |
This improves kernel launch latency by 2.2x (from 108us to 49us using @bertmaher's benchmarking script in issue #3619 ). Thanks also to @liboyue's analysis and suggestions. See the discussion in the third-party PR #3503 (comment)
Hi @apgoucher, I tested your optimizations. Results are great. Nit: this line does not need |
Hi @jlebar @apgoucher, please see the updated comment. I manually "generated"
The overheads of #3649 on my machine are 13us (short) and 16us (long). So if the Thanks for the discussion and quick fix. I'm closing this PR since @apgoucher is already working on it. Feel free to keep discussing in this PR and take any ideas that may help. If this PR ever helped, please link it in the release note if possible : ) |
One more note is the actual binary wrapper's launch time seemed to have regressed a bit, from 3.2us to 3.5us. Not sure it's my testing code or |
@liboyue Could you please share the contents of your manually generated In any case, I managed to shave off another 3 microseconds (from 15us to 12us) in PR #3660 by moving an import statement out of hot code, only checking the hashtable once instead of twice (using the dict
|
Oh, I see now: your function takes various shortcuts such as assuming that the types of the arguments are known beforehand (but that's not necessarily true -- for example, it's quite common for someone to replace a tensor argument with Other than that, it looks like the main advantages in your code is avoiding various function call indirection, and also by dynamically generating the body of
It absolutely did! @ptillet can you please ensure that @liboyue and @bertmaher are properly credited? Thanks! |
I included it here because the |
We can definitely know if one argument is a tensor by looking into the kernel's source code. But I agree this can be tricky. Please see This change costs 1us. Now the short kernels' overhead is around 5.3us. The
The function I tested only has 1 constexpr and 2 tensors. Adding an
The main advantage is it avoids parsing args and avoids creating unnecessary objects. 1. Generated signature of
|
The figures are automatically generated every night using the documentation workflow |
The launch overhead of triton kernels is a well known problem (see e.g. [1](triton-lang/triton#3503), [2](triton-lang/triton#2637), [3](triton-lang/triton#6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use. In many scenarios, these checks can be relaxed. This PR adds such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator: ``` @triton_dejavu.jitcache( check_keys=["x", "BLOCK_SIZE", "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"], ) @triton.jit def kernel_paged_attention_.... ``` Details see Readme. --------- Signed-off-by: Burkhard Ringlein <[email protected]>
1. Summary
This PR is a basic POC of how to reduce triton kernel launching overhead. Previous description can be found in https://github.com/liboyue/triton/tree/low-latency-jit-function/tmp/README.md
Experiments are based on
triton
2 that was shipped withPyTorch
2.2. Results show that it is possible to speed uptriton
kernel launching by around 10x (40us to 4.2us for kernels with a few params, or 62us to 6us for kernels with many params). The optimization already reached a point that I feel keep working on it won't bring too much improvements. So I'll stop here and let OpenAI team to decide whether to incorporate those changes or not. And as similar changes are already merged intotriton
main
branch, I'll close this PR.2. Sources of launching overhead
To be fair, triton runtime does not incur a very bad overhead. However, when inputs are small, like inferencing with kv-cache, a kernel can finish before launching the next kernel. What's worse is the overhead is proportional to number of arguments a kernel has (see benchmarks below).
The major source of overhead is
JITFunction.run
method. It creates a lot of python containers, and has some very expensive function calls:https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L401-L402
https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L416-L419
3. Benchmarks
The benchmarks measures run time for 3
JITFunction.run
implementations:triton
's implementation.JITFunction.run
JITFunction.run
assuming there is a type hint systemTwo additional functions are also measured:
Two running modes are measured:
kernel[grid](..., warmup=True)
, which is the Python overhead inJITFunction.run()
function.Environment:
Figures are run time for different input lengths.
3.1. Triton 2.0
Kernel launching overhead in us (data from 11 runs)
3.2. Triton 3.0
I only measured the code before and after @apgoucher 's PR. The improvements are really good.
Kernel launching overhead in us (data from 15 runs)
4. Proposed solutions
I believe the kernel launching overhead can be further reduced with the following simple optimizations.
4.1. Stronger assumptions on devices
It is very expensive to figure out which the
device_type
should be.For example, I guess it is ok for
triton
to assume no one will have NVIDIA and AMD GPUs on the same machine. Then, thedevice_type
can be cached attriton
's initialization time: if there is an NVIDIA GPU then"cuda"
else ... (idk which types are supported). Although this is not a future-proof solution, I believe it is reasonable to make some strong assumptions for now.4.2. Dynamically generate
run()
andrunner()
It is very expensive to call
signature.bind()
and pack call args. Generating these functions atjit
time can eliminate these expensive calls which can save a good amount of time. Please see https://github.com/liboyue/triton/blob/low-latency-jit-function/tmp/triton_21/low_latency_jit_python_generated.py for a manually "generated" example.For example, define a kernel as
The generated
run()
function's signature can be (type hints are useless here so omitted)In this way, Python parses params and sets default values, which is much faster than
signature.bind()
.Furthermore,
sig_key
,constexpr_key
,spec_key
, etc, can all be written explicitly as tuples ofrun()
's arguments.c_wrapper
's args can also be "hard-coded" in the same way.4.3. Improving type hints
With improved type hints, kernel definitions are more informative, so that the generated
run()
functions can rely less on Python runtime and even perform type checks. This reduces overhead and can provides some more safety if runtime type checking is enforced. Please see https://github.com/liboyue/triton/blob/low-latency-jit-function/tmp/triton_21/low_latency_jit_python_generated_with_type.py for a manually "generated" example.(see previous subsection for the example)
5.
triton
launching regressionI noticed binary wrapper overhead is ~3.2us for
triton
2, but this value regressed a bit to ~3.5us for themain
branch.