|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | - |
16 | 15 | import argparse |
| 16 | +from itertools import product |
| 17 | +from typing import Any |
17 | 18 |
|
18 | 19 | import torch |
| 20 | +from tabulate import tabulate |
| 21 | +from tqdm import tqdm |
19 | 22 | from triton.testing import do_bench |
20 | 23 |
|
21 | | -from xgrammar.kernels import apply_token_bitmask_inplace_kernels |
| 24 | +from xgrammar.kernels.apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda |
| 25 | +from xgrammar.kernels.apply_token_bitmask_inplace_torch_compile import ( |
| 26 | + apply_token_bitmask_inplace_torch_compile, |
| 27 | +) |
| 28 | +from xgrammar.kernels.apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton |
22 | 29 | from xgrammar.testing import _bool_mask_to_bitmask |
23 | 30 |
|
24 | | -if __name__ == "__main__": |
25 | | - parser = argparse.ArgumentParser() |
26 | | - parser.add_argument("--impl", type=str, choices=["cuda", "triton"], default="cuda") |
27 | | - parser.add_argument("--batch_size", type=int, default=4096) |
28 | | - parser.add_argument("--vocab_size", type=int, default=128000) |
29 | | - parser.add_argument("--masked_cnt", type=int, default=1024) |
30 | | - parser.add_argument("--stride", type=int, default=1) |
31 | | - parser.add_argument( |
32 | | - "--logits_dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float32" |
33 | | - ) |
34 | | - parser.add_argument("--warmup", type=int, default=500) |
35 | | - parser.add_argument("--rep", type=int, default=2000) |
36 | | - args = parser.parse_args() |
| 31 | +IMPL_TORCH_COMPILE: str = "Torch Compile" |
| 32 | +IMPL_TRITON: str = "Triton" |
| 33 | +IMPL_CUDA: str = "CUDA" |
| 34 | + |
| 35 | +ALL_IMPLS: list[str] = [IMPL_TORCH_COMPILE, IMPL_TRITON, IMPL_CUDA] |
| 36 | + |
| 37 | + |
| 38 | +def bench_single_impl( |
| 39 | + impl: str, |
| 40 | + logits: torch.Tensor, |
| 41 | + bitmask: torch.Tensor, |
| 42 | + logits_expected: torch.Tensor, |
| 43 | + kwargs: dict[str, Any], |
| 44 | + args: argparse.Namespace, |
| 45 | +) -> float: |
| 46 | + if impl == IMPL_TORCH_COMPILE: |
| 47 | + f = lambda: apply_token_bitmask_inplace_torch_compile(logits, bitmask, **kwargs) |
| 48 | + elif impl == IMPL_TRITON: |
| 49 | + f = lambda: apply_token_bitmask_inplace_triton(logits, bitmask, **kwargs) |
| 50 | + else: |
| 51 | + f = lambda: apply_token_bitmask_inplace_cuda(logits, bitmask, **kwargs) |
| 52 | + |
| 53 | + f() |
| 54 | + torch.testing.assert_close(logits, logits_expected.to("cuda")) |
| 55 | + |
| 56 | + torch.cuda.synchronize() |
| 57 | + exec_time = do_bench(f, warmup=args.warmup, rep=args.rep) |
| 58 | + return exec_time * 1000 |
37 | 59 |
|
| 60 | + |
| 61 | +def bench_single_setup(batch_size: int, masked_cnt: int, args: argparse.Namespace) -> list[float]: |
38 | 62 | vocab_size = args.vocab_size |
39 | | - batch_size = args.batch_size |
40 | | - bitmask_size = (vocab_size + 32 - 1) // 32 |
41 | | - masked_cnt = args.masked_cnt |
42 | 63 | stride = args.stride |
43 | 64 | logits_dtype = getattr(torch, args.logits_dtype) |
44 | | - |
45 | 65 | logits = torch.randn(batch_size, vocab_size, dtype=logits_dtype, device="cuda") |
46 | | - |
47 | 66 | if masked_cnt >= vocab_size: |
48 | 67 | bool_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device="cuda") |
49 | 68 | else: |
|
55 | 74 | bool_mask.scatter_(1, masked_positions, False) |
56 | 75 | assert (bool_mask.sum(dim=-1) + masked_cnt == vocab_size).all().item() |
57 | 76 | bitmask = _bool_mask_to_bitmask(bool_mask) |
58 | | - |
59 | 77 | masked_batch_ids = torch.arange(0, batch_size, stride, dtype=torch.int32, device="cuda") |
60 | 78 | kwargs = {} if stride == 1 else {"indices": masked_batch_ids} |
61 | 79 |
|
62 | | - logits_expected = logits.clone() |
63 | | - logits_expected[masked_batch_ids] = torch.masked_fill( |
64 | | - logits_expected[masked_batch_ids], ~bool_mask[masked_batch_ids], float("-inf") |
| 80 | + logits_copies = [logits.clone() for _ in range(len(args.impl))] |
| 81 | + logits[masked_batch_ids] = torch.masked_fill( |
| 82 | + logits[masked_batch_ids], ~bool_mask[masked_batch_ids], float("-inf") |
65 | 83 | ) |
| 84 | + return [ |
| 85 | + bench_single_impl(impl, logits_copy, bitmask, logits, kwargs, args) |
| 86 | + for impl, logits_copy in zip(args.impl, logits_copies) |
| 87 | + ] |
66 | 88 |
|
67 | | - if args.impl == "cuda": |
68 | | - if "cuda" not in apply_token_bitmask_inplace_kernels: |
69 | | - raise ImportError("CUDA is not installed") |
70 | | - f = lambda: apply_token_bitmask_inplace_kernels["cuda"](logits, bitmask, **kwargs) |
71 | | - elif args.impl == "triton": |
72 | | - if "triton" not in apply_token_bitmask_inplace_kernels: |
73 | | - raise ImportError("Triton is not installed") |
74 | | - f = lambda: apply_token_bitmask_inplace_kernels["triton"](logits, bitmask, **kwargs) |
75 | 89 |
|
76 | | - f() |
77 | | - torch.testing.assert_close(logits, logits_expected.to("cuda")) |
| 90 | +if __name__ == "__main__": |
| 91 | + parser = argparse.ArgumentParser() |
| 92 | + parser.add_argument( |
| 93 | + "--impl", type=str, nargs="*", choices=ALL_IMPLS, default=[IMPL_TORCH_COMPILE, IMPL_TRITON] |
| 94 | + ) |
| 95 | + parser.add_argument("--batch-size", type=int, nargs="*", default=[1, 8, 64, 512, 4096]) |
| 96 | + parser.add_argument("--vocab-size", type=int, default=128000) |
| 97 | + parser.add_argument("--masked-cnt", type=int, nargs="*", default=[1, 64000, 127000]) |
| 98 | + parser.add_argument("--stride", type=int, default=1) |
| 99 | + parser.add_argument( |
| 100 | + "--logits_dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float32" |
| 101 | + ) |
| 102 | + parser.add_argument("--warmup", type=int, default=500) |
| 103 | + parser.add_argument("--rep", type=int, default=2000) |
| 104 | + args = parser.parse_args() |
78 | 105 |
|
79 | | - torch.cuda.synchronize() |
80 | | - exec_time = do_bench(f, warmup=args.warmup, rep=args.rep) |
81 | | - exec_time *= 10**3 |
| 106 | + data_rows = [] |
| 107 | + for batch_size, masked_cnt in tqdm(list(product(args.batch_size, args.masked_cnt))): |
| 108 | + all_us = bench_single_setup(batch_size, masked_cnt, args) |
| 109 | + data_rows.append( |
| 110 | + [ |
| 111 | + batch_size, |
| 112 | + args.vocab_size, |
| 113 | + masked_cnt, |
| 114 | + f"{all_us[0]:.2f}", |
| 115 | + *[f"{us:.2f} ({all_us[0]/us:>4.2f}x)" for us in all_us[1:]], |
| 116 | + ] |
| 117 | + ) |
82 | 118 |
|
83 | | - print(f"Implementation: {args.impl}\t| Execution time (μs): {exec_time:.4f}") |
| 119 | + print( |
| 120 | + tabulate( |
| 121 | + data_rows, |
| 122 | + headers=[ |
| 123 | + "Batch\nsize", |
| 124 | + "Vocab\nsize", |
| 125 | + "Masked cnt", |
| 126 | + f"{args.impl[0]}\nBaseline us", |
| 127 | + *[f"{impl} \nus (speedup)" for impl in args.impl[1:]], |
| 128 | + ], |
| 129 | + tablefmt="pipe", |
| 130 | + floatfmt=".2f", |
| 131 | + colalign=["right"] * len(data_rows[0]), |
| 132 | + ) |
| 133 | + ) |
0 commit comments