Skip to content

Commit 939663d

Browse files
author
Siyuan Feng
committed
Support multinomial_from_uniform dispatch
1 parent fffd168 commit 939663d

File tree

22 files changed

+1218
-158
lines changed

22 files changed

+1218
-158
lines changed

include/tvm/relax/attrs/sampling.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/sampling.h
22+
* \brief Attributes for sampling operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_SAMPLING_H_
25+
#define TVM_RELAX_ATTRS_SAMPLING_H_
26+
27+
#include <tvm/relax/expr.h>
28+
29+
namespace tvm {
30+
namespace relax {
31+
32+
/*! \brief Attributes used in multinomial_from_uniform operator */
33+
struct MultinomialFromUniformAttrs : public tvm::AttrsNode<MultinomialFromUniformAttrs> {
34+
DataType dtype;
35+
36+
TVM_DECLARE_ATTRS(MultinomialFromUniformAttrs, "relax.attrs.MultinomialFromUniformAttrs") {
37+
TVM_ATTR_FIELD(dtype)
38+
.set_default(DataType::Int(64))
39+
.describe("Data type of the output indices.");
40+
}
41+
}; // struct MultinomialFromUniformAttrs
42+
43+
} // namespace relax
44+
} // namespace tvm
45+
46+
#endif // TVM_RELAX_ATTRS_SAMPLING_H_

python/tvm/relax/backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
"""Relax backends"""
1818

1919
from . import contrib
20-
from .pattern_registry import get_pattern, get_patterns_with_prefix
20+
from .dispatch_sampling import DispatchSampling
2121
from .dispatch_sort_scan import DispatchSortScan
22+
from .pattern_registry import get_pattern, get_patterns_with_prefix
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
18+
"""Dispatch sampling operators to platform dependent implementation."""
19+
20+
21+
from tvm import relax
22+
from tvm.ir import Op
23+
from tvm.ir.module import IRModule
24+
from tvm.ir.transform import PassContext, module_pass
25+
from tvm.relax import expr_functor
26+
27+
from .utils import BackendDispatcher
28+
29+
30+
@expr_functor.mutator
31+
class SamplingDispatcher(BackendDispatcher):
32+
"""Dispatcher to dispatch sampling op."""
33+
34+
def visit_call_(self, call: relax.Call) -> relax.Expr:
35+
if not isinstance(call.op, Op):
36+
return super().visit_call_(call)
37+
38+
if call.op.name == "relax.multinomial_from_uniform":
39+
from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel
40+
generic_get_sample_index,
41+
gpu_multinomial_from_uniform,
42+
)
43+
44+
prob, uniform_sample, sample_indices = call.args
45+
tgt = self._get_target(call.struct_info)
46+
dtype = call.attrs.dtype
47+
_, prob_dtype = self.get_shape_dtype(prob)
48+
sample_shape, sample_dtype = self.get_shape_dtype(uniform_sample)
49+
sample_indices_shape, sample_indices_dtype = self.get_shape_dtype(sample_indices)
50+
51+
if len(sample_shape) != 2 or sample_shape[1] != 1:
52+
raise ValueError("uniform_sample should be a 2D tensor with shape (N, 1)")
53+
54+
if len(sample_indices_shape) != 2 or sample_indices_shape[1] != 1:
55+
raise ValueError("sample_indices should be a 2D tensor with shape (N, 1)")
56+
57+
if self.is_gpu_target(tgt):
58+
gv = self.builder_.add_func(
59+
gpu_multinomial_from_uniform(
60+
prob_dtype, sample_dtype, sample_indices_dtype, dtype
61+
),
62+
"gpu_multinomial_from_uniform",
63+
)
64+
return relax.call_tir(
65+
gv,
66+
[prob, uniform_sample, sample_indices],
67+
out_sinfo=call.struct_info,
68+
)
69+
else:
70+
cumsum_prob = relax.op.cumsum(prob, axis=1, dtype=prob_dtype, exclusive=False)
71+
gv = self.builder_.add_func(
72+
generic_get_sample_index(prob_dtype, sample_dtype, sample_indices_dtype, dtype),
73+
"get_sample_index",
74+
)
75+
return relax.call_tir(
76+
gv,
77+
[cumsum_prob, uniform_sample, sample_indices],
78+
out_sinfo=call.struct_info,
79+
)
80+
81+
return super().visit_call_(call)
82+
83+
84+
@module_pass(opt_level=0, name="DispatchSampling")
85+
class DispatchSampling:
86+
"""Pass to dispatch scan and sort operators to platform dependent implementation."""
87+
88+
def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
89+
sampling_dispatcher = SamplingDispatcher(mod)
90+
for gv, func in mod.functions_items():
91+
if isinstance(func, relax.Function):
92+
func = sampling_dispatcher.visit_expr(func)
93+
sampling_dispatcher.builder_.update_func(gv, func)
94+
return sampling_dispatcher.builder_.finalize()

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,48 +26,22 @@
2626
from tvm.ir import GlobalVar, Op
2727
from tvm.ir.module import IRModule
2828
from tvm.ir.transform import PassContext, module_pass
29-
from tvm.relax import PyExprMutator, expr_functor
29+
from tvm.relax import expr_functor
3030
from tvm.target import Target
3131

32-
33-
def is_gpu_target(target: Target) -> bool:
34-
"""Check if the target is a GPU target."""
35-
return "gpu" in target.keys
32+
from .utils import BackendDispatcher
3633

3734

3835
@expr_functor.mutator
39-
class SortScanDispatcher(PyExprMutator):
40-
"""
41-
Dispatcher to dispatch sort and scan.
42-
43-
"""
36+
class SortScanDispatcher(BackendDispatcher):
37+
"""Dispatcher to dispatch sort and scan."""
4438

4539
calls_to_update: Dict[GlobalVar, Target]
4640

4741
def __init__(self, mod):
4842
super().__init__(mod)
4943
self.calls_to_update = {}
5044

51-
def _get_target(self, sinfo: relax.StructInfo) -> Target:
52-
# Get target information from TensorStructInfo
53-
if isinstance(sinfo, relax.TensorStructInfo):
54-
vdevice = sinfo.vdevice
55-
if vdevice is not None:
56-
return vdevice.target
57-
elif isinstance(sinfo, relax.TupleStructInfo):
58-
for f in sinfo.fields:
59-
tgt = self._get_target(f)
60-
if tgt != Target.current():
61-
return tgt
62-
# Return the target in current context
63-
target = Target.current()
64-
if target is None:
65-
raise ValueError(
66-
"Target not found. Please ensure that the target is annotated within the module, "
67-
"or alternatively, execute this within a specified target context."
68-
)
69-
return target
70-
7145
def apply_dlight_gpu_fallback(
7246
self,
7347
) -> None:
@@ -107,7 +81,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
10781
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
10882
te_func = topi.cuda.sort_thrust
10983
kwargs["workspace"] = self.allocate_workspace(call)
110-
elif is_gpu_target(tgt):
84+
elif self.is_gpu_target(tgt):
11185
te_func = topi.cuda.sort
11286
return self.builder_.call_te(
11387
te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs
@@ -120,7 +94,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
12094
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
12195
te_func = topi.cuda.argsort_thrust
12296
kwargs["workspace"] = self.allocate_workspace(call)
123-
elif is_gpu_target(tgt):
97+
elif self.is_gpu_target(tgt):
12498
te_func = topi.cuda.argsort
12599
return self.builder_.call_te(
126100
te_func,
@@ -137,7 +111,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
137111
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
138112
te_func = topi.cuda.topk_thrust
139113
kwargs["workspace"] = self.allocate_workspace(call)
140-
elif is_gpu_target(tgt):
114+
elif self.is_gpu_target(tgt):
141115
te_func = topi.cuda.topk
142116
tir_call = self.builder_.call_te(
143117
te_func,
@@ -162,7 +136,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
162136
if (
163137
shape is not None
164138
and (axis == -1 or axis == len(shape) - 1)
165-
and is_gpu_target(tgt)
139+
and self.is_gpu_target(tgt)
166140
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
167141
and call.op.name == "relax.cumsum"
168142
and call.attrs.exclusive == 0
@@ -202,11 +176,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
202176

203177
with tgt:
204178
if call.op.name == "relax.cumsum":
205-
te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum
179+
te_func = topi.cuda.cumsum if self.is_gpu_target(tgt) else topi.cumsum
206180
if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
207181
kwargs["workspace"] = self.allocate_workspace(call)
208182
elif call.op.name == "relax.cumprod":
209-
te_func = topi.cuda.cumprod if is_gpu_target(tgt) else topi.cumprod
183+
te_func = topi.cuda.cumprod if self.is_gpu_target(tgt) else topi.cumprod
210184
else:
211185
raise ValueError(f"Unsupported op: {call.op.name}")
212186
tir_call = self.builder_.call_te(

python/tvm/relax/backend/utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,61 @@
1717
# pylint: disable=invalid-name
1818
"""Utils for BYOC pattern matching"""
1919

20-
from tvm.relax import DataflowVar
20+
from typing import Tuple
21+
from tvm import relax
22+
from tvm.relax import DataflowVar, PyExprMutator
2123
from tvm.relax.transform import PatternCheckContext
24+
from tvm.target import Target
25+
26+
27+
class BackendDispatcher(PyExprMutator):
28+
"""Base class for backend dispatcher"""
29+
30+
def __init__(self, mod):
31+
super().__init__(mod)
32+
33+
@staticmethod
34+
def is_gpu_target(target: Target) -> bool:
35+
"""Check if the target is a GPU target."""
36+
return "gpu" in target.keys
37+
38+
@staticmethod
39+
def get_shape_dtype(expr: relax.Expr) -> Tuple[relax.ShapeExpr, str]:
40+
"""Get shape and dtype from an expression.
41+
If the shape and dtype is unknown, raise an error."""
42+
sinfo = expr.struct_info
43+
if not isinstance(expr.struct_info, relax.TensorStructInfo):
44+
raise ValueError(
45+
f"Expecting a expr with TensorStructInfo, but got {expr} with {expr.struct_info}"
46+
)
47+
48+
shape, dtype = sinfo.shape, sinfo.dtype
49+
if shape is None:
50+
raise ValueError(
51+
f"Expecting a expr with known shape, but got {expr} with unknown shape"
52+
)
53+
54+
return shape, dtype
55+
56+
def _get_target(self, sinfo: relax.StructInfo) -> Target:
57+
# Get target information from TensorStructInfo
58+
if isinstance(sinfo, relax.TensorStructInfo):
59+
vdevice = sinfo.vdevice
60+
if vdevice is not None:
61+
return vdevice.target
62+
elif isinstance(sinfo, relax.TupleStructInfo):
63+
for f in sinfo.fields:
64+
tgt = self._get_target(f)
65+
if tgt != Target.current():
66+
return tgt
67+
# Return the target in current context
68+
target = Target.current()
69+
if target is None:
70+
raise ValueError(
71+
"Target not found. Please ensure that the target is annotated within the module, "
72+
"or alternatively, execute this within a specified target context."
73+
)
74+
return target
2275

2376

2477
def has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:

python/tvm/relax/backend_tir/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
"""Relax backends, tir based"""
1818

1919
from . import contrib
20-
from .pattern import get_tir_pattern
2120
from .cumsum import gpu_2d_continuous_cumsum
21+
from .pattern import get_tir_pattern
22+
from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index

python/tvm/relax/backend_tir/cumsum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def gpu_2d_continuous_cumsum(
4141
Parameters
4242
----------
4343
ty_len : int
44-
The length of thread.y
44+
The length of `threadIdx.y`
4545
4646
tx_len : int
47-
The length of thread.x
47+
The length of `threadIdx.x`
4848
4949
thread_elem : int
5050
The number of elements processed by single thread
@@ -64,8 +64,8 @@ def gpu_2d_continuous_cumsum(
6464
out_dtype = out_dtype or in_dtype
6565

6666
# Configuration for GPU kernel
67-
TX = T.int64(tx_len) # thread.x
68-
TY = T.int64(ty_len) # thread.y
67+
TX = T.int64(tx_len) # threadIdx.x
68+
TY = T.int64(ty_len) # threadIdx.y
6969
N = T.int64(thread_elem) # number of elements in single thread
7070

7171
if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N):

0 commit comments

Comments
 (0)