Skip to content

Commit ac858a8

Browse files
committed
Added test
1 parent e2862b7 commit ac858a8

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
import random
20+
import numpy as np
21+
import tvm
22+
import tvm.testing
23+
import pytest
24+
from tvm import relax
25+
from tvm.contrib import utils
26+
from typing import List
27+
28+
29+
30+
31+
@pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup")
32+
def test_sampling():
33+
34+
def load_module(name: str, static_modules: List[tvm.runtime.Module]):
35+
assert len(static_modules) > 0
36+
if len(static_modules) == 1:
37+
return static_modules[0]
38+
static_mod = static_modules[0]
39+
for mod in static_modules[1:]:
40+
static_mod.import_module(mod)
41+
temp = utils.tempdir()
42+
mod_path = temp.relpath(f"{name}.so")
43+
static_mod.export_library(mod_path)
44+
return tvm.runtime.load_module(mod_path)
45+
46+
# Test configuration
47+
batch_size = 10
48+
vocab_size = 5
49+
num_iterations = 1000
50+
tol_atol = 0.02
51+
tol_rtol = 0.05 # relative tolerance
52+
53+
# Probability tensor (each row sums to 1)
54+
probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32")
55+
56+
dev = tvm.cuda(0)
57+
probs_tvm = tvm.nd.array(probs_np, device=dev)
58+
output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev)
59+
60+
device = tvm.cuda()
61+
target = tvm.target.Target.from_device(device)
62+
sampling_mod = load_module(
63+
"flashinfer_sampling",
64+
relax.backend.cuda.flashinfer.gen_sampling_module(
65+
target=target,
66+
),
67+
)
68+
sampling_func = sampling_mod["sampling_from_probs"]
69+
70+
counts = np.zeros((batch_size, vocab_size), dtype="int32")
71+
72+
for _ in range(num_iterations):
73+
deterministic = False
74+
# Generate seed and a random offset.
75+
philox_seed = np.uint64(random.getrandbits(63))
76+
philox_offset = np.uint64(random.getrandbits(63) % 1000)
77+
78+
# the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream)
79+
sampling_func(probs_tvm, output_tvm, None, deterministic,
80+
philox_seed, philox_offset, 0)
81+
82+
out = output_tvm.asnumpy()
83+
for i in range(batch_size):
84+
sampled_token = out[i]
85+
counts[i, sampled_token] += 1
86+
87+
# Convert counts to frequencies.
88+
frequencies = counts / float(num_iterations)
89+
90+
# For each row, check that the empirical frequency is close to the input probability.
91+
for row in range(batch_size):
92+
tvm.testing.assert_allclose(
93+
frequencies[row],
94+
probs_np[row],
95+
rtol=tol_rtol,
96+
atol=tol_atol
97+
)
98+
99+
if __name__ == "__main__":
100+
# Run the test standalone (if not using pytest)
101+
test_sampling()
102+
103+

0 commit comments

Comments
 (0)