Skip to content

Transposing 128x128 blocks is suboptimal #7815

@danthe3rd

Description

@danthe3rd

Describe the issue

Triton is not really efficient at transposing arrays of 128x128 elements. NCU reports many shared-memory bank conflicts, and memory IO peaks at ~1700GB/s (instead of 2800GB/s+ for other shapes)

Image Image

Python repro

Python code
import torch
import triton
import triton.language as tl
from typing import List, Optional, Union
import os
def benchmark_fn(
    fn,
    name: Optional[str] = None,
    *,
    cudagraph: bool = False,
    torch_profiler: bool = False,
    warmup_repeats: int = 0,
    repeats: int = 10,
    verbose: bool = False,
    IO: Union[List[torch.Tensor], int, None] = None,
    flops: int = 0,
) -> float:
    if name is None:
        name = fn.__name__
    # warmup
    fn()
    if "PROFILING" in os.environ:
        return 0.0
    if cudagraph:
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            fn()
        def fn_graphed():
            g.replay()
        fn_graphed()
        fn = fn_graphed
    # loop
    ev_start = torch.cuda.Event(enable_timing=True)
    ev_end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    for _ in range(warmup_repeats):
        fn()
    ev_start.record()
    for _ in range(repeats):
        fn()
    ev_end.record()
    torch.cuda.synchronize()
    # (elapsed_time returns in `ms`)
    dt = ev_start.elapsed_time(ev_end) / (1000 * repeats)
    if verbose:
        if dt < 0.005:
            time_str = f"{int(dt * 10000000) / 10}us"
        elif dt < 5.0:
            time_str = f"{int(dt * 10000) / 10}ms"
        else:
            time_str = f"{int(dt * 10) / 10}s"
        IO_bytes = 0
        if isinstance(IO, list):
            IO_bytes = sum(
                x.numel() * x.dtype.itemsize for x in IO if isinstance(x, torch.Tensor)
            )
        elif isinstance(IO, (int, float)):
            IO_bytes = int(IO)
        if IO_bytes <= 0:
            bytes_per_s_str = ""
        else:
            bytes_per_s_str = f" [{int(IO_bytes / dt / (1024 ** 3))} GB/s]"
        if flops <= 0:
            flops_str = ""
        else:
            flops_str = f" [{int(flops / dt / (1000 ** 3))} GFlops]"
        print(f"{name}: {time_str}{flops_str}{bytes_per_s_str}")
    # get a trace
    if torch_profiler:
        profiler = torch.profiler.profile(
            on_trace_ready=lambda p: p.export_chrome_trace(f"benchmark_{name}.json.gz"),
            profile_memory=True,
            record_shapes=True,
            with_stack=True,
            with_flops=True,
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
        )
        torch.cuda.synchronize()
        with profiler:
            fn()
            torch.cuda.synchronize()
    return dt
def _kernel_repr(proxy):
    constants = proxy.constants
    return constants["KERNEL_NAME"]
@triton.jit(repr=_kernel_repr)
def transpose_k(in_ptr, out_ptr, BLOCK_SZ: tl.constexpr, DIM0: tl.constexpr, DIM1: tl.constexpr, TRANSPOSE_GRID: tl.constexpr, KERNEL_NAME: tl.constexpr):
    if TRANSPOSE_GRID:
        b0, b1 = tl.program_id(axis=0) * BLOCK_SZ, tl.program_id(axis=1) * BLOCK_SZ
    else:
        b1, b0 = tl.program_id(axis=0) * BLOCK_SZ, tl.program_id(axis=1) * BLOCK_SZ
    x = tl.load(in_ptr + (b0 + tl.arange(0, BLOCK_SZ)[:, None]) * DIM1 + b1 + tl.arange(0, BLOCK_SZ)[None, :])
    x = tl.trans(x)
    tl.store(out_ptr + (b0 + tl.arange(0, BLOCK_SZ)[:, None]) * DIM1 + b1 + tl.arange(0, BLOCK_SZ)[None, :], x)
def transpose(x: torch.Tensor, num_warps: int, BLOCK_SZ: int, TRANSPOSE_GRID: bool, KERNEL_NAME: str):
    y = torch.empty_like(x)
    grid = (x.shape[0] // BLOCK_SZ, x.shape[1] // BLOCK_SZ)
    transpose_k[grid](
        x,
        y,
        BLOCK_SZ=BLOCK_SZ,
        TRANSPOSE_GRID=TRANSPOSE_GRID,
        DIM0=x.shape[0],
        DIM1=x.shape[1],
        num_warps=num_warps,
        KERNEL_NAME=KERNEL_NAME,
    )
    return y
x = torch.randn([8192 * 2, 8192 * 2], device="cuda", dtype=torch.bfloat16)
for dtype_name, dtype, dtype_bytes in [
    # ("fp32", torch.float32, 4),
    ("bf16", torch.bfloat16, 2),
    ("int8", torch.int8, 1),
]:
    print(f"\n  {dtype_name}")
    x = x.to(dtype)
    IO = x.numel() * dtype_bytes * 2
    for BLOCK_SZ in [32, 64, 128]:
        for num_warps in [1, 2, 4, 8, 16]:
            for TRANSPOSE_GRID in [True, False]:
                name = f"transpose_{dtype_name}_{num_warps}_{BLOCK_SZ}_{'T' if TRANSPOSE_GRID else 'N'}"
                latency = benchmark_fn(
                    lambda: transpose(x, num_warps=num_warps, BLOCK_SZ=BLOCK_SZ, TRANSPOSE_GRID=TRANSPOSE_GRID, KERNEL_NAME=name),
                    name=name,
                    cudagraph=True,
                    IO=IO,
                    verbose=True)

Output on H100

Output
  bf16
transpose_bf16_1_32_T: 480.5us [2080 GB/s]
transpose_bf16_1_32_N: 365.0us [2739 GB/s]
transpose_bf16_2_32_T: 479.2us [2086 GB/s]
transpose_bf16_2_32_N: 366.9us [2725 GB/s]
transpose_bf16_4_32_T: 471.8us [2119 GB/s]
transpose_bf16_4_32_N: 372.8us [2681 GB/s]
transpose_bf16_8_32_T: 541.4us [1846 GB/s]
transpose_bf16_8_32_N: 421.1us [2374 GB/s]
transpose_bf16_16_32_T: 691.8us [1445 GB/s]
transpose_bf16_16_32_N: 599.4us [1668 GB/s]
transpose_bf16_1_64_T: 409.3us [2442 GB/s]
transpose_bf16_1_64_N: 374.1us [2673 GB/s]
transpose_bf16_2_64_T: 407.2us [2455 GB/s]
transpose_bf16_2_64_N: 374.4us [2670 GB/s]
transpose_bf16_4_64_T: 405.3us [2466 GB/s]
transpose_bf16_4_64_N: 363.5us [2750 GB/s]
transpose_bf16_8_64_T: 400.4us [2497 GB/s]
transpose_bf16_8_64_N: 360.0us [2777 GB/s]
transpose_bf16_16_64_T: 420.2us [2379 GB/s]
transpose_bf16_16_64_N: 400.1us [2498 GB/s]
transpose_bf16_1_128_T: 607.2us [1646 GB/s]
transpose_bf16_1_128_N: 605.9us [1650 GB/s]
transpose_bf16_2_128_T: 591.6us [1690 GB/s]
transpose_bf16_2_128_N: 592.0us [1689 GB/s]
transpose_bf16_4_128_T: 588.3us [1699 GB/s]
transpose_bf16_4_128_N: 588.4us [1699 GB/s]
transpose_bf16_8_128_T: 590.4us [1693 GB/s]
transpose_bf16_8_128_N: 590.5us [1693 GB/s]
transpose_bf16_16_128_T: 599.6us [1667 GB/s]
transpose_bf16_16_128_N: 599.6us [1667 GB/s]
  int8
transpose_int8_1_32_T: 247.6us [2019 GB/s]
transpose_int8_1_32_N: 205.0us [2438 GB/s]
transpose_int8_2_32_T: 248.0us [2015 GB/s]
transpose_int8_2_32_N: 208.2us [2401 GB/s]
transpose_int8_4_32_T: 261.4us [1912 GB/s]
transpose_int8_4_32_N: 213.1us [2345 GB/s]
transpose_int8_8_32_T: 321.6us [1554 GB/s]
transpose_int8_8_32_N: 305.5us [1636 GB/s]
transpose_int8_16_32_T: 440.7us [1134 GB/s]
transpose_int8_16_32_N: 498.8us [1002 GB/s]
transpose_int8_1_64_T: 242.4us [2062 GB/s]
transpose_int8_1_64_N: 209.3us [2388 GB/s]
transpose_int8_2_64_T: 237.3us [2106 GB/s]
transpose_int8_2_64_N: 199.1us [2510 GB/s]
transpose_int8_4_64_T: 229.9us [2174 GB/s]
transpose_int8_4_64_N: 210.1us [2378 GB/s]
transpose_int8_8_64_T: 234.7us [2129 GB/s]
transpose_int8_8_64_N: 207.8us [2405 GB/s]
transpose_int8_16_64_T: 283.8us [1761 GB/s]
transpose_int8_16_64_N: 248.2us [2013 GB/s]
transpose_int8_1_128_T: 326.0us [1533 GB/s]
transpose_int8_1_128_N: 326.8us [1529 GB/s]
transpose_int8_2_128_T: 298.7us [1673 GB/s]
transpose_int8_2_128_N: 299.4us [1669 GB/s]
transpose_int8_4_128_T: 294.7us [1696 GB/s]
transpose_int8_4_128_N: 294.6us [1696 GB/s]
transpose_int8_8_128_T: 293.6us [1702 GB/s]
transpose_int8_8_128_N: 294.0us [1700 GB/s]
transpose_int8_16_128_T: 295.6us [1691 GB/s]
transpose_int8_16_128_N: 295.8us [1690 GB/s]

Environment details

Triton: 3.4.0
GPU: H100

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions