|
| 1 | +from collections import namedtuple |
| 2 | +from functools import partial |
| 3 | +import math |
| 4 | +import os |
| 5 | +from typing import NamedTuple |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +import torch.nn.functional as F |
| 9 | + |
| 10 | +import time |
| 11 | + |
| 12 | +try: |
| 13 | + import cudnn |
| 14 | +except ImportError: |
| 15 | + cudnn = None |
| 16 | +# cudnn = None |
| 17 | + |
| 18 | +Timing = NamedTuple('timing', [('mean', float)]) |
| 19 | + |
| 20 | + |
| 21 | +from einops import rearrange, repeat |
| 22 | + |
| 23 | +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler |
| 24 | +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler |
| 25 | +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func |
| 26 | +from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python |
| 27 | +from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python |
| 28 | +try: |
| 29 | + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 |
| 30 | +except ImportError: |
| 31 | + flash_attn_func_v3 = None |
| 32 | + |
| 33 | +if torch.cuda.get_device_capability()[0] != 9: |
| 34 | + flash_attn_func_v3 = None |
| 35 | +# flash_attn_func_v3 = None |
| 36 | + |
| 37 | +flash_attn_func = None |
| 38 | + |
| 39 | +from triton.testing import do_bench |
| 40 | + |
| 41 | +def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): |
| 42 | + # # Warmup |
| 43 | + # for _ in range(5): |
| 44 | + # func(*args, **kwargs) |
| 45 | + # time.sleep(1) |
| 46 | + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] |
| 47 | + # s = torch.cuda.Stream() |
| 48 | + # s.wait_stream(torch.cuda.current_stream()) |
| 49 | + # with torch.cuda.stream(s): |
| 50 | + # for _ in range(2): |
| 51 | + # out = func(*args, **kwargs) |
| 52 | + # torch.cuda.current_stream().wait_stream(s) |
| 53 | + # graph = torch.cuda.CUDAGraph() |
| 54 | + # with torch.cuda.graph(graph): |
| 55 | + # out = func(*args, **kwargs) |
| 56 | + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) |
| 57 | + # # return time_f[1].mean |
| 58 | + # return time_f[1] |
| 59 | + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) |
| 60 | + |
| 61 | + |
| 62 | +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)): |
| 63 | + if causal: |
| 64 | + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 |
| 65 | + else: |
| 66 | + if window_size == (None, None): |
| 67 | + avg_seqlen = seqlen_k |
| 68 | + else: |
| 69 | + row_idx = torch.arange(seqlen_q, device='cuda') |
| 70 | + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) |
| 71 | + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) |
| 72 | + avg_seqlen = (col_right - col_left + 1).float().mean().item() |
| 73 | + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) |
| 74 | + |
| 75 | + |
| 76 | +def convert_to_cudnn_type(torch_type): |
| 77 | + if torch_type == torch.float16: |
| 78 | + return cudnn.data_type.HALF |
| 79 | + elif torch_type == torch.bfloat16: |
| 80 | + return cudnn.data_type.BFLOAT16 |
| 81 | + elif torch_type == torch.float32: |
| 82 | + return cudnn.data_type.FLOAT |
| 83 | + elif torch_type == torch.int32: |
| 84 | + return cudnn.data_type.INT32 |
| 85 | + elif torch_type == torch.int64: |
| 86 | + return cudnn.data_type.INT64 |
| 87 | + else: |
| 88 | + raise ValueError("Unsupported tensor data type.") |
| 89 | + |
| 90 | + |
| 91 | +def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): |
| 92 | + b, nheads, seqlen_q, headdim = q.shape |
| 93 | + _, nheads_k, seqlen_k, _ = k.shape |
| 94 | + assert v.shape == (b, nheads_k, seqlen_k, headdim) |
| 95 | + assert cudnn is not None, 'CUDNN is not available' |
| 96 | + q_gpu, k_gpu, v_gpu = q, k, v |
| 97 | + o_gpu = torch.empty_like(q_gpu) |
| 98 | + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) |
| 99 | + graph = cudnn.pygraph( |
| 100 | + io_data_type=convert_to_cudnn_type(q.dtype), |
| 101 | + intermediate_data_type=cudnn.data_type.FLOAT, |
| 102 | + compute_data_type=cudnn.data_type.FLOAT, |
| 103 | + ) |
| 104 | + q = graph.tensor_like(q_gpu.detach()) |
| 105 | + k = graph.tensor_like(k_gpu.detach()) |
| 106 | + v = graph.tensor_like(v_gpu.detach()) |
| 107 | + |
| 108 | + o, stats = graph.sdpa( |
| 109 | + name="sdpa", |
| 110 | + q=q, |
| 111 | + k=k, |
| 112 | + v=v, |
| 113 | + is_inference=False, |
| 114 | + attn_scale=1.0 / math.sqrt(headdim), |
| 115 | + # use_causal_mask_bottom_right=causal or window_size_left is not None, |
| 116 | + use_causal_mask=causal or window_size_left is not None, |
| 117 | + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, |
| 118 | + ) |
| 119 | + |
| 120 | + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) |
| 121 | + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) |
| 122 | + |
| 123 | + graph.validate() |
| 124 | + graph.build_operation_graph() |
| 125 | + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) |
| 126 | + graph.check_support() |
| 127 | + graph.build_plans() |
| 128 | + |
| 129 | + variant_pack = { |
| 130 | + q: q_gpu, |
| 131 | + k: k_gpu, |
| 132 | + v: v_gpu, |
| 133 | + o: o_gpu, |
| 134 | + stats: stats_gpu, |
| 135 | + } |
| 136 | + |
| 137 | + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) |
| 138 | + |
| 139 | + def run(*args, **kwargs): |
| 140 | + graph.execute(variant_pack, workspace) |
| 141 | + return o_gpu |
| 142 | + |
| 143 | + return run |
| 144 | + |
| 145 | + |
| 146 | +def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): |
| 147 | + b, nheads, seqlen_q, headdim = q.shape |
| 148 | + _, nheads_k, seqlen_k, _ = k.shape |
| 149 | + assert v.shape == (b, nheads_k, seqlen_k, headdim) |
| 150 | + assert g.shape == (b, nheads, seqlen_q, headdim) |
| 151 | + assert o.shape == (b, nheads, seqlen_q, headdim) |
| 152 | + assert lse.shape == (b, nheads, seqlen_q, 1) |
| 153 | + assert cudnn is not None, 'CUDNN is not available' |
| 154 | + q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g |
| 155 | + dq_gpu = torch.empty_like(q_gpu) |
| 156 | + dk_gpu = torch.empty_like(k_gpu) |
| 157 | + dv_gpu = torch.empty_like(v_gpu) |
| 158 | + graph = cudnn.pygraph( |
| 159 | + io_data_type=convert_to_cudnn_type(q.dtype), |
| 160 | + intermediate_data_type=cudnn.data_type.FLOAT, |
| 161 | + compute_data_type=cudnn.data_type.FLOAT, |
| 162 | + ) |
| 163 | + q = graph.tensor_like(q_gpu.detach()) |
| 164 | + k = graph.tensor_like(k_gpu.detach()) |
| 165 | + v = graph.tensor_like(v_gpu.detach()) |
| 166 | + o = graph.tensor_like(o_gpu.detach()) |
| 167 | + g = graph.tensor_like(g_gpu.detach()) |
| 168 | + stats = graph.tensor_like(lse.detach()) |
| 169 | + |
| 170 | + dq, dk, dv = graph.sdpa_backward( |
| 171 | + name="sdpa_backward", |
| 172 | + q=q, |
| 173 | + k=k, |
| 174 | + v=v, |
| 175 | + o=o, |
| 176 | + dO=g, |
| 177 | + stats=stats, |
| 178 | + attn_scale=1.0 / math.sqrt(headdim), |
| 179 | + # use_causal_mask_bottom_right=causal or window_size_left is not None, |
| 180 | + use_causal_mask=causal or window_size_left is not None, |
| 181 | + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, |
| 182 | + ) |
| 183 | + |
| 184 | + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) |
| 185 | + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) |
| 186 | + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) |
| 187 | + |
| 188 | + graph.validate() |
| 189 | + graph.build_operation_graph() |
| 190 | + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) |
| 191 | + graph.check_support() |
| 192 | + graph.build_plans() |
| 193 | + |
| 194 | + variant_pack = { |
| 195 | + q: q_gpu, |
| 196 | + k: k_gpu, |
| 197 | + v: v_gpu, |
| 198 | + o: o_gpu, |
| 199 | + g: g_gpu, |
| 200 | + stats: lse, |
| 201 | + dq: dq_gpu, |
| 202 | + dk: dk_gpu, |
| 203 | + dv: dv_gpu, |
| 204 | + } |
| 205 | + |
| 206 | + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) |
| 207 | + |
| 208 | + def run(*args, **kwargs): |
| 209 | + graph.execute(variant_pack, workspace) |
| 210 | + return dq_gpu, dk_gpu, dv_gpu |
| 211 | + |
| 212 | + return run |
| 213 | + |
| 214 | + |
| 215 | +torch.manual_seed(0) |
| 216 | +repeats = 10 |
| 217 | +dropout_p = 0.0 |
| 218 | +causal = False |
| 219 | +dtype = torch.bfloat16 |
| 220 | +# dtype = torch.float8_e4m3fn |
| 221 | +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
| 222 | +device = 'cuda' |
| 223 | +verbose = True |
| 224 | +varlen = False |
| 225 | +has_backward = False |
| 226 | +page_size = None |
| 227 | +softcap = 0.0 |
| 228 | +V_colmajor = False |
| 229 | +deterministic = False |
| 230 | +batch_size = 2 |
| 231 | +# seqlen = 2048 |
| 232 | +seqlen = 8192 |
| 233 | +# seqlen = 4096 |
| 234 | +# seqlen = 2047 |
| 235 | +dim = 2048 |
| 236 | +# headdim = 128 |
| 237 | +# headdim = 64 |
| 238 | +headdim = 256 |
| 239 | +# for headdim in [64, 128, 256]: |
| 240 | +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] |
| 241 | +# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] |
| 242 | +# bs_seqlen_vals = [(32, 512), (16, 1024)] |
| 243 | +# bs_seqlen_vals = [(2, 64 * 132)] |
| 244 | +bs_seqlen_vals = [(4, 8192)] |
| 245 | +# bs_seqlen_vals = [(1, 16 * 1024)] |
| 246 | +time_f = {} |
| 247 | +time_b = {} |
| 248 | + |
| 249 | +# for headdim in [64, 128, 256]: |
| 250 | +# for headdim in [64, 96, 128, 192]: |
| 251 | +# for headdim in [64, 96, 128, 192, 256]: |
| 252 | +# for headdim in [64, 96, 128]: |
| 253 | +# for headdim in [64, 128, 256]: |
| 254 | +# for headdim in [64, 96, 128, 192, 256]: |
| 255 | +for headdim in [128]: |
| 256 | + nheads = dim // headdim |
| 257 | + # nheads = 128 |
| 258 | + # headdim = 64 |
| 259 | + # batch_size = 64 |
| 260 | + # seqlen = 512 |
| 261 | + # nheads = 8 |
| 262 | + # headdim = 128 |
| 263 | + nheads_kv = nheads |
| 264 | + # nheads_kv = nheads // 4 |
| 265 | + # nheads_kv = 1 |
| 266 | + headdim_v = headdim |
| 267 | + # headdim_v = 512 |
| 268 | + has_qv = headdim == 64 and headdim_v == 512 |
| 269 | + # has_qv = False |
| 270 | + |
| 271 | + for batch_size, seqlen in bs_seqlen_vals: |
| 272 | + num_splits = 0 |
| 273 | + # window_size = (-1, -1) |
| 274 | + window_size = (None, None) |
| 275 | + window_size_fa = (-1, -1) |
| 276 | + # window_size = (seqlen // 2 - 1, 0) |
| 277 | + pack_gqa = None |
| 278 | + # seqlen_q = 64 |
| 279 | + seqlen_q = seqlen |
| 280 | + leftpad_k = None |
| 281 | + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) |
| 282 | + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) |
| 283 | + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) |
| 284 | + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) |
| 285 | + q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] |
| 286 | + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_(has_backward) |
| 287 | + v_fa3 = v if not V_colmajor else v_colmajor |
| 288 | + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None |
| 289 | + # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) |
| 290 | + # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) |
| 291 | + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) |
| 292 | + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) |
| 293 | + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) |
| 294 | + stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) |
| 295 | + if varlen: |
| 296 | + q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] |
| 297 | + cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q |
| 298 | + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen |
| 299 | + # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) |
| 300 | + # q_unpad = q_unpad[:256] |
| 301 | + # seqlen_q = 256 |
| 302 | + # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) |
| 303 | + # q_unpad = q_unpad[:384] |
| 304 | + # seqlen_q = 384 |
| 305 | + if page_size is not None: |
| 306 | + assert seqlen % page_size == 0 |
| 307 | + k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] |
| 308 | + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), |
| 309 | + "(b s) -> b s", s=seqlen // page_size) |
| 310 | + else: |
| 311 | + page_table = None |
| 312 | + |
| 313 | + for causal in [False, True]: |
| 314 | + # for causal in [False]: |
| 315 | + print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") |
| 316 | + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) |
| 317 | + if cudnn is not None: |
| 318 | + # if False: |
| 319 | + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: |
| 320 | + cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) |
| 321 | + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) |
| 322 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: |
| 323 | + # if False: |
| 324 | + if not varlen: |
| 325 | + m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') |
| 326 | + else: |
| 327 | + m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') |
| 328 | + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean |
| 329 | + if has_backward: |
| 330 | + time.sleep(1) |
| 331 | + if not varlen: |
| 332 | + _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, |
| 333 | + repeats=repeats, verbose=False, desc='Fav2') |
| 334 | + else: |
| 335 | + _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, |
| 336 | + repeats=repeats, verbose=False, desc='Fav2') |
| 337 | + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean |
| 338 | + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) |
| 339 | + |
| 340 | + if cudnn is not None: |
| 341 | + # if False: |
| 342 | + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: |
| 343 | + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark |
| 344 | + m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') |
| 345 | + time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean |
| 346 | + time.sleep(1) |
| 347 | + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') |
| 348 | + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean |
| 349 | + # pytorch_profiler(cudnn_spda, backward=False) |
| 350 | + # pytorch_profiler(cudnn_spda_bwd, backward=False) |
| 351 | + time.sleep(1) |
| 352 | + if flash_attn_func_v3 is not None: |
| 353 | + if not varlen: |
| 354 | + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') |
| 355 | + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') |
| 356 | + # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) |
| 357 | + else: |
| 358 | + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') |
| 359 | + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) |
| 360 | + time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean |
| 361 | + if flash_attn_func_python is not None: |
| 362 | + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav3 python') |
| 363 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: |
| 364 | + time.sleep(1) |
| 365 | + if not varlen: |
| 366 | + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3') |
| 367 | + else: |
| 368 | + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, |
| 369 | + repeats=repeats, verbose=False, desc='Fav3') |
| 370 | + time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean |
| 371 | + # time.sleep(1) |
| 372 | + # if not varlen: |
| 373 | + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) |
| 374 | + # else: |
| 375 | + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) |
| 376 | + # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') |
| 377 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: |
| 378 | + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python') |
| 379 | + |
| 380 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: |
| 381 | + # if False: |
| 382 | + print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') |
| 383 | + if has_backward: |
| 384 | + print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') |
| 385 | + if cudnn is not None: |
| 386 | + print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') |
| 387 | + if has_backward: |
| 388 | + print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') |
| 389 | + if flash_attn_func_v3 is not None: |
| 390 | + print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') |
| 391 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: |
| 392 | + print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') |
| 393 | + |
| 394 | + if flash_attn_func_python is not None: |
| 395 | + print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') |
| 396 | + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: |
| 397 | + print(f'FAv2 Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') |
0 commit comments