4
4
*/
5
5
6
6
#include < nanobind/nanobind.h>
7
+ #include < nanobind/ndarray.h>
7
8
#include < nanobind/stl/optional.h>
8
9
#include < nanobind/stl/pair.h>
9
10
#include < nanobind/stl/string.h>
22
23
namespace nb = nanobind;
23
24
using namespace xgrammar ;
24
25
26
+ namespace {
27
+
25
28
std::vector<std::string> CommonEncodedVocabType (
26
29
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab
27
30
) {
@@ -39,6 +42,32 @@ std::vector<std::string> CommonEncodedVocabType(
39
42
return encoded_vocab_strs;
40
43
}
41
44
45
+ bool GrammarMatcher_FillNextTokenBitmask (
46
+ GrammarMatcher& matcher,
47
+ nb::ndarray<int32_t , nb::device::cpu> arr,
48
+ int32_t index,
49
+ bool debug_print
50
+ ) {
51
+ if (arr.ndim () != 1 && arr.ndim () != 2 ) {
52
+ throw nb::type_error (" token_bitmask tensor must be 1D or 2D" );
53
+ }
54
+
55
+ // Under the hood these are stored with the same standard (DLPack), but nanobind
56
+ // defines its own types, and doesn't expose a way to just get the object directly.
57
+ // We'll just do some pointer hackery to get there, rather than build the type back up manually:
58
+
59
+ // The data in an ndarray is defined as:
60
+ // detail::ndarray_handle* m_handle = nullptr;
61
+ // dlpack::dltensor m_dltensor;
62
+ // Assert this, then skip over m_handle and reinterpret m_dltensor.
63
+ static_assert (sizeof (arr) == sizeof (void *) + sizeof (nb::dlpack::dltensor));
64
+
65
+ const DLTensor& bitmask_dltensor =
66
+ *reinterpret_cast <::DLTensor*>(reinterpret_cast <char *>(&arr) + sizeof (void *));
67
+
68
+ return matcher.FillNextTokenBitmask (bitmask_dltensor, index, debug_print);
69
+ }
70
+
42
71
std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab (const TokenizerInfo& tokenizer) {
43
72
const auto & decoded_vocab = tokenizer.GetDecodedVocab ();
44
73
std::vector<nanobind::bytes> py_result;
@@ -49,6 +78,8 @@ std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo&
49
78
return py_result;
50
79
}
51
80
81
+ } // namespace
82
+
52
83
NB_MODULE (xgrammar_bindings, m) {
53
84
auto pyTokenizerInfo = nb::class_<TokenizerInfo>(m, " TokenizerInfo" );
54
85
pyTokenizerInfo
0 commit comments