Skip to content

Commit 1bc8cf8

Browse files
vvchernovValery Chernov
andauthored
[ONNX] Support Bernoulli op on ONNX front-end (#13802)
* add Bernoulli converter for onnx front-end * test for bernoulli was implemented * fix tuple split. update test for stability with different seed on ort and tvm sides * check that output values are 0 or 1 * remove std check as meaningless * calculate theoretical mean and compare with result, remove ort for comparison. clean code * add customized input as arg * add test with input sequence of 0 and 1 * pylint fix * fix inputs-shape issue * add binomial test * fix input type * small fix * update 0-1 check * init arrays in numpy style * check result determinism for fixed seed * fix inputs issue * modify binomial test * pylint fix --------- Co-authored-by: Valery Chernov <[email protected]>
1 parent c2cc019 commit 1bc8cf8

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5669,6 +5669,36 @@ def _impl_v16(cls, inputs, attr, params):
56695669
)
56705670

56715671

5672+
class Bernoulli(OnnxOpConverter):
5673+
"""Operator converter for Bernoulli"""
5674+
5675+
@classmethod
5676+
def _impl_v15(cls, inputs, attr, params):
5677+
in_dtype = infer_type(inputs[0]).checked_type.dtype
5678+
assert in_dtype in [
5679+
"float32",
5680+
"float64",
5681+
], "Only float input tensor is currently supported."
5682+
# The data type for the elements of the output tensor.
5683+
# if not specified, we will use the data type of the input tensor
5684+
out_dtype = attr.get("dtype", None)
5685+
if out_dtype is None:
5686+
out_dtype = in_dtype
5687+
else:
5688+
out_dtype = get_type(out_dtype)
5689+
5690+
seed = attr.get("seed", None)
5691+
if seed is None:
5692+
seed = np.random.randint(1e6)
5693+
else:
5694+
seed = int(seed)
5695+
5696+
key = _random.threefry_key(seed)
5697+
inter_outputs = _op.random.uniform(key, infer_shape(inputs[0]), in_dtype)
5698+
_, uniform_nums = _expr.TupleWrapper(inter_outputs, 2)
5699+
return _op.cast(_op.less(uniform_nums, inputs[0]), out_dtype)
5700+
5701+
56725702
class RandomNormal(OnnxOpConverter):
56735703
"""Operator converter for random_normal"""
56745704

@@ -6436,6 +6466,7 @@ def _get_convert_map(opset):
64366466
"QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset),
64376467
"QLinearLeakyRelu": QLinearLeakyRelu.get_converter(opset),
64386468
# Random number generation.
6469+
"Bernoulli": Bernoulli.get_converter(opset),
64396470
"RandomNormal": RandomNormal.get_converter(opset),
64406471
"RandomNormalLike": RandomNormalLike.get_converter(opset),
64416472
"RandomUniform": RandomUniform.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6914,6 +6914,165 @@ def verify_qlinearsigmoid(a_shape):
69146914
verify_qlinearsigmoid([])
69156915

69166916

6917+
@tvm.testing.parametrize_targets("llvm")
6918+
def test_random_bernoulli(target, dev):
6919+
"""test_random_bernoulli"""
6920+
6921+
def _get_tvm_output(
6922+
inputs,
6923+
out_dtype="int32",
6924+
seed=None,
6925+
target=target,
6926+
dev=dev,
6927+
use_vm=False,
6928+
freeze_params=False,
6929+
):
6930+
def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", seed=None):
6931+
onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
6932+
onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
6933+
node = helper.make_node(
6934+
"Bernoulli",
6935+
["input"],
6936+
["output"],
6937+
)
6938+
dtype_attr = helper.make_attribute("dtype", onnx_otype)
6939+
node.attribute.append(dtype_attr)
6940+
if seed is not None:
6941+
seed_attr = helper.make_attribute("seed", float(seed))
6942+
node.attribute.append(seed_attr)
6943+
6944+
graph = helper.make_graph(
6945+
[node],
6946+
"random_bernoulli_test",
6947+
inputs=[helper.make_tensor_value_info("input", onnx_itype, list(shape))],
6948+
outputs=[helper.make_tensor_value_info("output", onnx_otype, list(shape))],
6949+
)
6950+
return helper.make_model(graph, producer_name="random_bernoulli_test")
6951+
6952+
shape = inputs.shape
6953+
in_dtype = inputs.dtype
6954+
model = get_bernoulli_model(shape, in_dtype, out_dtype, seed)
6955+
6956+
if use_vm:
6957+
return get_tvm_output_with_vm(
6958+
model,
6959+
inputs,
6960+
target,
6961+
dev,
6962+
freeze_params=freeze_params,
6963+
)
6964+
else:
6965+
return get_tvm_output(
6966+
model,
6967+
inputs,
6968+
target,
6969+
dev,
6970+
)
6971+
6972+
def binom_test(input, ideal_mean, threshold=0.05):
6973+
# This test is strictly appropriate when input probabilities are all identical.
6974+
# In that case, it should lead to flaky failures in only one run in a million (p>=1e-6).
6975+
# The test should be over-conservative when input probabilities are not identical.
6976+
# (i.e., It should have a rate of flaky failures lower than one run in a million.)
6977+
# If this test starts repeatedly throwing flaky failures, consult a statistician
6978+
# in addition to your regular debugging.
6979+
bnm_test_res = scipy.stats.binomtest(
6980+
k=np.sum(input, dtype="int32"), n=len(input), p=ideal_mean
6981+
)
6982+
return bnm_test_res.pvalue > threshold
6983+
6984+
def verify_bernoulli(
6985+
inputs=None,
6986+
shape=[],
6987+
in_dtype="float32",
6988+
out_dtype="int32",
6989+
seed=None,
6990+
target=target,
6991+
dev=dev,
6992+
use_vm=False,
6993+
freeze_params=False,
6994+
in_out_equal=False,
6995+
):
6996+
if inputs is None:
6997+
assert len(shape) != 0
6998+
inputs = np.random.uniform(size=shape).astype(in_dtype)
6999+
7000+
tvm_out = _get_tvm_output(
7001+
inputs,
7002+
out_dtype,
7003+
seed,
7004+
target,
7005+
dev,
7006+
use_vm,
7007+
freeze_params,
7008+
)
7009+
7010+
if isinstance(tvm_out, list):
7011+
tvm_out = tvm_out[0]
7012+
# check that values are 0 or 1
7013+
tvm_flat = tvm_out.flatten()
7014+
assert np.array_equal(tvm_flat, tvm_flat.astype("bool"))
7015+
if in_out_equal:
7016+
tvm.testing.assert_allclose(inputs, tvm_out)
7017+
else:
7018+
# check that mean value is close to the theoretical one by binomial test
7019+
ideal_mean = np.mean(inputs)
7020+
repeats = 3
7021+
check = False
7022+
for i in range(repeats):
7023+
if binom_test(tvm_flat, ideal_mean):
7024+
check = True
7025+
break
7026+
else:
7027+
# repeat with new seed
7028+
seed = np.random.randint(1e6)
7029+
tvm_flat = _get_tvm_output(
7030+
inputs,
7031+
out_dtype,
7032+
seed,
7033+
target,
7034+
dev,
7035+
use_vm,
7036+
freeze_params,
7037+
).flatten()
7038+
assert check, "Binomial test failed"
7039+
7040+
# Test input sequence of 0 and 1
7041+
inputs = np.random.randint(2, size=[10000]).astype("float32")
7042+
verify_bernoulli(inputs, in_out_equal=True)
7043+
7044+
# Binomial test input with 0.5 values
7045+
val_num = 10000
7046+
inputs = np.ones([val_num], dtype="float32") * 0.5
7047+
verify_bernoulli(inputs)
7048+
7049+
# Binomial test input with 0.1 values
7050+
inputs = np.ones([val_num], dtype="float32") * 0.1
7051+
verify_bernoulli(inputs)
7052+
7053+
# Simple test
7054+
verify_bernoulli(shape=[val_num])
7055+
7056+
# Floating output type
7057+
verify_bernoulli(shape=[val_num], out_dtype="float32")
7058+
7059+
# Double input type
7060+
verify_bernoulli(shape=[val_num], in_dtype="float64")
7061+
7062+
# Test N-D tensor generation
7063+
verify_bernoulli(shape=[2, 4, 100, 100])
7064+
7065+
# Test with seed
7066+
verify_bernoulli(shape=[val_num], seed=np.random.randint(1e6))
7067+
7068+
# Test result determinism with the same seeds
7069+
inputs = np.random.uniform(size=[val_num])
7070+
fixed_seed = np.random.randint(1e6)
7071+
tvm_out_1 = _get_tvm_output(inputs, seed=fixed_seed)
7072+
tvm_out_2 = _get_tvm_output(inputs, seed=fixed_seed)
7073+
tvm.testing.assert_allclose(tvm_out_1, tvm_out_2)
7074+
7075+
69177076
@tvm.testing.parametrize_targets("llvm")
69187077
def test_random_uniform(target, dev):
69197078
"""test_random_uniform"""

0 commit comments

Comments
 (0)