Skip to content

Commit 604d004

Browse files
AhajhaUbospica
andauthored
Bind fill_next_token_bitmask against nb::ndarray (#338)
Nanobind has built-in support for the DLPack standard, which allows us to take anything adhering to the DLPack spec here rather than just PyTorch objects. See here: https://nanobind.readthedocs.io/en/latest/ndarray.html This does add a dependency on numpy, but only for a single type, which does seem a little overkill. If I push on #233 a little bit though, we should only need these definitions in the stubfiles, and will be unneeded for the `.py` files. So ideally this is a temporary dependency. My goal is to reduce the dependency on PyTorch a bit by making the code more general. I don't know if (or even think) that we can remove it entirely, but this seems worthwhile to do. The annotation on the `nb::ndarray` class will actually check at runtime that the parameter has the desired properties (in this case, it's on the CPU and is an int32_t), and will simply fail to call the function if not. This might be a slight breaking change in terms of what exception actually gets raised, but I think this is reasonable. A few related changes: - I've converted a few function signatures to take DLTensors by `const&` instead of `*`, didn't see a good reason for it. - Added an anonymous namespace in `nanobind.cc` just to keep things hygenic. --------- Signed-off-by: Ubospica <[email protected]> Co-authored-by: Ubospica <[email protected]> Co-authored-by: Yixin Dong <[email protected]>
1 parent d8906a2 commit 604d004

File tree

6 files changed

+78
-38
lines changed

6 files changed

+78
-38
lines changed

cpp/nanobind/nanobind.cc

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55

66
#include <nanobind/nanobind.h>
7+
#include <nanobind/ndarray.h>
78
#include <nanobind/stl/optional.h>
89
#include <nanobind/stl/pair.h>
910
#include <nanobind/stl/string.h>
@@ -20,7 +21,8 @@
2021
#include "xgrammar/exception.h"
2122

2223
namespace nb = nanobind;
23-
using namespace xgrammar;
24+
25+
namespace xgrammar {
2426

2527
std::vector<std::string> CommonEncodedVocabType(
2628
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab
@@ -39,6 +41,39 @@ std::vector<std::string> CommonEncodedVocabType(
3941
return encoded_vocab_strs;
4042
}
4143

44+
bool GrammarMatcher_FillNextTokenBitmask(
45+
GrammarMatcher& matcher, nb::ndarray<> arr, int32_t index, bool debug_print
46+
) {
47+
if (arr.ndim() != 1 && arr.ndim() != 2) {
48+
throw std::runtime_error("token_bitmask tensor must be 1D or 2D");
49+
}
50+
51+
// 2. Device: ensure the tensor is on CPU
52+
if (arr.device_type() != nb::device::cpu::value) {
53+
throw std::runtime_error("token_bitmask array must be on CPU");
54+
}
55+
56+
// 3. Data type: ensure 32-bit integers
57+
if (arr.dtype() != nb::dtype<int32_t>()) {
58+
throw std::runtime_error("token_bitmask array must be int32");
59+
}
60+
61+
// Under the hood these are stored with the same standard (DLPack), but nanobind
62+
// defines its own types, and doesn't expose a way to just get the object directly.
63+
// We'll just do some pointer hackery to get there, rather than build the type back up manually:
64+
65+
// The data in an ndarray is defined as:
66+
// detail::ndarray_handle* m_handle = nullptr;
67+
// dlpack::dltensor m_dltensor;
68+
// Assert this, then skip over m_handle and reinterpret m_dltensor.
69+
static_assert(sizeof(arr) == sizeof(void*) + sizeof(nb::dlpack::dltensor));
70+
71+
DLTensor* bitmask_dltensor_ptr =
72+
reinterpret_cast<::DLTensor*>(reinterpret_cast<char*>(&arr) + sizeof(void*));
73+
74+
return matcher.FillNextTokenBitmask(bitmask_dltensor_ptr, index, debug_print);
75+
}
76+
4277
std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
4378
const auto& decoded_vocab = tokenizer.GetDecodedVocab();
4479
std::vector<nanobind::bytes> py_result;
@@ -55,6 +90,10 @@ static void RegisterRuntimeError(nb::module_& m, const char* name) {
5590
static_cast<void>(nb::exception<T>{m, name, PyExc_RuntimeError});
5691
}
5792

93+
} // namespace xgrammar
94+
95+
using namespace xgrammar;
96+
5897
NB_MODULE(xgrammar_bindings, m) {
5998
RegisterRuntimeError<DeserializeFormatError>(m, "DeserializeFormatError");
6099
RegisterRuntimeError<DeserializeVersionError>(m, "DeserializeVersionError");

cpp/nanobind/python_methods.cc

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,6 @@ int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) {
4141
return static_cast<int>(tokenizer.GetVocabType());
4242
}
4343

44-
bool GrammarMatcher_FillNextTokenBitmask(
45-
GrammarMatcher& matcher,
46-
intptr_t token_bitmask_ptr,
47-
std::vector<int64_t> shape,
48-
int32_t index,
49-
bool debug_print
50-
) {
51-
XGRAMMAR_CHECK(shape.size() == 1 || shape.size() == 2) << "token_bitmask tensor must be 1D or 2D";
52-
53-
DLTensor bitmask_dltensor{
54-
reinterpret_cast<void*>(token_bitmask_ptr),
55-
DLDevice{kDLCPU, 0},
56-
static_cast<int32_t>(shape.size()),
57-
GetBitmaskDLType(),
58-
shape.data(),
59-
nullptr,
60-
0
61-
};
62-
return matcher.FillNextTokenBitmask(&bitmask_dltensor, index, debug_print);
63-
}
64-
6544
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
6645
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
6746
) {

cpp/nanobind/python_methods.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ TokenizerInfo TokenizerInfo_Init(
2828

2929
int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer);
3030

31-
bool GrammarMatcher_FillNextTokenBitmask(
32-
GrammarMatcher& matcher,
33-
intptr_t token_bitmask_ptr,
34-
std::vector<int64_t> shape,
35-
int32_t index,
36-
bool debug_print
37-
);
38-
3931
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
4032
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
4133
);

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"triton; platform_system == 'Linux' and platform_machine == 'x86_64'",
2222
"mlx-lm; platform_system == 'Darwin' and platform_machine == 'arm64'",
2323
"ninja",
24+
"numpy",
2425
"typing-extensions>=4.9.0",
2526
]
2627

python/xgrammar/matcher.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Optional, Tuple, Union
88

99
import torch
10+
from numpy.typing import ArrayLike
1011

1112
from .base import XGRObject, _core
1213
from .compiler import CompiledGrammar
@@ -281,7 +282,7 @@ def accept_string(self, input_str: Union[str, bytes], *, debug_print: bool = Fal
281282
return self._handle.accept_string(input_str, debug_print)
282283

283284
def fill_next_token_bitmask(
284-
self, bitmask: torch.Tensor, index: int = 0, *, debug_print: bool = False
285+
self, bitmask: ArrayLike, index: int = 0, *, debug_print: bool = False
285286
) -> bool:
286287
"""Fill the bitmask for the next token prediction. The input bitmask can be generated
287288
by allocate_token_bitmask, and must be on CPU. bitmask[index] will be filled with the
@@ -309,15 +310,11 @@ def fill_next_token_bitmask(
309310
Raises
310311
------
311312
RuntimeError
313+
If the bitmask is invalid (not on CPU, not int32, shape mismatch).
314+
312315
If the recursion depth is exceeded.
313316
"""
314-
if bitmask.device.type != "cpu":
315-
raise ValueError("bitmask should be on CPU.")
316-
if bitmask.dtype != bitmask_dtype:
317-
raise ValueError(f"bitmask should be of type {bitmask_dtype}.")
318-
return self._handle.fill_next_token_bitmask(
319-
bitmask.data_ptr(), list(bitmask.shape), index, debug_print
320-
)
317+
return self._handle.fill_next_token_bitmask(bitmask, index, debug_print)
321318

322319
def find_jump_forward_string(self) -> str:
323320
"""Find the jump-forward string for jump-forward decoding. This is the longest string that

tests/python/test_grammar_matcher_basic.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test the basic functionality of GrammarMatcher."""
22

3+
import math
34
import sys
45
from typing import List, Optional, Union
56

@@ -15,6 +16,8 @@
1516
_is_grammar_accept_string,
1617
)
1718

19+
_is_cuda_available = torch.cuda.is_available()
20+
1821
json_grammar = xgr.Grammar.builtin_json_grammar()
1922

2023

@@ -363,5 +366,34 @@ def test_override_stop_tokens(tokenizer_path: str, override_stop_tokens: List[in
363366
assert matcher_2.stop_token_ids == override_stop_tokens
364367

365368

369+
def test_fill_next_token_bitmask_errors():
370+
# llama 3.1 8b
371+
tokenizer = AutoTokenizer.from_pretrained(
372+
"meta-llama/Meta-Llama-3-8B-Instruct", use_fast=True, trust_remote_code=True
373+
)
374+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
375+
matcher = _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info)
376+
377+
bitmask1 = torch.zeros(1, math.ceil(tokenizer_info.vocab_size / 32) - 1, dtype=torch.int32)
378+
with pytest.raises(RuntimeError):
379+
matcher.fill_next_token_bitmask(bitmask1)
380+
381+
bitmask2 = torch.zeros(1, math.ceil(tokenizer_info.vocab_size / 32), dtype=torch.int32)
382+
with pytest.raises(RuntimeError):
383+
matcher.fill_next_token_bitmask(bitmask2, index=1)
384+
385+
bitmask3 = torch.zeros(1, math.ceil(tokenizer_info.vocab_size / 32), dtype=torch.float32)
386+
with pytest.raises(RuntimeError):
387+
matcher.fill_next_token_bitmask(bitmask3)
388+
389+
if _is_cuda_available:
390+
bitmask3 = torch.zeros(1, math.ceil(tokenizer_info.vocab_size / 32), 1, dtype=torch.int32)
391+
with pytest.raises(RuntimeError):
392+
matcher.fill_next_token_bitmask(bitmask3)
393+
394+
bitmask_correct = torch.zeros(1, math.ceil(tokenizer_info.vocab_size / 32), dtype=torch.int32)
395+
matcher.fill_next_token_bitmask(bitmask_correct)
396+
397+
366398
if __name__ == "__main__":
367399
pytest.main(sys.argv)

0 commit comments

Comments
 (0)