|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | + |
| 19 | +import random |
| 20 | +import numpy as np |
| 21 | +import tvm |
| 22 | +import tvm.testing |
| 23 | +import pytest |
| 24 | +from tvm import relax |
| 25 | +from tvm.contrib import utils |
| 26 | +from typing import List |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +@pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup") |
| 32 | +def test_sampling(): |
| 33 | + |
| 34 | + def load_module(name: str, static_modules: List[tvm.runtime.Module]): |
| 35 | + assert len(static_modules) > 0 |
| 36 | + if len(static_modules) == 1: |
| 37 | + return static_modules[0] |
| 38 | + static_mod = static_modules[0] |
| 39 | + for mod in static_modules[1:]: |
| 40 | + static_mod.import_module(mod) |
| 41 | + temp = utils.tempdir() |
| 42 | + mod_path = temp.relpath(f"{name}.so") |
| 43 | + static_mod.export_library(mod_path) |
| 44 | + return tvm.runtime.load_module(mod_path) |
| 45 | + |
| 46 | + # Test configuration |
| 47 | + batch_size = 10 |
| 48 | + vocab_size = 5 |
| 49 | + num_iterations = 1000 |
| 50 | + tol_atol = 0.02 |
| 51 | + tol_rtol = 0.05 # relative tolerance |
| 52 | + |
| 53 | + # Probability tensor (each row sums to 1) |
| 54 | + probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32") |
| 55 | + |
| 56 | + dev = tvm.cuda(0) |
| 57 | + probs_tvm = tvm.nd.array(probs_np, device=dev) |
| 58 | + output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev) |
| 59 | + |
| 60 | + device = tvm.cuda() |
| 61 | + target = tvm.target.Target.from_device(device) |
| 62 | + sampling_mod = load_module( |
| 63 | + "flashinfer_sampling", |
| 64 | + relax.backend.cuda.flashinfer.gen_sampling_module( |
| 65 | + target=target, |
| 66 | + ), |
| 67 | + ) |
| 68 | + sampling_func = sampling_mod["sampling_from_probs"] |
| 69 | + |
| 70 | + counts = np.zeros((batch_size, vocab_size), dtype="int32") |
| 71 | + |
| 72 | + for _ in range(num_iterations): |
| 73 | + deterministic = False |
| 74 | + # Generate seed and a random offset. |
| 75 | + philox_seed = np.uint64(random.getrandbits(63)) |
| 76 | + philox_offset = np.uint64(random.getrandbits(63) % 1000) |
| 77 | + |
| 78 | + # the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream) |
| 79 | + sampling_func(probs_tvm, output_tvm, None, deterministic, |
| 80 | + philox_seed, philox_offset, 0) |
| 81 | + |
| 82 | + out = output_tvm.asnumpy() |
| 83 | + for i in range(batch_size): |
| 84 | + sampled_token = out[i] |
| 85 | + counts[i, sampled_token] += 1 |
| 86 | + |
| 87 | + # Convert counts to frequencies. |
| 88 | + frequencies = counts / float(num_iterations) |
| 89 | + |
| 90 | + # For each row, check that the empirical frequency is close to the input probability. |
| 91 | + for row in range(batch_size): |
| 92 | + tvm.testing.assert_allclose( |
| 93 | + frequencies[row], |
| 94 | + probs_np[row], |
| 95 | + rtol=tol_rtol, |
| 96 | + atol=tol_atol |
| 97 | + ) |
| 98 | + |
| 99 | +if __name__ == "__main__": |
| 100 | + # Run the test standalone (if not using pytest) |
| 101 | + test_sampling() |
| 102 | + |
| 103 | + |
0 commit comments