-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Batch_dot does not support FP16 well #11796
Comments
Oops, wrong button. Relevant links: https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/dot-inl.h#L1347-L1364 https://github.com/dmlc/mshadow/blob/master/mshadow/dot_engine-inl.h#L528-L539. While for float the strided gemm is used, the half_t type is calling regular gemm. Instead, the strided gemm in cublas can be used which supports half_t: https://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemmbatched |
I'm adding cublas strided gemm calls in mshadow dmlc/mshadow#353 |
merged |
@szha: can we reopen this? For some reason, the fix in dmlc/mshadow#353 was reverted by this commit by @eric-haibin-lin . This code, run on version
takes 0.9s on a V100 (and 0.0318s when using float32 instead, a 30x slowdown!) We want to implement transformers using TensorCores for training, but there is no way of doing this in MXNet at the moment ( What is the plan for exposing any form of GEMM to users with Real16 and TensorCore support? |
Sorry about the revert. I found that it is better to implement fp16 ops in mxnet instead of in mshadow, since there are built in functionality to detect/enable tensorcore. I can make a PR in maybe two or three days. @sbodenstein are you using symbol or gluon to train transformer? |
@eric-haibin-lin: we are using symbol to train transformer. That would be great to reenable this as soon as possible. Is there any reason to not expose |
@eric-haibin-lin: any updates about this? |
Added in #13716 |
The
batch_dot
does not support FP16 well and can make training slower compared to using FP32. This is tested using Transformer model in Gluonnlp. This feature has been added in a NVIDIA mxnet. So I think it is good to enable this in the master.The text was updated successfully, but these errors were encountered: