Skip to content

[FRONTEND][RFC] Low latency kernel launching #3503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

liboyue
Copy link

@liboyue liboyue commented Mar 29, 2024

1. Summary

This PR is a basic POC of how to reduce triton kernel launching overhead. Previous description can be found in https://github.com/liboyue/triton/tree/low-latency-jit-function/tmp/README.md

Experiments are based on triton 2 that was shipped with PyTorch 2.2. Results show that it is possible to speed up triton kernel launching by around 10x (40us to 4.2us for kernels with a few params, or 62us to 6us for kernels with many params). The optimization already reached a point that I feel keep working on it won't bring too much improvements. So I'll stop here and let OpenAI team to decide whether to incorporate those changes or not. And as similar changes are already merged into triton main branch, I'll close this PR.

2. Sources of launching overhead

To be fair, triton runtime does not incur a very bad overhead. However, when inputs are small, like inferencing with kv-cache, a kernel can finish before launching the next kernel. What's worse is the overhead is proportional to number of arguments a kernel has (see benchmarks below).

The major source of overhead is JITFunction.run method. It creates a lot of python containers, and has some very expensive function calls:
https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L401-L402

https://github.com/openai/triton/blob/f5722cb9d8a8ce6211c77b85c340391e3e9c78e0/python/triton/runtime/jit.py#L416-L419

3. Benchmarks

The benchmarks measures run time for 3 JITFunction.run implementations:

  • default: triton's implementation.
  • python: an optimized python implementation.
  • python_generated: manually "generated" JITFunction.run
  • python_generated_with_type: manually "generated" JITFunction.run assuming there is a type hint system

Two additional functions are also measured:

  • noop: an empty op, to measure the resolution of benchmarks
  • kernel: the actual CUDA kernel with c wrapper. This is as close as possilbe to run bare kernels.

Two running modes are measured:

  • warmup: calling kernels with kernel[grid](..., warmup=True), which is the Python overhead in JITFunction.run() function.
  • (empty): calling kernels normally.

Environment:

  • CPU: Xeon 6154 @ 3.00GHz
  • GPU: RTX 2080 SUPER
  • CUDA: 12.1
  • PyTorch: 2.2.2

Figures are run time for different input lengths.

3.1. Triton 2.0

Kernel launching overhead in us (data from 11 runs)

default_short_warmup default_long_warmup python_generated_short_warmup python_generated_long_warmup python_generated_with_type_short_warmup python_generated_with_type_long_warmup python_short_warmup python_long_warmup
mean 40.3394 62.7914 4.22811 6.07617 3.54631 3.84094 13.8079 24.3178
std 0.11303 0.188443 0.0353561 0.0412608 0.0219647 0.0325193 0.11882 0.134266

kernel_time_triton_2

3.2. Triton 3.0

I only measured the code before and after @apgoucher 's PR. The improvements are really good.

Kernel launching overhead in us (data from 15 runs)

default_short_warmup default_long_warmup old_short_warmup old_long_warmup
mean 12.9805 15.8931 99.1191 155.938
std 0.0979331 0.123674 0.438825 0.66215

kernel_time_triton_3

4. Proposed solutions

I believe the kernel launching overhead can be further reduced with the following simple optimizations.

4.1. Stronger assumptions on devices

It is very expensive to figure out which the device_type should be.

For example, I guess it is ok for triton to assume no one will have NVIDIA and AMD GPUs on the same machine. Then, the device_type can be cached at triton's initialization time: if there is an NVIDIA GPU then "cuda" else ... (idk which types are supported). Although this is not a future-proof solution, I believe it is reasonable to make some strong assumptions for now.

4.2. Dynamically generate run() and runner()

It is very expensive to call signature.bind() and pack call args. Generating these functions at jit time can eliminate these expensive calls which can save a good amount of time. Please see https://github.com/liboyue/triton/blob/low-latency-jit-function/tmp/triton_21/low_latency_jit_python_generated.py for a manually "generated" example.

For example, define a kernel as

@triton.jit
def kernel(
    a,
    b: float,
    c: tl.int,
    d: tl.tensor,
    e: tl.tensor[tl.float32],
    NUM_BLOCKS: tl.constexpr[tl.int] = 10
):
    ...

The generated run() function's signature can be (type hints are useless here so omitted)

def run(
    self,
    a,
    b,
    c,
    d,
    e,
    NUM_BLOCKS=10,
    *,
    grid=None,
    num_warps=None,
    num_ctas=1,
    num_stages=None,
    enable_warp_specialization=False,
    enable_fp_fusion=True,
    extern_libs=None,
    stream=None,
    warmup=False,
    device=None,
    device_type=None
):

    assert type(b) == float
    assert torch.is_tensor(d)
    assert torch.is_tensor(e)
    assert e.dtype == torch.float32

In this way, Python parses params and sets default values, which is much faster than signature.bind().

Furthermore, sig_key, constexpr_key, spec_key, etc, can all be written explicitly as tuples of run()'s arguments. c_wrapper's args can also be "hard-coded" in the same way.

4.3. Improving type hints

With improved type hints, kernel definitions are more informative, so that the generated run() functions can rely less on Python runtime and even perform type checks. This reduces overhead and can provides some more safety if runtime type checking is enforced. Please see https://github.com/liboyue/triton/blob/low-latency-jit-function/tmp/triton_21/low_latency_jit_python_generated_with_type.py for a manually "generated" example.

(see previous subsection for the example)

5. triton launching regression

I noticed binary wrapper overhead is ~3.2us for triton 2, but this value regressed a bit to ~3.5us for the main branch.

@liboyue liboyue requested a review from ptillet as a code owner March 29, 2024 01:17
@liboyue liboyue marked this pull request as draft March 29, 2024 01:18
@liboyue liboyue changed the title [FRONTEND][POC] Low latency kernel launching [FRONTEND][RFC] Low latency kernel launching Mar 29, 2024
@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

First of all, thank you for the high-quality contribution and thorough, well-written PR description. If only every first-time PR was this good.

On the OAI side, we're very much aware of this issue (we were just talking about it yesterday) and we considered doing the C++ approach that you did. Our guess was that C++ couldn't be that much faster than Python, because the C++ still has to call into the Python runtime a lot.

Your experiments seem to validate this. 75us is a lot faster than 250us, but it's still very slow.

The way we've solved the problem of slow kernel launches on our side is by precompiling the kernels (there's an API) and then launching them directly from Pytorch or whatever, so you don't exercise any of the slow Triton stuff. I feel like this is the only way one is going to get kernel launches that are "fast enough".

It's @ptillet's final call, but my feeling is that going from 250us to 75us isn't worth the complexity, because if you care about 250us, you probably still are unhappy at 75us, so either way you probably want to precompile.

@liboyue
Copy link
Author

liboyue commented Mar 29, 2024

First of all, thank you for the high-quality contribution and thorough, well-written PR description. If only every first-time PR was this good.

On the OAI side, we're very much aware of this issue (we were just talking about it yesterday) and we considered doing the C++ approach that you did. Our guess was that C++ couldn't be that much faster than Python, because the C++ still has to call into the Python runtime a lot.

Your experiments seem to validate this. 75us is a lot faster than 250us, but it's still very slow.

The way we've solved the problem of slow kernel launches on our side is by precompiling the kernels (there's an API) and then launching them directly from Pytorch or whatever, so you don't exercise any of the slow Triton stuff. I feel like this is the only way one is going to get kernel launches that are "fast enough".

It's @ptillet's final call, but my feeling is that going from 250us to 75us isn't worth the complexity, because if you care about 250us, you probably still are unhappy at 75us, so either way you probably want to precompile.

Hi @jlebar, thanks for the quick feedback! I also feel the same -- maintaining some C++ code just to preprocess args can be difficult over time. But on the bright side, with this change, you probably can just launch triton kernels however you like without much overhead.

This PR is mostly optimizing for small inputs, like inferencing with kv-cache (well, without continuous batching or other techniques to optimize GPU util), and for a more general audience that are definitely outside of OAI : ) A few microseconds is not that much, but if a model consists a lot of kernels, then it does make a difference. Although I don't have any data to back my claim.

For the speed, please note I benchmarked pytorch autograd.Function. So running function.apply(x).backward(dy) costs 70us, not the kernel launching itself. Comparing short_params vs baseline (184us vs 73us), the extra overhead is around 110us for two consecutive launches. But comparing the optimized time (73us or 76us vs. 73us), I would say the kernel launching overhead is a few microseconds or less. The rest time is used by pytorch.

I didn't have too much time to run more rigorous benchmarks. But I'll definitely run more benchmarks if the conclusion is to use this PR or use similar methods.

@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

Ah I see, I should be subtracting baseline to get the cost here, so the claim is that the cost for Triton is down to less than 5us, which, if true, is a totally different ballgame.

To verify this I would be interested in launching some kernels without pytorch in the way, simply by invoking a @jit'ed function (as shown in the Triton tutorials).

I am confused by and a little suspicious of the data showing that long_params_optimized is consistently faster than short_params_optimized. Do you have a hypothesis that can explain this?

@liboyue
Copy link
Author

liboyue commented Mar 29, 2024

Ah I see, I should be subtracting baseline to get the cost here, so the claim is that the cost for Triton is down to less than 5us, which, if true, is a totally different ballgame.

To verify this I would be interested in launching some kernels without pytorch in the way, simply by invoking a @jit'ed function (as shown in the Triton tutorials).

So do you think it's worthy of the extra complexity? Will run more benchmarks over the weekend.

I am confused by and a little suspicious of the data showing that long_params_optimized is consistently faster than short_params_optimized. Do you have a hypothesis that can explain this?

I was debugging for a while but couldn't figure out. Then I just found I didn't set grad to zero for long_params_optimized. short and long are 77us and 84us. Updated data and code.

@liboyue
Copy link
Author

liboyue commented Mar 29, 2024

Ah I see, I should be subtracting baseline to get the cost here, so the claim is that the cost for Triton is down to less than 5us, which, if true, is a totally different ballgame.

To verify this I would be interested in launching some kernels without pytorch in the way, simply by invoking a @jit'ed function (as shown in the Triton tutorials).

I am confused by and a little suspicious of the data showing that long_params_optimized is consistently faster than short_params_optimized. Do you have a hypothesis that can explain this?

OK I just ran some more benchmarks. Please see the updated comment for details.

In summary, GPU time is 20us, CPU time is 31us and 8us before and after optimization. And that 8us can be completely hidden.

@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

So do you think it's worthy of the extra complexity?

I think the results you show are really promising. Honestly they almost seem too good to be true, maybe something fundamental is missing.

I'll send this to Phil to have a look.

@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

If you do run more benchmarks, trying something with 50x more arguments of type tensor might be interesting. So like 64 tensor arguments, instead of the 1 or 2 tensor args that the kernels currently have (if I'm reading right).

@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

I spoke with Phil. We think this is promising and we're in favor of making this change in principle, so long as (a) the full-featured C++ isn't overly complicated, (b) we can delete the relevant Python code so we don't have to maintain two copies of the logic, and (c) the performance gains really are as large as they look.

That said, this is going to be an invasive and high-risk change to OAI, so we'll have to do a lot of testing on our end, and we don't have bandwidth to do that immediately. I'll see if Meta has time to review and test this patch before we get to it, but to set expectations, we'll still need to test this carefully ourselves, and so it might be a while before we can merge it into mainline, maybe multiple months. (While we could keep two copies of this logic in tree while OAI tests the new fast version, that opens up a whole different can of worms, where we have to make all changes to the JIT twice, and test both versions, etc etc, so we're not in favor of that either.)

I totally understand if you want to abandon this change, because this isn't a smooth path. Another option would be to create a separate project (e.g. triton-fast-jit) so that one can do

@fast_triton_jit.jit()
def kernel...

@jlebar
Copy link
Collaborator

jlebar commented Mar 29, 2024

cc @peterbell10

@liboyue
Copy link
Author

liboyue commented Mar 29, 2024

I think the results you show are really promising. Honestly they almost seem too good to be true, maybe something fundamental is missing.

This is very possible. I make a lot of dumb mistakes. But the general idea is valid: interact with python only if needed.

If you do run more benchmarks, trying something with 50x more arguments of type tensor might be interesting. So like 64 tensor arguments, instead of the 1 or 2 tensor args that the kernels currently have (if I'm reading right).

There are roughly three sources of overhead in parsing params:

  1. Figuring out python type of args
  2. Reading args values from python.
  3. Figuring out the dtype of torch tensors.

To make it fast, we need different treatment for each kind of overhead.
For 1, we need more type hint. Like "tl.tensor", etc. So we know the python type at initialization time. Then we get rid of 1.
For 2, I don't have a good solution now.
For 3, we can get dtype from torch cpp backend once we have a reference to the tensor. This won't be that bad (I guess).

And if we could launch kernels directly from C++, it can be even faster. I think the bin.c_wrapper call is expensive. And packing args is also expensive. And the kernel cache is also expensive because we are sending a lot of python objects back and forth.

(a) the full-featured C++ isn't overly complicated, (b) we can delete the relevant Python code so we don't have to maintain two copies of the logic, and (c) the performance gains really are as large as they look.

Totally agreed.

it might be a while before we can merge it into mainline, maybe multiple months.

I expect this to be a decent-sized refactoring, so I expect it can be merged by end of this year. It will take me some time to get it ready for review first : )

@liboyue
Copy link
Author

liboyue commented Mar 29, 2024

Thanks for your quick response! I know this refactoring is not a priority for most industry users. So I'll do my best to make it easy to test and review.

@apgoucher
Copy link
Collaborator

apgoucher commented Mar 29, 2024

@liboyue Excellent work!

In the part where you assemble and return a complicated Python object like so:

    return py::make_tuple(py::make_tuple(cuda_version_key_,
                                         py::tuple(py::cast(sig_key)),
                                         py::tuple(py::cast(constexpr_key)),
                                         py::tuple(py::cast(spec_key))),
                          py::tuple(py::cast(non_constexpr_arg_values)));

could you instead serialise all of this information into a single std::string (which then gets converted into a Python string or bytes object)? Python dictionaries are more performant when the keys are strings (fast hashing, fast comparison, no pointer indirection, etc.) rather than nested tuples.

@liboyue
Copy link
Author

liboyue commented Apr 10, 2024

@liboyue Excellent work!

In the part where you assemble and return a complicated Python object like so:

    return py::make_tuple(py::make_tuple(cuda_version_key_,
                                         py::tuple(py::cast(sig_key)),
                                         py::tuple(py::cast(constexpr_key)),
                                         py::tuple(py::cast(spec_key))),
                          py::tuple(py::cast(non_constexpr_arg_values)));

could you instead serialise all of this information into a single std::string (which then gets converted into a Python string or bytes object)? Python dictionaries are more performant when the keys are strings (fast hashing, fast comparison, no pointer indirection, etc.) rather than nested tuples.

@apgoucher Thanks for the suggestion! I tried, but it turned out Python's hashing is not a bottleneck at all. And Python can be really fast. So I abandoned the C++ function. Please read the updated comment.

@liboyue
Copy link
Author

liboyue commented Apr 10, 2024

@jlebar please see the updated comment. I think the overhead can be significantly reduced by generating a few Python functions at triton.jit calls.

@apgoucher
Copy link
Collaborator

@jlebar please see the updated comment. I think the overhead can be significantly reduced by generating a few Python functions at triton.jit calls.

Oh excellent; in that case this should be a rather minor code change to JITFunction (no FFI for example!) to avoid the expensive bind call?

@bertmaher
Copy link
Collaborator

This is really cool, and we'd love to have something like it at Meta! I have two questions/comments:

  • it seems like the base revision is wrong since low_latency_jit tries to import some stuff that doesn't exist at that revision. I want to test this myself, so can you let me know which triton I should use as a base?
  • I'm skeptical of using do_bench to measure launch latency, since do_bench uses cudaEvents and does some gpu cache flushing, which can muddy the measurements of on-cpu launch latency. I think wall time using time.perf_counter() is probably a better measure for this sort of work

@bertmaher
Copy link
Collaborator

I was able to benchmark this approach using https://gist.github.com/bertmaher/e8869ebf5297dfc77e26d51037d21f80 and triton 2.1.0, and I'm seeing 34us launch latency. That is great!

@bertmaher
Copy link
Collaborator

Hmmm, hold up a second. I did a straight translation of the cpp back to python, and I'm also getting 34 us:

def specialization_key(v):
    if isinstance(v, int):
        return (v % 16 == 0, v % 8 == 0, v == 1)
    if hasattr(v, "data_ptr"):
        return (v.data_ptr() % 16 == 0,)
    return (False,)


def signature_key(v):
    if isinstance(v, int):
        if v >= -(1 << 31) and v <= ((1<<31) - 1):
            return "i32"
        elif v >= 0:
            return "u64"
        else:
            return "i64"
    if hasattr(v, "dtype"):
        return v.dtype
    if isinstance(v, float):
        return "fp32"

class LowLatencyJITFunction(_JITFunction):
    @staticmethod
    def _pinned_memory_of(arg):
        if hasattr(arg, "is_pinned"):
            return arg.is_pinned()
        return False

    def run(
        self,
        *args,
        grid=None,
        num_warps=None,
        num_ctas=1,
        num_stages=None,
        enable_warp_specialization=False,
        enable_fp_fusion=True,
        extern_libs=None,
        stream=None,
        warmup=False,
        device="cuda",
        device_type="cuda",
        **kwargs,
    ):
        spec_key = tuple(specialization_key(v) for p, v in zip(self.params, args) if not p.do_not_specialize)
        sig_key = tuple(signature_key(v) for p, v in zip(self.params, args) if not p.is_constexpr)
        constexpr_key = tuple(v for p, v in zip(self.params, args) if p.is_constexpr)
        non_constexpr_arg_values = tuple(v for p, v in zip(self.params, args) if not p.is_constexpr)

Which actually lines up with @jlebar's intuition that the C++ is mostly assembling python objects anyways. I think what's really going on here is that the key computation that's in main right now does a lot more "stuff". It handles kwargs/default arguments using inspect.Signature.bind, it allows backends to assemble custom options, it checks the signature against a cached signature, probably some other stuff too.

I do still think there's something to this approach, but it probably needs to have feature-parity with the python implementation to really evaluate. Something I'd be curious about is to actually put the cache itself in C++; that way we could assemble a key entirely with C++ types, which would probably save a good bit of pyobject creation overhead.

@liboyue
Copy link
Author

liboyue commented Apr 10, 2024

Hi @bertmaher, thanks for the comment. Please read the updated comment and use the latest code. We reached the same conclusion.

@apgoucher
Copy link
Collaborator

Which actually lines up with @jlebar's intuition that the C++ is mostly assembling python objects anyways. I think what's really going on here is that the key computation that's in main right now does a lot more "stuff". It handles kwargs/default arguments using inspect.Signature.bind, it allows backends to assemble custom options, it checks the signature against a cached signature, probably some other stuff too.

That inspect.Signature.bind call (and possibly the function that populates it with defaults) contains a lot of overhead. I can reach feature-parity with that more efficiently by memoizing a native Python function that assembles the dict from its own args and kwargs -- essentially, we're using Python's builtin support for handling args/kwargs with defaults instead of using inspect.Signature.bind (which, if you delve into the source code, is really convoluted). I've created a new PR and it passes all of the correctness tests in CI:

#3638

@bertmaher I'd be interested to see how much this change alone is able to reduce the launch overhead (hopefully with your automated profiling tools it shouldn't be too much effort to determine this?). It looks like the main commonality between the two pure-Python implementations that you and @liboyue have written is that they both avoid inspect.Signature.bind. I see that the latter also has some other optimisations such as memoizing the list of indices to the constexpr, non-constexpr, and specialized arguments -- this could also be done here if that helps.

apgoucher added a commit that referenced this pull request Apr 12, 2024
This improves kernel launch latency by 2.2x (from 108us to 49us using
@bertmaher's benchmarking script in issue
#3619 ). Thanks also to
@liboyue's analysis and suggestions.

See the discussion in the third-party PR
#3503 (comment)
@liboyue
Copy link
Author

liboyue commented Apr 15, 2024

Hi @apgoucher, I tested your optimizations. Results are great.

Nit: this line does not need dict()
https://github.com/openai/triton/blob/a0de891779975492a63f307f315b3063af2238f0/python/triton/runtime/jit.py#L486

@liboyue
Copy link
Author

liboyue commented Apr 15, 2024

Hi @jlebar @apgoucher, please see the updated comment. I manually "generated" run() functions with and without type hints. The launching latency can be reduced to around

  • 4.2us (short) and 6.2us (long) without type hints
  • 3.6us (short) and 3.9us (long) with type hints

The overheads of #3649 on my machine are 13us (short) and 16us (long). So if the run() function is dynamically generated, there will still be a large improvement.

Thanks for the discussion and quick fix. I'm closing this PR since @apgoucher is already working on it. Feel free to keep discussing in this PR and take any ideas that may help. If this PR ever helped, please link it in the release note if possible : )

@liboyue
Copy link
Author

liboyue commented Apr 15, 2024

One more note is the actual binary wrapper's launch time seemed to have regressed a bit, from 3.2us to 3.5us. Not sure it's my testing code or triton, but maybe worth of looking into.

@liboyue liboyue closed this Apr 15, 2024
@apgoucher
Copy link
Collaborator

Hi @jlebar @apgoucher, please see the updated comment. I manually "generated" run() functions with and without type hints. The launching latency can be reduced to around

  • 4.2us (short) and 6.2us (long) without type hints
  • 3.6us (short) and 3.9us (long) with type hints

The overheads of #3649 on my machine are 13us (short) and 16us (long). So if the run() function is dynamically generated, there will still be a large improvement.

@liboyue Could you please share the contents of your manually generated run()? According to the profiler, in my code only 17% of the time is being spent inside run() itself -- rather than other functions being called from within run() -- so I'm confused where your further 3x speedup comes from?

In any case, I managed to shave off another 3 microseconds (from 15us to 12us) in PR #3660 by moving an import statement out of hot code, only checking the hashtable once instead of twice (using the dict .get method instead of checking presence followed by retrieving), and moving get_current_target out of hot code.

Thanks for the discussion and quick fix. I'm closing this PR since @apgoucher is already working on it. Feel free to keep discussing in this PR and take any ideas that may help. If this PR ever helped, please link it in the release note if possible : )

@apgoucher
Copy link
Collaborator

Oh, I see now: your function takes various shortcuts such as assuming that the types of the arguments are known beforehand (but that's not necessarily true -- for example, it's quite common for someone to replace a tensor argument with None, so x.dtype would fail) and you're not stringifying the constexpr arguments when computing the cache key (which in particular would make 0.0 compare as being equal to 0, even though they carry different type information so result in different generated code).

Other than that, it looks like the main advantages in your code is avoiding various function call indirection, and also by dynamically generating the body of run() instead of just my binder() you're avoiding constructing and returning various containers (e.g. tuples).

If this PR ever helped, please link it in the release note if possible : )

It absolutely did! @ptillet can you please ensure that @liboyue and @bertmaher are properly credited? Thanks!

@apgoucher
Copy link
Collaborator

Hi @apgoucher, I tested your optimizations. Results are great.

Nit: this line does not need dict()

https://github.com/openai/triton/blob/a0de891779975492a63f307f315b3063af2238f0/python/triton/runtime/jit.py#L486

I included it here because the grid() callable is user-specified so may modify bound_args (e.g. by calling .pop()) which would cause issues because we later call bound_args.values(). But moving the grid canonicalisation code later in the run function (in particular, after we extract the values) means that we can indeed make your simplification, so I've applied it in the latest PR.

@liboyue
Copy link
Author

liboyue commented Apr 16, 2024

Oh, I see now: your function takes various shortcuts such as assuming that the types of the arguments are known beforehand

We can definitely know if one argument is a tensor by looking into the kernel's source code. But I agree this can be tricky. Please see LowLatencyJITFunctionPythonGeneratedShort in https://github.com/liboyue/triton/blob/low-latency-jit-function/tmp/triton_21/low_latency_jit_python_generated.py for an example that does not assume dtype and check for None.

This change costs 1us. Now the short kernels' overhead is around 5.3us. The main branch's overhead is around 7.9us (0845e65).

(but that's not necessarily true -- for example, it's quite common for someone to replace a tensor argument with None, so x.dtype would fail) and you're not stringifying the constexpr arguments when computing the cache key (which in particular would make 0.0 compare as being equal to 0, even though they carry different type information so result in different generated code).

The function I tested only has 1 constexpr and 2 tensors. Adding an is None check won't cost a few microseconds. Converting 1 constexpr into string also will not be that expensive. And triton 2 does not convert constexpr to strings. The experiment I wrote use a default triton kernel to compile, so the keys I generated have to match triton 2's behavior.

Other than that, it looks like the main advantages in your code is avoiding various function call indirection, and also by dynamically generating the body of run() instead of just my binder() you're avoiding constructing and returning various containers (e.g. tuples).

The main advantage is it avoids parsing args and avoids creating unnecessary objects.

1. Generated signature of run()

    def run(
        self,
        dy,
        a,
        output,
        N,
        M_STRIDE,
        BLOCK_SIZE,
        *,
        grid=None,
        num_warps=DEFAULT_NUM_WARPS,
        num_ctas=1,
        num_stages=DEFAULT_NUM_STAGES,
        enable_warp_specialization=False,
        enable_fp_fusion=True,
        extern_libs=None,
        stream=None,
        warmup=False,
        device=None,
        device_type="cuda",
    )

To check if device is set, simply do if device is not None. The current main does this check only at compiling time, but this may lead to silent failure: if the same kernel is used in multiple functions, and someone refactored only the first function, the code still works but in fact it might give unexpected outputs.
https://github.com/openai/triton/blob/0845e65a5432f97b797c0c4dcd71dd546cb71c9d/python/triton/runtime/jit.py#L495-L497

The * in the function signature is quite important. It guarantees no positional args can overwrite kwargs. Again, if someone refactored one kernel to reduce the number of args by 10, but forgot to refactor the rest 100 calls to this kernel, the extra 10 args will slip into excess_kwargs.

2. Unnecessary things

2.1. binder() creates one dict and returns a tuple

def dynamic_func(x_ptr, a, output_ptr, N, M_STRIDE, BLOCK_SIZE_N, **excess_kwargs):
    return {'x_ptr': x_ptr, 'a': a, 'output_ptr': output_ptr, 'N': N, 'M_STRIDE': M_STRIDE, 'BLOCK_SIZE_N': BLOCK_SIZE_N}, (mangle_type(x_ptr, False), mangle_type(a, False), mangle_type(output_ptr, False), mangle_type(N, False), mangle_type(M_STRIDE, False), compute_spec_key(x_ptr), compute_spec_key(a), compute_spec_key(output_ptr), compute_spec_key(N), compute_spec_key(M_STRIDE), ), (BLOCK_SIZE_N, ), (x_ptr, a, output_ptr, N, M_STRIDE, ), excess_kwargs

Creating a dict is expensive: you create both keys and values. Returning a tuple should be fine.

2.2. This can be avoided if run() is generated

https://github.com/openai/triton/blob/0845e65a5432f97b797c0c4dcd71dd546cb71c9d/python/triton/runtime/jit.py#L473

3. One question

Is serialize everything into strings faster? In which scenarios (how many arguments)?

https://github.com/openai/triton/blob/0845e65a5432f97b797c0c4dcd71dd546cb71c9d/python/triton/runtime/jit.py#L485

4. Other possible optimizations

  1. cache is a two-level dict. In fact, adding the device id to hash key and making cache a one-level dict will improve a tiny bit of performance.
  2. Create a large enum of all dtypes. In another word, it is possible to convert all dtypes (primitive and torch dtypes) into an integer. This may be able to improve another tiny bit of perf because hash of small integers are themselves.

Overall I think the current main branch is already fast enough. If my suggestions to dynamically generate run() function and create a stronger typing system are accepted, I think it (including the optimizations already merged into main) would almost be the best effort we can make at Python level. If @apgoucher would work on further optimizations, I would suggest use only one PR so it is easier for you and others to track these changes.

And, maybe re-run all figures in tutorials?

@Jokeren
Copy link
Contributor

Jokeren commented Apr 16, 2024

And, maybe re-run all figures in tutorials?

The figures are automatically generated every night using the documentation workflow

bringlein added a commit to IBM/triton-dejavu that referenced this pull request Apr 10, 2025
The launch overhead of triton kernels is a well known problem (see e.g. [1](triton-lang/triton#3503), [2](triton-lang/triton#2637), [3](triton-lang/triton#6064)). Parts of the launch overhead comes from the fact that the triton JIT checks very carefully if an existing binary is safe to use. 

In many scenarios, these checks can be relaxed. 
This PR adds such a cache with relaxed checks is implemented by `triton_dejavu.jitcache`. It is implemented as a decorator that could be used in front of the `triton.jit` decorator: 

```
@triton_dejavu.jitcache(
    check_keys=["x", "BLOCK_SIZE", "USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"],
)
@triton.jit
def kernel_paged_attention_.... 
```

Details see Readme. 
---------

Signed-off-by: Burkhard Ringlein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants