-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix dtype inference in arange_like operator #15930
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. This will fix the bug with float16 inputs. Would you mind just adding a simple unit test to check the dtype with float16 inputs?
@eric-haibin-lin Did you observe any crash with fp16 input? With below code snippet, it doesn't seem to crash but just gives numpy.float32 output: import mxnet as mx
import numpy as np
x = mx.sym.Variable('x', dtype=np.float16)
y = mx.sym.reshape(x, shape=(0, 0, -1))
z = mx.sym.contrib.arange_like(y, axis=-1)
mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), graph_req='null')
mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(np.float16)
out = mod.forward(is_train=False)
print(out[0].dtype) |
No I didn't expect a crash. I expect it copies dtype attribute like other xx_like ops |
@eric-haibin-lin do you think the below code snippet can be used as a test case? import mxnet as mx
import numpy as np
dtypes = [np.float16, np.float32, np.float64]
for t in dtypes:
x = mx.sym.Variable('x', dtype=t)
y = mx.sym.reshape(x, shape=(0, 0, -1))
z = mx.sym.contrib.arange_like(y, axis=-1)
mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), graph_req='null')
mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t)
out = mod.forward(is_train=False)
assert out[0].dtype == np.float32 |
Yes. Could you also check the forwward output with [0, 1, 2,.. ] etc? |
I hope to reserve the dtype attribution, and there is a default action when dtype is None. |
Just want to provide the same user experience for |
Description
Remove the dtype argument from parameter structure and use ElemwiseType instead.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments