diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 56058f0fa8dc3c..868df9711313bb 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -21,7 +21,6 @@ from paddle.base.data_feeder import check_dtype from paddle.device import ( is_compiled_with_cuda, - is_compiled_with_rocm, ) from paddle.device.cuda import get_device_capability from paddle.framework import ( @@ -43,7 +42,7 @@ def _get_arch_info(): # Get SMVersion from device. - if is_compiled_with_cuda() or is_compiled_with_rocm(): + if is_compiled_with_cuda(): cuda_version = paddle.version.cuda() if ( cuda_version is not None and cuda_version != 'False'