Skip to content

Commit 4ccf711

Browse files
committed
Bind fill_next_token_bitmask against nb::ndarray
1 parent b87ed7f commit 4ccf711

File tree

8 files changed

+45
-43
lines changed

8 files changed

+45
-43
lines changed

cpp/grammar_matcher.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
270270

271271
bool AcceptString(const std::string& input_str, bool debug_print = false);
272272

273-
bool FillNextTokenBitmask(DLTensor* next_token_bitmask, int index, bool debug_print = false);
273+
bool FillNextTokenBitmask(
274+
const DLTensor& next_token_bitmask, int index, bool debug_print = false
275+
);
274276

275277
std::string FindJumpForwardString();
276278

@@ -493,13 +495,13 @@ bool GrammarMatcher::Impl::IsTokenBitmaskAllTrue(int32_t* bitmask_data_ptr) {
493495
}
494496

495497
bool GrammarMatcher::Impl::FillNextTokenBitmask(
496-
DLTensor* next_token_bitmask, int index, bool debug_print
498+
const DLTensor& next_token_bitmask, int index, bool debug_print
497499
) {
498500
XGRAMMAR_CHECK(!IsStopTokenAccepted())
499501
<< "GrammarMatcher has terminated after accepting the stop token, but is trying to "
500502
"find the next token mask";
501503
int32_t* bitmask_data_ptr =
502-
CheckAndGetBitmaskPtr(*next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
504+
CheckAndGetBitmaskPtr(next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
503505
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
504506
const auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
505507
const auto& latest_stack_tops = stack_tops_history_.GetLatest();
@@ -851,7 +853,7 @@ bool GrammarMatcher::AcceptString(const std::string& input_str, bool debug_print
851853
}
852854

853855
bool GrammarMatcher::FillNextTokenBitmask(
854-
DLTensor* next_token_bitmask, int index, bool debug_print
856+
const DLTensor& next_token_bitmask, int index, bool debug_print
855857
) {
856858
return pimpl_->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
857859
}

cpp/nanobind/nanobind.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55

66
#include <nanobind/nanobind.h>
7+
#include <nanobind/ndarray.h>
78
#include <nanobind/stl/optional.h>
89
#include <nanobind/stl/pair.h>
910
#include <nanobind/stl/string.h>
@@ -22,6 +23,8 @@
2223
namespace nb = nanobind;
2324
using namespace xgrammar;
2425

26+
namespace {
27+
2528
std::vector<std::string> CommonEncodedVocabType(
2629
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab
2730
) {
@@ -39,6 +42,32 @@ std::vector<std::string> CommonEncodedVocabType(
3942
return encoded_vocab_strs;
4043
}
4144

45+
bool GrammarMatcher_FillNextTokenBitmask(
46+
GrammarMatcher& matcher,
47+
nb::ndarray<int32_t, nb::device::cpu> arr,
48+
int32_t index,
49+
bool debug_print
50+
) {
51+
if (arr.ndim() != 1 && arr.ndim() != 2) {
52+
throw nb::type_error("token_bitmask tensor must be 1D or 2D");
53+
}
54+
55+
// Under the hood these are stored with the same standard (DLPack), but nanobind
56+
// defines its own types, and doesn't expose a way to just get the object directly.
57+
// We'll just do some pointer hackery to get there, rather than build the type back up manually:
58+
59+
// The data in an ndarray is defined as:
60+
// detail::ndarray_handle* m_handle = nullptr;
61+
// dlpack::dltensor m_dltensor;
62+
// Assert this, then skip over m_handle and reinterpret m_dltensor.
63+
static_assert(sizeof(arr) == sizeof(void*) + sizeof(nb::dlpack::dltensor));
64+
65+
const DLTensor& bitmask_dltensor =
66+
*reinterpret_cast<::DLTensor*>(reinterpret_cast<char*>(&arr) + sizeof(void*));
67+
68+
return matcher.FillNextTokenBitmask(bitmask_dltensor, index, debug_print);
69+
}
70+
4271
std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
4372
const auto& decoded_vocab = tokenizer.GetDecodedVocab();
4473
std::vector<nanobind::bytes> py_result;
@@ -49,6 +78,8 @@ std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo&
4978
return py_result;
5079
}
5180

81+
} // namespace
82+
5283
NB_MODULE(xgrammar_bindings, m) {
5384
auto pyTokenizerInfo = nb::class_<TokenizerInfo>(m, "TokenizerInfo");
5485
pyTokenizerInfo

cpp/nanobind/python_methods.cc

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,6 @@ int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) {
3939
return static_cast<int>(tokenizer.GetVocabType());
4040
}
4141

42-
bool GrammarMatcher_FillNextTokenBitmask(
43-
GrammarMatcher& matcher,
44-
intptr_t token_bitmask_ptr,
45-
std::vector<int64_t> shape,
46-
int32_t index,
47-
bool debug_print
48-
) {
49-
XGRAMMAR_CHECK(shape.size() == 1 || shape.size() == 2) << "token_bitmask tensor must be 1D or 2D";
50-
51-
DLTensor bitmask_dltensor{
52-
reinterpret_cast<void*>(token_bitmask_ptr),
53-
DLDevice{kDLCPU, 0},
54-
static_cast<int32_t>(shape.size()),
55-
GetBitmaskDLType(),
56-
shape.data(),
57-
nullptr,
58-
0
59-
};
60-
return matcher.FillNextTokenBitmask(&bitmask_dltensor, index, debug_print);
61-
}
62-
6342
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
6443
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
6544
) {

cpp/nanobind/python_methods.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,6 @@ TokenizerInfo TokenizerInfo_Init(
2727

2828
int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer);
2929

30-
bool GrammarMatcher_FillNextTokenBitmask(
31-
GrammarMatcher& matcher,
32-
intptr_t token_bitmask_ptr,
33-
std::vector<int64_t> shape,
34-
int32_t index,
35-
bool debug_print
36-
);
37-
3830
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
3931
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
4032
);

include/xgrammar/matcher.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ class GrammarMatcher {
107107
* and with shape (GetBitmaskSize(),) and dtype int32.
108108
* \return Whether the bitmask need to be applied (not all-true).
109109
*/
110-
bool FillNextTokenBitmask(DLTensor* next_token_bitmask, int index = 0, bool debug_print = false);
110+
bool FillNextTokenBitmask(
111+
const DLTensor& next_token_bitmask, int index = 0, bool debug_print = false
112+
);
111113

112114
/*!
113115
* \brief Find the jump-forward string for jump-forward decoding. This is the longest string that

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"triton; platform_system == 'Linux' and platform_machine == 'x86_64'",
2323
"mlx-lm; platform_system == 'Darwin' and platform_machine == 'arm64'",
2424
"ninja",
25+
"numpy",
2526
]
2627
dynamic = ["version"]
2728

python/xgrammar/matcher.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Optional, Tuple, Union
77

88
import torch
9+
from numpy.typing import ArrayLike
910

1011
from .base import XGRObject, _core
1112
from .compiler import CompiledGrammar
@@ -269,7 +270,7 @@ def accept_string(self, input_str: Union[str, bytes], *, debug_print: bool = Fal
269270
return self._handle.accept_string(input_str, debug_print)
270271

271272
def fill_next_token_bitmask(
272-
self, bitmask: torch.Tensor, index: int = 0, *, debug_print: bool = False
273+
self, bitmask: ArrayLike, index: int = 0, *, debug_print: bool = False
273274
) -> bool:
274275
"""Fill the bitmask for the next token prediction. The input bitmask can be generated
275276
by allocate_token_bitmask, and must be on CPU. bitmask[index] will be filled with the
@@ -299,13 +300,7 @@ def fill_next_token_bitmask(
299300
RuntimeError
300301
If the recursion depth is exceeded.
301302
"""
302-
if bitmask.device.type != "cpu":
303-
raise ValueError("bitmask should be on CPU.")
304-
if bitmask.dtype != bitmask_dtype:
305-
raise ValueError(f"bitmask should be of type {bitmask_dtype}.")
306-
return self._handle.fill_next_token_bitmask(
307-
bitmask.data_ptr(), list(bitmask.shape), index, debug_print
308-
)
303+
return self._handle.fill_next_token_bitmask(bitmask, index, debug_print)
309304

310305
def find_jump_forward_string(self) -> str:
311306
"""Find the jump-forward string for jump-forward decoding. This is the longest string that

web/src/xgrammar_binding.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ std::vector<int32_t> GrammarMatcher_GetNextTokenBitmask(GrammarMatcher& matcher,
8585
tensor.strides = &strides[0];
8686
tensor.byte_offset = 0;
8787
// 3. Populate tensor, hence result
88-
matcher.FillNextTokenBitmask(&tensor);
88+
matcher.FillNextTokenBitmask(tensor);
8989
return result;
9090
}
9191

0 commit comments

Comments
 (0)