Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Ascend Atlas 300I Duo NPU device #2471

Closed
seoibiubiu opened this issue Sep 15, 2024 · 24 comments
Closed

Test Ascend Atlas 300I Duo NPU device #2471

seoibiubiu opened this issue Sep 15, 2024 · 24 comments
Assignees

Comments

@seoibiubiu
Copy link

Does it support Ascend Atlas 300I Duo NPU devices?

@lvhan028
Copy link
Collaborator

No, it doesn't.

@jinminxi104
Copy link
Collaborator

jinminxi104 commented Sep 17, 2024

Since we don't have 300I device, 300I device is not tested. But, huawei document shows that almost all operators supported on 800t are also supported on 300 inference series. We recommend you to try it on 300I and let us know if there is any error on 300i device.

@seoibiubiu
Copy link
Author

Since we don't have 300I device, 300I device is not tested. But, huawei document shows that almost all operators supported on 800t are also supported on 300 inference series. We recommend you to try it on 300I and let us know if there is any error on 300i device.

thanks, I will try and give feedback soon.

@seoibiubiu
Copy link
Author

Now, I can tell everyone about the testing of the Atlas 300I Duo device:

My environment

  • NPU Driver: 24.1.RC2
  • CANN version: 8.0.RC2
  • model used: Qwen2-7B-Instruct

Python code

# 离线推理测试
from lmdeploy import pipeline
from lmdeploy import PytorchEngineConfig

pipe = pipeline("/opt/models/Qwen2-7B-Instruct", backend_config = PytorchEngineConfig(tp=1, device_type="ascend"))
question = ["Shanghai is", "Please introduce China", "How are you?"]
response = pipe(question)
print(response)

Response

root@c486e2f96ded:/opt/lmdeploy# python3 offline_test.py
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:301: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:260: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/cpp_extension.py:28: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/opt/lmdeploy/lmdeploy/serve/utils.py:22: DeprecationWarning: There is no current event loop
  event_loop = asyncio.get_event_loop()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
[E compiler_depend.ts:270] call aclnnFlashAttentionVarLenScore failed, detail:EZ9999: Inner Error!
EZ9999: 2024-09-20-07:49:24.783.260  Op FlashAttentionScore does not has any binary.
        TraceBack (most recent call last):
        Kernel Run failed. opType: 29, FlashAttentionScore
        launch failed for FlashAttentionScore, errno:561000.

[ERROR] 2024-09-20-07:49:24 (PID:759, Device:0, RankID:-1) ERR01005 OPS internal error
Exception raised from operator() at build/CMakeFiles/torch_npu.dir/compiler_depend.ts:452 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x68 (0xfffd264dd898 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x6c (0xfffd264962a8 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xd4f864 (0xfffce481f864 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #3: <unknown function> + 0xe789a0 (0xfffce49489a0 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #4: <unknown function> + 0x5eab64 (0xfffce40bab64 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #5: <unknown function> + 0x5eb018 (0xfffce40bb018 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #6: <unknown function> + 0x5e8520 (0xfffce40b8520 in /usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #7: <unknown function> + 0x946ec (0xfffd265046ec in /usr/local/python3.10.5/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #8: <unknown function> + 0x7624 (0xfffd6cc07624 in /lib/aarch64-linux-gnu/libpthread.so.0)
frame #9: <unknown function> + 0xd162c (0xfffd6cd3162c in /lib/aarch64-linux-gnu/libc.so.6)


2024-09-20 07:49:24,787 - lmdeploy - ERROR - Engine loop failed with error: The Inner error is reported as above. The process exits for this inner error, and the current working operator name is aclnnFlashAttentionVarLenScore.
Since the operator is called asynchronously, the stacktrace may be inaccurate. If you want to get the accurate stacktrace, pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.
[ERROR] 2024-09-20-07:49:24 (PID:759, Device:0, RankID:-1) ERR00100 PTA call acl api failed
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 944, in async_loop
    await self._async_loop()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 934, in _async_loop
    await __step(True)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 920, in __step
    raise e
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 912, in __step
    raise out
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 856, in _async_loop_background
    await self._async_step_background(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 735, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/lmdeploy/lmdeploy/utils.py", line 237, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 633, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 611, in __forward
    return await self.model_agent.async_forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 332, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 299, in _forward_impl
    output = model_forward(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/graph_runner.py", line 25, in __call__
    return self.model(**kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 340, in forward
    hidden_states = self.model(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 278, in forward
    hidden_states, residual = decoder_layer(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 194, in forward
    hidden_states = self.self_attn(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 101, in forward
    attn_output = self.o_proj(attn_output)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/linear.py", line 921, in forward
    return self.impl.forward(x, self.weight, self.bias, all_reduce)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/default/linear.py", line 20, in forward
    out = F.linear(x, weight, bias)
RuntimeError: The Inner error is reported as above. The process exits for this inner error, and the current working operator name is aclnnFlashAttentionVarLenScore.
Since the operator is called asynchronously, the stacktrace may be inaccurate. If you want to get the accurate stacktrace, pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.
[ERROR] 2024-09-20-07:49:24 (PID:759, Device:0, RankID:-1) ERR00100 PTA call acl api failed
/usr/local/python3.10.5/lib/python3.10/tempfile.py:837: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpe97gupob'>
  _warnings.warn(warn_message, ResourceWarning)

@seoibiubiu
Copy link
Author

maybe this operator is not supported :(
算子

@seoibiubiu seoibiubiu changed the title Ascend NPU support Test Ascend Atlas 300I Duo NPU device by LMDeploy V0.6.0 Sep 20, 2024
Copy link

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.

@github-actions github-actions bot added the Stale label Sep 28, 2024
@jinminxi104
Copy link
Collaborator

maybe this operator is not supported

Oh..We use this training-op for more efficient calculation. We'll solve this issue in next month.

Copy link

github-actions bot commented Oct 8, 2024

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.

@wangyuanxiong-hub
Copy link

maybe this operator is not supported

Oh..We use this training-op for more efficient calculation. We'll solve this issue in next month.

Hi, I am also troubled by this. Is there any new progress recently?

@qiling1345
Copy link

any solution yet? thanks for helping!

Copy link

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.

@github-actions github-actions bot added the Stale label Oct 22, 2024
@yao-fengchen
Copy link
Collaborator

We have prepared a plan on 300I Duo, could you help us test it? @zer0py2c @wangyuanxiong-hub @qiling1345

@yao-fengchen
Copy link
Collaborator

We have prepared a plan on 300I Duo, could you help us test it? @zer0py2c @wangyuanxiong-hub @qiling1345

You can try running it with https://github.com/yao-fengchen/dlinfer/tree/fix_attn and https://github.com/DeepLink-org/lmdeploy/tree/fix_attn on 300I Duo and let us know if there is any error on 300I Duo.

@seoibiubiu
Copy link
Author

We have prepared a plan on 300I Duo, could you help us test it?

I'm honored! I will try and give feedback soon.

@seoibiubiu
Copy link
Author

We have prepared a plan on 300I Duo, could you help us test it? @zer0py2c @wangyuanxiong-hub @qiling1345

You can try running it with https://github.com/yao-fengchen/dlinfer/tree/fix_attn and https://github.com/DeepLink-org/lmdeploy/tree/fix_attn on 300I Duo and let us know if there is any error on 300I Duo.

I get an error when I executed offline inference code, and my operation is as follows:

环境配置

  • NPU驱动版本:24.1.rc2
  • NPU固件版本:7.3.0.1.231
  • CANN版本:8.0.RC3.alpha001

运行容器

  1. 下载并解压fix_attn源码
unzip lmdeploy-fix_attn.zip -d /root/projects/
unzip dlinfer-fix_attn.zip -d /root/projects/lmdeploy-fix_attn
mv Ascend-cann-kernels-310p_8.0.RC3.alpha001_linux.run /root/projects/lmdeploy-fix_attn
mv Ascend-cann-toolkit_8.0.RC3.alpha001_linux-aarch64.run /root/projects/lmdeploy-fix_attn
cd /root/projects/lmdeploy-fix_attn
  1. 更改Dockerfile中dlinfer安装命令
# dlinfer
# transformers>=4.41.0 is required for internlm2 model
# timm is required for internvl2 model
COPY dlinfer-fix_attn /opt/dlinfer-fix_attn
RUN --mount=type=cache,target=/root/.cache/pip \
    pip3 install transformers>=4.41.0 timm && \
    cd /opt/dlinfer-fix_attn && \
    pip3 install -r requirements/ascend/full.txt && \
    DEVICE=ascend python3 setup.py develop
  1. 构建镜像
DOCKER_BUILDKIT=1 \
  docker build --no-cache \
  -t lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3 \
  -f docker/Dockerfile_aarch64_ascend .
  1. 验证环境
docker run \
  -e ASCEND_VISIBLE_DEVICES=0 \
  --rm --name lmdeploy \
  -t lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3 \
  lmdeploy check_env

响应内容如下

Warning : ASCEND_HOME_PATH environment variable is not set.
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
sys.platform: linux
Python: 3.10.5 (main, Oct 23 2024, 14:39:51) [GCC 9.4.0]
CUDA available: False
MUSA available: False
numpy_random_seed: 2147483648
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
PyTorch: 2.1.0
PyTorch compiling details: PyTorch built with:
 - GCC 10.2
 - C++ Version: 201703
 - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
 - OpenMP 201511 (a.k.a. OpenMP 4.5)
 - LAPACK is enabled (usually provided by MKL)
 - NNPACK is enabled
 - CPU capability usage: NO AVX
 - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.16.0
LMDeploy: 0.6.1+
transformers: 4.45.2
gradio: Not Found
fastapi: 0.115.3
pydantic: 2.9.2
triton: Not Found
  1. 启动容器
docker run -itd \
  --name=lmdeploy \
  --ipc=host \
  --runtime ascend \
  -e ASCEND_VISIBLE_DEVICES=0,1 \
  -v /root/models/Qwen2-7B-Instruct/:/Qwen2-7B-Instruct \
  lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3

执行测试

Python代码
from lmdeploy import pipeline
from lmdeploy import PytorchEngineConfig

if __name__ == "__main__":
    pipe = pipeline("/Qwen2-7B-Instruct/", backend_config = PytorchEngineConfig(tp=1, device_type="ascend"))
    question = ["Shanghai is", "Please introduce China", "How are you?"]
    response = pipe(question)
    print(response)
执行效果
root@46620ecb0997:/opt/lmdeploy# python3 offline_test.py
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:301: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:260: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/cpp_extension.py:28: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/opt/lmdeploy/lmdeploy/serve/utils.py:22: DeprecationWarning: There is no current event loop
  event_loop = asyncio.get_event_loop()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
2024-10-24 01:51:00,982 - lmdeploy - ERROR - request.py:21 - Engine loop failed with error: call aclnnPromptFlashAttention failed, detail:EZ1001: 2024-10-24-01:51:00.977.598 PromptFlashAttention LaunchAicore failed.
        TraceBack (most recent call last):
        attention mask must be NULL,when Qs,Kvs is unAlign or Qs is not equal to Kvs, Qs = 22, Kvs = 22[FUNC:RunBigKernelTilingWithParams][FILE:prompt_flash_attention_tiling.cpp][LINE:2070]
        Tiling failed
        Tiling Failed.
        Kernel GetWorkspace failed. opType: 31
        PromptFlashAttention LaunchAicore failed.

[ERROR] 2024-10-24-01:51:00 (PID:733, Device:0, RankID:-1) ERR01005 OPS internal error
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 948, in async_loop
    await self._async_loop()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 942, in _async_loop
    await __step()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 930, in __step
    raise e
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 924, in __step
    raise out
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 858, in _async_loop_background
    await self._async_step_background(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 740, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/lmdeploy/lmdeploy/utils.py", line 241, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 631, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 609, in __forward
    return await self.model_agent.async_forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 304, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 271, in _forward_impl
    output = model_forward(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/graph_runner.py", line 25, in __call__
    return self.model(**kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 344, in forward
    hidden_states = self.model(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 284, in forward
    hidden_states, residual = decoder_layer(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 200, in forward
    hidden_states = self.self_attn(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 91, in forward
    attn_output = self.attn_fwd(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 67, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/dlinfer/attention.py", line 89, in forward
    attn_output = self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py", line 93, in paged_attention_fwd
    return prefill_attention(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py", line 27, in prefill_attention
    return ext_ops.prefill_attention(
  File "/opt/dlinfer-fix_attn/dlinfer/utils/graph/custom_op.py", line 65, in inner_func
    return func(*args, **kwargs_default_dict)
  File "/opt/dlinfer-fix_attn/dlinfer/ops/llm.py", line 137, in prefill_attention
    return vendor_ops_registry["prefill_attention"](
  File "/opt/dlinfer-fix_attn/dlinfer/vendor/ascend/torch_npu_ops.py", line 139, in prefill_attention
    torch.ops.npu_ext.npu_prompt_flash_attention_out(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: call aclnnPromptFlashAttention failed, detail:EZ1001: 2024-10-24-01:51:00.977.598 PromptFlashAttention LaunchAicore failed.
        TraceBack (most recent call last):
        attention mask must be NULL,when Qs,Kvs is unAlign or Qs is not equal to Kvs, Qs = 22, Kvs = 22[FUNC:RunBigKernelTilingWithParams][FILE:prompt_flash_attention_tiling.cpp][LINE:2070]
        Tiling failed
        Tiling Failed.
        Kernel GetWorkspace failed. opType: 31
        PromptFlashAttention LaunchAicore failed.

[ERROR] 2024-10-24-01:51:00 (PID:733, Device:0, RankID:-1) ERR01005 OPS internal error
/usr/local/python3.10.5/lib/python3.10/tempfile.py:837: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmp9woyci7n'>
  _warnings.warn(warn_message, ResourceWarning)

@github-actions github-actions bot removed the Stale label Oct 24, 2024
@wangyuanxiong-hub
Copy link

We have prepared a plan on 300I Duo, could you help us test it? @zer0py2c @wangyuanxiong-hub @qiling1345

You can try running it with https://github.com/yao-fengchen/dlinfer/tree/fix_attn and https://github.com/DeepLink-org/lmdeploy/tree/fix_attn on 300I Duo and let us know if there is any error on 300I Duo.

I get an error when I executed offline inference code, and my operation is as follows:

环境配置

  • NPU驱动版本:24.1.rc2
  • NPU固件版本:7.3.0.1.231
  • CANN版本:8.0.RC3.alpha001

运行容器

  1. 下载并解压fix_attn源码
unzip lmdeploy-fix_attn.zip -d /root/projects/
unzip dlinfer-fix_attn.zip -d /root/projects/lmdeploy-fix_attn
mv Ascend-cann-kernels-310p_8.0.RC3.alpha001_linux.run /root/projects/lmdeploy-fix_attn
mv Ascend-cann-toolkit_8.0.RC3.alpha001_linux-aarch64.run /root/projects/lmdeploy-fix_attn
cd /root/projects/lmdeploy-fix_attn
  1. 更改Dockerfile中dlinfer安装命令
# dlinfer
# transformers>=4.41.0 is required for internlm2 model
# timm is required for internvl2 model
COPY dlinfer-fix_attn /opt/dlinfer-fix_attn
RUN --mount=type=cache,target=/root/.cache/pip \
    pip3 install transformers>=4.41.0 timm && \
    cd /opt/dlinfer-fix_attn && \
    pip3 install -r requirements/ascend/full.txt && \
    DEVICE=ascend python3 setup.py develop
  1. 构建镜像
DOCKER_BUILDKIT=1 \
  docker build --no-cache \
  -t lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3 \
  -f docker/Dockerfile_aarch64_ascend .
  1. 验证环境
docker run \
  -e ASCEND_VISIBLE_DEVICES=0 \
  --rm --name lmdeploy \
  -t lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3 \
  lmdeploy check_env

响应内容如下

Warning : ASCEND_HOME_PATH environment variable is not set.
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
sys.platform: linux
Python: 3.10.5 (main, Oct 23 2024, 14:39:51) [GCC 9.4.0]
CUDA available: False
MUSA available: False
numpy_random_seed: 2147483648
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
PyTorch: 2.1.0
PyTorch compiling details: PyTorch built with:
 - GCC 10.2
 - C++ Version: 201703
 - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
 - OpenMP 201511 (a.k.a. OpenMP 4.5)
 - LAPACK is enabled (usually provided by MKL)
 - NNPACK is enabled
 - CPU capability usage: NO AVX
 - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.16.0
LMDeploy: 0.6.1+
transformers: 4.45.2
gradio: Not Found
fastapi: 0.115.3
pydantic: 2.9.2
triton: Not Found
  1. 启动容器
docker run -itd \
  --name=lmdeploy \
  --ipc=host \
  --runtime ascend \
  -e ASCEND_VISIBLE_DEVICES=0,1 \
  -v /root/models/Qwen2-7B-Instruct/:/Qwen2-7B-Instruct \
  lmdeploy-aarch64-ascend:fix_attn_support_ascend_310p3

执行测试

Python代码
from lmdeploy import pipeline
from lmdeploy import PytorchEngineConfig

if __name__ == "__main__":
    pipe = pipeline("/Qwen2-7B-Instruct/", backend_config = PytorchEngineConfig(tp=1, device_type="ascend"))
    question = ["Shanghai is", "Please introduce China", "How are you?"]
    response = pipe(question)
    print(response)
执行效果
root@46620ecb0997:/opt/lmdeploy# python3 offline_test.py
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:301: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:260: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/cpp_extension.py:28: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/opt/lmdeploy/lmdeploy/serve/utils.py:22: DeprecationWarning: There is no current event loop
  event_loop = asyncio.get_event_loop()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
2024-10-24 01:51:00,982 - lmdeploy - ERROR - request.py:21 - Engine loop failed with error: call aclnnPromptFlashAttention failed, detail:EZ1001: 2024-10-24-01:51:00.977.598 PromptFlashAttention LaunchAicore failed.
        TraceBack (most recent call last):
        attention mask must be NULL,when Qs,Kvs is unAlign or Qs is not equal to Kvs, Qs = 22, Kvs = 22[FUNC:RunBigKernelTilingWithParams][FILE:prompt_flash_attention_tiling.cpp][LINE:2070]
        Tiling failed
        Tiling Failed.
        Kernel GetWorkspace failed. opType: 31
        PromptFlashAttention LaunchAicore failed.

[ERROR] 2024-10-24-01:51:00 (PID:733, Device:0, RankID:-1) ERR01005 OPS internal error
Traceback (most recent call last):
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 948, in async_loop
    await self._async_loop()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 942, in _async_loop
    await __step()
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 930, in __step
    raise e
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 924, in __step
    raise out
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 858, in _async_loop_background
    await self._async_step_background(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 740, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/lmdeploy/lmdeploy/utils.py", line 241, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 631, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 609, in __forward
    return await self.model_agent.async_forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 304, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 271, in _forward_impl
    output = model_forward(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/graph_runner.py", line 25, in __call__
    return self.model(**kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 344, in forward
    hidden_states = self.model(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 284, in forward
    hidden_states, residual = decoder_layer(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 200, in forward
    hidden_states = self.self_attn(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 91, in forward
    attn_output = self.attn_fwd(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 67, in forward
    return self.impl.forward(
  File "/opt/lmdeploy/lmdeploy/pytorch/backends/dlinfer/attention.py", line 89, in forward
    attn_output = self.paged_attention_fwd(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py", line 93, in paged_attention_fwd
    return prefill_attention(
  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py", line 27, in prefill_attention
    return ext_ops.prefill_attention(
  File "/opt/dlinfer-fix_attn/dlinfer/utils/graph/custom_op.py", line 65, in inner_func
    return func(*args, **kwargs_default_dict)
  File "/opt/dlinfer-fix_attn/dlinfer/ops/llm.py", line 137, in prefill_attention
    return vendor_ops_registry["prefill_attention"](
  File "/opt/dlinfer-fix_attn/dlinfer/vendor/ascend/torch_npu_ops.py", line 139, in prefill_attention
    torch.ops.npu_ext.npu_prompt_flash_attention_out(
  File "/usr/local/python3.10.5/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: call aclnnPromptFlashAttention failed, detail:EZ1001: 2024-10-24-01:51:00.977.598 PromptFlashAttention LaunchAicore failed.
        TraceBack (most recent call last):
        attention mask must be NULL,when Qs,Kvs is unAlign or Qs is not equal to Kvs, Qs = 22, Kvs = 22[FUNC:RunBigKernelTilingWithParams][FILE:prompt_flash_attention_tiling.cpp][LINE:2070]
        Tiling failed
        Tiling Failed.
        Kernel GetWorkspace failed. opType: 31
        PromptFlashAttention LaunchAicore failed.

[ERROR] 2024-10-24-01:51:00 (PID:733, Device:0, RankID:-1) ERR01005 OPS internal error
/usr/local/python3.10.5/lib/python3.10/tempfile.py:837: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmp9woyci7n'>
  _warnings.warn(warn_message, ResourceWarning)

Yes, I encountered the same problem. It seems that the 128 alignment mentioned in Huawei's documentation does not mean less than 128.

@seoibiubiu
Copy link
Author

@wangyuanxiong-hub 你的 NPU 驱动版本是多少?我看了下官网,我用的 24.1.RC2 版本应该跟 CANN 8.0.RC3.alpha001 不配套。
CANN8 0 RC3 alpha001对应驱动版本

@yao-fengchen
Copy link
Collaborator

I fix this problem in the last commit on https://github.com/yao-fengchen/dlinfer/tree/fix_attn. However, thers is a limitation on 310P device that does not support gqa. So, on 310P device, only mha models, such as llama-2-7b-chat-hf, can be run temporarily. Moreover, the speed may be relatively slow on 310P device. @wangyuanxiong-hub @zer0py2c
https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha001/apiref/appdevgapi/context/aclnnIncreFlashAttentionV4.md
cfe50c2e61d0a81dac932578ea029bb

@seoibiubiu
Copy link
Author

I fix this problem in the last commit on https://github.com/yao-fengchen/dlinfer/tree/fix_attn. However, thers is a limitation on 310P device that does not support gqa. So, on 310P device, only mha models, such as llama-2-7b-chat-hf, can be run temporarily. Moreover, the speed may be relatively slow on 310P device. @wangyuanxiong-hub @zer0py2c https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha001/apiref/appdevgapi/context/aclnnIncreFlashAttentionV4.md cfe50c2e61d0a81dac932578ea029bb

Very thanksful, and I learned a lot. 😃

@seoibiubiu seoibiubiu changed the title Test Ascend Atlas 300I Duo NPU device by LMDeploy V0.6.0 Test Ascend Atlas 300I Duo NPU device Oct 24, 2024
@yao-fengchen
Copy link
Collaborator

If there are no issues with this PR DeepLink-org/dlinfer#80 when testing the MHA models on 310P device, we will include this feature in https://github.com/[DeepLink-org/dlinfer](https://github.com/DeepLink-org/dlinfer).

@seoibiubiu
Copy link
Author

If there are no issues with this PR DeepLink-org/dlinfer#80 when testing the MHA models on 310P device, we will include this feature in [https://github.com/DeepLink-org/dlinfer](https://github.com/%5BDeepLink-org/dlinfer%5D(https://github.com/DeepLink-org/dlinfer)).

我使用了dlinfer_fix_attn最后一次提交测试通过,再次感谢dlinfer团队的贡献!

环境配置

  • NPU驱动版本: 24.1.RC2
  • NPU固件版本:7.3.0.1.231
  • CANN版本: 8.0.RC3.alpha003
  • 大模型:Qwen1.5-7B-Chat

Qwen1 5-7B-Chat不支持GQA

测试脚本

from lmdeploy import pipeline
from lmdeploy import PytorchEngineConfig

if __name__ == "__main__":
    pipe = pipeline("/models/Qwen1.5-7B-Chat/", backend_config = PytorchEngineConfig(tp=1, device_type="ascend"))
    question = ["Shanghai is", "Please introduce China", "How are you?"]
    response = pipe(question)
    print(response)

执行效果

root@a4aefff7f9da:/opt/lmdeploy# python3 offline_infer.py
[W compiler_depend.ts:623] Warning: expandable_segments currently defaults to false. You can enable this feature by `export PYTORCH_NPU_ALLOC_CONF = expandable_segments:True`. (function operator())
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:301: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:260: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch/utils/cpp_extension.py:28: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
/opt/lmdeploy/lmdeploy/serve/utils.py:22: DeprecationWarning: There is no current event loop
  event_loop = asyncio.get_event_loop()
/usr/local/python3.10.5/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
/opt/lmdeploy/lmdeploy/pytorch/engine/logits_process.py:332: UserWarning: AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, If you are looking for a user facing API to enable running your inference-only workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code is under risk of producing silent wrong result in some edge cases. See Note [AutoDispatchBelowAutograd] for more details. (Triggered internally at build/CMakeFiles/torch_npu.dir/compiler_depend.ts:74.)
  stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:671: ImportWarning: TEMetaPathLoader.exec_module() not found; falling back to load_module()
[Response(text='Shanghai, officially known as the City of Shanghai, is a megacity located in the eastern part of China. It is the largest city by population and economic aggregate, serving as the provincial capital of Shanghai Municipality. Known for its iconic landmarks like the Bund, the Oriental Pearl Tower, and the Pudong skyline, it is a major financial, technological, and international hub. It is also a key port in China and plays a significant role in global trade.', generate_token_len=94, input_token_len=22, session_id=0, finish_reason='stop', token_ids=[2016, 30070, 11, 18562, 3881, 438, 279, 4311, 315, 37047, 11, 374, 264, 18740, 4018, 7407, 304, 279, 23149, 949, 315, 5616, 13, 1084, 374, 279, 7772, 3283, 553, 7042, 323, 6955, 23192, 11, 13480, 438, 279, 34931, 6722, 315, 37047, 35703, 2719, 13, 48286, 369, 1181, 26277, 59924, 1075, 279, 29608, 11, 279, 70951, 36243, 21938, 11, 323, 279, 393, 661, 644, 87739, 11, 432, 374, 264, 3598, 5896, 11, 29016, 11, 323, 6489, 18719, 13, 1084, 374, 1083, 264, 1376, 2635, 304, 5616, 323, 11088, 264, 5089, 3476, 304, 3644, 6559, 13], logprobs=None, index=0), Response(text='China, officially known as the People\'s Republic of China (PRC), is a vast and diverse country located in East Asia. It is the world\'s most populous nation, with over 1.4 billion people, covering an area of approximately 9.6 million square kilometers. The country is governed by the Communist Party of China, led by the General Secretary of the Central Committee, and is governed under the principle of socialism with Chinese characteristics.\n\nGeographically, China is bordered by several countries, including Russia to the north, Mongolia and Kazakhstan to the north-west, India and Nepal to the west, and Vietnam, Laos, and Myanmar to the south. It shares the South China Sea with several Southeast Asian nations and has an extensive coastline along the Pacific Ocean to the east.\n\nHistorically, China has a rich civilization dating back thousands of years, known for its contributions to philosophy, science, literature, and the arts. It has been a major player in world affairs, with the Silk Road, the Great Wall, and the Terracotta Army being iconic symbols. The country has experienced various dynasties, including the Han, Ming, and Qing, and has modernized significantly since the 20th century, implementing economic reforms that transformed it into a global economic powerhouse.\n\nEconomically, China is the world\'s second-largest economy, with a mixed market system that includes state-owned enterprises, private enterprises, and a large informal sector. It is a member of the World Trade Organization, the World Bank, and the International Monetary Fund, and is driving global trade through its Belt and Road Initiative.\n\nCultural aspects include a rich culinary tradition, diverse languages (Mandarin being the official), and a strong emphasis on family values. The country practices Confucianism, Taoism, and Buddhism, and has a unique blend of traditional and modern culture.\n\nIn terms of international relations, China is a major player in the United Nations, actively participating in global issues such as climate change, poverty reduction, and peacekeeping. It has been working on a "One Belt, One Road" (OBOR) initiative to promote infrastructure development and connectivity across Asia and beyond, further enhancing its influence in the region.', generate_token_len=444, input_token_len=22, session_id=1, finish_reason='stop', token_ids=[22282, 11, 18562, 3881, 438, 279, 8853, 594, 5429, 315, 5616, 320, 6480, 34, 701, 374, 264, 12767, 323, 16807, 3146, 7407, 304, 6326, 13622, 13, 1084, 374, 279, 1879, 594, 1429, 94451, 6995, 11, 448, 916, 220, 16, 13, 19, 7094, 1251, 11, 18202, 458, 3082, 315, 13187, 220, 24, 13, 21, 3526, 9334, 40568, 13, 576, 3146, 374, 26702, 553, 279, 36861, 8554, 315, 5616, 11, 6069, 553, 279, 3251, 12386, 315, 279, 10684, 10341, 11, 323, 374, 26702, 1212, 279, 17508, 315, 50518, 448, 8453, 17452, 382, 9499, 63931, 11, 5616, 374, 76217, 553, 3807, 5837, 11, 2670, 8359, 311, 279, 10200, 11, 90750, 323, 72437, 311, 279, 10200, 37602, 11, 6747, 323, 48964, 311, 279, 9710, 11, 323, 22500, 11, 95744, 11, 323, 52355, 311, 279, 9806, 13, 1084, 13248, 279, 4882, 5616, 15029, 448, 3807, 35564, 14533, 16675, 323, 702, 458, 16376, 79844, 3156, 279, 16462, 21575, 311, 279, 10984, 382, 48983, 2673, 11, 5616, 702, 264, 9080, 34917, 4924, 1182, 9037, 315, 1635, 11, 3881, 369, 1181, 19026, 311, 19128, 11, 8038, 11, 17206, 11, 323, 279, 18560, 13, 1084, 702, 1012, 264, 3598, 2781, 304, 1879, 21978, 11, 448, 279, 51036, 9536, 11, 279, 8513, 9736, 11, 323, 279, 17655, 580, 22193, 13011, 1660, 26277, 17738, 13, 576, 3146, 702, 10321, 5257, 31070, 559, 550, 11, 2670, 279, 20644, 11, 55883, 11, 323, 61912, 11, 323, 702, 6481, 1506, 11941, 2474, 279, 220, 17, 15, 339, 9294, 11, 24931, 6955, 30243, 429, 23507, 432, 1119, 264, 3644, 6955, 74114, 382, 36, 44217, 2673, 11, 5616, 374, 279, 1879, 594, 2086, 66967, 8584, 11, 448, 264, 9519, 3081, 1849, 429, 5646, 1584, 28699, 39819, 11, 869, 39819, 11, 323, 264, 3460, 41787, 10486, 13, 1084, 374, 264, 4462, 315, 279, 4337, 17214, 20395, 11, 279, 4337, 8547, 11, 323, 279, 7179, 73114, 13190, 11, 323, 374, 9842, 3644, 6559, 1526, 1181, 32893, 323, 9536, 37656, 382, 34, 43447, 13566, 2924, 264, 9080, 57341, 13815, 11, 16807, 15459, 320, 44, 437, 42740, 1660, 279, 3946, 701, 323, 264, 3746, 24654, 389, 2997, 2750, 13, 576, 3146, 12378, 14974, 1754, 1103, 2142, 11, 59508, 2142, 11, 323, 60224, 11, 323, 702, 264, 4911, 20334, 315, 8606, 323, 6481, 7674, 382, 641, 3793, 315, 6489, 4300, 11, 5616, 374, 264, 3598, 2781, 304, 279, 3639, 19140, 11, 22040, 23528, 304, 3644, 4714, 1741, 438, 9977, 2297, 11, 19005, 13951, 11, 323, 8919, 32394, 13, 1084, 702, 1012, 3238, 389, 264, 330, 3966, 32893, 11, 3776, 9536, 1, 320, 20608, 868, 8, 20162, 311, 11926, 13737, 4401, 323, 30257, 3941, 13622, 323, 7797, 11, 4623, 46494, 1181, 10173, 304, 279, 5537, 13], logprobs=None, index=1), Response(text="As an AI, I don't have feelings, but I'm here to assist you. How can I help you today?", generate_token_len=25, input_token_len=23, session_id=2, finish_reason='stop', token_ids=[2121, 458, 15235, 11, 358, 1513, 944, 614, 15650, 11, 714, 358, 2776, 1588, 311, 7789, 498, 13, 2585, 646, 358, 1492, 498, 3351, 30], logprobs=None, index=2)]
/usr/local/python3.10.5/lib/python3.10/tempfile.py:837: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpgldhgua0'>
  _warnings.warn(warn_message, ResourceWarning)

@seoibiubiu
Copy link
Author

seoibiubiu commented Oct 24, 2024

Test Qwen 1.5-7B-Chat with lmdeploy_fix_attn and dlinfer_fix_attn is successful by offline inference!

@seoibiubiu
Copy link
Author

Does it support tensor parallelism on the 300I Duo? @yao-fengchen

@yao-fengchen
Copy link
Collaborator

In theory, the MHA models support tensor parallelism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants