Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions include/tvm/relax/attrs/sampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/sampling.h
* \brief Attributes for sampling operators.
*/
#ifndef TVM_RELAX_ATTRS_SAMPLING_H_
#define TVM_RELAX_ATTRS_SAMPLING_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in multinomial_from_uniform operator */
struct MultinomialFromUniformAttrs : public tvm::AttrsNode<MultinomialFromUniformAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(MultinomialFromUniformAttrs, "relax.attrs.MultinomialFromUniformAttrs") {
TVM_ATTR_FIELD(dtype)
.set_default(DataType::Int(64))
.describe("Data type of the output indices.");
}
}; // struct MultinomialFromUniformAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_SAMPLING_H_
3 changes: 2 additions & 1 deletion python/tvm/relax/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
"""Relax backends"""

from . import contrib
from .pattern_registry import get_pattern, get_patterns_with_prefix
from .dispatch_sampling import DispatchSampling
from .dispatch_sort_scan import DispatchSortScan
from .pattern_registry import get_pattern, get_patterns_with_prefix
94 changes: 94 additions & 0 deletions python/tvm/relax/backend/dispatch_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
"""Dispatch sampling operators to platform dependent implementation."""


from tvm import relax
from tvm.ir import Op
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext, module_pass
from tvm.relax import expr_functor

from .utils import BackendDispatcher


@expr_functor.mutator
class SamplingDispatcher(BackendDispatcher):
"""Dispatcher to dispatch sampling op."""

def visit_call_(self, call: relax.Call) -> relax.Expr:
if not isinstance(call.op, Op):
return super().visit_call_(call)

if call.op.name == "relax.multinomial_from_uniform":
from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel
generic_get_sample_index,
gpu_multinomial_from_uniform,
)

prob, uniform_sample, sample_indices = call.args
tgt = self._get_target(call.struct_info)
dtype = call.attrs.dtype
_, prob_dtype = self.get_shape_dtype(prob)
sample_shape, sample_dtype = self.get_shape_dtype(uniform_sample)
sample_indices_shape, sample_indices_dtype = self.get_shape_dtype(sample_indices)

if len(sample_shape) != 2 or sample_shape[1] != 1:
raise ValueError("uniform_sample should be a 2D tensor with shape (N, 1)")

if len(sample_indices_shape) != 2 or sample_indices_shape[1] != 1:
raise ValueError("sample_indices should be a 2D tensor with shape (N, 1)")

if self.is_gpu_target(tgt):
gv = self.builder_.add_func(
gpu_multinomial_from_uniform(
prob_dtype, sample_dtype, sample_indices_dtype, dtype
),
"gpu_multinomial_from_uniform",
)
return relax.call_tir(
gv,
[prob, uniform_sample, sample_indices],
out_sinfo=call.struct_info,
)
else:
cumsum_prob = relax.op.cumsum(prob, axis=1, dtype=prob_dtype, exclusive=False)
gv = self.builder_.add_func(
generic_get_sample_index(prob_dtype, sample_dtype, sample_indices_dtype, dtype),
"get_sample_index",
)
return relax.call_tir(
gv,
[cumsum_prob, uniform_sample, sample_indices],
out_sinfo=call.struct_info,
)

return super().visit_call_(call)


@module_pass(opt_level=0, name="DispatchSampling")
class DispatchSampling:
"""Pass to dispatch scan and sort operators to platform dependent implementation."""

def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
sampling_dispatcher = SamplingDispatcher(mod)
for gv, func in mod.functions_items():
if isinstance(func, relax.Function):
func = sampling_dispatcher.visit_expr(func)
sampling_dispatcher.builder_.update_func(gv, func)
return sampling_dispatcher.builder_.finalize()
46 changes: 10 additions & 36 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,22 @@
from tvm.ir import GlobalVar, Op
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext, module_pass
from tvm.relax import PyExprMutator, expr_functor
from tvm.relax import expr_functor
from tvm.target import Target


def is_gpu_target(target: Target) -> bool:
"""Check if the target is a GPU target."""
return "gpu" in target.keys
from .utils import BackendDispatcher


@expr_functor.mutator
class SortScanDispatcher(PyExprMutator):
"""
Dispatcher to dispatch sort and scan.

"""
class SortScanDispatcher(BackendDispatcher):
"""Dispatcher to dispatch sort and scan."""

calls_to_update: Dict[GlobalVar, Target]

def __init__(self, mod):
super().__init__(mod)
self.calls_to_update = {}

def _get_target(self, sinfo: relax.StructInfo) -> Target:
# Get target information from TensorStructInfo
if isinstance(sinfo, relax.TensorStructInfo):
vdevice = sinfo.vdevice
if vdevice is not None:
return vdevice.target
elif isinstance(sinfo, relax.TupleStructInfo):
for f in sinfo.fields:
tgt = self._get_target(f)
if tgt != Target.current():
return tgt
# Return the target in current context
target = Target.current()
if target is None:
raise ValueError(
"Target not found. Please ensure that the target is annotated within the module, "
"or alternatively, execute this within a specified target context."
)
return target

def apply_dlight_gpu_fallback(
self,
) -> None:
Expand Down Expand Up @@ -107,7 +81,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.sort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif is_gpu_target(tgt):
elif self.is_gpu_target(tgt):
te_func = topi.cuda.sort
return self.builder_.call_te(
te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs
Expand All @@ -120,7 +94,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.argsort_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif is_gpu_target(tgt):
elif self.is_gpu_target(tgt):
te_func = topi.cuda.argsort
return self.builder_.call_te(
te_func,
Expand All @@ -137,7 +111,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
te_func = topi.cuda.topk_thrust
kwargs["workspace"] = self.allocate_workspace(call)
elif is_gpu_target(tgt):
elif self.is_gpu_target(tgt):
te_func = topi.cuda.topk
tir_call = self.builder_.call_te(
te_func,
Expand All @@ -162,7 +136,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if (
shape is not None
and (axis == -1 or axis == len(shape) - 1)
and is_gpu_target(tgt)
and self.is_gpu_target(tgt)
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
and call.op.name == "relax.cumsum"
and call.attrs.exclusive == 0
Expand Down Expand Up @@ -202,11 +176,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:

with tgt:
if call.op.name == "relax.cumsum":
te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum
te_func = topi.cuda.cumsum if self.is_gpu_target(tgt) else topi.cumsum
if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
kwargs["workspace"] = self.allocate_workspace(call)
elif call.op.name == "relax.cumprod":
te_func = topi.cuda.cumprod if is_gpu_target(tgt) else topi.cumprod
te_func = topi.cuda.cumprod if self.is_gpu_target(tgt) else topi.cumprod
else:
raise ValueError(f"Unsupported op: {call.op.name}")
tir_call = self.builder_.call_te(
Expand Down
55 changes: 54 additions & 1 deletion python/tvm/relax/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,61 @@
# pylint: disable=invalid-name
"""Utils for BYOC pattern matching"""

from tvm.relax import DataflowVar
from typing import Tuple
from tvm import relax
from tvm.relax import DataflowVar, PyExprMutator
from tvm.relax.transform import PatternCheckContext
from tvm.target import Target


class BackendDispatcher(PyExprMutator):
"""Base class for backend dispatcher"""

def __init__(self, mod):
super().__init__(mod)

@staticmethod
def is_gpu_target(target: Target) -> bool:
"""Check if the target is a GPU target."""
return "gpu" in target.keys

@staticmethod
def get_shape_dtype(expr: relax.Expr) -> Tuple[relax.ShapeExpr, str]:
"""Get shape and dtype from an expression.
If the shape and dtype is unknown, raise an error."""
sinfo = expr.struct_info
if not isinstance(expr.struct_info, relax.TensorStructInfo):
raise ValueError(
f"Expecting a expr with TensorStructInfo, but got {expr} with {expr.struct_info}"
)

shape, dtype = sinfo.shape, sinfo.dtype
if shape is None:
raise ValueError(
f"Expecting a expr with known shape, but got {expr} with unknown shape"
)

return shape, dtype

def _get_target(self, sinfo: relax.StructInfo) -> Target:
# Get target information from TensorStructInfo
if isinstance(sinfo, relax.TensorStructInfo):
vdevice = sinfo.vdevice
if vdevice is not None:
return vdevice.target
elif isinstance(sinfo, relax.TupleStructInfo):
for f in sinfo.fields:
tgt = self._get_target(f)
if tgt != Target.current():
return tgt
# Return the target in current context
target = Target.current()
if target is None:
raise ValueError(
"Target not found. Please ensure that the target is annotated within the module, "
"or alternatively, execute this within a specified target context."
)
return target


def has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/backend_tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
"""Relax backends, tir based"""

from . import contrib
from .pattern import get_tir_pattern
from .cumsum import gpu_2d_continuous_cumsum
from .pattern import get_tir_pattern
from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index
8 changes: 4 additions & 4 deletions python/tvm/relax/backend_tir/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def gpu_2d_continuous_cumsum(
Parameters
----------
ty_len : int
The length of thread.y
The length of `threadIdx.y`

tx_len : int
The length of thread.x
The length of `threadIdx.x`

thread_elem : int
The number of elements processed by single thread
Expand All @@ -64,8 +64,8 @@ def gpu_2d_continuous_cumsum(
out_dtype = out_dtype or in_dtype

# Configuration for GPU kernel
TX = T.int64(tx_len) # thread.x
TY = T.int64(ty_len) # thread.y
TX = T.int64(tx_len) # threadIdx.x
TY = T.int64(ty_len) # threadIdx.y
N = T.int64(thread_elem) # number of elements in single thread

if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N):
Expand Down
Loading