-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
Describe the issue
Triton is not really efficient at transposing arrays of 128x128 elements. NCU reports many shared-memory bank conflicts, and memory IO peaks at ~1700GB/s (instead of 2800GB/s+ for other shapes)


Python repro
Python code
import torch
import triton
import triton.language as tl
from typing import List, Optional, Union
import os
def benchmark_fn(
fn,
name: Optional[str] = None,
*,
cudagraph: bool = False,
torch_profiler: bool = False,
warmup_repeats: int = 0,
repeats: int = 10,
verbose: bool = False,
IO: Union[List[torch.Tensor], int, None] = None,
flops: int = 0,
) -> float:
if name is None:
name = fn.__name__
# warmup
fn()
if "PROFILING" in os.environ:
return 0.0
if cudagraph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
fn()
def fn_graphed():
g.replay()
fn_graphed()
fn = fn_graphed
# loop
ev_start = torch.cuda.Event(enable_timing=True)
ev_end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
for _ in range(warmup_repeats):
fn()
ev_start.record()
for _ in range(repeats):
fn()
ev_end.record()
torch.cuda.synchronize()
# (elapsed_time returns in `ms`)
dt = ev_start.elapsed_time(ev_end) / (1000 * repeats)
if verbose:
if dt < 0.005:
time_str = f"{int(dt * 10000000) / 10}us"
elif dt < 5.0:
time_str = f"{int(dt * 10000) / 10}ms"
else:
time_str = f"{int(dt * 10) / 10}s"
IO_bytes = 0
if isinstance(IO, list):
IO_bytes = sum(
x.numel() * x.dtype.itemsize for x in IO if isinstance(x, torch.Tensor)
)
elif isinstance(IO, (int, float)):
IO_bytes = int(IO)
if IO_bytes <= 0:
bytes_per_s_str = ""
else:
bytes_per_s_str = f" [{int(IO_bytes / dt / (1024 ** 3))} GB/s]"
if flops <= 0:
flops_str = ""
else:
flops_str = f" [{int(flops / dt / (1000 ** 3))} GFlops]"
print(f"{name}: {time_str}{flops_str}{bytes_per_s_str}")
# get a trace
if torch_profiler:
profiler = torch.profiler.profile(
on_trace_ready=lambda p: p.export_chrome_trace(f"benchmark_{name}.json.gz"),
profile_memory=True,
record_shapes=True,
with_stack=True,
with_flops=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
torch.cuda.synchronize()
with profiler:
fn()
torch.cuda.synchronize()
return dt
def _kernel_repr(proxy):
constants = proxy.constants
return constants["KERNEL_NAME"]
@triton.jit(repr=_kernel_repr)
def transpose_k(in_ptr, out_ptr, BLOCK_SZ: tl.constexpr, DIM0: tl.constexpr, DIM1: tl.constexpr, TRANSPOSE_GRID: tl.constexpr, KERNEL_NAME: tl.constexpr):
if TRANSPOSE_GRID:
b0, b1 = tl.program_id(axis=0) * BLOCK_SZ, tl.program_id(axis=1) * BLOCK_SZ
else:
b1, b0 = tl.program_id(axis=0) * BLOCK_SZ, tl.program_id(axis=1) * BLOCK_SZ
x = tl.load(in_ptr + (b0 + tl.arange(0, BLOCK_SZ)[:, None]) * DIM1 + b1 + tl.arange(0, BLOCK_SZ)[None, :])
x = tl.trans(x)
tl.store(out_ptr + (b0 + tl.arange(0, BLOCK_SZ)[:, None]) * DIM1 + b1 + tl.arange(0, BLOCK_SZ)[None, :], x)
def transpose(x: torch.Tensor, num_warps: int, BLOCK_SZ: int, TRANSPOSE_GRID: bool, KERNEL_NAME: str):
y = torch.empty_like(x)
grid = (x.shape[0] // BLOCK_SZ, x.shape[1] // BLOCK_SZ)
transpose_k[grid](
x,
y,
BLOCK_SZ=BLOCK_SZ,
TRANSPOSE_GRID=TRANSPOSE_GRID,
DIM0=x.shape[0],
DIM1=x.shape[1],
num_warps=num_warps,
KERNEL_NAME=KERNEL_NAME,
)
return y
x = torch.randn([8192 * 2, 8192 * 2], device="cuda", dtype=torch.bfloat16)
for dtype_name, dtype, dtype_bytes in [
# ("fp32", torch.float32, 4),
("bf16", torch.bfloat16, 2),
("int8", torch.int8, 1),
]:
print(f"\n {dtype_name}")
x = x.to(dtype)
IO = x.numel() * dtype_bytes * 2
for BLOCK_SZ in [32, 64, 128]:
for num_warps in [1, 2, 4, 8, 16]:
for TRANSPOSE_GRID in [True, False]:
name = f"transpose_{dtype_name}_{num_warps}_{BLOCK_SZ}_{'T' if TRANSPOSE_GRID else 'N'}"
latency = benchmark_fn(
lambda: transpose(x, num_warps=num_warps, BLOCK_SZ=BLOCK_SZ, TRANSPOSE_GRID=TRANSPOSE_GRID, KERNEL_NAME=name),
name=name,
cudagraph=True,
IO=IO,
verbose=True)
Output on H100
Output
bf16
transpose_bf16_1_32_T: 480.5us [2080 GB/s]
transpose_bf16_1_32_N: 365.0us [2739 GB/s]
transpose_bf16_2_32_T: 479.2us [2086 GB/s]
transpose_bf16_2_32_N: 366.9us [2725 GB/s]
transpose_bf16_4_32_T: 471.8us [2119 GB/s]
transpose_bf16_4_32_N: 372.8us [2681 GB/s]
transpose_bf16_8_32_T: 541.4us [1846 GB/s]
transpose_bf16_8_32_N: 421.1us [2374 GB/s]
transpose_bf16_16_32_T: 691.8us [1445 GB/s]
transpose_bf16_16_32_N: 599.4us [1668 GB/s]
transpose_bf16_1_64_T: 409.3us [2442 GB/s]
transpose_bf16_1_64_N: 374.1us [2673 GB/s]
transpose_bf16_2_64_T: 407.2us [2455 GB/s]
transpose_bf16_2_64_N: 374.4us [2670 GB/s]
transpose_bf16_4_64_T: 405.3us [2466 GB/s]
transpose_bf16_4_64_N: 363.5us [2750 GB/s]
transpose_bf16_8_64_T: 400.4us [2497 GB/s]
transpose_bf16_8_64_N: 360.0us [2777 GB/s]
transpose_bf16_16_64_T: 420.2us [2379 GB/s]
transpose_bf16_16_64_N: 400.1us [2498 GB/s]
transpose_bf16_1_128_T: 607.2us [1646 GB/s]
transpose_bf16_1_128_N: 605.9us [1650 GB/s]
transpose_bf16_2_128_T: 591.6us [1690 GB/s]
transpose_bf16_2_128_N: 592.0us [1689 GB/s]
transpose_bf16_4_128_T: 588.3us [1699 GB/s]
transpose_bf16_4_128_N: 588.4us [1699 GB/s]
transpose_bf16_8_128_T: 590.4us [1693 GB/s]
transpose_bf16_8_128_N: 590.5us [1693 GB/s]
transpose_bf16_16_128_T: 599.6us [1667 GB/s]
transpose_bf16_16_128_N: 599.6us [1667 GB/s]
int8
transpose_int8_1_32_T: 247.6us [2019 GB/s]
transpose_int8_1_32_N: 205.0us [2438 GB/s]
transpose_int8_2_32_T: 248.0us [2015 GB/s]
transpose_int8_2_32_N: 208.2us [2401 GB/s]
transpose_int8_4_32_T: 261.4us [1912 GB/s]
transpose_int8_4_32_N: 213.1us [2345 GB/s]
transpose_int8_8_32_T: 321.6us [1554 GB/s]
transpose_int8_8_32_N: 305.5us [1636 GB/s]
transpose_int8_16_32_T: 440.7us [1134 GB/s]
transpose_int8_16_32_N: 498.8us [1002 GB/s]
transpose_int8_1_64_T: 242.4us [2062 GB/s]
transpose_int8_1_64_N: 209.3us [2388 GB/s]
transpose_int8_2_64_T: 237.3us [2106 GB/s]
transpose_int8_2_64_N: 199.1us [2510 GB/s]
transpose_int8_4_64_T: 229.9us [2174 GB/s]
transpose_int8_4_64_N: 210.1us [2378 GB/s]
transpose_int8_8_64_T: 234.7us [2129 GB/s]
transpose_int8_8_64_N: 207.8us [2405 GB/s]
transpose_int8_16_64_T: 283.8us [1761 GB/s]
transpose_int8_16_64_N: 248.2us [2013 GB/s]
transpose_int8_1_128_T: 326.0us [1533 GB/s]
transpose_int8_1_128_N: 326.8us [1529 GB/s]
transpose_int8_2_128_T: 298.7us [1673 GB/s]
transpose_int8_2_128_N: 299.4us [1669 GB/s]
transpose_int8_4_128_T: 294.7us [1696 GB/s]
transpose_int8_4_128_N: 294.6us [1696 GB/s]
transpose_int8_8_128_T: 293.6us [1702 GB/s]
transpose_int8_8_128_N: 294.0us [1700 GB/s]
transpose_int8_16_128_T: 295.6us [1691 GB/s]
transpose_int8_16_128_N: 295.8us [1690 GB/s]
Environment details
Triton: 3.4.0
GPU: H100