diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 6250997253..7d72438224 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -201,7 +201,7 @@ def __check_model_transformers_version(config, trans_version): f'but transformers {trans_version} is installed.') _handle_exception(e, 'transformers', logger, message=message) - def __check_model_dtype_support(config): + def __check_model_dtype_support(config, device_type): """Checking model dtype support.""" logger.debug('Checking dtype support.') @@ -215,7 +215,7 @@ def __check_model_dtype_support(config): model_path=model_path, dtype=dtype) if model_config.dtype == torch.bfloat16: - assert is_bf16_supported(), ( + assert is_bf16_supported(device_type), ( 'bf16 is not supported on your device') except AssertionError as e: message = ( @@ -234,7 +234,7 @@ def __check_model_dtype_support(config): _, trans_version = __check_transformers_version() config = __check_config(trans_version) __check_model_transformers_version(config, trans_version) - __check_model_dtype_support(config) + __check_model_dtype_support(config, device_type) check_awq(config, device_type)