Skip to content

Commit e41ac58

Browse files
committed
minor: fix structured output generation
1 parent 8d93486 commit e41ac58

File tree

9 files changed

+529
-123
lines changed

9 files changed

+529
-123
lines changed

scratchpad/constrained/base_backend.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,66 @@ class BaseGrammarObject:
1616
pass
1717

1818

19+
INVALID_GRAMMAR_OBJ: BaseGrammarObject = BaseGrammarObject()
20+
21+
1922
class BaseGrammarBackend:
2023
def __init__(self):
2124
self.executor = ThreadPoolExecutor()
22-
self.cache = {}
23-
self.cache_lock = Lock()
24-
25-
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
26-
with self.cache_lock:
27-
if key in self.cache:
28-
cache_hit = True
29-
entry = self.cache[key]
30-
else:
31-
cache_hit = False
32-
entry = CacheEntry(None, Event())
33-
self.cache[key] = entry
34-
35-
if cache_hit:
36-
entry.event.wait()
25+
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
26+
27+
def _not_supported(self, key_type: str, key_string: str) -> None:
28+
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
29+
30+
def dispatch_fallback(
31+
self, key_type: str, key_string: str
32+
) -> Optional[BaseGrammarObject]:
33+
"""
34+
This function should not be reached in any case.
35+
"""
36+
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
37+
38+
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
39+
return self._not_supported("json", key_string)
40+
41+
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
42+
return self._not_supported("regex", key_string)
43+
44+
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
45+
return self._not_supported("ebnf", key_string)
46+
47+
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
48+
return self._not_supported("structural_tag", key_string)
49+
50+
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
51+
key_type, key_string = key
52+
if key_type == "json":
53+
return self.dispatch_json(key_string)
54+
elif key_type == "regex":
55+
return self.dispatch_regex(key_string)
56+
elif key_type == "ebnf":
57+
return self.dispatch_ebnf(key_string)
58+
elif key_type == "structural_tag":
59+
return self.dispatch_structural_tag(key_string)
60+
elif key_type == "structural_pattern":
61+
return self.dispatch_structural_pattern(key_string)
3762
else:
38-
entry.value = self.init_value_impl(key)
39-
entry.event.set()
40-
return entry.value.copy() if entry.value else None
41-
42-
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
43-
raise NotImplementedError()
63+
return self.dispatch_fallback(key_type, key_string)
4464

45-
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
46-
with self.cache_lock:
47-
entry = self.cache.get(key)
48-
if not entry or not entry.event.is_set():
49-
return None
50-
val = self.cache[key].value
51-
return val.copy() if val else None
65+
def get_cached_or_future_value(
66+
self, key: Tuple[str, str]
67+
) -> Optional[BaseGrammarObject]:
68+
value = self.cache.get(key)
69+
if value:
70+
return value.copy(), True
71+
value = self.executor.submit(self._init_value_dispatch, key)
72+
return value, False
5273

53-
def get_future_value(self, key: Tuple[str, str]) -> Future:
54-
return self.executor.submit(self.init_value, key)
74+
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
75+
self.cache[key] = value
5576

5677
def reset(self):
57-
with self.cache_lock:
58-
self.cache.clear()
78+
self.cache.clear()
5979

6080

6181
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Adapt from
2+
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
3+
4+
from typing import List, Optional, Union
5+
6+
import torch
7+
import triton
8+
import triton.language as tl
9+
10+
from scratchpad.utils import get_device_core_count
11+
12+
13+
@triton.jit
14+
def apply_token_bitmask_inplace_kernel(
15+
logits_ptr,
16+
bitmask_ptr,
17+
indices_ptr,
18+
num_rows,
19+
vocab_size,
20+
logits_strides,
21+
bitmask_strides,
22+
NUM_SMS: tl.constexpr,
23+
BLOCK_SIZE: tl.constexpr,
24+
):
25+
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
26+
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
27+
the masked logits will be set to -inf.
28+
29+
Parameters
30+
----------
31+
logits_ptr : tl.tensor
32+
Pointer to the logits tensor to apply the bitmask to.
33+
34+
bitmask_ptr : tl.tensor
35+
Pointer to the bitmask tensor to apply.
36+
37+
indices_ptr : Optional[tl.tensor]
38+
Optional pointer to indices tensor specifying which rows to apply the mask to.
39+
40+
num_rows : int
41+
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
42+
43+
vocab_size : int
44+
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
45+
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
46+
47+
logits_strides : int
48+
Stride between rows in the logits tensor.
49+
50+
bitmask_strides : int
51+
Stride between rows in the bitmask tensor.
52+
53+
NUM_SMS : int
54+
Number of streaming multiprocessors to use.
55+
56+
BLOCK_SIZE : int
57+
Size of processing blocks.
58+
"""
59+
60+
pid = tl.program_id(0)
61+
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
62+
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
63+
row_id = work_id // num_blocks
64+
block_offset = (work_id % num_blocks) * BLOCK_SIZE
65+
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
66+
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
67+
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
68+
vocab_mask = offsets < vocab_size
69+
packed_bitmask_mask = bitmask_offsets < bitmask_strides
70+
packed_bitmask = tl.load(
71+
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
72+
packed_bitmask_mask,
73+
)
74+
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
75+
bitmask = bitmask.reshape(BLOCK_SIZE)
76+
77+
tl.store(
78+
logits_ptr + batch_id * logits_strides + offsets,
79+
-float("inf"),
80+
vocab_mask & bitmask,
81+
)
82+
83+
84+
def apply_token_bitmask_inplace_triton(
85+
logits: torch.Tensor,
86+
bitmask: torch.Tensor,
87+
indices: Optional[Union[List[int], torch.Tensor]] = None,
88+
):
89+
NUM_SMS = get_device_core_count()
90+
BLOCK_SIZE = 4096
91+
BITS_PER_BLOCK = 32
92+
93+
# Check input dtype
94+
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
95+
96+
# Check input tensor shapes.
97+
logits_shape = logits.shape
98+
bitmask_shape = bitmask.shape
99+
if logits.ndim == 1:
100+
logits_shape = (1, logits_shape[0])
101+
if bitmask.ndim == 1:
102+
bitmask_shape = (1, bitmask_shape[0])
103+
104+
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
105+
assert required_bitmask_width >= bitmask_shape[1], (
106+
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
107+
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
108+
)
109+
110+
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
111+
112+
num_rows = None
113+
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
114+
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
115+
num_rows = indices.shape[0]
116+
else:
117+
assert (
118+
logits_shape[0] == bitmask_shape[0]
119+
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
120+
num_rows = logits_shape[0]
121+
122+
if NUM_SMS > 0:
123+
grid = (NUM_SMS,)
124+
else:
125+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
126+
grid = (num_rows * num_blocks,)
127+
NUM_SMS = triton.next_power_of_2(grid[0])
128+
129+
apply_token_bitmask_inplace_kernel[grid](
130+
logits,
131+
bitmask,
132+
indices,
133+
num_rows,
134+
vocab_size,
135+
logits_shape[1],
136+
bitmask_shape[1],
137+
NUM_SMS,
138+
BLOCK_SIZE,
139+
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
140+
num_stages=3,
141+
)

0 commit comments

Comments
 (0)