Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Random uniform, Random normal
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 18, 2018
1 parent 5e46db4 commit a32a550
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 4 deletions.
52 changes: 52 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,3 +1655,55 @@ def convert_size(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)


@mx_op.register("_random_uniform")
def convert_random_uniform(node, **kwargs):
"""Map MXNet's random_uniform operator attributes to onnx's RandomUniform
operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

# Converting to float32
low = float(attrs.get("low", 0))
high = float(attrs.get("high", 1.0))
shape = convert_string_to_list(attrs.get('shape', '[]'))
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get('dtype', 'float32'))]

node = onnx.helper.make_node(
'RandomUniform',
input_nodes,
[name],
low=low,
high=high,
dtype=dtype,
shape=shape,
name=name
)
return [node]


@mx_op.register("_random_normal")
def convert_random_normal(node, **kwargs):
"""Map MXNet's random_normal operator attributes to onnx's RandomNormal
operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

# Converting to float32
mean = float(attrs.get("loc", 0))
scale = float(attrs.get("scale", 1.0))
shape = convert_string_to_list(attrs.get('shape', '[]'))
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get('dtype', 'float32'))]

node = onnx.helper.make_node(
'RandomNormal',
input_nodes,
[name],
mean=mean,
scale=scale,
dtype=dtype,
shape=shape,
name=name
)
return [node]
22 changes: 18 additions & 4 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,28 @@ def identity(attrs, inputs, proto_obj):

def random_uniform(attrs, inputs, proto_obj):
"""Draw random samples from a uniform distribtuion."""
new_attr = translation_utils._remove_attributes(attrs, ['seed'])
return 'random_uniform', new_attr, inputs
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
"Instructions to install - https://github.com/onnx/onnx")
new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(new_attrs.get('dtype', 1))]
return 'random_uniform', new_attrs, inputs


def random_normal(attrs, inputs, proto_obj):
"""Draw random samples from a Gaussian distribution."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
"Instructions to install - https://github.com/onnx/onnx")
new_attr = translation_utils._remove_attributes(attrs, ['seed'])
new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'})
return 'random_uniform', new_attr, inputs
new_attr = translation_utils._fix_attribute_names(new_attr, {'mean': 'loc'})
new_attr['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(new_attr.get('dtype', 1))]
return 'random_normal', new_attr, inputs


# Arithmetic Operations
def add(attrs, inputs, proto_obj):
Expand Down
32 changes: 32 additions & 0 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,38 @@ def test_ops(op_name, inputs, input_tensors, numpy_op):
np.logical_not(input_data[0]).astype(np.float32))


@with_seed()
def test_random_uniform():
shape = (2, 2)
inputs = []
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=shape)]
nodes = [helper.make_node("RandomUniform", [], ["output"], low=0.5, high=1.0, shape=shape)]
graph = helper.make_graph(nodes,
"random_uniform_test",
inputs,
outputs)
model = helper.make_model(graph)
bkd_rep = backend.prepare(model)
output = bkd_rep.run([])
assert output[0].shape == shape


@with_seed()
def test_random_normal():
shape = (2, 2)
inputs = []
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=shape)]
nodes = [helper.make_node("RandomNormal", [], ["output"], mean=0, scale=1, shape=shape)]
graph = helper.make_graph(nodes,
"random_normal_test",
inputs,
outputs)
model = helper.make_model(graph)
bkd_rep = backend.prepare(model)
output = bkd_rep.run([])
assert output[0].shape == shape


def _assert_sym_equal(lhs, rhs):
assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical
assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical
Expand Down

0 comments on commit a32a550

Please sign in to comment.