Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 28, 2018
1 parent 4030bff commit ebbe347
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,15 +717,16 @@ def spacetodepth(attrs, inputs, proto_obj):


def hardmax(attrs, inputs, proto_obj):
"""Returns batched one-hot vectors."""
input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
input_shape = input_tensor_data[1]

axis = int(attrs.get('axis', 1))
axis = axis if axis >= 0 else len(input_shape) + axis

if axis == len(input_shape) - 1:
argmax = symbol.argmax(inputs[0], axis=-1)
one_hot = symbol.one_hot(argmax, depth=input_shape[-1])
amax = symbol.argmax(inputs[0], axis=-1)
one_hot = symbol.one_hot(amax, depth=input_shape[-1])
return one_hot, attrs, inputs

# since reshape doesn't take a tensor for shape,
Expand All @@ -734,7 +735,7 @@ def hardmax(attrs, inputs, proto_obj):
new_shape = (int(np.prod(input_shape[:axis])),
int(np.prod(input_shape[axis:])))
reshape_op = symbol.reshape(inputs[0], new_shape)
argmax = symbol.argmax(reshape_op, axis=-1)
one_hot = symbol.one_hot(argmax, depth=new_shape[-1])
amax = symbol.argmax(reshape_op, axis=-1)
one_hot = symbol.one_hot(amax, depth=new_shape[-1])
hardmax_op = symbol.reshape(one_hot, input_shape)
return hardmax_op, attrs, inputs

0 comments on commit ebbe347

Please sign in to comment.