Skip to content

Commit

Permalink
miss device_type when checking is_bf16_supported on ascend platform (I…
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 authored and AllentDan committed Nov 13, 2024
1 parent c3d65b7 commit 082c530
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __check_model_transformers_version(config, trans_version):
f'but transformers {trans_version} is installed.')
_handle_exception(e, 'transformers', logger, message=message)

def __check_model_dtype_support(config):
def __check_model_dtype_support(config, device_type):
"""Checking model dtype support."""
logger.debug('Checking <Model> dtype support.')

Expand All @@ -215,7 +215,7 @@ def __check_model_dtype_support(config):
model_path=model_path,
dtype=dtype)
if model_config.dtype == torch.bfloat16:
assert is_bf16_supported(), (
assert is_bf16_supported(device_type), (
'bf16 is not supported on your device')
except AssertionError as e:
message = (
Expand All @@ -234,7 +234,7 @@ def __check_model_dtype_support(config):
_, trans_version = __check_transformers_version()
config = __check_config(trans_version)
__check_model_transformers_version(config, trans_version)
__check_model_dtype_support(config)
__check_model_dtype_support(config, device_type)
check_awq(config, device_type)


Expand Down

0 comments on commit 082c530

Please sign in to comment.