Skip to content

Commit 45af5c7

Browse files
authored
[Frontend][ONNX] Support RandomNormal operator (#9493)
1 parent 008367d commit 45af5c7

File tree

8 files changed

+418
-2
lines changed

8 files changed

+418
-2
lines changed

include/tvm/relay/attrs/random.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ struct UniformAttrs : public tvm::AttrsNode<UniformAttrs> {
4949
}
5050
};
5151

52+
struct NormalAttrs : public tvm::AttrsNode<NormalAttrs> {
53+
Array<Integer> out_shape;
54+
DataType out_dtype;
55+
56+
TVM_DECLARE_ATTRS(NormalAttrs, "relay.attrs.NormalAttrs") {
57+
TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate");
58+
TVM_ATTR_FIELD(out_dtype)
59+
.set_default(NullValue<DataType>())
60+
.describe("Data type of the generated numbers");
61+
}
62+
};
63+
5264
} // namespace relay
5365
} // namespace tvm
5466
#endif // TVM_RELAY_ATTRS_RANDOM_H_

python/tvm/relay/frontend/onnx.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3888,6 +3888,62 @@ def _impl_v12(cls, inputs, attr, params):
38883888
return _op.einsum(inputs, equation)
38893889

38903890

3891+
class RandomNormal(OnnxOpConverter):
3892+
"""Operator converter for random_normal"""
3893+
3894+
@classmethod
3895+
def _impl_v1(cls, inputs, attr, params):
3896+
dtype = get_type(attr.get("dtype", 1))
3897+
mean = attr.get("mean", 0.0)
3898+
scale = attr.get("scale", 1.0)
3899+
seed = attr.get("seed", None)
3900+
shape = attr["shape"]
3901+
3902+
assert dtype in [
3903+
"float32",
3904+
"float64",
3905+
], "Only float random value generation is currently supported."
3906+
3907+
if seed is None:
3908+
seed = np.random.randint(1e6)
3909+
else:
3910+
seed = int(seed)
3911+
key = _random.threefry_key(seed)
3912+
output = _op.random.normal(key, shape, dtype=dtype, mean=mean, scale=scale)
3913+
_, vals = _expr.TupleWrapper(output, 2)
3914+
return vals
3915+
3916+
3917+
class RandomNormalLike(OnnxOpConverter):
3918+
"""Operator converter for random_normal_like"""
3919+
3920+
@classmethod
3921+
def _impl_v1(cls, inputs, attr, params):
3922+
dtype = attr.get("dtype", None)
3923+
scale = attr.get("scale", 1.0)
3924+
mean = attr.get("mean", 0.0)
3925+
seed = attr.get("seed", None)
3926+
shape = infer_shape(inputs[0])
3927+
if dtype is None:
3928+
dtype = infer_type(inputs[0]).checked_type.dtype
3929+
else:
3930+
dtype = get_type(dtype)
3931+
3932+
assert dtype in [
3933+
"float32",
3934+
"float64",
3935+
], "Only float random value generation is currently supported."
3936+
3937+
if seed is None:
3938+
seed = np.random.randint(1e6)
3939+
else:
3940+
seed = int(seed)
3941+
key = _random.threefry_key(seed)
3942+
output = _op.random.normal(key, shape, dtype=dtype, mean=mean, scale=scale)
3943+
_, vals = _expr.TupleWrapper(output, 2)
3944+
return vals
3945+
3946+
38913947
class RandomUniform(OnnxOpConverter):
38923948
"""Operator converter for random_uniform"""
38933949

@@ -3906,6 +3962,38 @@ def _impl_v1(cls, inputs, attr, params):
39063962

39073963
if seed is None:
39083964
seed = np.random.randint(1e6)
3965+
else:
3966+
seed = int(seed)
3967+
key = _random.threefry_key(seed)
3968+
output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high)
3969+
_, vals = _expr.TupleWrapper(output, 2)
3970+
return vals
3971+
3972+
3973+
class RandomUniformLike(OnnxOpConverter):
3974+
"""Operator converter for random_uniform_like"""
3975+
3976+
@classmethod
3977+
def _impl_v1(cls, inputs, attr, params):
3978+
dtype = attr.get("dtype", None)
3979+
high = attr.get("high", 1.0)
3980+
low = attr.get("low", 0.0)
3981+
seed = attr.get("seed", None)
3982+
shape = infer_shape(inputs[0])
3983+
if dtype is None:
3984+
dtype = infer_type(inputs[0]).checked_type.dtype
3985+
else:
3986+
dtype = get_type(dtype)
3987+
3988+
assert dtype in [
3989+
"float32",
3990+
"float64",
3991+
], "Only float random value generation is currently supported."
3992+
3993+
if seed is None:
3994+
seed = np.random.randint(1e6)
3995+
else:
3996+
seed = int(seed)
39093997
key = _random.threefry_key(seed)
39103998
output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high)
39113999
_, vals = _expr.TupleWrapper(output, 2)
@@ -4396,7 +4484,10 @@ def _get_convert_map(opset):
43964484
"QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset),
43974485
"QLinearLeakyRelu": QLinearLeakyRelu.get_converter(opset),
43984486
# Random number generation.
4487+
"RandomNormal": RandomNormal.get_converter(opset),
4488+
"RandomNormalLike": RandomNormalLike.get_converter(opset),
43994489
"RandomUniform": RandomUniform.get_converter(opset),
4490+
"RandomUniformLike": RandomUniformLike.get_converter(opset),
44004491
# Loss functions / training
44014492
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
44024493
"SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset),

python/tvm/relay/op/random/_kernel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@
3131
# Distribution
3232
register_strategy("random.uniform", strategy.uniform_strategy)
3333
register_pattern("random.uniform", OpPattern.OPAQUE)
34+
register_strategy("random.normal", strategy.normal_strategy)
35+
register_pattern("random.normal", OpPattern.OPAQUE)

python/tvm/relay/op/random/kernel.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,51 @@ def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
183183
if not isinstance(high, Expr):
184184
high = const(high, dtype=dtype)
185185
return _make.uniform(key, low, high, shape, dtype)
186+
187+
188+
def normal(key, shape, dtype="float32", mean=0.0, scale=1.0):
189+
"""Draw samples from a normal distribution.
190+
191+
Example
192+
-------
193+
194+
.. code-block:: python
195+
196+
key = threefry_key(0)
197+
key, random_values = normal(key, (100,), low=0, high=10)
198+
199+
Parameters
200+
----------
201+
key : relay.Expr
202+
key that uniquely determines the random values. Multiple uses with the
203+
same generator will generate the same random values. This generator should be
204+
treated as an opaque pointer. You can create one from calling
205+
:py:func:`threefry_key`, :py:func:`threefry_split`, or
206+
:py:func:`threefry_generate`. **Do not use this generator again after calling
207+
this function.**
208+
209+
shape : Sequence[int]
210+
Desired outputs shape of random numbers.
211+
212+
dtype : str
213+
Desired outputs type of random numbers.
214+
215+
low : float or relay.Expr, optional
216+
Mean of the normal distribution.
217+
218+
high : float or relay.Expr, optional
219+
Standard deviation of the normal distribution.
220+
221+
Returns
222+
-------
223+
new_key : relay.Expr
224+
New random key to pass to future uses of random functions.
225+
226+
random_values : relay.Expr
227+
The generated normal distributed random numbers.
228+
"""
229+
if not isinstance(mean, Expr):
230+
mean = const(mean, dtype=dtype)
231+
if not isinstance(scale, Expr):
232+
scale = const(scale, dtype=dtype)
233+
return _make.normal(key, mean, scale, shape, dtype)

python/tvm/relay/op/strategy/generic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,18 @@ def uniform_strategy(attrs, inputs, out_type, target):
16391639
return strategy
16401640

16411641

1642+
@override_native_generic_func("normal_strategy")
1643+
def normal_strategy(attrs, inputs, out_type, target):
1644+
"""normal generic strategy"""
1645+
strategy = _op.OpStrategy()
1646+
strategy.add_implementation(
1647+
wrap_compute_uniform(topi.random.normal),
1648+
wrap_topi_schedule(topi.generic.schedule_extern),
1649+
name="normal.generic",
1650+
)
1651+
return strategy
1652+
1653+
16421654
def wrap_compute_scanop(topi_compute):
16431655
"""Wrap scanop style topi compute"""
16441656

python/tvm/topi/random/kernel.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Pseudorandom number kernels."""
18+
import math
1819
import numpy as np
1920

2021
import tvm
@@ -544,3 +545,60 @@ def uniform(gen, low, high, out_shape, out_dtype):
544545
uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)
545546

546547
return new_gen, uniform_values
548+
549+
550+
def normal(gen, mean, scale, out_shape, out_dtype):
551+
"""Draw samples from a normal distribution.
552+
The algorithm is based on Box-Muller transform
553+
554+
Parameters
555+
----------
556+
gen : ThreefryKey
557+
Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
558+
reused in another function, otherwise random numbers will be repeated.
559+
560+
mean : Tensor[(), out_dtype]
561+
The mean of the normal distribution.
562+
563+
scale : Tensor[(), out_dtype]
564+
The standard deviation of the normal distribution.
565+
566+
out_shape : Sequence[int]
567+
Output shape of the random numbers.
568+
569+
out_dtype : str
570+
The output dtype.
571+
572+
Returns
573+
-------
574+
new_gen : ThreefryKey
575+
New generator state that is distinct from `gen`.
576+
577+
out : Tensor[out_shape, out_dtype]
578+
Tensor of random numbers with shape `out_shape` and type `out_dtype`.
579+
"""
580+
out_shape = list(out_shape)
581+
# Box-Muller transform need two pieces of original uniform data
582+
out_shape.insert(0, 2)
583+
new_gen, uniform_values = uniform(
584+
gen,
585+
tvm.tir.const(0.0, out_dtype),
586+
tvm.tir.const(1.0, out_dtype),
587+
out_shape,
588+
out_dtype,
589+
)
590+
two_pi = tvm.tir.const(2.0 * math.pi, out_dtype)
591+
uniform_values_1 = tvm.topi.strided_slice(uniform_values, [0], [1], strides=[1], axes=[0])
592+
uniform_values_1 = tvm.topi.squeeze(uniform_values_1, axis=0)
593+
uniform_values_2 = tvm.topi.strided_slice(uniform_values, [1], [2], strides=[1], axes=[0])
594+
uniform_values_2 = tvm.topi.squeeze(uniform_values_2, axis=0)
595+
uniform_values_1 = tvm.topi.subtract(tvm.tir.const(1.0, out_dtype), uniform_values_1)
596+
sqrt_values = tvm.topi.sqrt(
597+
tvm.topi.multiply(tvm.tir.const(-2.0, out_dtype), tvm.topi.log(uniform_values_1))
598+
)
599+
sin_values = tvm.topi.sin(tvm.topi.multiply(two_pi, uniform_values_2))
600+
random_values = tvm.topi.add(
601+
tvm.topi.multiply(tvm.topi.multiply(sqrt_values, sin_values), scale), mean
602+
)
603+
604+
return new_gen, random_values

src/relay/op/random/kernel.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,5 +132,52 @@ RELAY_REGISTER_OP("random.uniform")
132132
.add_argument("high", "Tensor", "Higher bound of the distribution")
133133
.add_type_rel("Uniform", UniformRel);
134134

135+
TVM_REGISTER_NODE_TYPE(NormalAttrs);
136+
137+
bool NormalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
138+
const TypeReporter& reporter) {
139+
const NormalAttrs* param = attrs.as<NormalAttrs>();
140+
ICHECK_EQ(types.size(), 4) << "Normal should have three inputs and one output";
141+
142+
std::vector<IndexExpr> oshape;
143+
for (auto& x : param->out_shape) {
144+
oshape.push_back(x);
145+
}
146+
DataType out_dtype = param->out_dtype;
147+
// we are supporting float32 and float64 at the moment.
148+
if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) {
149+
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
150+
<< "We only support generating Normal random value of "
151+
<< "type float32 or float64, got " << out_dtype << ".");
152+
return false;
153+
}
154+
reporter->Assign(types[0], ThreefryKeyType());
155+
reporter->Assign(types[1], TensorType({}, out_dtype));
156+
reporter->Assign(types[2], TensorType({}, out_dtype));
157+
// generate returns the next key and an array of random values
158+
reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)}));
159+
return true;
160+
}
161+
162+
Expr MakeNormal(Expr key, Expr mean, Expr scale, Array<Integer> out_shape, DataType out_dtype) {
163+
auto attrs = make_object<NormalAttrs>();
164+
attrs->out_shape = out_shape;
165+
attrs->out_dtype = out_dtype;
166+
static const Op& op = Op::Get("random.normal");
167+
return Call(op, {key, mean, scale}, Attrs(attrs), {});
168+
}
169+
170+
TVM_REGISTER_GLOBAL("relay.op.random._make.normal").set_body_typed(MakeNormal);
171+
172+
RELAY_REGISTER_OP("random.normal")
173+
.describe(
174+
R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE)
175+
.set_num_inputs(3)
176+
.set_attrs_type<NormalAttrs>()
177+
.add_argument("key", "Tensor", "Input Threefry key")
178+
.add_argument("mean", "Tensor", "Mean of the distribution")
179+
.add_argument("scale", "Tensor", "Standard deviation of the distribution")
180+
.add_type_rel("Normal", NormalRel);
181+
135182
} // namespace relay
136183
} // namespace tvm

0 commit comments

Comments
 (0)