From 9a7ab51bd81db4f3202c886415e349f502556223 Mon Sep 17 00:00:00 2001 From: Birdylx <29754889+Birdylx@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:19:38 +0800 Subject: [PATCH] [NPU] fix npu llava infer (#757) `paddle.amp.is_bfloat16_supported()` will raise error in NPU device. And bfloat16 is supported in default at 910B NPU device. Co-authored-by: LokeZhou --- paddlemix/examples/llava/run_predict_multiround.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlemix/examples/llava/run_predict_multiround.py b/paddlemix/examples/llava/run_predict_multiround.py index fb1a9746c..811c4e6a1 100644 --- a/paddlemix/examples/llava/run_predict_multiround.py +++ b/paddlemix/examples/llava/run_predict_multiround.py @@ -36,7 +36,11 @@ def main(args): paddle.seed(seed=0) compute_dtype = "float16" if args.fp16 else "bfloat16" - if compute_dtype== "bfloat16" and not paddle.amp.is_bfloat16_supported(): + if "npu" in paddle.get_device(): + is_bfloat16_supported = True + else: + is_bfloat16_supported = paddle.amp.is_bfloat16_supported() + if compute_dtype== "bfloat16" and not is_bfloat16_supported: logger.warning("bfloat16 is not supported on your device,change to float32") compute_dtype = "float32"