|
| 1 | +# Adapt from |
| 2 | +# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py |
| 3 | + |
| 4 | +from typing import List, Optional, Union |
| 5 | + |
| 6 | +import torch |
| 7 | +import triton |
| 8 | +import triton.language as tl |
| 9 | + |
| 10 | +from scratchpad.utils import get_device_core_count |
| 11 | + |
| 12 | + |
| 13 | +@triton.jit |
| 14 | +def apply_token_bitmask_inplace_kernel( |
| 15 | + logits_ptr, |
| 16 | + bitmask_ptr, |
| 17 | + indices_ptr, |
| 18 | + num_rows, |
| 19 | + vocab_size, |
| 20 | + logits_strides, |
| 21 | + bitmask_strides, |
| 22 | + NUM_SMS: tl.constexpr, |
| 23 | + BLOCK_SIZE: tl.constexpr, |
| 24 | +): |
| 25 | + """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor, |
| 26 | + where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask, |
| 27 | + the masked logits will be set to -inf. |
| 28 | +
|
| 29 | + Parameters |
| 30 | + ---------- |
| 31 | + logits_ptr : tl.tensor |
| 32 | + Pointer to the logits tensor to apply the bitmask to. |
| 33 | +
|
| 34 | + bitmask_ptr : tl.tensor |
| 35 | + Pointer to the bitmask tensor to apply. |
| 36 | +
|
| 37 | + indices_ptr : Optional[tl.tensor] |
| 38 | + Optional pointer to indices tensor specifying which rows to apply the mask to. |
| 39 | +
|
| 40 | + num_rows : int |
| 41 | + Number of rows to process. If indices_ptr is provided, this is the number of unique indices. |
| 42 | +
|
| 43 | + vocab_size : int |
| 44 | + Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the |
| 45 | + same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary. |
| 46 | +
|
| 47 | + logits_strides : int |
| 48 | + Stride between rows in the logits tensor. |
| 49 | +
|
| 50 | + bitmask_strides : int |
| 51 | + Stride between rows in the bitmask tensor. |
| 52 | +
|
| 53 | + NUM_SMS : int |
| 54 | + Number of streaming multiprocessors to use. |
| 55 | +
|
| 56 | + BLOCK_SIZE : int |
| 57 | + Size of processing blocks. |
| 58 | + """ |
| 59 | + |
| 60 | + pid = tl.program_id(0) |
| 61 | + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) |
| 62 | + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): |
| 63 | + row_id = work_id // num_blocks |
| 64 | + block_offset = (work_id % num_blocks) * BLOCK_SIZE |
| 65 | + batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) |
| 66 | + offsets = block_offset + tl.arange(0, BLOCK_SIZE) |
| 67 | + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) |
| 68 | + vocab_mask = offsets < vocab_size |
| 69 | + packed_bitmask_mask = bitmask_offsets < bitmask_strides |
| 70 | + packed_bitmask = tl.load( |
| 71 | + bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, |
| 72 | + packed_bitmask_mask, |
| 73 | + ) |
| 74 | + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 |
| 75 | + bitmask = bitmask.reshape(BLOCK_SIZE) |
| 76 | + |
| 77 | + tl.store( |
| 78 | + logits_ptr + batch_id * logits_strides + offsets, |
| 79 | + -float("inf"), |
| 80 | + vocab_mask & bitmask, |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +def apply_token_bitmask_inplace_triton( |
| 85 | + logits: torch.Tensor, |
| 86 | + bitmask: torch.Tensor, |
| 87 | + indices: Optional[Union[List[int], torch.Tensor]] = None, |
| 88 | +): |
| 89 | + NUM_SMS = get_device_core_count() |
| 90 | + BLOCK_SIZE = 4096 |
| 91 | + BITS_PER_BLOCK = 32 |
| 92 | + |
| 93 | + # Check input dtype |
| 94 | + assert bitmask.dtype == torch.int32, "bitmask must be of type int32" |
| 95 | + |
| 96 | + # Check input tensor shapes. |
| 97 | + logits_shape = logits.shape |
| 98 | + bitmask_shape = bitmask.shape |
| 99 | + if logits.ndim == 1: |
| 100 | + logits_shape = (1, logits_shape[0]) |
| 101 | + if bitmask.ndim == 1: |
| 102 | + bitmask_shape = (1, bitmask_shape[0]) |
| 103 | + |
| 104 | + required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK |
| 105 | + assert required_bitmask_width >= bitmask_shape[1], ( |
| 106 | + f"Bitmask width too large: allow at most {required_bitmask_width} int32s for " |
| 107 | + f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}" |
| 108 | + ) |
| 109 | + |
| 110 | + vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK) |
| 111 | + |
| 112 | + num_rows = None |
| 113 | + if isinstance(indices, list) or isinstance(indices, torch.Tensor): |
| 114 | + indices = torch.tensor(indices, dtype=torch.int32, device=logits.device) |
| 115 | + num_rows = indices.shape[0] |
| 116 | + else: |
| 117 | + assert ( |
| 118 | + logits_shape[0] == bitmask_shape[0] |
| 119 | + ), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}" |
| 120 | + num_rows = logits_shape[0] |
| 121 | + |
| 122 | + if NUM_SMS > 0: |
| 123 | + grid = (NUM_SMS,) |
| 124 | + else: |
| 125 | + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) |
| 126 | + grid = (num_rows * num_blocks,) |
| 127 | + NUM_SMS = triton.next_power_of_2(grid[0]) |
| 128 | + |
| 129 | + apply_token_bitmask_inplace_kernel[grid]( |
| 130 | + logits, |
| 131 | + bitmask, |
| 132 | + indices, |
| 133 | + num_rows, |
| 134 | + vocab_size, |
| 135 | + logits_shape[1], |
| 136 | + bitmask_shape[1], |
| 137 | + NUM_SMS, |
| 138 | + BLOCK_SIZE, |
| 139 | + num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()), |
| 140 | + num_stages=3, |
| 141 | + ) |
0 commit comments