Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ void ApplyTokenBitmaskInplaceCPU(
logits->ndim == 2
? std::make_pair(static_cast<int>(logits->shape[0]), static_cast<int>(logits->shape[1]))
: std::make_pair(1, static_cast<int>(logits->shape[0]));
int logits_stride0 = logits->strides[0];
std::pair<int, int> bitmask_shape =
bitmask.ndim == 2
? std::make_pair(static_cast<int>(bitmask.shape[0]), static_cast<int>(bitmask.shape[1]))
: std::make_pair(1, static_cast<int>(bitmask.shape[0]));
int bitmask_stride0 = bitmask.strides[0];

XGRAMMAR_CHECK(
vocab_size <= bitmask_shape.second * DynamicBitset::BITS_PER_BLOCK &&
Expand All @@ -133,18 +135,18 @@ void ApplyTokenBitmaskInplaceCPU(
// Apply mask
if (indices.has_value()) {
for (auto idx : indices.value()) {
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_shape.second;
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_stride0;
DynamicBitset bitset(vocab_size, data_ptr);
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * logits_shape.second;
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * logits_stride0;
for (int i = bitset.FindFirstZero(); i != -1; i = bitset.FindNextZero(i)) {
logits_ptr[i] = -std::numeric_limits<float>::infinity();
}
}
} else {
for (int idx = 0; idx < logits_shape.first; ++idx) {
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_shape.second;
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_stride0;
DynamicBitset bitset(vocab_size, data_ptr);
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * logits_shape.second;
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * logits_stride0;
for (int i = bitset.FindFirstZero(); i != -1; i = bitset.FindNextZero(i)) {
logits_ptr[i] = -std::numeric_limits<float>::infinity();
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/nanobind/nanobind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,10 @@ NB_MODULE(xgrammar_bindings, m) {
&Kernels_ApplyTokenBitmaskInplaceCPU,
nb::arg("logits_ptr"),
nb::arg("logits_shape"),
nb::arg("logits_strides"),
nb::arg("bitmask_ptr"),
nb::arg("bitmask_shape"),
nb::arg("bitmask_strides"),
nb::arg("vocab_size"),
nb::arg("indices").none(),
nb::call_guard<nb::gil_scoped_release>()
Expand Down
8 changes: 6 additions & 2 deletions cpp/nanobind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,25 @@ std::pair<bool, int> Testing_IsSingleTokenBitmask(
void Kernels_ApplyTokenBitmaskInplaceCPU(
intptr_t logits_ptr,
std::pair<int64_t, int64_t> logits_shape,
std::pair<int64_t, int64_t> logits_strides,
intptr_t bitmask_ptr,
std::pair<int64_t, int64_t> bitmask_shape,
std::pair<int64_t, int64_t> bitmask_strides,
int vocab_size,
std::optional<std::vector<int>> indices
) {
std::array<int64_t, 2> logits_shape_arr = {logits_shape.first, logits_shape.second};
std::array<int64_t, 2> logits_strides_arr = {logits_strides.first, logits_strides.second};
std::array<int64_t, 2> bitmask_shape_arr = {bitmask_shape.first, bitmask_shape.second};
std::array<int64_t, 2> bitmask_strides_arr = {bitmask_strides.first, bitmask_strides.second};

DLTensor logits_dltensor{
reinterpret_cast<void*>(logits_ptr),
DLDevice{kDLCPU, 0},
2,
DLDataType{kDLFloat, 32, 1},
logits_shape_arr.data(),
nullptr,
logits_strides_arr.data(),
0
};

Expand All @@ -106,7 +110,7 @@ void Kernels_ApplyTokenBitmaskInplaceCPU(
2,
GetBitmaskDLType(),
bitmask_shape_arr.data(),
nullptr,
bitmask_strides_arr.data(),
0
};

Expand Down
2 changes: 2 additions & 0 deletions cpp/nanobind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ std::pair<bool, int> Testing_IsSingleTokenBitmask(
void Kernels_ApplyTokenBitmaskInplaceCPU(
intptr_t logits_ptr,
std::pair<int64_t, int64_t> logits_shape,
std::pair<int64_t, int64_t> logits_strides,
intptr_t bitmask_ptr,
std::pair<int64_t, int64_t> bitmask_shape,
std::pair<int64_t, int64_t> bitmask_strides,
int vocab_size,
std::optional<std::vector<int>> indices
);
Expand Down
19 changes: 18 additions & 1 deletion python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,28 @@ def apply_token_bitmask_inplace_cpu(
raise ValueError("bitmask should be 1D or 2D, but got {}D".format(bitmask.dim()))

logits_shape = (1, logits.shape[0]) if logits.dim() == 1 else (logits.shape[0], logits.shape[1])
logits_stride = logits.stride()
logits_stride = (
(logits_stride[0], 1) if logits.dim() == 1 else (logits_stride[0], logits_stride[1])
)

bitmask_shape = (
(1, bitmask.shape[0]) if bitmask.dim() == 1 else (bitmask.shape[0], bitmask.shape[1])
)
bitmask_stride = bitmask.stride()
bitmask_stride = (
(bitmask_stride[0], 1) if bitmask.dim() == 1 else (bitmask_stride[0], bitmask_stride[1])
)

vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32) if vocab_size is None else vocab_size

_core.kernels.apply_token_bitmask_inplace_cpu(
logits.data_ptr(), logits_shape, bitmask.data_ptr(), bitmask_shape, vocab_size, indices
logits.data_ptr(),
logits_shape,
logits_stride,
bitmask.data_ptr(),
bitmask_shape,
bitmask_stride,
vocab_size,
indices,
)
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def apply_token_bitmask_inplace_triton(
indices,
num_rows,
vocab_size,
logits.shape[-1],
bitmask.shape[-1],
logits.stride()[0],
bitmask.stride()[0],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
Expand Down
40 changes: 40 additions & 0 deletions tests/python/test_token_bitmask_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,46 @@ def test_apply_token_bitmask_inplace(device: str):
torch.testing.assert_close(logits, expected)


@pytest.mark.parametrize("device", ("cpu", "cuda"))
def test_apply_token_bitmask_inplace_shape_stride_mismatch(device: str):
if device == "cuda" and not _is_cuda_available:
pytest.skip(reason="CUDA is not installed")

col = 100
compacted_col = (col + 31) // 32
neginf = float("-inf")
# Mask even positions (0-indexed) in the first row, and
# mask odd positions in the second row.
bool_mask = torch.tensor(
[[i % 2 == 0 for i in range(col)], [i % 2 == 1 for i in range(col)]],
dtype=torch.bool,
device=device,
)
# In int32 binary representation,
# 0x55555555 = 1431655765
# 0xAAAAAAAA = -1431655766
bitmask = torch.tensor(
[[1431655765] * compacted_col, [-1431655766] * compacted_col],
dtype=torch.int32,
device=device,
)
master_logits = torch.tensor(
[[i + 0.1 for i in range(col + 1)], [i + 0.2 for i in range(col + 1)]],
dtype=torch.float32,
device=device,
)
logits = master_logits[:, :col]

# Ensure the test environment setup is accurate (i.e. shape[-1] != stride[0])
assert logits.size() == (2, col)
assert logits.stride() == (col + 1, 1)

expected = torch.where(bool_mask, logits, neginf)

xgr.apply_token_bitmask_inplace(logits, bitmask)
torch.testing.assert_close(logits, expected)


def get_apply_token_bitmask_kernel(impl: str) -> Callable:
if impl == "cpu":
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import apply_token_bitmask_inplace_cpu
Expand Down
Loading