diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index d8869738f0a4..557258ad7637 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -615,6 +615,8 @@ FP32_FUNCS.extend([ '_sg_mkldnn_conv', '_sg_mkldnn_fully_connected', + '_sg_mkldnn_selfatt_qk', + '_sg_mkldnn_selfatt_valatt', ]) # Functions that have to be cast to FP32 only for