diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index a0b17b333099..36a5d3e9979a 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -224,7 +224,8 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride, - pool_type=pool_type, global_pool=global_pool, cudnn_off=False) + pool_type=pool_type, global_pool=global_pool, cudnn_off=False, + pooling_convention=convention) arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape) arg_names = pooling_fp32.list_arguments() pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')