Skip to content

Commit 7ab6e93

Browse files
IrfnfnkemedSeven-Streams
authored andcommitted
[Fix] Add support for kDLCUDAHost in token bitmask device type check (mlc-ai#242)
- Modify the device type check to allow kDLCUDAHost in addition to kDLCPU. - This change ensures that token bitmasks can be stored in pinned memory (kDLCUDAHost).
1 parent e899521 commit 7ab6e93

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

cpp/grammar_matcher.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ int32_t* CheckAndGetBitmaskPtr(const DLTensor& token_bitmask, int vocab_size, in
4646
<< "The provided index is out of bounds";
4747
}
4848

49-
XGRAMMAR_CHECK(token_bitmask.device.device_type == kDLCPU)
50-
<< "The provided bitmask's device is not valid: should be CPU";
49+
XGRAMMAR_CHECK(
50+
token_bitmask.device.device_type == kDLCPU ||
51+
token_bitmask.device.device_type == kDLCUDAHost ||
52+
token_bitmask.device.device_type == kDLROCMHost
53+
) << "The provided bitmask's device is not valid: should be CPU";
5154

5255
return reinterpret_cast<int32_t*>(token_bitmask.data) + index * buffer_size;
5356
}
@@ -67,10 +70,14 @@ void ApplyTokenBitmaskInplaceCPU(
6770
DLTensor* logits, const DLTensor& bitmask, std::optional<std::vector<int>> indices
6871
) {
6972
// Check device and dim
70-
XGRAMMAR_CHECK(logits->device.device_type == kDLCPU)
71-
<< "The provided logits's device is not valid: should be CPU";
72-
XGRAMMAR_CHECK(bitmask.device.device_type == kDLCPU)
73-
<< "The provided bitmask's device is not valid: should be CPU";
73+
XGRAMMAR_CHECK(
74+
logits->device.device_type == kDLCPU || logits->device.device_type == kDLCUDAHost ||
75+
logits->device.device_type == kDLROCMHost
76+
) << "The provided logits's device is not valid: should be CPU";
77+
XGRAMMAR_CHECK(
78+
bitmask.device.device_type == kDLCPU || bitmask.device.device_type == kDLCUDAHost ||
79+
bitmask.device.device_type == kDLROCMHost
80+
) << "The provided bitmask's device is not valid: should be CPU";
7481
XGRAMMAR_CHECK(logits->ndim == 2 || logits->ndim == 1)
7582
<< "The provided logits's shape is not valid: should be 2D or 1D";
7683
XGRAMMAR_CHECK(bitmask.ndim == 2 || bitmask.ndim == 1)

0 commit comments

Comments
 (0)