diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h new file mode 100644 index 000000000000..a878dd9766d7 --- /dev/null +++ b/include/tvm/relax/attrs/sampling.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/sampling.h + * \brief Attributes for sampling operators. + */ +#ifndef TVM_RELAX_ATTRS_SAMPLING_H_ +#define TVM_RELAX_ATTRS_SAMPLING_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in multinomial_from_uniform operator */ +struct MultinomialFromUniformAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(MultinomialFromUniformAttrs, "relax.attrs.MultinomialFromUniformAttrs") { + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(64)) + .describe("Data type of the output indices."); + } +}; // struct MultinomialFromUniformAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_SAMPLING_H_ diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py index e4a89bdb95ad..6d0ca302018c 100644 --- a/python/tvm/relax/backend/__init__.py +++ b/python/tvm/relax/backend/__init__.py @@ -17,5 +17,6 @@ """Relax backends""" from . import contrib -from .pattern_registry import get_pattern, get_patterns_with_prefix +from .dispatch_sampling import DispatchSampling from .dispatch_sort_scan import DispatchSortScan +from .pattern_registry import get_pattern, get_patterns_with_prefix diff --git a/python/tvm/relax/backend/dispatch_sampling.py b/python/tvm/relax/backend/dispatch_sampling.py new file mode 100644 index 000000000000..68d162fdf19b --- /dev/null +++ b/python/tvm/relax/backend/dispatch_sampling.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local +"""Dispatch sampling operators to platform dependent implementation.""" + + +from tvm import relax +from tvm.ir import Op +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm.relax import expr_functor + +from .utils import BackendDispatcher + + +@expr_functor.mutator +class SamplingDispatcher(BackendDispatcher): + """Dispatcher to dispatch sampling op.""" + + def visit_call_(self, call: relax.Call) -> relax.Expr: + if not isinstance(call.op, Op): + return super().visit_call_(call) + + if call.op.name == "relax.multinomial_from_uniform": + from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel + generic_get_sample_index, + gpu_multinomial_from_uniform, + ) + + prob, uniform_sample, sample_indices = call.args + tgt = self._get_target(call.struct_info) + dtype = call.attrs.dtype + _, prob_dtype = self.get_shape_dtype(prob) + sample_shape, sample_dtype = self.get_shape_dtype(uniform_sample) + sample_indices_shape, sample_indices_dtype = self.get_shape_dtype(sample_indices) + + if len(sample_shape) != 2 or sample_shape[1] != 1: + raise ValueError("uniform_sample should be a 2D tensor with shape (N, 1)") + + if len(sample_indices_shape) != 2 or sample_indices_shape[1] != 1: + raise ValueError("sample_indices should be a 2D tensor with shape (N, 1)") + + if self.is_gpu_target(tgt): + gv = self.builder_.add_func( + gpu_multinomial_from_uniform( + prob_dtype, sample_dtype, sample_indices_dtype, dtype + ), + "gpu_multinomial_from_uniform", + ) + return relax.call_tir( + gv, + [prob, uniform_sample, sample_indices], + out_sinfo=call.struct_info, + ) + else: + cumsum_prob = relax.op.cumsum(prob, axis=1, dtype=prob_dtype, exclusive=False) + gv = self.builder_.add_func( + generic_get_sample_index(prob_dtype, sample_dtype, sample_indices_dtype, dtype), + "get_sample_index", + ) + return relax.call_tir( + gv, + [cumsum_prob, uniform_sample, sample_indices], + out_sinfo=call.struct_info, + ) + + return super().visit_call_(call) + + +@module_pass(opt_level=0, name="DispatchSampling") +class DispatchSampling: + """Pass to dispatch scan and sort operators to platform dependent implementation.""" + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + sampling_dispatcher = SamplingDispatcher(mod) + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + func = sampling_dispatcher.visit_expr(func) + sampling_dispatcher.builder_.update_func(gv, func) + return sampling_dispatcher.builder_.finalize() diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index 53948b8449b0..e37869c40c46 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -26,21 +26,15 @@ from tvm.ir import GlobalVar, Op from tvm.ir.module import IRModule from tvm.ir.transform import PassContext, module_pass -from tvm.relax import PyExprMutator, expr_functor +from tvm.relax import expr_functor from tvm.target import Target - -def is_gpu_target(target: Target) -> bool: - """Check if the target is a GPU target.""" - return "gpu" in target.keys +from .utils import BackendDispatcher @expr_functor.mutator -class SortScanDispatcher(PyExprMutator): - """ - Dispatcher to dispatch sort and scan. - - """ +class SortScanDispatcher(BackendDispatcher): + """Dispatcher to dispatch sort and scan.""" calls_to_update: Dict[GlobalVar, Target] @@ -48,26 +42,6 @@ def __init__(self, mod): super().__init__(mod) self.calls_to_update = {} - def _get_target(self, sinfo: relax.StructInfo) -> Target: - # Get target information from TensorStructInfo - if isinstance(sinfo, relax.TensorStructInfo): - vdevice = sinfo.vdevice - if vdevice is not None: - return vdevice.target - elif isinstance(sinfo, relax.TupleStructInfo): - for f in sinfo.fields: - tgt = self._get_target(f) - if tgt != Target.current(): - return tgt - # Return the target in current context - target = Target.current() - if target is None: - raise ValueError( - "Target not found. Please ensure that the target is annotated within the module, " - "or alternatively, execute this within a specified target context." - ) - return target - def apply_dlight_gpu_fallback( self, ) -> None: @@ -107,7 +81,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.sort_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif is_gpu_target(tgt): + elif self.is_gpu_target(tgt): te_func = topi.cuda.sort return self.builder_.call_te( te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs @@ -120,7 +94,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.argsort_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif is_gpu_target(tgt): + elif self.is_gpu_target(tgt): te_func = topi.cuda.argsort return self.builder_.call_te( te_func, @@ -137,7 +111,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if can_use_thrust(tgt, "tvm.contrib.thrust.sort"): te_func = topi.cuda.topk_thrust kwargs["workspace"] = self.allocate_workspace(call) - elif is_gpu_target(tgt): + elif self.is_gpu_target(tgt): te_func = topi.cuda.topk tir_call = self.builder_.call_te( te_func, @@ -162,7 +136,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if ( shape is not None and (axis == -1 or axis == len(shape) - 1) - and is_gpu_target(tgt) + and self.is_gpu_target(tgt) and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan") and call.op.name == "relax.cumsum" and call.attrs.exclusive == 0 @@ -202,11 +176,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: with tgt: if call.op.name == "relax.cumsum": - te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum + te_func = topi.cuda.cumsum if self.is_gpu_target(tgt) else topi.cumsum if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"): kwargs["workspace"] = self.allocate_workspace(call) elif call.op.name == "relax.cumprod": - te_func = topi.cuda.cumprod if is_gpu_target(tgt) else topi.cumprod + te_func = topi.cuda.cumprod if self.is_gpu_target(tgt) else topi.cumprod else: raise ValueError(f"Unsupported op: {call.op.name}") tir_call = self.builder_.call_te( diff --git a/python/tvm/relax/backend/utils.py b/python/tvm/relax/backend/utils.py index e5ecb7c5f4f1..fdc0e99756de 100644 --- a/python/tvm/relax/backend/utils.py +++ b/python/tvm/relax/backend/utils.py @@ -17,8 +17,61 @@ # pylint: disable=invalid-name """Utils for BYOC pattern matching""" -from tvm.relax import DataflowVar +from typing import Tuple +from tvm import relax +from tvm.relax import DataflowVar, PyExprMutator from tvm.relax.transform import PatternCheckContext +from tvm.target import Target + + +class BackendDispatcher(PyExprMutator): + """Base class for backend dispatcher""" + + def __init__(self, mod): + super().__init__(mod) + + @staticmethod + def is_gpu_target(target: Target) -> bool: + """Check if the target is a GPU target.""" + return "gpu" in target.keys + + @staticmethod + def get_shape_dtype(expr: relax.Expr) -> Tuple[relax.ShapeExpr, str]: + """Get shape and dtype from an expression. + If the shape and dtype is unknown, raise an error.""" + sinfo = expr.struct_info + if not isinstance(expr.struct_info, relax.TensorStructInfo): + raise ValueError( + f"Expecting a expr with TensorStructInfo, but got {expr} with {expr.struct_info}" + ) + + shape, dtype = sinfo.shape, sinfo.dtype + if shape is None: + raise ValueError( + f"Expecting a expr with known shape, but got {expr} with unknown shape" + ) + + return shape, dtype + + def _get_target(self, sinfo: relax.StructInfo) -> Target: + # Get target information from TensorStructInfo + if isinstance(sinfo, relax.TensorStructInfo): + vdevice = sinfo.vdevice + if vdevice is not None: + return vdevice.target + elif isinstance(sinfo, relax.TupleStructInfo): + for f in sinfo.fields: + tgt = self._get_target(f) + if tgt != Target.current(): + return tgt + # Return the target in current context + target = Target.current() + if target is None: + raise ValueError( + "Target not found. Please ensure that the target is annotated within the module, " + "or alternatively, execute this within a specified target context." + ) + return target def has_leaking_intermediate_variables(context: PatternCheckContext) -> bool: diff --git a/python/tvm/relax/backend_tir/__init__.py b/python/tvm/relax/backend_tir/__init__.py index 10def47b8d5f..b64bdcda6bb6 100644 --- a/python/tvm/relax/backend_tir/__init__.py +++ b/python/tvm/relax/backend_tir/__init__.py @@ -17,5 +17,6 @@ """Relax backends, tir based""" from . import contrib -from .pattern import get_tir_pattern from .cumsum import gpu_2d_continuous_cumsum +from .pattern import get_tir_pattern +from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index diff --git a/python/tvm/relax/backend_tir/cumsum.py b/python/tvm/relax/backend_tir/cumsum.py index ade961ecf17d..1bb7c6b2c118 100644 --- a/python/tvm/relax/backend_tir/cumsum.py +++ b/python/tvm/relax/backend_tir/cumsum.py @@ -41,10 +41,10 @@ def gpu_2d_continuous_cumsum( Parameters ---------- ty_len : int - The length of thread.y + The length of `threadIdx.y` tx_len : int - The length of thread.x + The length of `threadIdx.x` thread_elem : int The number of elements processed by single thread @@ -64,8 +64,8 @@ def gpu_2d_continuous_cumsum( out_dtype = out_dtype or in_dtype # Configuration for GPU kernel - TX = T.int64(tx_len) # thread.x - TY = T.int64(ty_len) # thread.y + TX = T.int64(tx_len) # threadIdx.x + TY = T.int64(ty_len) # threadIdx.y N = T.int64(thread_elem) # number of elements in single thread if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N): diff --git a/python/tvm/relax/backend_tir/sampling.py b/python/tvm/relax/backend_tir/sampling.py new file mode 100644 index 000000000000..a0a5c29ddf7e --- /dev/null +++ b/python/tvm/relax/backend_tir/sampling.py @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-nested-blocks +"""Backend kernels for sampling operator.""" + +import math +from typing import Callable, Optional +from tvm.script import tir as T +from tvm.tir import PrimFunc + + +def _is_power_of_two(n: int): + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def gpu_multinomial_from_uniform( + prob_dtype: str = "float32", + sample_dtype: str = "float32", + sample_indices_dtype: str = "int64", + dtype: str = "int64", + ty_len: int = 4, + tx_len: int = 32, + thread_elem: int = 4, + eps: float = 1e-6, +) -> PrimFunc: + """Generate GPU kernel for multinomial_from_uniform operator. + + Parameters + ---------- + ty_len : int + The length of `threadIdx.y` + + tx_len : int + The length of `threadIdx.x` + + thread_elem : int + The number of elements processed by single thread + + prob_dtype : str + The probability data type + + sample_dtype : str + The sample data type + + sample_indices_dtype : str + The sample indices data type + + dtype : str + The output data type + + Returns + ------- + func : PrimFunc + The generated function + """ + + TX = T.int64(tx_len) # threadIdx.x + TY = T.int64(ty_len) # threadIdx.y + + # number of elements to be processed by single thread + thread_elem = T.int64(thread_elem) + # number of elements to be processed by single warp + warp_elem = T.int64(tx_len * thread_elem) + # number of elements to be processed by single block(SM) + block_elem = T.int64(tx_len * ty_len * thread_elem) + + LOG_TX = T.int64(int(math.log2(tx_len))) + LOG_TY = T.int64(int(math.log2(ty_len))) + + if ( + not _is_power_of_two(tx_len) + or not _is_power_of_two(ty_len) + or not _is_power_of_two(thread_elem) + ): + raise ValueError( + "Configuration of tx_len, ty_len, thread_elem must be power of 2," + f"but got {tx_len}, {ty_len}, {thread_elem}" + ) + + @T.macro + def block_cumsum( + ty: T.int64, + tx: T.int64, + source_local: T.Buffer, + output_shared: T.Buffer, + ): + """cumsum inside block (SM)""" + # Inclusive scan inside thread + for i in T.unroll(1, thread_elem): + source_local[i] += source_local[i - 1] + # Store data to shared memory + for i in T.vectorized(thread_elem): + output_shared[ty * warp_elem + tx * thread_elem + i] = source_local[i] + # Inclusive scan inside warp + for i in T.unroll(LOG_TX): + for j in T.vectorized(thread_elem): + idx: T.int64 = ty * warp_elem + tx * thread_elem + if tx >= (1 << i): + output_shared[idx + j] += output_shared[ + idx - (1 << i) * thread_elem + thread_elem - 1 + ] + # Inclusive scan inside block + for i in T.unroll(1, TY): + for j in T.vectorized(thread_elem): + if ty == 0: + idx: T.int64 = i * warp_elem + tx * thread_elem + output_shared[idx + j] += output_shared[i * warp_elem - 1] + + def compare_bool_not_equal(a: T.bool, b: T.bool) -> T.bool: + # Vulkan does not support compare two bool value direct + # return a != b + return T.Cast("int8", a) != T.Cast("int8", b) + + @T.macro + def block_adjacent_difference_left( + ty: T.int64, + tx: T.int64, + source_local: T.Buffer, + output_local: T.Buffer, + ): + with T.block(): + shared_buf = T.alloc_buffer((TX * TY,), "bool", scope="shared") + tx_idx = ty * TX + tx + shared_buf[tx_idx] = source_local[thread_elem - 1] + output_local[0] = T.if_then_else( + tx_idx != 0, + compare_bool_not_equal(source_local[0], shared_buf[tx_idx - 1]), + source_local[0], + ) + for i in T.unroll(1, thread_elem): + output_local[i] = compare_bool_not_equal(source_local[i], source_local[i - 1]) + + def op_reduce_min(a, b): + return T.min(a, b) + + def op_reduce_sum(a, b): + return a + b + + @T.macro + def block_reduce_with_mask( + ty: T.int64, + tx: T.int64, + init_value, + data_local: T.Buffer, + output_local: T.Buffer, + dtype: str, + reduce_op: Callable, # T.macro + mask_local: Optional[T.Buffer] = None, + ): + with T.block(): + local_sum = T.alloc_buffer((), dtype, scope="local") + shared_buf = T.alloc_buffer((TX * TY,), dtype, scope="shared") + idx = ty * TX + tx + + local_sum[()] = T.Cast(dtype, init_value) + for i in T.unroll(thread_elem): + if mask_local is not None: + if mask_local[i]: + local_sum[()] = reduce_op(local_sum[()], data_local[i]) + else: + local_sum[()] = reduce_op(local_sum[()], data_local[i]) + shared_buf[idx] = local_sum[()] + + for i in T.unroll(LOG_TX + LOG_TY): + if idx % (1 << (i + 1)) == 0: + shared_buf[idx] = reduce_op(shared_buf[idx], shared_buf[idx + (1 << i)]) + output_local[()] = shared_buf[0] + + @T.macro + def single_batch_sampling( + prob, + row_idx, + vocab_size, + ty, + tx, + step_iter, + threshold, + aggregate, + uniform_sample, + sample_id_local, + ): + with T.block(): + prob_gt_threshold = T.alloc_buffer((thread_elem,), prob_dtype, scope="local") + cumsum = T.alloc_buffer((block_elem,), prob_dtype, scope="shared") + greater_than_u = T.alloc_buffer((thread_elem,), "bool", scope="local") + mask = T.alloc_buffer((thread_elem,), "bool", scope="local") + valid = T.alloc_buffer((thread_elem,), "bool", scope="local") + indices = T.alloc_buffer((thread_elem), dtype, scope="local") + step_aggregate = T.alloc_buffer((), prob_dtype, scope="local") + # Load prob data from global memory to local memory + for v in T.unroll(thread_elem): + idx = step_iter * block_elem + ty * warp_elem + tx * thread_elem + v + prob_local = T.if_then_else( + idx < vocab_size, + prob[row_idx, idx], + T.Cast(prob_dtype, 0), + ) + prob_gt_threshold[v] = T.if_then_else( + prob_local > threshold, prob_local, T.Cast(prob_dtype, 0) + ) + valid[v] = prob_local > threshold and idx < vocab_size + + block_reduce_with_mask( + ty, + tx, + init_value=0, + data_local=prob_gt_threshold, + output_local=step_aggregate, + dtype=prob_dtype, + reduce_op=op_reduce_sum, + mask_local=None, + ) + if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= uniform_sample - eps): + block_cumsum(ty, tx, prob_gt_threshold, cumsum) + # Note: it should be `T.vectorized` instead of `T.unroll` + # However, it will cause vulkan codegen error + for v in T.unroll(thread_elem): + greater_than_u[v] = ( + cumsum[ty * warp_elem + tx * thread_elem + v] + aggregate[()] + >= uniform_sample - eps + ) + + block_adjacent_difference_left(ty, tx, greater_than_u, mask) + # Same as above, it should be `T.vectorized` + for v in T.unroll(thread_elem): + mask[v] = mask[v] and valid[v] + indices[v] = step_iter * block_elem + ty * warp_elem + tx * thread_elem + v + block_reduce_with_mask( + ty, + tx, + init_value=vocab_size - 1, + data_local=indices, + output_local=sample_id_local, + dtype=dtype, + reduce_op=op_reduce_min, + mask_local=mask, + ) + + aggregate[()] += step_aggregate[()] + + @T.prim_func + def parallel_sampling_from_prob( + var_prob: T.handle, + var_uniform_samples: T.handle, + var_row_indices: T.handle, + var_sampled_token_ids: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + n, vocab_size, batch_size = T.int64(), T.int64(), T.int64() + # match buffers + prob = T.match_buffer(var_prob, (n, vocab_size), prob_dtype) + uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1), sample_dtype) + row_indices = T.match_buffer(var_row_indices, (batch_size, 1), sample_indices_dtype) + token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), dtype) + # local buffers + aggregate = T.alloc_buffer((), prob_dtype, scope="local") + sample_id_local = T.alloc_buffer((), dtype, scope="local") + step_iter = T.alloc_buffer((), "int32", scope="local") + + for bx in T.thread_binding(batch_size, thread="blockIdx.x"): + row_idx = row_indices[bx, 0] + for ty in T.thread_binding(TY, thread="threadIdx.y"): + for tx in T.thread_binding(TX, thread="threadIdx.x"): + u = uniform_samples[bx, 0] + aggregate[()] = T.Cast(prob_dtype, 0) + step_iter[()] = T.int32(0) + # at least one iteration + while T.tvm_thread_invariant( + (step_iter[()] == 0 or aggregate[()] < u - eps) + and T.Cast("int64", step_iter[()]) < T.ceildiv(vocab_size, block_elem) + ): + single_batch_sampling( + prob, + row_idx, + vocab_size, + ty, + tx, + T.Cast("int64", step_iter[()]), + 0.0, + aggregate, + u, + sample_id_local, + ) + step_iter[()] += 1 + if tx == 0 and ty == 0: + token_ids[bx, 0] = sample_id_local[()] + + return parallel_sampling_from_prob + + +def generic_get_sample_index( + prob_dtype: str = "float32", + sample_dtype: str = "float32", + sample_indices_dtype: str = "int64", + dtype: str = "int64", +): + """Generate a generic get_sample_index kernel.""" + + @T.prim_func(private=True) + def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): + batch, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(A, (batch, vocab_size), prob_dtype) + out_batch = T.int64() + usample = T.match_buffer(B, (out_batch, 1), sample_dtype) + sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype) + output_index = T.match_buffer(D, (out_batch, 1), dtype) + + for ax0, ax1 in T.grid(out_batch, vocab_size): + with T.block("T_get_sample_index"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.writes(output_index[v_ax0, 0]) + if ( + usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] + or v_ax1 + 1 == vocab_size + ): + if v_ax1 == 0: + output_index[v_ax0, 0] = 0 + elif ( + usample[v_ax0, T.int64(0)] + >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1] + ): + output_index[v_ax0, 0] = v_ax1 + + return _get_sample_index diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 45428692b830..725a930fd680 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2352,6 +2352,7 @@ def multinomial_from_uniform( uniform_sample: Tensor, sample_indices: Optional[Tensor] = None, dtype: str = "int64", + name: str = "multinomial_from_uniform", ): """Returns a tensor where each row contains the index sampled from the multinomial probability distribution located in the corresponding row of tensor prob. @@ -2403,8 +2404,6 @@ def multinomial_from_uniform( multinomial_from_uniform(prob, usample, sample_indices) -> [[1], [2]] """ - prob_dtype = prob.dtype - sample_dtype = uniform_sample.dtype out_batch = uniform_sample.shape[0] if sample_indices is not None: @@ -2417,40 +2416,9 @@ def multinomial_from_uniform( ), "Number of samples must match the number of probability distributions." sample_indices = Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1)) - sample_indices_dtype = sample_indices.dtype - - @T.prim_func(private=True) - def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): - batch, vocab_size = T.int64(), T.int64() - prob = T.match_buffer(A, (batch, vocab_size), prob_dtype) - out_batch = T.int64() - usample = T.match_buffer(B, (out_batch, 1), sample_dtype) - sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype) - output_index = T.match_buffer(D, (out_batch, 1), dtype) - - for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_sample_index"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.writes(output_index[v_ax0, 0]) - if ( - usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] - or v_ax1 + 1 == vocab_size - ): - if v_ax1 == 0: - output_index[v_ax0, 0] = 0 - elif ( - usample[v_ax0, T.int64(0)] - >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1] - ): - output_index[v_ax0, 0] = v_ax1 - - cumsum_prob = cumsum(prob, axis=1, exclusive=False) - - return tensor_ir_op( - _get_sample_index, - "get_sample_index", - args=[cumsum_prob, uniform_sample, sample_indices], - out=Tensor.placeholder([out_batch, 1], dtype), + return wrap_nested( + _op.multinomial_from_uniform(prob._expr, uniform_sample._expr, sample_indices._expr, dtype), + name, ) @@ -2554,12 +2522,12 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): for ax0, ax1 in T.grid(batch, vocab_size): with T.block("T_get_renorm_prob"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0: + if not _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] - elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) == 1: + elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1): if v_ax1 + 1 == vocab_size: renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1] - elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1) == 0: + elif not _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1] @T.prim_func(private=True) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 5b585e18b450..4581defa1a77 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -96,11 +96,12 @@ tile, ) from .mask import masked_fill -from .qdq import quantize, dequantize +from .qdq import dequantize, quantize +from .sampling import multinomial_from_uniform from .search import argmax, argmin, where from .set import unique -from .sorting import sort, argsort, topk -from .statistical import cumsum, cumprod, max, mean, min, prod, std, sum, variance +from .sorting import argsort, sort, topk +from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma from .unary import ( abs, diff --git a/python/tvm/relax/op/sampling.py b/python/tvm/relax/op/sampling.py new file mode 100644 index 000000000000..bcd43a392247 --- /dev/null +++ b/python/tvm/relax/op/sampling.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Sampling operators.""" + +from .. import args_converter +from ..expr import Expr +from . import _ffi_api + + +@args_converter.auto +def multinomial_from_uniform( + prob: Expr, + uniform_sample: Expr, + sample_indices: Expr, + dtype: str = "int64", +) -> Expr: + """Returns a tensor where each row contains the index sampled from the multinomial + probability distribution located in the corresponding row of tensor prob. + + Notes + ----- + For better cpu performance, use 'vm.builtin.multinomial_from_uniform'. + For accurate results, ensure probabilities are between 0 and 1 and sum to 1. + + Parameters + ---------- + prob : relax.Expr + A 2-D tensor of shape (batch, vocab_size) representing probability distributions. + Each row is a distribution across vocabulary for a batch, where: + Values range from [0, 1], indicating the probability of each vocabulary item. + The sum of values in each row is 1, forming a valid distribution. + + uniform_sample : relax.Expr + The uniformly sampled 2-D tensor with the shape (n, 1). + Values range from 0 to 1, indicating probabilities sampled uniformly. + + sample_indices : relax.Expr + The 2-D tensor with the shape [n, 1], which indicates the specific + probability distribution to sample from. The value of sample_indices[i] + determines that the ith token should be sampled from the sample_indices[i]th + probability distribution. For instance, if there are 3 distinct probability + distributions and the requirement is to sample 2, 3, and 4 tokens from each, + then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. + + dtype : str + The data type of the output tensor. + + Returns + ------- + result : relax.Expr + The computed tensor with shape (n, 1). + + Examples + -------- + .. code-block:: python + + prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]] + usample = [[0.4], [0.9]] + sample_indices = [[0], [1]] + + multinomial_from_uniform(prob, usample) + -> [[1], [2]] + multinomial_from_uniform(prob, usample, sample_indices) + -> [[1], [2]] + + """ + + return _ffi_api.multinomial_from_uniform( # type: ignore + prob, + uniform_sample, + sample_indices, + dtype, + ) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 36ba46a1a5e3..d068f800d0e9 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -81,6 +81,7 @@ def default_build_pipeline(): def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: seq = tvm.transform.Sequential( [ + backend.DispatchSampling(), backend.DispatchSortScan(), transform.LegalizeOps(), transform.RewriteDataflowReshape(), diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 6dbf5c5dfdb4..ef9ae775450b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,32 +20,38 @@ import builtins import functools import inspect -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import tvm from tvm import DataType, relax from tvm.ir import PrimExpr, VDevice -from ..ir import decl_function, lookup_vdevice -from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, ShapeExpr, Var, VarBinding, const -from tvm.relax.utils import gen_call_tir_inputs - +from tvm.relax import ( + Call, + Expr, + ExternFunc, + ShapeExpr, + TupleGetItem, + Var, + VarBinding, + const, +) ############################### Operators ############################### from tvm.relax.op import ( abs, acos, acosh, - asin, - asinh, - atan, - atanh, add, arange, argmax, argmin, argsort, + asin, + asinh, assert_op, astype, + atan, + atanh, bitwise_and, bitwise_not, bitwise_or, @@ -53,12 +59,13 @@ broadcast_to, builtin, call_builtin_with_ctx, + call_dps_packed, call_inplace_packed, call_pure_packed, call_tir, call_tir_inplace, call_tir_with_grad, - call_dps_packed, + ccl, ceil, clip, collapse_sum_like, @@ -68,10 +75,12 @@ cosh, cumprod, cumsum, - einsum, - scatter_elements, + dequantize, divide, + dynamic_strided_slice, + einsum, equal, + erf, ewise_fma, exp, expand_dims, @@ -108,8 +117,10 @@ memory, min, minimum, + multinomial_from_uniform, multiply, negative, + nn, not_equal, null_value, ones, @@ -119,75 +130,70 @@ print, prod, quantize, - dequantize, repeat, reshape, - tensor_to_shape, - shape_to_tensor, round, rsqrt, + scatter_elements, shape_of, - std, - strided_slice, - dynamic_strided_slice, - sum, - take, - variance, + shape_to_tensor, sigmoid, sign, sin, sinh, sort, split, + sqrt, square, squeeze, - sqrt, + std, + strided_slice, subtract, + sum, + take, tan, tanh, - erf, + tensor_to_shape, tile, topk, tril, triu, unique, + variance, vm, where, wrap_param, zeros, zeros_like, - nn, - ccl, ) - +from tvm.relax.op.builtin import stop_lift_params +from tvm.relax.struct_info import StructInfo +from tvm.relax.utils import args_converter, gen_call_tir_inputs +from tvm.runtime import Object as tvm_Object +from tvm.runtime import ObjectGeneric from tvm.runtime.ndarray import ( cpu, cuda, device, + ext_dev, gpu, - rocm, - opencl, + hexagon, metal, + opencl, + rocm, vpi, vulkan, - ext_dev, - hexagon, webgpu, ) -from tvm.relax.op.builtin import stop_lift_params -from tvm.relax.struct_info import StructInfo -from tvm.relax.utils import args_converter -from tvm.runtime import Object as tvm_Object -from tvm.runtime import ObjectGeneric - +from ..ir import decl_function, lookup_vdevice from . import _ffi_api, frame ##################### Python Native Function Alias ###################### py_print = builtins.print -py_tuple = tuple -py_str = str +py_tuple = tuple # pylint: disable=used-before-assignment +py_str = str # pylint: disable=used-before-assignment ################################ Device ################################ @@ -741,6 +747,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "metal", "min", "minimum", + "multinomial_from_uniform", "multiply", "negative", "not_equal", diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 679ae4e8adc0..313e6c5f4412 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -479,14 +479,27 @@ def visit_if(self: Parser, node: doc.If) -> None: The doc AST if node. """ with self.var_table.with_frame(): - with T.If(self.eval_expr(node.test)): - with T.Then(): + predicate = self.eval_expr(node.test) + if isinstance(predicate, (PrimExpr, tvm.tir.expr.ExprOp)): + with T.If(self.eval_expr(node.test)): + with T.Then(): + with self.var_table.with_frame(): + self.visit_body(node.body) + if node.orelse: + with T.Else(): + with self.var_table.with_frame(): + self.visit_body(node.orelse) + elif isinstance(predicate, bool): + if predicate: with self.var_table.with_frame(): self.visit_body(node.body) - if node.orelse: - with T.Else(): - with self.var_table.with_frame(): - self.visit_body(node.orelse) + elif node.orelse: + with self.var_table.with_frame(): + self.visit_body(node.orelse) + else: + self.report_error( + node.test, f"If condition must be a boolean expression, but got {predicate}" + ) @dispatch.register(token="tir", type_name="Assert") diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index b23baa031303..d5ed4fd99768 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -81,7 +81,11 @@ def _detect_vulkan(dev: Device) -> Target: "supports_int8": f_get_target_property(dev, "supports_int8"), "supports_int16": f_get_target_property(dev, "supports_int16"), "supports_int64": f_get_target_property(dev, "supports_int64"), + "supports_8bit_buffer": f_get_target_property(dev, "supports_8bit_buffer"), "supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"), + "supports_storage_buffer_storage_class": f_get_target_property( + dev, "supports_storage_buffer_storage_class" + ), } ) diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 022ef31c66d0..36527c35841e 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -550,7 +550,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& // TODO(tvm-team): Currently, it is unable to express partially-static shape. Revisit when // PrimValue lands. return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice); -} // namespace relax +} // TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy TVM_REGISTER_OP("relax.dynamic_strided_slice") diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc new file mode 100644 index 000000000000..35ee4c486b1d --- /dev/null +++ b/src/relax/op/tensor/sampling.cc @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file sampling.cc + * \brief sampling operators. + */ + +#include "sampling.h" + +#include + +#include + +namespace tvm { +namespace relax { + +/* relax.multinomial_from_uniform */ +TVM_REGISTER_NODE_TYPE(MultinomialFromUniformAttrs); + +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.multinomial_from_uniform"); + return Call(op, {std::move(prob), std::move(uniform_sample), std::move(sample_indices)}, + Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.multinomial_from_uniform").set_body_typed(multinomial_from_uniform); + +StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { + CheckNumArguments(call, ctx); + TensorStructInfo prob_sinfo = GetInputTensorStructInfo(call, 0, ctx); + TensorStructInfo uniform_sample_sinfo = GetInputTensorStructInfo(call, 1, ctx); + TensorStructInfo sample_indices_sinfo = GetInputTensorStructInfo(call, 2, ctx); + const auto* attrs = call->attrs.as(); + + if (!prob_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input prob to have float dtype. " + "However, the given prob dtype is " + << prob_sinfo->dtype); + } + if (!uniform_sample_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input uniform_sample to have float " + "dtype. However, the given uniform_sample dtype is " + << uniform_sample_sinfo->dtype); + } + if (!sample_indices_sinfo->dtype.is_int()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial from uniform op requires the input sample_indices to have int " + "dtype. However, the given sample_indices dtype is " + << sample_indices_sinfo->dtype); + } + if (prob_sinfo->IsUnknownNdim() || uniform_sample_sinfo->IsUnknownNdim() || + sample_indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(attrs->dtype, kUnknownNDim, prob_sinfo->vdevice); + } + if (prob_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input prob to be a 2D tensor. " + "However, the given prob tensor has ndim " + << prob_sinfo->ndim); + } + if (uniform_sample_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input uniform_sample to be a 2D " + "tensor. However, the given uniform_sample tensor has ndim " + << uniform_sample_sinfo->ndim); + } + if (sample_indices_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input sample_indices to be a 2D " + "tensor. However, the given sample_indices tensor has ndim " + << sample_indices_sinfo->ndim); + } + + // Expected to be `(batch, vocab_size)` + const auto* prob_shape = prob_sinfo->shape.as(); + // Expected to be `(n, 1)` + const auto* uniform_sample_shape = uniform_sample_sinfo->shape.as(); + // Expected to be `(n, 1)` + const auto* sample_indices_shape = sample_indices_sinfo->shape.as(); + // The output shape is expected to be `(n, 1)` + + if (prob_shape == nullptr || uniform_sample_shape == nullptr || sample_indices_shape == nullptr) { + return TensorStructInfo(attrs->dtype, 2, prob_sinfo->vdevice); + } + + PrimExpr batch = prob_shape->values[0]; + PrimExpr n = uniform_sample_shape->values[0]; + arith::Analyzer ana; + if (!ana.CanProveEqual(n, sample_indices_shape->values[0])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input uniform_sample and " + "sample_indices to have the same batch size. " + "However, the given uniform_sample tensor has batch size `" + << n << "` and the given sample_indices tensor has batch size `" + << sample_indices_shape->values[0] << "`"); + } + if (!tir::is_one(uniform_sample_shape->values[1]) || + !tir::is_one(sample_indices_shape->values[1])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Multinomial_from_uniform op requires the input uniform_sample and " + "sample_indices to be 2D tensors with the second dimension being 1. " + "However, the given uniform_sample tensor has shape " + << uniform_sample_sinfo->shape + << " and the given sample_indices tensor has shape " + << sample_indices_sinfo->shape); + } + return TensorStructInfo(ShapeExpr({n, 1}), attrs->dtype, prob_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.multinomial_from_uniform") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("prob", "Tensor", "The probability tensor.") + .add_argument("uniform_sample", "Tensor", "The uniform sample tensor.") + .add_argument("sample_indices", "Tensor", "The sample indices tensor.") + .set_attr("FInferStructInfo", InferStructInfoMultinomialFromUniform) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/sampling.h b/src/relax/op/tensor/sampling.h new file mode 100644 index 000000000000..d13aa835d68d --- /dev/null +++ b/src/relax/op/tensor/sampling.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file sampling.h + * \brief The functions to make Relax tensor sampling operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SAMPLING_H_ +#define TVM_RELAX_OP_TENSOR_SAMPLING_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Returns a tensor where each row contains the index sampled from the multinomial + * probability distribution located in the corresponding row of tensor prob. + * \param prob A 2-D tensor of shape (batch, vocab_size) representing probability distributions. + * Each row is a distribution across vocabulary for a batch, where: + * Values range from [0, 1], indicating the probability of each vocabulary item. + * The sum of values in each row is 1, forming a valid distribution. + * \param uniform_sample A 2-D tensor with the shape (n, 1). Values range from 0 to 1, indicating + * probabilities sampled uniformly. + * \param sample_indices The 2-D tensor with the shape [n, 1], which indicates the specific + * probability distribution to sample from. The value of sample_indices[i] + * determines that the ith token should be sampled from the sample_indices[i]th + * probability distribution. For instance, if there are 3 distinct probability + * distributions and the requirement is to sample 2, 3, and 4 tokens from each, + * then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. + * \param dtype The data type of the output tensor. + * \return The sampled result. + */ +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SAMPLING_H_ diff --git a/tests/python/relax/test_backend_dispatch_sampling.py b/tests/python/relax/test_backend_dispatch_sampling.py new file mode 100644 index 000000000000..18d625d01995 --- /dev/null +++ b/tests/python/relax/test_backend_dispatch_sampling.py @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import tvm +import tvm.script +import tvm.testing +from tvm.ir.base import assert_structural_equal +from tvm.relax.backend import DispatchSampling +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +@I.ir_module +class MultiFromUniformModule: + @R.function + def foo( + prob: R.Tensor((3, 5), "float32"), + uniform_sample: R.Tensor((6, 1), "float32"), + sample_indices: R.Tensor((6, 1), "int64"), + ): + with R.dataflow(): + gv = R.multinomial_from_uniform(prob, uniform_sample, sample_indices, dtype="int64") + R.output(gv) + return gv + + +def test_dispatch_multinomial_from_uniform_generic(): + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): + batch, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(A, (batch, vocab_size)) + out_batch = T.int64() + usample = T.match_buffer(B, (out_batch, 1)) + sample_indices = T.match_buffer(C, (out_batch, 1), "int64") + output_index = T.match_buffer(D, (out_batch, 1), "int64") + # with T.block("root"): + for ax0, ax1 in T.grid(out_batch, vocab_size): + with T.block("T_get_sample_index"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size: + if v_ax1 == T.int64(0): + output_index[v_ax0, 0] = T.int64(0) + else: + if usample[v_ax0, T.int64(0)] >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)]: + output_index[v_ax0, 0] = v_ax1 + + @R.function + def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64")) -> R.Tensor((6, 1), dtype="int64"): + cls = Expected + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="float32", exclusive=0) + gv = R.call_tir(cls.get_sample_index, (lv, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) + R.output(gv) + return gv + # fmt: on + + with tvm.target.Target("llvm"): + mod = DispatchSampling()(MultiFromUniformModule) + + assert_structural_equal(mod, Expected) + + +def test_dispatch_multinomial_from_uniform_gpu(): + # fmt: off + @I.ir_module + class Expected: + @T.prim_func + def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + n, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(var_prob, (n, vocab_size)) + batch_size = T.int64() + uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1)) + row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int64") + token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int64") + # with T.block("root"): + aggregate = T.alloc_buffer((), scope="local") + sample_id_local = T.alloc_buffer((), "int64", scope="local") + step_iter = T.alloc_buffer((), "int32", scope="local") + for bx in T.thread_binding(batch_size, thread="blockIdx.x"): + row_idx: T.int64 = row_indices[bx, 0] + for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + u: T.float32 = uniform_samples[bx, 0] + aggregate[()] = T.Cast("float32", 0) + step_iter[()] = 0 + while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)): + with T.block(""): + T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()]) + T.writes(sample_id_local[()], aggregate[()]) + prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local") + cumsum = T.alloc_buffer((T.int64(512),), scope="shared") + greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local") + mask = T.alloc_buffer((T.int64(4),), "bool", scope="local") + valid = T.alloc_buffer((T.int64(4),), "bool", scope="local") + indices = T.alloc_buffer((T.int64(4),), "int64", scope="local") + step_aggregate = T.alloc_buffer((), scope="local") + for v in T.unroll(T.int64(4)): + idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v + prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0)) + prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0), prob_local, T.Cast("float32", 0)) + valid[v] = prob_local > T.float32(0) and idx < vocab_size + with T.block(""): + T.reads(prob_gt_threshold[T.int64(0):T.int64(4)]) + T.writes(step_aggregate[()]) + local_sum = T.alloc_buffer((), scope="local") + shared_buf = T.alloc_buffer((T.int64(128),), scope="shared") + idx: T.int64 = ty * T.int64(32) + tx + local_sum[()] = T.Cast("float32", 0) + for i in T.unroll(T.int64(4)): + local_sum[()] = local_sum[()] + prob_gt_threshold[i] + shared_buf[idx] = local_sum[()] + for i in T.unroll(T.int64(7)): + if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): + shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)] + step_aggregate[()] = shared_buf[0] + if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)): + for i in T.unroll(T.int64(1), T.int64(4)): + prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)] + for i in T.vectorized(T.int64(4)): + cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i] + for i in T.unroll(T.int64(5)): + for j in T.vectorized(T.int64(4)): + idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + if tx >= T.shift_left(T.int64(1), i): + cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)] + for i in T.unroll(T.int64(1), T.int64(4)): + for j in T.vectorized(T.int64(4)): + if ty == T.int64(0): + idx: T.int64 = i * T.int64(128) + tx * T.int64(4) + cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)] + for v in T.unroll(T.int64(4)): + greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07) + with T.block(""): + T.reads(greater_than_u[T.int64(0):T.int64(4)]) + T.writes(mask[T.int64(0):T.int64(4)]) + shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared") + tx_idx: T.int64 = ty * T.int64(32) + tx + shared_buf[tx_idx] = greater_than_u[T.int64(3)] + mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0]) + for i in T.unroll(T.int64(1), T.int64(4)): + mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)]) + for v in T.unroll(T.int64(4)): + mask[v] = mask[v] and valid[v] + indices[v] = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v + with T.block(""): + T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)]) + T.writes(sample_id_local[()]) + local_sum = T.alloc_buffer((), "int64", scope="local") + shared_buf = T.alloc_buffer((T.int64(128),), "int64", scope="shared") + idx: T.int64 = ty * T.int64(32) + tx + local_sum[()] = T.Cast("int64", vocab_size - T.int64(1)) + for i in T.unroll(T.int64(4)): + if mask[i]: + local_sum[()] = T.min(local_sum[()], indices[i]) + shared_buf[idx] = local_sum[()] + for i in T.unroll(T.int64(7)): + if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0): + shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)]) + sample_id_local[()] = shared_buf[0] + aggregate[()] = aggregate[()] + step_aggregate[()] + step_iter[()] = step_iter[()] + 1 + if tx == T.int64(0) and ty == T.int64(0): + token_ids[bx, 0] = sample_id_local[()] + + @R.function + def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64")) -> R.Tensor((6, 1), dtype="int64"): + cls = Expected + with R.dataflow(): + gv = R.call_tir(cls.parallel_sampling_from_prob, (prob, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) + R.output(gv) + return gv + # fmt: on + + with tvm.target.Target("cuda"): + mod = DispatchSampling()(MultiFromUniformModule) + + assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 8bf52d7918e5..a632a867432b 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -849,7 +849,7 @@ def test(self): vm["test"](*effects) -@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_multinomial_from_uniform(): prob_shape = (3, 5) @@ -863,27 +863,6 @@ def foo(self, prob: Tensor, uniform_sample: Tensor, sample_indices: Tensor): # fmt: off @I.ir_module class Expected: - @T.prim_func(private=True) - def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): - batch, vocab_size = T.int64(), T.int64() - prob = T.match_buffer(A, (batch, vocab_size)) - out_batch = T.int64() - usample = T.match_buffer(B, (out_batch, 1)) - sample_indices = T.match_buffer(C, (out_batch, 1), "int64") - output_index = T.match_buffer(D, (out_batch, 1), "int64") - # with T.block("root"): - for ax0, ax1 in T.grid(out_batch, vocab_size): - with T.block("T_get_sample_index"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(usample[v_ax0, T.int64(0)], prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)]) - T.writes(output_index[v_ax0, 0]) - if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size: - if v_ax1 == T.int64(0): - output_index[v_ax0, 0] = T.int64(0) - else: - if usample[v_ax0, T.int64(0)] >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)]: - output_index[v_ax0, 0] = v_ax1 - @R.function def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): @@ -896,11 +875,9 @@ def _initialize_effect() -> R.Tuple(R.Object): @R.function def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) - cls = Expected with R.dataflow(): - cumsum: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=0) - lv1 = R.call_tir(cls.get_sample_index, (cumsum, uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64")) - gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)) = lv1, (_io,) + multinomial_from_uniform: R.Tensor((6, 1), dtype="int64") = R.multinomial_from_uniform(prob, uniform_sample, sample_indices, dtype="int64") + gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)) = multinomial_from_uniform, (_io,) R.output(gv1) return gv1 # fmt: on @@ -919,11 +896,12 @@ def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1) tvm.ir.assert_structural_equal(mod, Expected) - target = tvm.target.Target("cuda -libs=thrust", host="llvm") + target = tvm.target.Target("cuda", host="llvm") with target: + mod = relax.backend.DispatchSampling()(mod) mod = tir.transform.DefaultGPUSchedule()(mod) ex = relax.build(mod, target) - dev = tvm.cuda(0) + dev = tvm.device(str(target), 0) vm = relax.VirtualMachine(ex, dev) effects = vm["_initialize_effect"]() @@ -1001,14 +979,14 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0]) T.writes(renorm_prob[v_ax0, 0]) - if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)) == T.bool(False): + if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] else: - if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True): + if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]: if v_ax1 + T.int64(1) == vocab_size: renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1] else: - if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == T.bool(False): + if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]): renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] @R.function diff --git a/tests/python/relax/test_op_sampling.py b/tests/python/relax/test_op_sampling.py new file mode 100644 index 000000000000..d8806cf62500 --- /dev/null +++ b/tests/python/relax/test_op_sampling.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_multinomial_from_uniform(): + bb = relax.BlockBuilder() + prob0 = relax.Var("prob", R.Tensor((3, 5), "float32")) + prob1 = relax.Var("prob", R.Tensor(ndim=2, dtype="float32")) + prob2 = relax.Var("prob", R.Tensor(dtype="float32")) + + uniform_sample0 = relax.Var("u", R.Tensor((6, 1), "float32")) + uniform_sample1 = relax.Var("u", R.Tensor(ndim=2, dtype="float32")) + uniform_sample2 = relax.Var("u", R.Tensor(dtype="float32")) + + sample_indices0 = relax.Var("s", R.Tensor((6, 1), "int64")) + sample_indices1 = relax.Var("s", R.Tensor((6, 1), "int32")) + + _check_inference( + bb, + relax.op.multinomial_from_uniform(prob0, uniform_sample0, sample_indices0), + R.Tensor((6, 1), "int64"), + ) + _check_inference( + bb, + relax.op.multinomial_from_uniform(prob0, uniform_sample0, sample_indices0, dtype="int32"), + R.Tensor((6, 1), "int32"), + ) + _check_inference( + bb, + relax.op.multinomial_from_uniform(prob1, uniform_sample1, sample_indices1), + R.Tensor(ndim=2, dtype="int64"), + ) + _check_inference( + bb, + relax.op.multinomial_from_uniform(prob1, uniform_sample1, sample_indices1, dtype="int32"), + R.Tensor(ndim=2, dtype="int32"), + ) + _check_inference( + bb, + relax.op.multinomial_from_uniform(prob2, uniform_sample2, sample_indices0), + R.Tensor(dtype="int64"), + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 25a904a157da..2dcbc89d47a6 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -486,5 +486,29 @@ def func() -> None: assert func.body.node.dom.extent.dtype == "int64" +def test_deterministic_branch(): + """Test deterministic branch""" + + def create_func(predicate: bool): + @T.prim_func(private=True) + def func() -> None: + if predicate: + T.evaluate(0) + else: + T.evaluate(1) + + return func + + def create_expected(value): + @T.prim_func(private=True) + def expected() -> None: + T.evaluate(value) + + return expected + + tvm.ir.assert_structural_equal(create_func(True), create_expected(0)) + tvm.ir.assert_structural_equal(create_func(False), create_expected(1)) + + if __name__ == "__main__": tvm.testing.main()