Skip to content

Commit 591dff9

Browse files
authored
Fix and improve apply_token_bitmask benchmark script (#391)
Currently, the bench script is not runnable (from xgrammar.kernels import apply_token_bitmask_inplace_kernels not found). # Change - Update the script to make it runnable - Kick off multiple setup in a single run, so we could create benchmark report in one shot # Usage ```bash (xgrammar) Fri Aug 08 22:18:25 [/data/users/jialino/gitrepos/xgrammar] python3 examples/benchmark/bench_apply_token_bitmask_inplace.py Running cmake --build & --install in /data/users/jialino/gitrepos/xgrammar/build ninja: no work to do. -- Install configuration: "RelWithDebInfo" -- Up-to-date: /home/jialino/uv_env/xgrammar/lib64/python3.12/site-packages/xgrammar/./xgrammar_bindings.cpython-312-x86_64-linux-gnu.so W0808 22:18:51.578000 320509 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. W0808 22:18:51.578000 320509 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures. 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [01:13<00:00, 4.92s/it] | Batch | Vocab | Masked cnt | Torch Compile | Triton | | size | size | | Baseline us | us (speedup) | |--------:|--------:|-------------:|----------------:|----------------:| | 1 | 128000 | 1 | 6.04 | 5.52 (1.09x) | | 1 | 128000 | 64000 | 5.96 | 6.16 (0.97x) | | 1 | 128000 | 127000 | 6.01 | 6.27 (0.96x) | | 8 | 128000 | 1 | 10.90 | 6.04 (1.81x) | | 8 | 128000 | 64000 | 10.90 | 7.76 (1.40x) | | 8 | 128000 | 127000 | 10.91 | 8.02 (1.36x) | | 64 | 128000 | 1 | 48.72 | 13.36 (3.65x) | | 64 | 128000 | 64000 | 48.74 | 46.35 (1.05x) | | 64 | 128000 | 127000 | 48.74 | 33.26 (1.47x) | | 512 | 128000 | 1 | 350.11 | 67.43 (5.19x) | | 512 | 128000 | 64000 | 347.57 | 330.76 (1.05x) | | 512 | 128000 | 127000 | 345.73 | 250.06 (1.38x) | | 4096 | 128000 | 1 | 2903.81 | 494.67 (5.87x) | | 4096 | 128000 | 64000 | 2855.70 | 2516.79 (1.13x) | | 4096 | 128000 | 127000 | 2720.98 | 1936.44 (1.41x) | ``` Signed-off-by: Jialin Ouyang <[email protected]>
1 parent 77b1888 commit 591dff9

File tree

2 files changed

+108
-76
lines changed

2 files changed

+108
-76
lines changed

examples/benchmark/README.md

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,26 @@ python3 bench_grammar_compile_mask_gen.py [-h] [--backend {xgrammar,outlines,lmf
2121

2222
#### Run
2323
```bash
24-
python3 bench_apply_token_bitmask_inplace.py [-h] [--impl {cuda,triton}]
25-
[--batch_size BATCH_SIZE] [--vocab_size VOCAB_SIZE]
26-
[--masked_cnt MASKED_CNT] [--stride STRIDE]
27-
[--logits_dtype {float32,float16,bfloat16}]
28-
[--warmup WARMUP] [--rep REP]
24+
python3 examples/benchmark/bench_apply_token_bitmask_inplace.py
2925
```
3026

3127
#### Results
32-
33-
| GPU | Batch size | Vocab size | Masked cnt | Triton (μs) | CUDA (μs) | Speedup |
34-
|:--------------:|-----------:|-----------:|-----------:|-------------:|----------:|--------:|
35-
| H100 80GB HBM3 | 1 | 128k | 1k | 5.95 | 6.57 | 0.91x |
36-
| | 1 | 128k | 64k | 6.38 | 6.46 | 0.99x |
37-
| | 1 | 128k | 127k | 6.69 | 6.48 | 1.03x |
38-
| | 8 | 128k | 1k | 6.77 | 6.94 | 0.98x |
39-
| | 8 | 128k | 64k | 8.05 | 9.19 | 0.88x |
40-
| | 8 | 128k | 127k | 8.49 | 8.08 | 1.05x |
41-
| | 64 | 128k | 1k | 14.97 | 13.82 | 1.08x |
42-
| | 64 | 128k | 64k | 43.13 | 30.98 | 1.39x |
43-
| | 64 | 128k | 127k | 33.85 | 21.43 | 1.58x |
44-
| | 512 | 128k | 1k | 82.65 | 61.13 | 1.35x |
45-
| | 512 | 128k | 64k | 293.51 | 194.06 | 1.51x |
46-
| | 512 | 128k | 127k | 240.11 | 119.77 | 2.00x |
47-
| | 4096 | 128k | 1k | 566.17 | 417.33 | 1.36x |
48-
| | 4096 | 128k | 64k | 2198.59 | 1491.79 | 1.47x |
49-
| | 4096 | 128k | 127k | 1812.39 | 897.17 | 2.02x |
50-
| A100 SXM4 80GB | 1 | 128k | 1k | 8.32 | 7.97 | 1.04x |
51-
| | 1 | 128k | 64k | 9.26 | 8.24 | 1.12x |
52-
| | 1 | 128k | 127k | 8.81 | 8.71 | 1.01x |
53-
| | 8 | 128k | 1k | 9.56 | 10.31 | 0.93x |
54-
| | 8 | 128k | 64k | 12.72 | 13.22 | 0.96x |
55-
| | 8 | 128k | 127k | 13.45 | 11.27 | 1.19x |
56-
| | 64 | 128k | 1k | 22.95 | 25.57 | 0.90x |
57-
| | 64 | 128k | 64k | 58.52 | 56.47 | 1.04x |
58-
| | 64 | 128k | 127k | 44.83 | 39.29 | 1.14x |
59-
| | 512 | 128k | 1k | 132.92 | 108.60 | 1.22x |
60-
| | 512 | 128k | 64k | 362.08 | 349.54 | 1.04x |
61-
| | 512 | 128k | 127k | 306.75 | 233.20 | 1.32x |
62-
| | 4096 | 128k | 1k | 955.99 | 777.94 | 1.23x |
63-
| | 4096 | 128k | 64k | 2756.63 | 2707.57 | 1.02x |
64-
| | 4096 | 128k | 127k | 2472.82 | 1782.41 | 1.39x |
28+
H100
29+
| Batch | Vocab | Masked cnt | Torch Compile | Triton |
30+
| size | size | | Baseline us | us (speedup) |
31+
|--------:|--------:|-------------:|----------------:|----------------:|
32+
| 1 | 128000 | 1 | 6.04 | 5.52 (1.09x) |
33+
| 1 | 128000 | 64000 | 5.96 | 6.16 (0.97x) |
34+
| 1 | 128000 | 127000 | 6.01 | 6.27 (0.96x) |
35+
| 8 | 128000 | 1 | 10.90 | 6.04 (1.81x) |
36+
| 8 | 128000 | 64000 | 10.90 | 7.76 (1.40x) |
37+
| 8 | 128000 | 127000 | 10.91 | 8.02 (1.36x) |
38+
| 64 | 128000 | 1 | 48.72 | 13.36 (3.65x) |
39+
| 64 | 128000 | 64000 | 48.74 | 46.35 (1.05x) |
40+
| 64 | 128000 | 127000 | 48.74 | 33.26 (1.47x) |
41+
| 512 | 128000 | 1 | 350.11 | 67.43 (5.19x) |
42+
| 512 | 128000 | 64000 | 347.57 | 330.76 (1.05x) |
43+
| 512 | 128000 | 127000 | 345.73 | 250.06 (1.38x) |
44+
| 4096 | 128000 | 1 | 2903.81 | 494.67 (5.87x) |
45+
| 4096 | 128000 | 64000 | 2855.70 | 2516.79 (1.13x) |
46+
| 4096 | 128000 | 127000 | 2720.98 | 1936.44 (1.41x) |

examples/benchmark/bench_apply_token_bitmask_inplace.py

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,57 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
import argparse
16+
from itertools import product
17+
from typing import Any
1718

1819
import torch
20+
from tabulate import tabulate
21+
from tqdm import tqdm
1922
from triton.testing import do_bench
2023

21-
from xgrammar.kernels import apply_token_bitmask_inplace_kernels
24+
from xgrammar.kernels.apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda
25+
from xgrammar.kernels.apply_token_bitmask_inplace_torch_compile import (
26+
apply_token_bitmask_inplace_torch_compile,
27+
)
28+
from xgrammar.kernels.apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton
2229
from xgrammar.testing import _bool_mask_to_bitmask
2330

24-
if __name__ == "__main__":
25-
parser = argparse.ArgumentParser()
26-
parser.add_argument("--impl", type=str, choices=["cuda", "triton"], default="cuda")
27-
parser.add_argument("--batch_size", type=int, default=4096)
28-
parser.add_argument("--vocab_size", type=int, default=128000)
29-
parser.add_argument("--masked_cnt", type=int, default=1024)
30-
parser.add_argument("--stride", type=int, default=1)
31-
parser.add_argument(
32-
"--logits_dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float32"
33-
)
34-
parser.add_argument("--warmup", type=int, default=500)
35-
parser.add_argument("--rep", type=int, default=2000)
36-
args = parser.parse_args()
31+
IMPL_TORCH_COMPILE: str = "Torch Compile"
32+
IMPL_TRITON: str = "Triton"
33+
IMPL_CUDA: str = "CUDA"
34+
35+
ALL_IMPLS: list[str] = [IMPL_TORCH_COMPILE, IMPL_TRITON, IMPL_CUDA]
36+
37+
38+
def bench_single_impl(
39+
impl: str,
40+
logits: torch.Tensor,
41+
bitmask: torch.Tensor,
42+
logits_expected: torch.Tensor,
43+
kwargs: dict[str, Any],
44+
args: argparse.Namespace,
45+
) -> float:
46+
if impl == IMPL_TORCH_COMPILE:
47+
f = lambda: apply_token_bitmask_inplace_torch_compile(logits, bitmask, **kwargs)
48+
elif impl == IMPL_TRITON:
49+
f = lambda: apply_token_bitmask_inplace_triton(logits, bitmask, **kwargs)
50+
else:
51+
f = lambda: apply_token_bitmask_inplace_cuda(logits, bitmask, **kwargs)
52+
53+
f()
54+
torch.testing.assert_close(logits, logits_expected.to("cuda"))
55+
56+
torch.cuda.synchronize()
57+
exec_time = do_bench(f, warmup=args.warmup, rep=args.rep)
58+
return exec_time * 1000
3759

60+
61+
def bench_single_setup(batch_size: int, masked_cnt: int, args: argparse.Namespace) -> list[float]:
3862
vocab_size = args.vocab_size
39-
batch_size = args.batch_size
40-
bitmask_size = (vocab_size + 32 - 1) // 32
41-
masked_cnt = args.masked_cnt
4263
stride = args.stride
4364
logits_dtype = getattr(torch, args.logits_dtype)
44-
4565
logits = torch.randn(batch_size, vocab_size, dtype=logits_dtype, device="cuda")
46-
4766
if masked_cnt >= vocab_size:
4867
bool_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device="cuda")
4968
else:
@@ -55,29 +74,60 @@
5574
bool_mask.scatter_(1, masked_positions, False)
5675
assert (bool_mask.sum(dim=-1) + masked_cnt == vocab_size).all().item()
5776
bitmask = _bool_mask_to_bitmask(bool_mask)
58-
5977
masked_batch_ids = torch.arange(0, batch_size, stride, dtype=torch.int32, device="cuda")
6078
kwargs = {} if stride == 1 else {"indices": masked_batch_ids}
6179

62-
logits_expected = logits.clone()
63-
logits_expected[masked_batch_ids] = torch.masked_fill(
64-
logits_expected[masked_batch_ids], ~bool_mask[masked_batch_ids], float("-inf")
80+
logits_copies = [logits.clone() for _ in range(len(args.impl))]
81+
logits[masked_batch_ids] = torch.masked_fill(
82+
logits[masked_batch_ids], ~bool_mask[masked_batch_ids], float("-inf")
6583
)
84+
return [
85+
bench_single_impl(impl, logits_copy, bitmask, logits, kwargs, args)
86+
for impl, logits_copy in zip(args.impl, logits_copies)
87+
]
6688

67-
if args.impl == "cuda":
68-
if "cuda" not in apply_token_bitmask_inplace_kernels:
69-
raise ImportError("CUDA is not installed")
70-
f = lambda: apply_token_bitmask_inplace_kernels["cuda"](logits, bitmask, **kwargs)
71-
elif args.impl == "triton":
72-
if "triton" not in apply_token_bitmask_inplace_kernels:
73-
raise ImportError("Triton is not installed")
74-
f = lambda: apply_token_bitmask_inplace_kernels["triton"](logits, bitmask, **kwargs)
7589

76-
f()
77-
torch.testing.assert_close(logits, logits_expected.to("cuda"))
90+
if __name__ == "__main__":
91+
parser = argparse.ArgumentParser()
92+
parser.add_argument(
93+
"--impl", type=str, nargs="*", choices=ALL_IMPLS, default=[IMPL_TORCH_COMPILE, IMPL_TRITON]
94+
)
95+
parser.add_argument("--batch-size", type=int, nargs="*", default=[1, 8, 64, 512, 4096])
96+
parser.add_argument("--vocab-size", type=int, default=128000)
97+
parser.add_argument("--masked-cnt", type=int, nargs="*", default=[1, 64000, 127000])
98+
parser.add_argument("--stride", type=int, default=1)
99+
parser.add_argument(
100+
"--logits_dtype", type=str, choices=["float32", "float16", "bfloat16"], default="float32"
101+
)
102+
parser.add_argument("--warmup", type=int, default=500)
103+
parser.add_argument("--rep", type=int, default=2000)
104+
args = parser.parse_args()
78105

79-
torch.cuda.synchronize()
80-
exec_time = do_bench(f, warmup=args.warmup, rep=args.rep)
81-
exec_time *= 10**3
106+
data_rows = []
107+
for batch_size, masked_cnt in tqdm(list(product(args.batch_size, args.masked_cnt))):
108+
all_us = bench_single_setup(batch_size, masked_cnt, args)
109+
data_rows.append(
110+
[
111+
batch_size,
112+
args.vocab_size,
113+
masked_cnt,
114+
f"{all_us[0]:.2f}",
115+
*[f"{us:.2f} ({all_us[0]/us:>4.2f}x)" for us in all_us[1:]],
116+
]
117+
)
82118

83-
print(f"Implementation: {args.impl}\t| Execution time (μs): {exec_time:.4f}")
119+
print(
120+
tabulate(
121+
data_rows,
122+
headers=[
123+
"Batch\nsize",
124+
"Vocab\nsize",
125+
"Masked cnt",
126+
f"{args.impl[0]}\nBaseline us",
127+
*[f"{impl} \nus (speedup)" for impl in args.impl[1:]],
128+
],
129+
tablefmt="pipe",
130+
floatfmt=".2f",
131+
colalign=["right"] * len(data_rows[0]),
132+
)
133+
)

0 commit comments

Comments
 (0)