diff --git a/.github/Dockerfile.buildwheel b/.github/Dockerfile.buildwheel index 92d1fa72874..abfd3b8de24 100644 --- a/.github/Dockerfile.buildwheel +++ b/.github/Dockerfile.buildwheel @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # ARG PY_VERSION=3.11 -FROM quay.io/ascend/manylinux:8.3.rc1-910b-manylinux_2_28-py${PY_VERSION} +FROM quay.io/ascend/manylinux:8.3.rc2-910b-manylinux_2_28-py${PY_VERSION} ARG COMPILE_CUSTOM_KERNELS=1 ARG SOC_VERSION diff --git a/.github/workflows/_e2e_nightly_multi_node.yaml b/.github/workflows/_e2e_nightly_multi_node.yaml index 99b2036a3cc..d91e503213d 100644 --- a/.github/workflows/_e2e_nightly_multi_node.yaml +++ b/.github/workflows/_e2e_nightly_multi_node.yaml @@ -15,7 +15,7 @@ on: required: false type: string description: base image for pods - default: "swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11" + default: "swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11" config_file_path: required: true type: string @@ -69,7 +69,7 @@ jobs: # This is the runner with no NPU for k8s controller runs-on: ${{ inputs.runner }} container: - image: m.daocloud.io/quay.io/ascend/cann:8.3.rc1-a3-ubuntu22.04-py3.11 + image: m.daocloud.io/quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 env: KUBECONFIG: /tmp/kubeconfig KUBECTL: /root/.cache/.kube/kubectl diff --git a/.github/workflows/_e2e_nightly_single_node.yaml b/.github/workflows/_e2e_nightly_single_node.yaml index 8b4a425727d..07ef9be82a4 100644 --- a/.github/workflows/_e2e_nightly_single_node.yaml +++ b/.github/workflows/_e2e_nightly_single_node.yaml @@ -29,7 +29,7 @@ on: image: required: false type: string - default: "swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11" + default: "swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11" tests: required: true type: string diff --git a/.github/workflows/_e2e_nightly_single_node_models.yaml b/.github/workflows/_e2e_nightly_single_node_models.yaml index 29cd12b3ec3..1ce99fe3666 100644 --- a/.github/workflows/_e2e_nightly_single_node_models.yaml +++ b/.github/workflows/_e2e_nightly_single_node_models.yaml @@ -59,7 +59,7 @@ jobs: name: ${{inputs.model_list}} accuracy test runs-on: ${{ inputs.runner }} container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 env: VLLM_USE_MODELSCOPE: True GHA_VLLM_ASCEND_VERSION: ${{ inputs.vllm-ascend }} @@ -108,10 +108,7 @@ jobs: if: ${{ inputs.runner == 'linux-aarch64-a2-4' && contains(inputs.model_list, 'Qwen3-Next-80B-A3B-Instruct') }} shell: bash -l {0} run: | - wget -q https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/Ascend-BiSheng-toolkit_aarch64.run -O /tmp/Ascend-BiSheng-toolkit_aarch64.run - chmod a+x /tmp/Ascend-BiSheng-toolkit_aarch64.run - /tmp/Ascend-BiSheng-toolkit_aarch64.run --install - . /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" - name: Resolve vllm-ascend version @@ -225,4 +222,4 @@ jobs: path: ./benchmarks/accuracy/ if-no-files-found: warn retention-days: 90 - overwrite: true \ No newline at end of file + overwrite: true diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 75c3d9471dd..6906930ac61 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -179,6 +179,7 @@ jobs: VLLM_USE_MODELSCOPE: True if: ${{ inputs.type == 'full' }} run: | + pytest -sv tests/e2e/multicard/test_quantization.py pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py pytest -sv tests/e2e/multicard/test_full_graph_mode.py @@ -211,7 +212,7 @@ jobs: if: ${{ needs.e2e.result == 'success' && needs.e2e-2-cards.result == 'success' && inputs.type == 'full' }} runs-on: linux-aarch64-a3-4 container: - image: m.daocloud.io/quay.io/ascend/cann:8.3.rc1-a3-ubuntu22.04-py3.11 + image: m.daocloud.io/quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 env: VLLM_LOGGING_LEVEL: ERROR VLLM_USE_MODELSCOPE: True @@ -274,11 +275,8 @@ jobs: - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) shell: bash -l {0} run: | - wget -q https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/Ascend-BiSheng-toolkit_aarch64.run -O /tmp/Ascend-BiSheng-toolkit_aarch64.run - chmod a+x /tmp/Ascend-BiSheng-toolkit_aarch64.run - /tmp/Ascend-BiSheng-toolkit_aarch64.run --install - . /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl" - name: Run vllm-project/vllm-ascend Qwen3 Next test working-directory: ./vllm-ascend @@ -287,5 +285,5 @@ jobs: VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_USE_MODELSCOPE: True run: | - . /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh pytest -sv tests/e2e/multicard/test_qwen3_next.py diff --git a/.github/workflows/_nightly_image_build.yaml b/.github/workflows/_nightly_image_build.yaml index 609a62ce755..baa18ae7440 100644 --- a/.github/workflows/_nightly_image_build.yaml +++ b/.github/workflows/_nightly_image_build.yaml @@ -45,7 +45,7 @@ jobs: --network host \ --platform linux/arm64 \ -f .github/Dockerfile.nightly.${TARGET} \ - --build-arg CANN_VERSION="8.3.rc1" \ + --build-arg CANN_VERSION="8.3.rc2" \ --build-arg UBUNTU_VERSION="22.04" \ --build-arg PYTHON_VERSION="3.11" \ -t "$IMAGE_TAG" . diff --git a/.github/workflows/nightly_benchmarks.yaml b/.github/workflows/nightly_benchmarks.yaml index d8c425d0d2d..21144b4082e 100644 --- a/.github/workflows/nightly_benchmarks.yaml +++ b/.github/workflows/nightly_benchmarks.yaml @@ -55,7 +55,7 @@ jobs: vllm_ascend_branch: main max-parallel: 1 container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 volumes: - /usr/local/dcmi:/usr/local/dcmi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index d23d427efff..b095e696e84 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -96,8 +96,14 @@ jobs: --exclude libge_common_base.so \ --exclude libc10.so \ --exclude libc_sec.so \ + --exclude libnnopbase.so \ + --exclude libprofapi.so \ + --exclude libgraph_base.so \ + --exclude libgraph.so \ + --exclude libexe_graph.so \ --exclude "libascend*.so" \ --exclude "libtorch*.so" \ + --exclude "libopapi.so" \ --exclude "liberror_manager.so" done rm -f dist/*.whl diff --git a/.github/workflows/vllm_ascend_test_310p.yaml b/.github/workflows/vllm_ascend_test_310p.yaml index b3d3132e7e5..9e14ddfb621 100644 --- a/.github/workflows/vllm_ascend_test_310p.yaml +++ b/.github/workflows/vllm_ascend_test_310p.yaml @@ -58,7 +58,7 @@ jobs: runs-on: ${{ matrix.os }} container: # TODO(yikun): Remove m.daocloud.io prefix when infra proxy ready - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-310p-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-310p-ubuntu22.04-py3.11 env: VLLM_LOGGING_LEVEL: ERROR VLLM_USE_MODELSCOPE: True diff --git a/.github/workflows/vllm_ascend_test_full_vllm_main.yaml b/.github/workflows/vllm_ascend_test_full_vllm_main.yaml index dbd632912af..0c93b7742c0 100644 --- a/.github/workflows/vllm_ascend_test_full_vllm_main.yaml +++ b/.github/workflows/vllm_ascend_test_full_vllm_main.yaml @@ -41,5 +41,5 @@ jobs: with: vllm: main runner: linux-aarch64-a2 - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 type: full diff --git a/.github/workflows/vllm_ascend_test_pr_full.yaml b/.github/workflows/vllm_ascend_test_pr_full.yaml index 1c699391eee..754334b9990 100644 --- a/.github/workflows/vllm_ascend_test_pr_full.yaml +++ b/.github/workflows/vllm_ascend_test_pr_full.yaml @@ -76,5 +76,5 @@ jobs: with: vllm: ${{ matrix.vllm_version }} runner: linux-aarch64-a2 - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 type: full diff --git a/.github/workflows/vllm_ascend_test_pr_light.yaml b/.github/workflows/vllm_ascend_test_pr_light.yaml index 06c8cfd26dd..f293fa53115 100644 --- a/.github/workflows/vllm_ascend_test_pr_light.yaml +++ b/.github/workflows/vllm_ascend_test_pr_light.yaml @@ -76,8 +76,8 @@ jobs: if: ${{ needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true') }} runs-on: ubuntu-latest container: - # fixme: vllm-ascend install failed with 8.3.rc1 on github action - image: quay.io/ascend/cann:8.2.rc1-910b-ubuntu22.04-py3.11 + # fixme: vllm-ascend install failed with 8.3.rc2 on github action + image: quay.io/ascend/cann:8.2.rc2-910b-ubuntu22.04-py3.11 env: VLLM_LOGGING_LEVEL: ERROR VLLM_USE_MODELSCOPE: True @@ -151,5 +151,5 @@ jobs: with: vllm: ${{ matrix.vllm_version }} runner: linux-aarch64-a2 - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 type: light diff --git a/.github/workflows/vllm_ascend_test_report.yaml b/.github/workflows/vllm_ascend_test_report.yaml index d4cd79a87ae..0e63356e9ba 100644 --- a/.github/workflows/vllm_ascend_test_report.yaml +++ b/.github/workflows/vllm_ascend_test_report.yaml @@ -74,7 +74,7 @@ jobs: with: vllm: v0.11.2 runner: ${{ matrix.runner }} - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11 + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 model_list: ${{ toJson(matrix.model_list) }} upload: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.vllm-ascend-version == 'latest' }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34819347edb..f2f42d5b159 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: codespell args: [ --toml, pyproject.toml, - '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', + '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND' ] additional_dependencies: @@ -37,7 +37,7 @@ repos: - id: typos args: [ "--force-exclude", - "--exclude", "csrc/mla_preprocess/**" + "--exclude", "csrc/**" ] - repo: https://github.com/PyCQA/isort rev: 6.0.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 272bdb13c74..f0136bc48e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,7 +63,8 @@ ascendc_library(vllm_ascend_kernels SHARED message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") file(GLOB VLLM_ASCEND_SRC -${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) +${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp +${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp) include_directories( ${pybind11_INCLUDE_DIRS} @@ -81,6 +82,7 @@ set( ${TORCH_NPU_INCLUDE_DIRS} ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform + ${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform ) pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC}) @@ -88,6 +90,7 @@ pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC}) target_link_directories( vllm_ascend_C PRIVATE + ${TORCH_LIBRARY_DIRS} ${TORCH_NPU_PATH}/lib/ ${ASCEND_HOME_PATH}/lib64 ) @@ -96,7 +99,7 @@ target_link_libraries( vllm_ascend_C PUBLIC ${TORCH_LIBRARIES} - libtorch_npu.so + torch_npu vllm_ascend_kernels ascendcl tiling_api @@ -104,6 +107,7 @@ target_link_libraries( platform ascendalog dl + opapi ) target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") diff --git a/Dockerfile b/Dockerfile index 2cc85ae8dbf..cc5605ee0bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-910b-ubuntu22.04-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/Dockerfile.310p b/Dockerfile.310p index 354f02a6c7d..9d2032631c2 100644 --- a/Dockerfile.310p +++ b/Dockerfile.310p @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-310p-ubuntu22.04-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-310p-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/Dockerfile.310p.openEuler b/Dockerfile.310p.openEuler index 3463939c254..659a56c6f7c 100644 --- a/Dockerfile.310p.openEuler +++ b/Dockerfile.310p.openEuler @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-310p-openeuler24.03-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-310p-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/Dockerfile.a3 b/Dockerfile.a3 index ba6703e087d..de6f1a5aefa 100644 --- a/Dockerfile.a3 +++ b/Dockerfile.a3 @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-a3-ubuntu22.04-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/Dockerfile.a3.openEuler b/Dockerfile.a3.openEuler index 259aa98eb5b..7761f341f91 100644 --- a/Dockerfile.a3.openEuler +++ b/Dockerfile.a3.openEuler @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-a3-openeuler24.03-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-a3-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/Dockerfile.openEuler b/Dockerfile.openEuler index 4c6c6f9e31e..9666dee487a 100644 --- a/Dockerfile.openEuler +++ b/Dockerfile.openEuler @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -FROM quay.io/ascend/cann:8.3.rc1-910b-openeuler24.03-py3.11 +FROM quay.io/ascend/cann:8.3.rc2-910b-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt new file mode 100644 index 00000000000..dab92509d45 --- /dev/null +++ b/csrc/CMakeLists.txt @@ -0,0 +1,642 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +cmake_minimum_required(VERSION 3.16) + +project(cann_ops_custom) + +option(BUILD_OPEN_PROJECT "Build open ascend ops project." ON) +option(ENABLE_CCACHE "Enable ccache capability" ON) +set(ASCEND_COMPUTE_UNIT "ascend910b" CACHE STRING "soc that need to be compiled") +set(ASCEND_OP_NAME "ALL" CACHE STRING "operators that need to be compiled") +set(VENDOR_NAME "customize" CACHE STRING "vendor name") + +include(cmake/config.cmake) +include(cmake/func.cmake) +include(cmake/intf.cmake) + +if (BUILD_OPEN_PROJECT) + set(_op_host_aclnn_link + $ + exe_graph + register + c_sec + ) + set(CMAKE_MODULE_PATH + ${CMAKE_MODULE_PATH} + ${CMAKE_CURRENT_LIST_DIR}/cmake/modules + ) + + set(CMAKE_PREFIX_PATH + ${CMAKE_PREFIX_PATH} + ${ASCEND_CANN_PACKAGE_PATH} + ) + + find_package(alog MODULE REQUIRED) + add_library(op_host_aclnn SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnn PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnn PRIVATE + $<$:-std=gnu++1z> + ) + + add_library(op_host_aclnnInner SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnnInner PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnnInner PRIVATE + $<$:-std=gnu++1z> + ) + + add_library(op_host_aclnnExc SHARED EXCLUDE_FROM_ALL) + target_link_libraries(op_host_aclnnExc PRIVATE + ${_op_host_aclnn_link} + ) + target_compile_options(op_host_aclnnExc PRIVATE + $<$:-std=gnu++1z> + ) + + # op api + add_library(opapi SHARED) + # When compiling a specified operator without aclnn src + if (NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + ) + target_sources(opapi PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/opapi_stub.cpp + ) + endif() + + target_compile_options(opapi PRIVATE + $<$:-std=gnu++1z> + ) + target_include_directories(opapi PRIVATE + $ + $ + $ + ) + target_compile_options(opapi PRIVATE + -Werror=format + ) + target_compile_definitions(opapi PRIVATE + -DACLNN_LOG_FMT_CHECK + ) + target_link_libraries(opapi PRIVATE + $ + -Wl,--whole-archive + ops_aclnn + -Wl,--no-whole-archive + -lopapi + nnopbase + profapi + ge_common_base + ascend_dump + ascendalog + dl + ) + set_target_properties(opapi PROPERTIES OUTPUT_NAME + cust_opapi + ) + install(TARGETS opapi + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_api/lib + ) + + # op proto + add_library(opsproto SHARED) + target_compile_options(opsproto PRIVATE + $<$:-std=c++11> + -fvisibility=hidden + ) + target_compile_definitions(opsproto PRIVATE + LOG_CPP + PROCESS_LOG + ) + target_link_libraries(opsproto PRIVATE + $ + $ + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + -Wl,--no-as-needed + exe_graph + graph + graph_base + register + ascendalog + error_manager + platform + -Wl,--as-needed + c_sec + ) + set_target_properties(opsproto PROPERTIES OUTPUT_NAME + cust_opsproto_rt2.0 + ) + install(TARGETS opsproto + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR} + ) + + # op tiling + add_library(optiling SHARED) + target_sources(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/utils/src/fallback_comm.cpp + ) + target_compile_options(optiling PRIVATE + $<$:-std=c++11> + -fvisibility=hidden + ) + target_compile_definitions(optiling PRIVATE + LOG_CPP + PROCESS_LOG + ) + target_link_libraries(optiling PRIVATE + $ + $ + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + -Wl,--no-as-needed + graph + graph_base + exe_graph + platform + register + ascendalog + error_manager + -Wl,--as-needed + -Wl,--whole-archive + tiling_api + -Wl,--no-whole-archive + mmpa + c_sec + ) + set_target_properties(optiling PROPERTIES OUTPUT_NAME + cust_opmaster_rt2.0 + ) + add_custom_command(TARGET optiling + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${TILING_CUSTOM_DIR} + COMMAND ln -sf $ ${TILING_CUSTOM_FILE} + ) + install(TARGETS optiling + LIBRARY DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR} + ) + + # optiling compat + set(compat_optiling_dir ${CMAKE_CURRENT_BINARY_DIR}/compat) + set(compat_optiling_file ${compat_optiling_dir}/liboptiling.so) + add_custom_target(optiling_compat ALL + DEPENDS ${compat_optiling_file} + ) + + add_custom_command( + OUTPUT ${compat_optiling_file} + COMMAND ${CMAKE_COMMAND} -E make_directory ${compat_optiling_dir} + COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ ${compat_optiling_file} + ) + + install(FILES ${compat_optiling_file} + DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/op_tiling + ) + + add_ops_tiling_keys( + OP_NAME "ALL" + TILING_KEYS ${TILING_KEY} + ) + + add_opc_config( + OP_NAME "ALL" + CONFIG ${OP_DEBUG_CONFIG} + ) + + if(ADD_OPS_COMPILE_OPTION_V2) + add_ops_compile_options( + OP_NAME "ALL" + OPTIONS ${OPS_COMPILE_OPTIONS} + ) + endif() +endif () + +add_subdirectory(utils) + +set(OP_LIST) +set(OP_DIR_LIST) +op_add_subdirectory(OP_LIST OP_DIR_LIST) + +foreach (OP_DIR ${OP_DIR_LIST}) + add_subdirectory(${OP_DIR}/op_host) +endforeach () + +set(OP_DEPEND_DIR_LIST) +op_add_depend_directory( + OP_LIST ${OP_LIST} + OP_DIR_LIST OP_DEPEND_DIR_LIST +) +foreach (OP_DEPEND_DIR ${OP_DEPEND_DIR_LIST}) + add_subdirectory(${OP_DEPEND_DIR}/op_host) +endforeach () + +# ------------------------------------------------ aclnn ------------------------------------------------ +get_target_property(base_aclnn_srcs op_host_aclnn SOURCES) +get_target_property(base_aclnn_inner_srcs op_host_aclnnInner SOURCES) +get_target_property(base_aclnn_exclude_srcs op_host_aclnnExc SOURCES) + +if (BUILD_OPEN_PROJECT) + set(base_aclnn_binary_dir ${ASCEND_AUTOGEN_DIR}) +else() + get_target_property(base_aclnn_binary_dir op_host_aclnn BINARY_DIR) +endif () + +set(generate_aclnn_srcs) +set(generate_aclnn_inner_srcs) +set(generate_aclnn_headers) +set(generate_proto_dir ${base_aclnn_binary_dir}) +set(generate_exclude_proto_srcs) +set(generate_proto_srcs) +set(generate_proto_headers) + +if (base_aclnn_srcs) + foreach (_src ${base_aclnn_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_aclnn_srcs ${base_aclnn_binary_dir}/aclnn_${_op_name}.cpp) + list(APPEND generate_aclnn_headers ${base_aclnn_binary_dir}/aclnn_${_op_name}.h) + list(APPEND generate_proto_srcs ${generate_proto_dir}/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/${_op_name}_proto.h) + endif () + endforeach () +else () + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + ) + + target_sources(op_host_aclnn PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_stub.cpp + ) +endif () + +if (base_aclnn_inner_srcs) + foreach (_src ${base_aclnn_inner_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_aclnn_inner_srcs ${base_aclnn_binary_dir}/inner/aclnnInner_${_op_name}.cpp) + list(APPEND generate_proto_srcs ${generate_proto_dir}/inner/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/inner/${_op_name}_proto.h) + endif () + endforeach () +else () + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + ) + + target_sources(op_host_aclnnInner PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_inner_stub.cpp + ) +endif () + +if (base_aclnn_exclude_srcs) + foreach (_src ${base_aclnn_exclude_srcs}) + string(REGEX MATCH "^${CMAKE_CURRENT_SOURCE_DIR}" is_match "${_src}") + if (is_match) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "_def$" "" _op_name ${name_without_ext}) + list(APPEND generate_exclude_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp) + list(APPEND generate_proto_srcs ${generate_proto_dir}/exc/${_op_name}_proto.cpp) + list(APPEND generate_proto_headers ${generate_proto_dir}/exc/${_op_name}_proto.h) + endif () + endforeach () +else() + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + COMMAND touch ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + ) + + target_sources(op_host_aclnnExc PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/op_host_aclnn_exc_stub.cpp + ) +endif () + +if (BUILD_OPEN_PROJECT) + if (generate_aclnn_srcs OR generate_aclnn_inner_srcs) + set(ops_aclnn_src ${generate_aclnn_srcs} ${generate_aclnn_inner_srcs}) + else () + set(ops_aclnn_src ${CMAKE_CURRENT_BINARY_DIR}/ops_aclnn_src_stub.cpp) + + add_custom_command(OUTPUT ${ops_aclnn_src} + COMMAND touch ${ops_aclnn_src} + ) + endif () + + set_source_files_properties(${ops_aclnn_src} + PROPERTIES GENERATED TRUE + ) + add_library(ops_aclnn STATIC + ${ops_aclnn_src} + ) + target_compile_options(ops_aclnn PRIVATE + $<$:-std=gnu++1z> + ) + target_link_libraries(ops_aclnn PRIVATE + $ + ) + add_dependencies(ops_aclnn opbuild_gen_default opbuild_gen_inner) + + set_source_files_properties(${generate_proto_srcs} + PROPERTIES GENERATED TRUE + ) + target_sources(opsproto PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto ops_proto_headers) + + install(FILES ${generate_proto_headers} + DESTINATION packages/vendors/${VENDOR_NAME}/op_proto/inc OPTIONAL + ) + + redefine_file_macro( + TARGET_NAME + op_host_aclnn + op_host_aclnnInner + op_host_aclnnExc + opapi + opsproto + optiling + ops_aclnn + ) +else() + if (generate_aclnn_srcs OR generate_aclnn_inner_srcs) + set_source_files_properties(${generate_aclnn_srcs} ${generate_aclnn_inner_srcs} + TARGET_DIRECTORY acl_op_builtin + PROPERTIES GENERATED TRUE + ) + + target_sources(acl_op_builtin PRIVATE + ${generate_aclnn_srcs} + ${generate_aclnn_inner_srcs} + ) + endif () + + if (generate_proto_srcs) + set_source_files_properties(${generate_proto_srcs} + TARGET_DIRECTORY opsproto opsproto_rt2.0 + PROPERTIES GENERATED TRUE + ) + target_sources(opsproto PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto ops_proto_headers) + + target_sources(opsproto_rt2.0 PRIVATE + ${generate_proto_srcs} + ) + add_dependencies(opsproto_rt2.0 ops_proto_headers) + endif () + + add_target_source( + TARGET_NAME opmaster_rt2.0 opmaster_static_rt2.0 + BASE_TARGET optiling + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) + + add_target_source( + TARGET_NAME opsproto_rt2.0 opsproto_static_rt2.0 + BASE_TARGET opsproto + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) + + add_static_ops( + ACLNN_SRC ${generate_aclnn_srcs} + ACLNN_INNER_SRC ${generate_aclnn_inner_srcs} + SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ) +endif () + +if (generate_aclnn_headers) + install(FILES ${generate_aclnn_headers} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL + ) +endif () + +add_library(ops_proto_headers INTERFACE) + +target_include_directories(ops_proto_headers INTERFACE + $ + $ + $ + $ +) + +if ((NOT BUILD_OPEN_PROJECT) AND ("${PRODUCT_SIDE}" STREQUAL "device")) + ExternalProject_Add(extern_opbuild_gen_default + SOURCE_DIR ${TOP_DIR}/cmake/superbuild + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -G ${CMAKE_GENERATOR} + -DHOST_PACKAGE=opp + -DBUILD_MOD=ops + -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/opbuild_output + -DFEATURE_LIST=custom_opbuild_out_dir=${generate_proto_dir} + + BUILD_COMMAND TARGETS=opbuild_gen_all $(MAKE) + INSTALL_COMMAND "" + LIST_SEPARATOR :: + EXCLUDE_FROM_ALL TRUE + ) + add_dependencies(ops_proto_headers extern_opbuild_gen_default) +else() + add_dependencies(ops_proto_headers opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) +endif () + +if (NOT BUILD_OPEN_PROJECT) + if (generate_proto_srcs) + install_package( + PACKAGE ops_adv + TARGETS ops_proto_headers + FILES ${generate_proto_headers} + DESTINATION include/ops_adv/proto + ) + endif () +endif () + +# ------------------------------------------------ opbuild ------------------------------------------------ +if (BUILD_OPEN_PROJECT) + if (generate_aclnn_srcs) + add_custom_command(OUTPUT ${generate_aclnn_srcs} ${generate_aclnn_headers} + COMMAND mkdir -p ${base_aclnn_binary_dir} + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=1 + OPS_PROJECT_NAME=aclnn + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir} + ) + endif () + + add_custom_target(opbuild_gen_default + DEPENDS ${generate_aclnn_srcs} ${generate_aclnn_headers} op_host_aclnn + ) + + if (generate_aclnn_inner_srcs) + add_custom_command(OUTPUT ${generate_aclnn_inner_srcs} + COMMAND mkdir -p ${base_aclnn_binary_dir}/inner + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=1 + OPS_PROJECT_NAME=aclnnInner + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir}/inner + ) + endif () + + add_custom_target(opbuild_gen_inner + DEPENDS ${generate_aclnn_inner_srcs} op_host_aclnnInner + ) + + if (generate_exclude_proto_srcs) + add_custom_command(OUTPUT ${generate_exclude_proto_srcs} + COMMAND mkdir -p ${base_aclnn_binary_dir}/exc + COMMAND OPS_PROTO_SEPARATE=1 + OPS_ACLNN_GEN=0 + OPS_PROJECT_NAME=aclnnExc + ${OP_BUILD_TOOL} + $ + ${base_aclnn_binary_dir}/exc + ) + endif () + + add_custom_target(opbuild_gen_exc + DEPENDS ${generate_exclude_proto_srcs} op_host_aclnnExc + ) +endif () + +# ------------------------------------------------ generate adapt py ------------------------------------------------ +add_custom_target(generate_adapt_py + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_impl_build.py + \"\" + \"\" + \"\" + \"\" + ${ASCEND_IMPL_OUT_DIR} + ${ASCEND_AUTOGEN_DIR} + --opsinfo-dir ${base_aclnn_binary_dir} ${base_aclnn_binary_dir}/inner ${base_aclnn_binary_dir}/exc +) + +add_dependencies(generate_adapt_py opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + +foreach (_op_name ${OP_LIST}) + install(FILES ${ASCEND_IMPL_OUT_DIR}/dynamic/${_op_name}.py + DESTINATION ${IMPL_DYNAMIC_INSTALL_DIR} + OPTIONAL + ) +endforeach () + +install(DIRECTORY ${OPS_ADV_UTILS_KERNEL_INC}/ + DESTINATION ${IMPL_INSTALL_DIR}/ascendc/common +) + +foreach (op_dir ${OP_DIR_LIST}) + get_filename_component(_op_name "${op_dir}" NAME) + + file(GLOB KERNEL_FILES + ${op_dir}/op_kernel/*.cpp + ${op_dir}/op_kernel/*.h + ) + + install(FILES ${KERNEL_FILES} + DESTINATION ${IMPL_INSTALL_DIR}/ascendc/${_op_name} + OPTIONAL + ) +endforeach () + +# ------------------------------------------------ generate compile cmd ------------------------------------------------ +if (BUILD_OPEN_PROJECT) + add_custom_target(prepare_build ALL) + add_custom_target(generate_compile_cmd ALL) + add_custom_target(generate_ops_info ALL) + add_dependencies(prepare_build generate_adapt_py generate_compile_cmd) + + foreach (compute_unit ${ASCEND_COMPUTE_UNIT}) + add_compile_cmd_target( + COMPUTE_UNIT ${compute_unit} + ) + + add_ops_info_target( + COMPUTE_UNIT ${compute_unit} + ) + endforeach () +else() + add_dependencies(tbe_ops_json_info generate_adapt_py) +endif () + +# ------------------------------------------------ opp kernel ------------------------------------------------ +if (ENABLE_OPS_KERNEL) + add_custom_target(ops_kernel ALL) + add_custom_target(ops_config ALL) + add_dependencies(ops_kernel ops_config) + + foreach (compute_unit ${ASCEND_COMPUTE_UNIT}) + add_bin_compile_target( + COMPUTE_UNIT + ${compute_unit} + OP_INFO + ${OP_DIR_LIST} + ) + endforeach () +endif () + +if (BUILD_OPEN_PROJECT) + add_custom_target(modify_vendor ALL + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh + ) + + # modify VENDOR_NAME in install.sh and upgrade.sh + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/scripts/install.sh ${CMAKE_CURRENT_BINARY_DIR}/scripts/upgrade.sh + COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/scripts + COMMAND cp -r ${ASCEND_PROJECT_DIR}/scripts/* ${CMAKE_CURRENT_BINARY_DIR}/scripts/ + COMMAND chmod +w ${CMAKE_CURRENT_BINARY_DIR}/scripts/* + COMMAND sed -i "s/vendor_name=customize/vendor_name=${VENDOR_NAME}/g" ${CMAKE_CURRENT_BINARY_DIR}/scripts/* + ) + + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/scripts/ + DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ + ) + + # gen version.info + set(version_info_dir ${CMAKE_CURRENT_BINARY_DIR}) + set(version_info_file ${version_info_dir}/version.info) + add_custom_target(gen_version_info ALL + DEPENDS ${version_info_file} + ) + + add_custom_command(OUTPUT ${version_info_file} + COMMAND bash ${ASCENDC_CMAKE_UTIL_DIR}/gen_version_info.sh ${ASCEND_CANN_PACKAGE_PATH} ${version_info_dir} + ) + + install(FILES ${version_info_file} + DESTINATION packages/vendors/${VENDOR_NAME}/ + ) + + # CPack config + set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) + set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION}) + set(CPACK_PACKAGE_DESCRIPTION "CPack ops project") + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack ops project") + set(CPACK_PACKAGE_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CPACK_PACKAGE_FILE_NAME "CANN-custom_ops-${CANN_VERSION}-linux.${CMAKE_SYSTEM_PROCESSOR}.run") + set(CPACK_GENERATOR External) + set(CPACK_CMAKE_GENERATOR "Unix Makefiles") + set(CPACK_EXTERNAL_ENABLE_STAGING TRUE) + set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${ASCEND_CMAKE_DIR}/makeself.cmake) + set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) + include(CPack) +endif () diff --git a/csrc/aclnn_torch_adapter/NPUBridge.cpp b/csrc/aclnn_torch_adapter/NPUBridge.cpp new file mode 100644 index 00000000000..dc335cb45b5 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUBridge.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "NPUBridge.h" + +namespace vllm_ascend +{ + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::StorageImpl *storageImpl) + { + return static_cast(storageImpl); + } + + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(c10::Storage &&storage) + { + return static_cast(storage.unsafeGetStorageImpl()); + } + + NPUStorageImpl *NPUBridge::GetNpuStorageImpl(const at::Tensor &tensor) + { + return static_cast(tensor.storage().unsafeGetStorageImpl()); + } + + NPUStorageDesc &NPUBridge::GetNpuStorageImplDesc(const at::Tensor &tensor) + { + return static_cast(tensor.storage().unsafeGetStorageImpl())->npu_desc_; + } +} diff --git a/csrc/aclnn_torch_adapter/NPUBridge.h b/csrc/aclnn_torch_adapter/NPUBridge.h new file mode 100644 index 00000000000..e93a10485a9 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUBridge.h @@ -0,0 +1,29 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include "NPUStorageImpl.h" + +namespace vllm_ascend +{ + + class NPUBridge + { + public: + // at::tensor to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(const at::Tensor &tensor); + + // c10::StorageImpl to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(c10::StorageImpl *storageImpl); + + // c10::Storage to NPUStorageImpl + static NPUStorageImpl *GetNpuStorageImpl(c10::Storage &&storage); + + // tensor to NPUStorageDesc + static NPUStorageDesc &GetNpuStorageImplDesc(const at::Tensor &tensor); + }; +} diff --git a/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp b/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp new file mode 100644 index 00000000000..9dfe0c0ccaf --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUStorageImpl.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "NPUStorageImpl.h" + +namespace vllm_ascend +{ + + NPUStorageImpl::NPUStorageImpl( + use_byte_size_t use_byte_size, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator *allocator, + bool resizable) : c10::StorageImpl(use_byte_size, + size_bytes, + at::DataPtr(std::move(data_ptr)), + allocator, + resizable) + { + } + + void NPUStorageImpl::release_resources() + { + StorageImpl::release_resources(); + } + + c10::intrusive_ptr make_npu_storage_impl( + c10::StorageImpl::use_byte_size_t, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator *allocator, + bool resizable) + { + if (data_ptr == nullptr) + { + data_ptr = allocator->allocate(size_bytes.as_int_unchecked()); + } + // Correctly create NPUStorageImpl object. + c10::intrusive_ptr npu_storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes.as_int_unchecked(), + std::move(data_ptr), + allocator, + resizable); + // There is no need to consider the NPUStorageDesc information, it will be carried out in the subsequent processing. + return npu_storage_impl; + } + +} diff --git a/csrc/aclnn_torch_adapter/NPUStorageImpl.h b/csrc/aclnn_torch_adapter/NPUStorageImpl.h new file mode 100644 index 00000000000..fcf293b1a24 --- /dev/null +++ b/csrc/aclnn_torch_adapter/NPUStorageImpl.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020, Huawei Technologies Co., Ltd +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "acl/acl_rt.h" +#include "acl/acl_base.h" + +namespace vllm_ascend +{ + + struct NPUStorageDesc + { + public: + struct use_byte_size_t + { + }; + + c10::SmallVector base_sizes_; + c10::SmallVector base_strides_; + c10::SmallVector storage_sizes_; + int64_t base_offset_ = 0; + use_byte_size_t base_dtype_ = {}; + aclFormat origin_format_ = ACL_FORMAT_UNDEFINED; + aclFormat npu_format_ = ACL_FORMAT_ND; + // used to make CANN GE tensor from storagImpl + caffe2::TypeMeta data_type_ = caffe2::TypeMeta::Make(); + }; + + struct NPUStorageImpl : public c10::StorageImpl + { + explicit NPUStorageImpl( + use_byte_size_t use_byte_size, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator *allocator, + bool resizable); + ~NPUStorageImpl() override = default; + + void release_resources() override; + + NPUStorageDesc npu_desc_; + + NPUStorageDesc get_npu_desc() const + { + return npu_desc_; + } + }; + + c10::intrusive_ptr make_npu_storage_impl( + c10::StorageImpl::use_byte_size_t, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator *allocator, + bool resizable); + +} diff --git a/csrc/aclnn_torch_adapter/op_api_common.h b/csrc/aclnn_torch_adapter/op_api_common.h new file mode 100644 index 00000000000..e4c8a517d0e --- /dev/null +++ b/csrc/aclnn_torch_adapter/op_api_common.h @@ -0,0 +1,591 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OP_API_COMMON_ADAPTER +#define OP_API_COMMON_ADAPTER + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "NPUBridge.h" +#include "NPUStorageImpl.h" + +#define NPU_NAME_SPACE at_npu::native +using namespace at; + +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; + +typedef aclTensor *(*_aclCreateTensor)( + const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t *stride, int64_t offset, aclFormat format, + const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data); +typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type); +typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size); +typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value, + uint64_t size); +typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size); +typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value, + uint64_t size); + +typedef int (*_aclDestroyTensor)(const aclTensor *tensor); +typedef int (*_aclDestroyScalar)(const aclScalar *scalar); +typedef int (*_aclDestroyIntArray)(const aclIntArray *array); +typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array); +typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array); +typedef int (*_aclDestroyTensorList)(const aclTensorList *array); + +constexpr int kHashBufSize = 8192; +constexpr int kHashBufMaxSize = kHashBufSize + 1024; +extern thread_local char g_hashBuf[kHashBufSize]; +extern thread_local int g_hashOffset; + +#ifdef MMCV_WITH_XLA +#define DEVICE_TYPE at_npu::key::NativeDeviceType +#else +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 +#endif + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable + [static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +#define GET_OP_API_FUNC(apiName) \ + reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + g_hashOffset += size_expression; + +bool IsOpInputBaseFormat(const at::Tensor &tensor) +{ + if (!tensor.is_privateuseone()) { + return true; + } + const auto format = vllm_ascend::NPUBridge::GetNpuStorageImplDesc(tensor).npu_format_; + return (format == ACL_FORMAT_ND) || (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_NHWC) || + (format == ACL_FORMAT_NCDHW); +} + +inline const char *GetOpApiLibName(void) { return "libopapi.so"; } + +inline const char *GetCustOpApiLibName(void) { return "libcust_opapi.so"; } + +inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, + const char *apiName) { + auto funcAddr = dlsym(handler, apiName); + return funcAddr; +} + +inline void *GetOpApiLibHandler(const char *libName) { + auto handler = dlopen(libName, RTLD_LAZY); + return handler; +} + +inline void *GetOpApiFuncAddr(const char *apiName) { + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = + GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler == nullptr) { + return nullptr; + } + return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); +} + +inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) { + c10::Scalar expScalar; + const at::Tensor *aclInput = &tensor; + if (aclInput->scalar_type() == at::ScalarType::Double) { + double value = *(double *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Long) { + int64_t value = *(int64_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Float) { + float value = *(float *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Int) { + int value = *(int *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Half) { + c10::Half value = *(c10::Half *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Bool) { + int8_t value = *(int8_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::BFloat16) { + c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } + return expScalar; +} + +inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) { + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); +} + +inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, + at::ScalarType scalar_data_type) { + return CopyTensorHostToDevice( + scalar_to_tensor(cpu_scalar).to(scalar_data_type)); +} + +inline aclTensor *ConvertType(const at::Tensor &at_tensor) { + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!at_tensor.defined()) { + return nullptr; + } + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + auto itemsize = at_tensor.itemsize(); + TORCH_CHECK(itemsize != 0, "When ConvertType, tensor item size cannot be zero."); + + const auto dimNum = at_tensor.sizes().size(); + aclFormat format = ACL_FORMAT_ND; + if (!IsOpInputBaseFormat(at_tensor)) { + format = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.npu_format_; + if (acl_data_type != ACL_STRING) { + storageDims = vllm_ascend::NPUBridge::GetNpuStorageImpl(at_tensor)->npu_desc_.storage_sizes_; + } + } else { + switch (dimNum) { + case 3: + format = ACL_FORMAT_NCL; + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + if (acl_data_type != ACL_STRING) { + storageDims.push_back(at_tensor.storage().nbytes() / itemsize); + } + } + + if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); + at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); + return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(), + acl_data_type, aclInput.strides().data(), + aclInput.storage_offset(), format, + storageDims.data(), storageDims.size(), + const_cast(aclInput.storage().data())); + } + + auto acl_tensor = aclCreateTensor( + at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, + at_tensor.strides().data(), at_tensor.storage_offset(), format, + storageDims.data(), storageDims.size(), + const_cast(at_tensor.storage().data())); + return acl_tensor; +} + +inline aclScalar *ConvertType(const at::Scalar &at_scalar) { + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + if (aclCreateScalar == nullptr) { + return nullptr; + } + + at::ScalarType scalar_data_type = at_scalar.type(); + aclDataType acl_data_type = + kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + aclScalar *acl_scalar = nullptr; + switch (scalar_data_type) { + case at::ScalarType::Double: { + double value = at_scalar.toDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Long: { + int64_t value = at_scalar.toLong(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Bool: { + bool value = at_scalar.toBool(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::ComplexDouble: { + auto value = at_scalar.toComplexDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + default: + acl_scalar = nullptr; + break; + } + return acl_scalar; +} + +inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) { + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + if (aclCreateIntArray == nullptr) { + return nullptr; + } + auto array = aclCreateIntArray(at_array.data(), at_array.size()); + return array; +} + +template +inline aclBoolArray *ConvertType(const std::array &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclBoolArray *ConvertType(const at::ArrayRef &value) { + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) { + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + if (aclCreateTensorList == nullptr) { + return nullptr; + } + + std::vector tensor_list(at_tensor_list.size()); + for (size_t i = 0; i < at_tensor_list.size(); i++) { + tensor_list[i] = ConvertType(at_tensor_list[i]); + } + auto acl_tensor_list = + aclCreateTensorList(tensor_list.data(), tensor_list.size()); + return acl_tensor_list; +} + +inline aclTensor *ConvertType(const c10::optional &opt_tensor) { + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return ConvertType(opt_tensor.value()); + } + return nullptr; +} + +inline aclIntArray *ConvertType( + const c10::optional &opt_array) { + if (opt_array.has_value()) { + return ConvertType(opt_array.value()); + } + return nullptr; +} + +inline aclScalar *ConvertType(const c10::optional &opt_scalar) { + if (opt_scalar.has_value()) { + return ConvertType(opt_scalar.value()); + } + return nullptr; +} + +inline aclDataType ConvertType(const at::ScalarType scalarType) { + return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; +} + +template +T ConvertType(T value) { + return value; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, + std::index_sequence) { + typedef int (*OpApiFunc)( + typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) { + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, + std::make_index_sequence{}); +} + +inline void Release(aclTensor *p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +inline void Release(aclScalar *p) { + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + if (aclDestroyScalar == nullptr) { + return; + } + aclDestroyScalar(p); +} + +inline void Release(aclIntArray *p) { + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + if (aclDestroyIntArray == nullptr) { + return; + } + + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray *p) { + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + if (aclDestroyBoolArray == nullptr) { + return; + } + + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList *p) { + static const auto aclDestroyTensorList = + GET_OP_API_FUNC(aclDestroyTensorList); + if (aclDestroyTensorList == nullptr) { + return; + } + + aclDestroyTensorList(p); +} + +template +void Release(T value) { + (void)value; +} + +template +void CallRelease(Tuple t, std::index_sequence) { + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple &t) { + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std::make_index_sequence{}); +} + +template +constexpr auto ConvertTypes(Ts &... args) { + return std::make_tuple(ConvertType(args)...); +} + +template +auto call(Function f, Tuple t, std::index_sequence) { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template +void AddParamToBuf(const std::array &value) { + MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); +} + +template +void AddParamToBuf(const T &value) { + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void AddParamToBuf(const at::Tensor &); +void AddParamToBuf(const at::Scalar &); +void AddParamToBuf(const at::IntArrayRef &); +void AddParamToBuf(const at::ArrayRef &); +void AddParamToBuf(const at::TensorList &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const at::ScalarType); +void AddParamToBuf(const string &); +void AddParamToBuf(); + +template +void AddParamToBuf(const T &arg, Args &... args) { + AddParamToBuf(arg); + AddParamToBuf(args...); +} + +uint64_t CalcHashId(); +typedef int (*InitHugeMemThreadLocal)(void *, bool); +typedef void (*UnInitHugeMemThreadLocal)(void *, bool); +typedef void (*ReleaseHugeMem)(void *, bool); + +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = \ + GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = \ + GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = \ + GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK( \ + getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \ + #aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \ + GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = \ + reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = \ + reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = \ + ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = \ + ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, \ + "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = \ + at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = \ + at::empty({workspace_size}, options.dtype(kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, \ + acl_stream, executor]() -> int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \ + const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = \ + opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \ + aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = \ + reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ + } while (false) + +#endif diff --git a/csrc/build.sh b/csrc/build.sh new file mode 100644 index 00000000000..76efeaaa072 --- /dev/null +++ b/csrc/build.sh @@ -0,0 +1,189 @@ +#!/bin/bash +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +set -e + +CURRENT_DIR=$(dirname $(readlink -f ${BASH_SOURCE[0]})) +BUILD_DIR=${CURRENT_DIR}/build +OUTPUT_DIR=${CURRENT_DIR}/output +USER_ID=$(id -u) +PARENT_JOB="false" +CHECK_COMPATIBLE="true" +ASAN="false" +COV="false" +VERBOSE="false" + +if [ "${USER_ID}" != "0" ]; then + DEFAULT_TOOLKIT_INSTALL_DIR="${HOME}/Ascend/ascend-toolkit/latest" + DEFAULT_INSTALL_DIR="${HOME}/Ascend/latest" +else + DEFAULT_TOOLKIT_INSTALL_DIR="/usr/local/Ascend/ascend-toolkit/latest" + DEFAULT_INSTALL_DIR="/usr/local/Ascend/latest" +fi + +CUSTOM_OPTION="-DBUILD_OPEN_PROJECT=ON" + +function help_info() { + echo "Usage: $0 [options]" + echo "Options:" + echo + echo "-h|--help Displays help message." + echo + echo "-n|--op-name Specifies the compiled operator. If there are multiple values, separate them with semicolons and use quotation marks. The default is all." + echo " For example: -n \"flash_attention_score\" or -n \"flash_attention_score;flash_attention_score_grad\"" + echo + echo "-c|--compute-unit Specifies the chip type. If there are multiple values, separate them with semicolons and use quotation marks. The default is ascend910b." + echo " For example: -c \"ascend910b\" or -c \"ascend910b;ascend310p\"" + echo + echo "--cov Compiles with cov." + echo + echo "--verbose Displays more compilation information." + echo +} + +function log() { + local current_time=`date +"%Y-%m-%d %H:%M:%S"` + echo "[$current_time] "$1 +} + +function set_env() +{ + source $ASCEND_CANN_PACKAGE_PATH/bin/setenv.bash || echo "0" + + export BISHENG_REAL_PATH=$(which bisheng || true) + + if [ -z "${BISHENG_REAL_PATH}" ];then + log "Error: bisheng compilation tool not found, Please check whether the cann package or environment variables are set." + exit 1 + fi +} + +function clean() +{ + if [ -n "${BUILD_DIR}" ];then + rm -rf ${BUILD_DIR} + fi + mkdir -p ${BUILD_DIR} ${OUTPUT_DIR} +} + +function cmake_config() +{ + local extra_option="$1" + log "Info: cmake config ${CUSTOM_OPTION} ${extra_option} ." + cmake .. ${CUSTOM_OPTION} ${extra_option} +} + +function build() +{ + local target="$1" + if [ "${VERBOSE}" == "true" ];then + local option="--verbose" + fi + cmake --build . --target ${target} ${JOB_NUM} ${option} +} + +function gen_bisheng(){ + local ccache_program=$1 + local gen_bisheng_dir=${BUILD_DIR}/gen_bisheng_dir + + if [ ! -d "${gen_bisheng_dir}" ];then + mkdir -p ${gen_bisheng_dir} + fi + + pushd ${gen_bisheng_dir} + $(> bisheng) + echo "#!/bin/bash" >> bisheng + echo "ccache_args=""\"""${ccache_program} ${BISHENG_REAL_PATH}""\"" >> bisheng + echo "args=""$""@" >> bisheng + + if [ "${VERBOSE}" == "true" ];then + echo "echo ""\"""$""{ccache_args} ""$""args""\"" >> bisheng + fi + + echo "eval ""\"""$""{ccache_args} ""$""args""\"" >> bisheng + chmod +x bisheng + + export PATH=${gen_bisheng_dir}:$PATH + popd +} + +function build_package(){ + build package +} + +function build_host(){ + build_package +} + +function build_kernel(){ + build ops_kernel +} + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + help_info + exit + ;; + -n|--op-name) + ascend_op_name="$2" + shift 2 + ;; + -c|--compute-unit) + ascend_compute_unit="$2" + shift 2 + ;; + *) + help_info + exit 1 + ;; + esac +done + +if [ -n "${ascend_compute_unit}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_COMPUTE_UNIT=${ascend_compute_unit}" +fi + +if [ -n "${ascend_op_name}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DASCEND_OP_NAME=${ascend_op_name}" +fi + +if [ -n "${ASCEND_HOME_PATH}" ];then + ASCEND_CANN_PACKAGE_PATH=${ASCEND_HOME_PATH} +elif [ -n "${ASCEND_OPP_PATH}" ];then + ASCEND_CANN_PACKAGE_PATH=$(dirname ${ASCEND_OPP_PATH}) +elif [ -d "${DEFAULT_TOOLKIT_INSTALL_DIR}" ];then + ASCEND_CANN_PACKAGE_PATH=${DEFAULT_TOOLKIT_INSTALL_DIR} +elif [ -d "${DEFAULT_INSTALL_DIR}" ];then + ASCEND_CANN_PACKAGE_PATH=${DEFAULT_INSTALL_DIR} +else + log "Error: Please set the toolkit package installation directory through parameter -p|--package-path." + exit 1 +fi + +if [ "${PARENT_JOB}" == "false" ];then + CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2)) + JOB_NUM="-j${CPU_NUM}" +fi + +CUSTOM_OPTION="${CUSTOM_OPTION} -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE}" + +set_env +clean + +ccache_system=$(which ccache || true) +if [ -n "${ccache_system}" ];then + CUSTOM_OPTION="${CUSTOM_OPTION} -DENABLE_CCACHE=ON -DCUSTOM_CCACHE=${ccache_system}" + gen_bisheng ${ccache_system} +fi + +cd ${BUILD_DIR} +cmake_config +build_package diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh new file mode 100644 index 00000000000..9dba287e3ae --- /dev/null +++ b/csrc/build_aclnn.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +ROOT_DIR=$1 +SOC_VERSION=$2 + +if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then + # ASCEND310P series + # currently, no custom aclnn ops for ASCEND310 series + # CUSTOM_OPS="" + # SOC_ARG="ascend310p" + exit 0 +elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then + # ASCEND910B (A2) series + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + SOC_ARG="ascend910b" +elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then + # ASCEND910C (A3) series + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + SOC_ARG="ascend910_93" +else + # others + # currently, no custom aclnn ops for other series + exit 0 +fi + +# build custom ops +cd csrc +rm -rf build output +echo "building custom ops $CUSTOM_OPS for $SOC_VERSION" +bash build.sh -n $CUSTOM_OPS -c $SOC_ARG + +# install custom ops to vllm_ascend/_cann_ops_custom +./output/CANN-custom_ops*.run --install-path=$ROOT_DIR/vllm_ascend/_cann_ops_custom +source $ROOT_DIR/vllm_ascend/_cann_ops_custom/vendors/customize/bin/set_env.bash diff --git a/csrc/cmake/config.cmake b/csrc/cmake/config.cmake new file mode 100644 index 00000000000..38553f824e1 --- /dev/null +++ b/csrc/cmake/config.cmake @@ -0,0 +1,235 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +######################################################################################################################## +# Environment Check +######################################################################################################################## + +# Python3 +find_package(Python3) +if ((NOT Python3_FOUND) OR (${Python3_EXECUTABLE} STREQUAL "")) + message(FATAL_ERROR "Can't find python3.") +endif () +set(HI_PYTHON "${Python3_EXECUTABLE}" CACHE STRING "python executor") + +# Get the base CANN path +if (CUSTOM_ASCEND_CANN_PACKAGE_PATH) + set(ASCEND_CANN_PACKAGE_PATH ${CUSTOM_ASCEND_CANN_PACKAGE_PATH}) +elseif (DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH}) +elseif (DEFINED ENV{ASCEND_OPP_PATH}) + get_filename_component(ASCEND_CANN_PACKAGE_PATH "$ENV{ASCEND_OPP_PATH}/.." ABSOLUTE) +else() + set(ASCEND_CANN_PACKAGE_PATH "/usr/local/Ascend/latest") +endif () +message(STATUS "ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH}") + +######################################################################################################################## +# Common Configuration +######################################################################################################################## + +# Switches +option(PREPARE_BUILD "Prepare build." OFF) +option(ENABLE_OPS_HOST "Build ops host." ON) +option(ENABLE_OPS_KERNEL "Build ops kernel." ON) +if (TESTS_EXAMPLE_OPS_TEST OR TESTS_UT_OPS_TEST) + set(ENABLE_OPS_KERNEL OFF) +endif () +set(OP_DEBUG_CONFIG "false" CACHE STRING "op debug config") + +# Path configuration +# Source tree related paths +get_filename_component(OPS_ADV_DIR "${CMAKE_CURRENT_SOURCE_DIR}" REALPATH) +get_filename_component(OPS_ADV_CMAKE_DIR "${OPS_ADV_DIR}/cmake" REALPATH) +get_filename_component(OPS_ADV_UTILS_KERNEL_INC "${OPS_ADV_DIR}/utils/inc/kernel" REALPATH) + + +# Build tree related paths +set(ASCEND_IMPL_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/impl CACHE STRING "ascend impl output directories") +set(ASCEND_BINARY_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary CACHE STRING "ascend binary output directories") +set(ASCEND_AUTOGEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/autogen CACHE STRING "Auto generate file directories") +set(ASCEND_CUSTOM_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_compile_options.ini) +set(ASCEND_CUSTOM_TILING_KEYS ${ASCEND_AUTOGEN_DIR}/custom_tiling_keys.ini) +set(ASCEND_CUSTOM_OPC_OPTIONS ${ASCEND_AUTOGEN_DIR}/custom_opc_options.ini) +set(OP_BUILD_TOOL ${ASCEND_CANN_PACKAGE_PATH}/tools/opbuild/op_build CACHE STRING "op_build tool") +file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_DIR}) +file(REMOVE ${ASCEND_CUSTOM_OPTIONS}) +file(TOUCH ${ASCEND_CUSTOM_OPTIONS}) +file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS}) +file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS}) +file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS}) +file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS}) +if (BUILD_OPEN_PROJECT) + if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project/cmake) + set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project) + else() + set(ASCEND_PROJECT_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/op_project_templates/ascendc/customize) + endif() + set(ASCEND_CMAKE_DIR ${ASCEND_PROJECT_DIR}/cmake CACHE STRING "ascend project cmake") + set(IMPL_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl) + set(IMPL_DYNAMIC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/${VENDOR_NAME}_impl/dynamic) + set(ACLNN_INC_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_api/include) +else() + set(ASCEND_CMAKE_DIR ${TOP_DIR}/asl/ops/cann/ops/built-in/ascendc/samples/customize/cmake CACHE STRING "ascend project cmake") + set(IMPL_INSTALL_DIR lib/ascendc/impl) + set(IMPL_DYNAMIC_INSTALL_DIR lib/ascendc/impl/dynamic) + set(ACLNN_INC_INSTALL_DIR lib/include) + set(OPS_STATIC_TYPES infer train) + set(OPS_STATIC_SCRIPT ${TOP_DIR}/asl/ops/cann/ops/built-in/kernel/binary_script/build_opp_kernel_static.py) +endif () +set(ASCENDC_CMAKE_UTIL_DIR ${ASCEND_CMAKE_DIR}/util) +set(CUSTOM_DIR ${CMAKE_BINARY_DIR}/custom) +set(TILING_CUSTOM_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/op_tiling) +set(TILING_CUSTOM_FILE ${TILING_CUSTOM_DIR}/liboptiling.so) + +# Temporary adaptation for ascendc changes, to be removed after switching to the new version of ascendc +if(EXISTS ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py) + set(ADD_OPS_COMPILE_OPTION_V2 ON) +else() + set(ADD_OPS_COMPILE_OPTION_V2 OFF) +endif() + +######################################################################################################################## +# CMake Options, Default Parameters Setting +# Configure CMake options and default parameters according to the CMake build process +# CMake build process: 1) Configuration phase; 2) Build phase; 3) Installation phase; +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + # Build phase + # Build type + # The Generator in CMake is a tool used to generate native build systems. Generally divided into two types: + # 1. Single-configuration generator: + # In the configuration phase, only one build type is allowed to be specified through the variable CMAKE_BUILD_TYPE; + # In the build phase, the build type cannot be changed, and only the build type specified through the variable CMAKE_BUILD_TYPE in the configuration phase can be used; + # Common generators of this type include: Ninja, Unix Makefiles + # 2. Multi-configuration generator: + # In the configuration phase, only the list of build types available in the build phase is specified through the variable CMAKE_CONFIGURATION_TYPES; + # In the build phase, the specific build type of the build phase is specified through the "--config" parameter; + # Common generators of this type include: Xcode, Visual Studio + # Therefore: + # 1. In the single-configuration generator scenario, if the build type (CMAKE_BUILD_TYPE) is not specified, the default is Debug; + # 2. In the multi-configuration generator scenario, if the build types available in the build phase (CMAKE_CONFIGURATION_TYPES) are not specified, + # it is defaulted to the full set of build types allowed by CMake [Debug;Release;MinSizeRel;RelWithDebInfo] + get_property(GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + if (GENERATOR_IS_MULTI_CONFIG) + if (NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;MinSizeRel;RelWithDebInfo" CACHE STRING "Configuration Build type" FORCE) + endif () + else () + if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type(default Debug)" FORCE) + endif () + endif () + + # Build phase + # Executable runtime library file search path RPATH + # Do not skip RPATH in UTest and Example scenarios + if (TESTS_UT_OPS_TEST OR TESTS_EXAMPLE_OPS_TEST) + set(CMAKE_SKIP_RPATH FALSE) + else () + set(CMAKE_SKIP_RPATH TRUE) + endif () + + # Build phase + # CCACHE configuration + if (ENABLE_CCACHE) + if (CUSTOM_CCACHE) + set(CCACHE_PROGRAM ${CUSTOM_CCACHE}) + else() + find_program(CCACHE_PROGRAM ccache) + endif () + if (CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "C cache Compiler") + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM} CACHE PATH "CXX cache Compiler") + endif () + endif () + + # Installation phase + # Installation path + # When CMAKE_INSTALL_PREFIX is not explicitly set (i.e., CMAKE_INSTALL_PREFIX takes the default value), + # correct its value to be level with the build tree root directory CMAKE_CURRENT_BINARY_DIR + if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + get_filename_component(_Install_Path_Prefix "${CMAKE_CURRENT_BINARY_DIR}/../output" REALPATH) + set(CMAKE_INSTALL_PREFIX "${_Install_Path_Prefix}" CACHE STRING "Install path" FORCE) + endif () +endif () + +######################################################################################################################## +# Public Compilation Parameters +######################################################################################################################## +list(TRANSFORM ASCEND_COMPUTE_UNIT TOLOWER) +if (BUILD_OPEN_PROJECT) + message(STATUS "ENABLE_CCACHE=${ENABLE_CCACHE}, CUSTOM_CCACHE=${CUSTOM_CCACHE}") + message(STATUS "CCACHE_PROGRAM=${CCACHE_PROGRAM}") + message(STATUS "ASCEND_COMPUTE_UNIT=${ASCEND_COMPUTE_UNIT}") + message(STATUS "ASCEND_OP_NAME=${ASCEND_OP_NAME}") + message(STATUS "TILING_KEY=${TILING_KEY}") + message(STATUS "TESTS_UT_OPS_TEST=${TESTS_UT_OPS_TEST}") + message(STATUS "TESTS_EXAMPLE_OPS_TEST=${TESTS_EXAMPLE_OPS_TEST}") +endif () + +######################################################################################################################## +# Preprocessing +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + if (NOT PREPARE_BUILD AND ENABLE_OPS_KERNEL) + if (TILING_KEY) + string(REPLACE ";" "::" EP_TILING_KEY "${TILING_KEY}") + else() + set(EP_TILING_KEY FALSE) + endif () + + if (OPS_COMPILE_OPTIONS) + string(REPLACE ";" "::" EP_OPS_COMPILE_OPTIONS "${OPS_COMPILE_OPTIONS}") + else() + set(EP_OPS_COMPILE_OPTIONS FALSE) + endif () + + string(REPLACE ";" "::" EP_ASCEND_COMPUTE_UNIT "${ASCEND_COMPUTE_UNIT}") + + execute_process(COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/scripts/prepare.sh + -s ${CMAKE_CURRENT_SOURCE_DIR} + -b ${CMAKE_CURRENT_BINARY_DIR}/prepare_build + -p ${ASCEND_CANN_PACKAGE_PATH} + --autogen-dir ${ASCEND_AUTOGEN_DIR} + --build-open-project ${BUILD_OPEN_PROJECT} + --binary-out-dir ${ASCEND_BINARY_OUT_DIR} + --impl-out-dir ${ASCEND_IMPL_OUT_DIR} + --op-build-tool ${OP_BUILD_TOOL} + --ascend-cmake-dir ${ASCEND_CMAKE_DIR} + --tiling-key ${EP_TILING_KEY} + --ops-compile-options ${EP_OPS_COMPILE_OPTIONS} + --check-compatible ${CHECK_COMPATIBLE} + --ascend-compute_unit ${EP_ASCEND_COMPUTE_UNIT} + --op_debug_config ${OP_DEBUG_CONFIG} + --ascend-op-name "${ASCEND_OP_NAME}" + RESULT_VARIABLE result + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PREPARE_BUILD_OUTPUT_VARIABLE) + if (result) + message(FATAL_ERROR "Error: ops prepare build failed.") + endif () + + file(REMOVE ${ASCEND_CUSTOM_OPTIONS}) + file(TOUCH ${ASCEND_CUSTOM_OPTIONS}) + file(REMOVE ${ASCEND_CUSTOM_TILING_KEYS}) + file(TOUCH ${ASCEND_CUSTOM_TILING_KEYS}) + file(REMOVE ${ASCEND_CUSTOM_OPC_OPTIONS}) + file(TOUCH ${ASCEND_CUSTOM_OPC_OPTIONS}) + endif () +endif () + +######################################################################################################################## +# Other Configuration +######################################################################################################################## +if (BUILD_OPEN_PROJECT) + if (TESTS_UT_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/config_utest.cmake) + endif () +endif () diff --git a/csrc/cmake/func.cmake b/csrc/cmake/func.cmake new file mode 100644 index 00000000000..f2bebf75639 --- /dev/null +++ b/csrc/cmake/func.cmake @@ -0,0 +1,609 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +function(add_target_source) + cmake_parse_arguments(ADD "" "BASE_TARGET;SRC_DIR" "TARGET_NAME" ${ARGN}) + + get_target_property(all_srcs ${ADD_BASE_TARGET} SOURCES) + set(add_srcs) + foreach(_src ${all_srcs}) + string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_src}") + if (is_match) + list(APPEND add_srcs ${_src}) + endif () + endforeach() + + get_target_property(all_includes ${ADD_BASE_TARGET} INCLUDE_DIRECTORIES) + set(add_includes) + foreach(_include ${all_includes}) + string(REGEX MATCH "^${ADD_SRC_DIR}" is_match "${_include}") + if (is_match) + list(APPEND add_includes ${_include}) + endif () + endforeach() + + foreach(_target_name ${ADD_TARGET_NAME}) + target_sources(${_target_name} PRIVATE + ${add_srcs} + ) + + target_include_directories(${_target_name} PRIVATE + ${add_includes} + ) + endforeach() +endfunction() + +function(op_add_subdirectory OP_LIST OP_DIR_LIST) + set(_OP_LIST) + set(_OP_DIR_LIST) + + file(GLOB OP_HOST_CMAKE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/**/op_host/CMakeLists.txt") + + foreach(OP_CMAKE_FILE ${OP_HOST_CMAKE_FILES}) + get_filename_component(OP_HOST_DIR "${OP_CMAKE_FILE}" DIRECTORY) + get_filename_component(OP_DIR "${OP_HOST_DIR}" DIRECTORY) + get_filename_component(OP_NAME "${OP_DIR}" NAME) + + if (NOT BUILD_OPEN_PROJECT) + if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${OP_NAME}) + continue() + endif () + endif () + + if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "") + if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + if (NOT ${OP_NAME} IN_LIST ASCEND_OP_NAME) + continue() + endif () + endif () + endif () + + list(APPEND _OP_LIST ${OP_NAME}) + list(APPEND _OP_DIR_LIST ${OP_DIR}) + endforeach() + + list(REMOVE_DUPLICATES _OP_LIST) + list(REMOVE_DUPLICATES _OP_DIR_LIST) + list(SORT _OP_LIST) + list(SORT _OP_DIR_LIST) + set(${OP_LIST} ${_OP_LIST} PARENT_SCOPE) + set(${OP_DIR_LIST} ${_OP_DIR_LIST} PARENT_SCOPE) +endfunction() + +function(op_add_depend_directory) + cmake_parse_arguments(DEP "" "OP_DIR_LIST" "OP_LIST" ${ARGN}) + set(_OP_DEPEND_DIR_LIST) + foreach(op_name ${DEP_OP_LIST}) + if (DEFINED ${op_name}_depends) + foreach(depend_info ${${op_name}_depends}) + if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}/op_host/CMakeLists.txt) + continue() + endif () + + get_filename_component(_depend_op_name "${depend_info}" NAME) + if (NOT BUILD_OPEN_PROJECT) + if (EXISTS ${TOP_DIR}/asl/ops/cann/ops/built-in/tbe/impl/ascendc/${_depend_op_name}) + continue() + endif () + endif () + + if (NOT ${_depend_op_name} IN_LIST DEP_OP_LIST) + list(APPEND _OP_DEPEND_DIR_LIST ${CMAKE_CURRENT_SOURCE_DIR}/${depend_info}) + endif () + endforeach() + endif() + endforeach() + + list(SORT _OP_DEPEND_DIR_LIST) + set(${DEP_OP_DIR_LIST} ${_OP_DEPEND_DIR_LIST} PARENT_SCOPE) +endfunction() + +function(add_compile_cmd_target) + cmake_parse_arguments(CMD "" "COMPUTE_UNIT" "" ${ARGN}) + + if(ADD_OPS_COMPILE_OPTION_V2) + set(OP_DEBUG_CONFIG_OPTION --opc-config-file ${ASCEND_CUSTOM_OPC_OPTIONS}) + else() + if(OP_DEBUG_CONFIG) + set(OP_DEBUG_CONFIG_OPTION --op-debug-config ${OP_DEBUG_CONFIG}) + endif() + set(OP_TILING_KEY_OPTION --tiling-keys ${ASCEND_CUSTOM_TILING_KEYS}) + endif() + + set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${CMD_COMPUTE_UNIT}) + set(GEN_OUT_DIR ${_OUT_DIR}/gen) + set(COMPILE_CMD_TARGET generate_compile_cmd_${CMD_COMPUTE_UNIT}) + add_custom_target(${COMPILE_CMD_TARGET} ALL + COMMAND mkdir -p ${GEN_OUT_DIR} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/inner/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_bin_param_build.py + ${base_aclnn_binary_dir}/exc/aic-${CMD_COMPUTE_UNIT}-ops-info.ini + ${GEN_OUT_DIR} + ${CMD_COMPUTE_UNIT} + ${OP_TILING_KEY_OPTION} + ${OP_DEBUG_CONFIG_OPTION} + ) + + add_dependencies(${COMPILE_CMD_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + add_dependencies(generate_compile_cmd ${COMPILE_CMD_TARGET}) +endfunction() + +function(add_ops_info_target) + cmake_parse_arguments(OPINFO "" "COMPUTE_UNIT" "" ${ARGN}) + + set(OPS_INFO_TARGET generate_ops_info_${OPINFO_COMPUTE_UNIT}) + set(OPS_INFO_JSON ${ASCEND_AUTOGEN_DIR}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.json) + set(CUSTOM_OPS_INFO_DIR ${CUSTOM_DIR}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT}) + + set(OPS_INFO_INI ${base_aclnn_binary_dir}/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + set(OPS_INFO_INNER_INI ${base_aclnn_binary_dir}/inner/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + set(OPS_INFO_EXCLUDE_INI ${base_aclnn_binary_dir}/exc/aic-${OPINFO_COMPUTE_UNIT}-ops-info.ini) + + add_custom_command(OUTPUT ${OPS_INFO_JSON} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/parse_ini_to_json.py + ${OPS_INFO_INI} + ${OPS_INFO_INNER_INI} + ${OPS_INFO_EXCLUDE_INI} + ${OPS_INFO_JSON} + COMMAND mkdir -p ${CUSTOM_OPS_INFO_DIR} + COMMAND cp -f ${OPS_INFO_JSON} ${CUSTOM_OPS_INFO_DIR} + ) + + add_custom_target(${OPS_INFO_TARGET} ALL + DEPENDS ${OPS_INFO_JSON} + ) + + add_dependencies(${OPS_INFO_TARGET} opbuild_gen_default opbuild_gen_inner opbuild_gen_exc) + add_dependencies(generate_ops_info ${OPS_INFO_TARGET}) + + install(FILES ${OPS_INFO_JSON} + DESTINATION packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/config/${OPINFO_COMPUTE_UNIT} OPTIONAL + ) +endfunction() + +function(add_ops_compile_options) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;OPTIONS" ${ARGN}) + + if(NOT OP_COMPILE_OPTIONS) + return() + endif() + + if(ADD_OPS_COMPILE_OPTION_V2) + execute_process(COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_gen_options.py + ${ASCEND_CUSTOM_OPTIONS} ${OP_COMPILE_OP_NAME} + ${OP_COMPILE_COMPUTE_UNIT} ${OP_COMPILE_OPTIONS} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR) + if (EXEC_RESULT) + message("add ops compile options info: ${EXEC_INFO}") + message("add ops compile options error: ${EXEC_ERROR}") + message(FATAL_ERROR "Error: add ops compile options failed!") + endif () + else() + file(APPEND ${ASCEND_CUSTOM_OPTIONS} + "${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_OPTIONS}\n" + ) + endif() +endfunction() + +function(add_ops_tiling_keys) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;TILING_KEYS" ${ARGN}) + + if(NOT OP_COMPILE_TILING_KEYS) + return() + endif() + + if(ADD_OPS_COMPILE_OPTION_V2) + list(JOIN OP_COMPILE_TILING_KEYS "," STRING_TILING_KEYS) + add_ops_compile_options( + OP_NAME ${OP_COMPILE_OP_NAME} + OPTIONS --tiling_key=${STRING_TILING_KEYS} + ) + else() + file(APPEND ${ASCEND_CUSTOM_TILING_KEYS} + "${OP_COMPILE_OP_NAME},${OP_COMPILE_COMPUTE_UNIT},${OP_COMPILE_TILING_KEYS}\n" + ) + endif() +endfunction() + +function(add_opc_config) + cmake_parse_arguments(OP_COMPILE "" "OP_NAME" "COMPUTE_UNIT;CONFIG" ${ARGN}) + + if(NOT ADD_OPS_COMPILE_OPTION_V2) + return() + endif() + + if(NOT OP_COMPILE_CONFIG) + return() + endif() + + string(REPLACE "," ";" OP_COMPILE_CONFIG_LIST "${OP_COMPILE_CONFIG}") + + set(_OPC_CONFIG) + + foreach(_option ${OP_COMPILE_CONFIG_LIST}) + if("${_option}" STREQUAL "ccec_g") + list(APPEND _OPC_CONFIG "-g") + elseif("${_option}" STREQUAL "ccec_O0") + list(APPEND _OPC_CONFIG "-O0") + elseif("${_option}" STREQUAL "oom") + list(APPEND _OPC_CONFIG "--oom") + elseif("${_option}" STREQUAL "dump_cce") + list(APPEND _OPC_CONFIG "--save-temp-files") + endif() + endforeach() + + if(_OPC_CONFIG) + add_ops_compile_options( + OP_NAME ${OP_COMPILE_OP_NAME} + OPTIONS ${_OPC_CONFIG} + ) + endif() +endfunction() + +function(add_ops_src_copy) + cmake_parse_arguments(SRC_COPY "" "TARGET_NAME;SRC;DST;BE_RELIED;COMPUTE_UNIT" "" ${ARGN}) + + set(OPS_UTILS_INC_KERNEL_TARGET ops_utils_inc_kernel_${SRC_COPY_COMPUTE_UNIT}) + if (EXISTS ${OPS_ADV_UTILS_KERNEL_INC}) + if (NOT TARGET ${OPS_UTILS_INC_KERNEL_TARGET}) + get_filename_component(_ROOT_OPS_SRC_DIR "${SRC_COPY_DST}" DIRECTORY) + set(OPS_UTILS_INC_KERNEL_DIR ${_ROOT_OPS_SRC_DIR}/ascendc/common) + add_custom_command(OUTPUT ${OPS_UTILS_INC_KERNEL_DIR} + COMMAND mkdir -p ${OPS_UTILS_INC_KERNEL_DIR} + COMMAND cp -rf ${OPS_ADV_UTILS_KERNEL_INC}/*.* ${OPS_UTILS_INC_KERNEL_DIR} + ) + + add_custom_target(${OPS_UTILS_INC_KERNEL_TARGET} + DEPENDS ${OPS_UTILS_INC_KERNEL_DIR} + ) + endif () + endif () + + if (NOT TARGET ${SRC_COPY_TARGET_NAME}) + set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done) + add_custom_command(OUTPUT ${_BUILD_FLAG} + COMMAND mkdir -p ${SRC_COPY_DST} + COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST} + COMMAND touch ${_BUILD_FLAG} + ) + + add_custom_target(${SRC_COPY_TARGET_NAME} + DEPENDS ${_BUILD_FLAG} + ) + endif () + + if (TARGET ${OPS_UTILS_INC_KERNEL_TARGET}) + add_dependencies(${SRC_COPY_TARGET_NAME} ${OPS_UTILS_INC_KERNEL_TARGET}) + endif () + + if (DEFINED SRC_COPY_BE_RELIED) + add_dependencies(${SRC_COPY_BE_RELIED} ${SRC_COPY_TARGET_NAME}) + endif () + +endfunction() + +function(add_bin_compile_target) + cmake_parse_arguments(BINARY "" "COMPUTE_UNIT" "OP_INFO" ${ARGN}) + + set(_INSTALL_DIR packages/vendors/${VENDOR_NAME}/op_impl/ai_core/tbe/kernel) + set(_OUT_DIR ${ASCEND_BINARY_OUT_DIR}/${BINARY_COMPUTE_UNIT}) + + set(BIN_OUT_DIR ${_OUT_DIR}/bin) + set(GEN_OUT_DIR ${_OUT_DIR}/gen) + set(SRC_OUT_DIR ${_OUT_DIR}/src) + file(MAKE_DIRECTORY ${BIN_OUT_DIR}) + + foreach(_op_info ${BINARY_OP_INFO}) + get_filename_component(_op_name "${_op_info}" NAME) + set(${_op_name}_dir ${_op_info}) + endforeach() + + set(_ops_target_list) + set(compile_scripts) + file(GLOB scripts_list ${GEN_OUT_DIR}/*.sh) + list(APPEND compile_scripts ${scripts_list}) + + foreach(bin_script ${compile_scripts}) + get_filename_component(bin_file ${bin_script} NAME_WE) + string(REPLACE "-" ";" bin_sep ${bin_file}) + list(GET bin_sep 0 op_type) + list(GET bin_sep 1 op_file) + list(GET bin_sep 2 op_index) + + if (NOT DEFINED ${op_file}_dir) + continue() + endif () + + if (NOT TARGET ${op_file}) + add_custom_target(${op_file}) + add_dependencies(ops_kernel ${op_file}) + endif () + + set(OP_TARGET_NAME ${op_file}_${BINARY_COMPUTE_UNIT}) + + if (NOT TARGET ${OP_TARGET_NAME}) + add_custom_target(${OP_TARGET_NAME}) + add_dependencies(${op_file} ${OP_TARGET_NAME}) + list(APPEND _ops_target_list ${OP_TARGET_NAME}) + + set(OP_SRC_OUT_DIR ${SRC_OUT_DIR}/${op_file}) + set(OP_BIN_OUT_DIR ${BIN_OUT_DIR}/${op_file}) + file(MAKE_DIRECTORY ${OP_SRC_OUT_DIR}) + + add_ops_src_copy( + TARGET_NAME + ${OP_TARGET_NAME}_src_copy + SRC + ${${op_file}_dir} + DST + ${OP_SRC_OUT_DIR} + COMPUTE_UNIT + ${BINARY_COMPUTE_UNIT} + ) + + if (DEFINED ${op_file}_depends) + foreach(depend_info ${${op_file}_depends}) + get_filename_component(_depend_op_name "${depend_info}" NAME) + set(_depend_op_target ${_depend_op_name}_${BINARY_COMPUTE_UNIT}_src_copy) + add_ops_src_copy( + TARGET_NAME + ${_depend_op_target} + SRC + ${CMAKE_SOURCE_DIR}/${depend_info} + DST + ${SRC_OUT_DIR}/${_depend_op_name} + COMPUTE_UNIT + ${BINARY_COMPUTE_UNIT} + BE_RELIED + ${OP_TARGET_NAME}_src_copy + ) + endforeach() + endif () + + set(DYNAMIC_PY_FILE ${OP_SRC_OUT_DIR}/${op_type}.py) + add_custom_command(OUTPUT ${DYNAMIC_PY_FILE} + COMMAND cp -rf ${ASCEND_IMPL_OUT_DIR}/dynamic/${op_file}.py ${DYNAMIC_PY_FILE} + ) + + add_custom_target(${OP_TARGET_NAME}_py_copy + DEPENDS ${DYNAMIC_PY_FILE} + ) + + add_custom_command(OUTPUT ${OP_BIN_OUT_DIR} + COMMAND mkdir -p ${OP_BIN_OUT_DIR} + ) + + add_custom_target(${OP_TARGET_NAME}_mkdir + DEPENDS ${OP_BIN_OUT_DIR} + ) + + install(DIRECTORY ${OP_BIN_OUT_DIR} + DESTINATION ${_INSTALL_DIR}/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + + install(FILES ${BIN_OUT_DIR}/${op_file}.json + DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + endif () + + set(_group "1-0") + if (DEFINED ASCEND_OP_NAME AND NOT "${ASCEND_OP_NAME}" STREQUAL "") + if (NOT "${ASCEND_OP_NAME}" STREQUAL "all" AND NOT "${ASCEND_OP_NAME}" STREQUAL "ALL") + if (${op_file} IN_LIST ASCEND_OP_NAME) + list(LENGTH ASCEND_OP_NAME _len) + list(FIND ASCEND_OP_NAME ${op_file} _index) + math(EXPR _next_index "${_index} + 1") + if (${_next_index} LESS ${_len}) + list(GET ASCEND_OP_NAME ${_next_index} _group_str) + set(_regex "^[0-9]+-[0-9]+$") + string(REGEX MATCH "${_regex}" match "${_group_str}") + if (match) + set(_group ${_group_str}) + endif () + endif () + endif () + endif () + endif () + + string(REPLACE "-" ";" _group_sep ${_group}) + + list(GET _group_sep 1 start_index) + set(end_index ${op_index}) + list(GET _group_sep 0 step) + + set(_compile_flag false) + if (${start_index} LESS ${end_index}) + foreach(i RANGE ${start_index} ${end_index} ${step}) + if (${i} EQUAL ${end_index}) + set(_compile_flag true) + break() + endif () + endforeach() + elseif (${start_index} EQUAL ${end_index}) + set(_compile_flag true) + else() + set(_compile_flag false) + endif () + + if (_compile_flag) + set(_BUILD_COMMAND) + set(_BUILD_FLAG ${GEN_OUT_DIR}/${OP_TARGET_NAME}_${op_index}.done) + if (ENABLE_OPS_HOST) + list(APPEND _BUILD_COMMAND export ASCEND_CUSTOM_OPP_PATH=${CUSTOM_DIR} &&) + endif () + list(APPEND _BUILD_COMMAND export HI_PYTHON="python3" &&) + list(APPEND _BUILD_COMMAND export TILINGKEY_PAR_COMPILE=1 &&) + list(APPEND _BUILD_COMMAND export BIN_FILENAME_HASHED=1 &&) + list(APPEND _BUILD_COMMAND bash ${bin_script} ${OP_SRC_OUT_DIR}/${op_type}.py ${OP_BIN_OUT_DIR}) + if(CMAKE_GENERATOR MATCHES "Unix Makefiles") + list(APPEND _BUILD_COMMAND && echo $(MAKE)) + endif() + + add_custom_command(OUTPUT ${_BUILD_FLAG} + COMMAND ${_BUILD_COMMAND} + COMMAND touch ${_BUILD_FLAG} + WORKING_DIRECTORY ${GEN_OUT_DIR} + ) + + add_custom_target(${OP_TARGET_NAME}_${op_index} + DEPENDS ${_BUILD_FLAG} + ) + + if (ENABLE_OPS_HOST) + add_dependencies(${OP_TARGET_NAME}_${op_index} optiling generate_ops_info) + endif () + add_dependencies(${OP_TARGET_NAME}_${op_index} ${OP_TARGET_NAME}_src_copy ${OP_TARGET_NAME}_py_copy ${OP_TARGET_NAME}_mkdir) + add_dependencies(${OP_TARGET_NAME} ${OP_TARGET_NAME}_${op_index}) + endif () + endforeach() + + if (_ops_target_list) + set(OPS_CONFIG_TARGET ops_config_${BINARY_COMPUTE_UNIT}) + set(BINARY_INFO_CONFIG_FILE ${BIN_OUT_DIR}/binary_info_config.json) + + add_custom_command(OUTPUT ${BINARY_INFO_CONFIG_FILE} + COMMAND ${HI_PYTHON} ${ASCENDC_CMAKE_UTIL_DIR}/ascendc_ops_config.py -p ${BIN_OUT_DIR} -s ${BINARY_COMPUTE_UNIT} + ) + + add_custom_target(${OPS_CONFIG_TARGET} + DEPENDS ${BINARY_INFO_CONFIG_FILE} + ) + + add_dependencies(ops_config ${OPS_CONFIG_TARGET}) + + foreach(_op_target ${_ops_target_list}) + add_dependencies(${OPS_CONFIG_TARGET} ${_op_target}) + endforeach() + + install(FILES ${BINARY_INFO_CONFIG_FILE} + DESTINATION ${_INSTALL_DIR}/config/${BINARY_COMPUTE_UNIT} OPTIONAL + ) + endif () +endfunction() + +function(redefine_file_macro) + cmake_parse_arguments(_FILE "" "" "TARGET_NAME" ${ARGN}) + + foreach(_target_name ${_FILE_TARGET_NAME}) + target_compile_options(${_target_name} PRIVATE + -Wno-builtin-macro-redefined + ) + + get_target_property(_srcs ${_target_name} SOURCES) + + foreach(_src ${_srcs}) + get_filename_component(_src_name "${_src}" NAME) + set_source_files_properties(${_src} + PROPERTIES COMPILE_DEFINITIONS __FILE__="${_src_name}" + ) + endforeach() + endforeach() +endfunction() + +function(add_static_ops) + cmake_parse_arguments(STATIC "" "SRC_DIR" "ACLNN_SRC;ACLNN_INNER_SRC" ${ARGN}) + set(prepare_ops_adv_static_target prepare_ops_adv_static) + set(static_src_temp_dir ${CMAKE_CURRENT_BINARY_DIR}/static_src_temp_dir) + set(modified_files) + foreach(ops_type ${OPS_STATIC_TYPES}) + get_target_property(all_srcs aclnn_ops_${ops_type} SOURCES) + set(add_srcs) + set(generate_aclnn_srcs) + foreach(_src ${all_srcs}) + string(REGEX MATCH "^${STATIC_SRC_DIR}" is_match "${_src}") + if (is_match) + list(APPEND add_srcs ${_src}) + endif () + endforeach() + + foreach(_src ${add_srcs}) + get_filename_component(name_without_ext ${_src} NAME_WE) + string(REGEX REPLACE "^aclnn_" "" _op_name ${name_without_ext}) + + foreach(_aclnn_src ${STATIC_ACLNN_SRC}) + get_filename_component(aclnn_name ${_aclnn_src} NAME_WE) + if("aclnn_${_op_name}" STREQUAL "${aclnn_name}") + list(APPEND generate_aclnn_srcs ${_aclnn_src}) + break() + endif() + endforeach() + + foreach(_aclnn_inner_src ${STATIC_ACLNN_INNER_SRC}) + get_filename_component(aclnn_inner_name ${_aclnn_inner_src} NAME_WE) + if("aclnnInner_${_op_name}" STREQUAL "${aclnn_inner_name}") + list(APPEND generate_aclnn_srcs ${_aclnn_inner_src}) + break() + endif() + endforeach() + endforeach() + + if(add_srcs) + list(TRANSFORM add_srcs REPLACE "${STATIC_SRC_DIR}" "${static_src_temp_dir}" OUTPUT_VARIABLE add_static_srcs) + list(APPEND modified_files ${add_static_srcs}) + set(aclnn_ops_static_target aclnn_ops_${ops_type}_static) + set_source_files_properties(${add_static_srcs} + TARGET_DIRECTORY ${aclnn_ops_static_target} + PROPERTIES GENERATED TRUE + ) + + target_sources(${aclnn_ops_static_target} PRIVATE + ${add_static_srcs} + ) + add_dependencies(${aclnn_ops_static_target} ${prepare_ops_adv_static_target}) + endif() + + if(generate_aclnn_srcs) + list(REMOVE_DUPLICATES generate_aclnn_srcs) + set(aclnn_op_target acl_op_${ops_type}_builtin) + set_source_files_properties(${generate_aclnn_srcs} + TARGET_DIRECTORY ${aclnn_op_target} + PROPERTIES GENERATED TRUE + ) + + target_sources(${aclnn_op_target} PRIVATE + ${generate_aclnn_srcs} + ) + endif() + endforeach() + + if(NOT TARGET ${prepare_ops_adv_static_target}) + list(REMOVE_DUPLICATES modified_files) + add_custom_command(OUTPUT ${static_src_temp_dir} + COMMAND mkdir -p ${static_src_temp_dir} + COMMAND cp -rf ${STATIC_SRC_DIR}/src ${static_src_temp_dir} + COMMAND ${HI_PYTHON} -B ${OPS_STATIC_SCRIPT} InsertIni -p ${static_src_temp_dir} -f ${modified_files} + ) + + add_custom_target(${prepare_ops_adv_static_target} + DEPENDS ${static_src_temp_dir} + ) + endif() +endfunction() + +if (BUILD_OPEN_PROJECT) + if (TESTS_UT_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/func_utest.cmake) + endif () + if (TESTS_EXAMPLE_OPS_TEST) + include(${OPS_ADV_CMAKE_DIR}/func_examples.cmake) + endif () +endif () diff --git a/csrc/cmake/intf.cmake b/csrc/cmake/intf.cmake new file mode 100644 index 00000000000..20c63563cbd --- /dev/null +++ b/csrc/cmake/intf.cmake @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +if (BUILD_OPEN_PROJECT) + include(${OPS_ADV_CMAKE_DIR}/intf_pub.cmake) +endif () diff --git a/csrc/cmake/intf_pub.cmake b/csrc/cmake/intf_pub.cmake new file mode 100644 index 00000000000..4856aeef400 --- /dev/null +++ b/csrc/cmake/intf_pub.cmake @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +# Custom package scenario, public compilation configuration for Host side targets +# Note: To ensure compatibility with the built-in package compilation process, the intf_pub name cannot be changed +add_library(intf_pub INTERFACE) +target_include_directories(intf_pub + INTERFACE + ${ASCEND_CANN_PACKAGE_PATH}/include + ${ASCEND_CANN_PACKAGE_PATH}/include/external + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/platform + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/runtime + ${ASCEND_CANN_PACKAGE_PATH}/include/experiment/msprof +) +target_link_directories(intf_pub + INTERFACE + ${ASCEND_CANN_PACKAGE_PATH}/lib64 +) +target_compile_options(intf_pub + INTERFACE + -fPIC + -O2 + -Wall -Wundef -Wcast-qual -Wpointer-arith -Wdate-time + -Wfloat-equal -Wformat=2 -Wshadow + -Wsign-compare -Wunused-macros -Wvla -Wdisabled-optimization -Wempty-body -Wignored-qualifiers + -Wimplicit-fallthrough=3 -Wtype-limits -Wshift-negative-value -Wswitch-default + -Wframe-larger-than=32768 -Woverloaded-virtual + -Wnon-virtual-dtor -Wshift-overflow=2 -Wshift-count-overflow + -Wwrite-strings -Wmissing-format-attribute -Wformat-nonliteral + -Wdelete-non-virtual-dtor -Wduplicated-cond + -Wtrampolines -Wsized-deallocation -Wlogical-op -Wsuggest-attribute=format + -Wduplicated-branches + -Wmissing-include-dirs -Wformat-signedness + -Wreturn-local-addr -Wextra + -Wredundant-decls -Wfloat-conversion + -Wno-write-strings -Wall -Wno-dangling-else -Wno-comment -Wno-conversion-null -Wno-return-type + -Wno-unknown-pragmas -Wno-sign-compare + -Wno-error=undef + -Wno-error=comment + -Wno-error=conversion-null + -Wno-error=dangling-else + -Wno-error=return-type + -Wno-error=shadow + -Wno-error=sign-compare + -Wno-error=unknown-pragmas + -Wno-error=unused-parameter + -Wno-error=cast-qual + -Wno-error=format= + -Wno-error=maybe-uninitialized + -Wno-error=missing-field-initializers + -Wno-error=redundant-decls + -Wno-error=unused-variable + $<$:-Wnested-externs> + $<$:-g> + $,-fstack-protector-strong,-fstack-protector-all> +) +target_compile_definitions(intf_pub + INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:_FORTIFY_SOURCE=2> +) +target_link_options(intf_pub + INTERFACE + $<$,EXECUTABLE>:-pie> + $<$:-s> + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack +) diff --git a/csrc/cmake/modules/Findalog.cmake b/csrc/cmake/modules/Findalog.cmake new file mode 100644 index 00000000000..d016a50ef94 --- /dev/null +++ b/csrc/cmake/modules/Findalog.cmake @@ -0,0 +1,113 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +if (alog_FOUND) + message(STATUS "Package alog has been found.") + return() +endif() + +set(_cmake_targets_defined "") +set(_cmake_targets_not_defined "") +set(_cmake_expected_targets "") +foreach(_cmake_expected_target IN ITEMS slog alog alog_headers) + list(APPEND _cmake_expected_targets "${_cmake_expected_target}") + if(TARGET "${_cmake_expected_target}") + list(APPEND _cmake_targets_defined "${_cmake_expected_target}") + else() + list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") + endif() +endforeach() +unset(_cmake_expected_target) + +if(_cmake_targets_defined STREQUAL _cmake_expected_targets) + unset(_cmake_targets_defined) + unset(_cmake_targets_not_defined) + unset(_cmake_expected_targets) + unset(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() + +if(NOT _cmake_targets_defined STREQUAL "") + string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") + string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") +endif() +unset(_cmake_targets_defined) +unset(_cmake_targets_not_defined) +unset(_cmake_expected_targets) + +find_path(_INCLUDE_DIR + NAMES base/alog_pub.h + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +find_library(slog_SHARED_LIBRARY + NAMES libascendalog.so + PATH_SUFFIXES lib64 + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +find_library(alog_SHARED_LIBRARY + NAMES libascendalog.so + PATH_SUFFIXES lib64 + NO_CMAKE_SYSTEM_PATH + NO_CMAKE_FIND_ROOT_PATH) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(alog + FOUND_VAR + alog_FOUND + REQUIRED_VARS + _INCLUDE_DIR + slog_SHARED_LIBRARY + alog_SHARED_LIBRARY +) + +if(alog_FOUND) + set(alog_INCLUDE_DIR "${_INCLUDE_DIR}") + include(CMakePrintHelpers) + message(STATUS "Variables in alog module:") + cmake_print_variables(alog_INCLUDE_DIR) + cmake_print_variables(slog_SHARED_LIBRARY) + cmake_print_variables(alog_SHARED_LIBRARY) + + add_library(slog SHARED IMPORTED) + set_target_properties(slog PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG" + INTERFACE_LINK_LIBRARIES "alog_headers" + IMPORTED_LOCATION "${slog_SHARED_LIBRARY}" + ) + + add_library(alog SHARED IMPORTED) + set_target_properties(alog PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "LOG_CPP;PROCESS_LOG" + INTERFACE_LINK_LIBRARIES "alog_headers" + IMPORTED_LOCATION "${alog_SHARED_LIBRARY}" + ) + + add_library(alog_headers INTERFACE IMPORTED) + set_target_properties(alog_headers PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${alog_INCLUDE_DIR}" + ) + + include(CMakePrintHelpers) + cmake_print_properties(TARGETS slog + PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION + ) + cmake_print_properties(TARGETS alog + PROPERTIES INTERFACE_COMPILE_DEFINITIONS INTERFACE_LINK_LIBRARIES IMPORTED_LOCATION + ) + cmake_print_properties(TARGETS alog_headers + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ) +endif() + +# Cleanup temporary variables. +set(_INCLUDE_DIR) diff --git a/csrc/cmake/scripts/prepare.sh b/csrc/cmake/scripts/prepare.sh new file mode 100644 index 00000000000..dac9da76862 --- /dev/null +++ b/csrc/cmake/scripts/prepare.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +CPU_NUM=$(($(cat /proc/cpuinfo | grep "^processor" | wc -l)*2)) +JOB_NUM="-j${CPU_NUM}" + +while [[ $# -gt 0 ]]; do + case $1 in + -s) + PATH_TO_SOURCE="$2" + shift 2 + ;; + -b) + PATH_TO_BUILD="$2" + shift 2 + ;; + -p) + ASCEND_CANN_PACKAGE_PATH="$2" + shift 2 + ;; + --autogen-dir) + ASCEND_AUTOGEN_DIR="$2" + shift 2 + ;; + --build-open-project) + BUILD_OPEN_PROJECT="$2" + shift 2 + ;; + --binary-out-dir) + ASCEND_BINARY_OUT_DIR="$2" + shift 2 + ;; + --impl-out-dir) + ASCEND_IMPL_OUT_DIR="$2" + shift 2 + ;; + --op-build-tool) + OP_BUILD_TOOL="$2" + shift 2 + ;; + --ascend-cmake-dir) + ASCEND_CMAKE_DIR="$2" + shift 2 + ;; + --tiling-key) + TILING_KEY="$2" + shift 2 + ;; + --ops-compile-options) + OPS_COMPILE_OPTIONS="$2" + shift 2 + ;; + --check-compatible) + CHECK_COMPATIBLE="$2" + shift 2 + ;; + --ascend-compute_unit) + ASCEND_COMPUTE_UNIT="$2" + shift 2 + ;; + --ascend-op-name) + ASCEND_OP_NAME="$2" + shift 2 + ;; + --op_debug_config) + OP_DEBUG_CONFIG="$2" + shift 2 + ;; + *) + break + ;; + esac +done + +function clean() { + if [ -n "${PATH_TO_BUILD}" ];then + rm -rf ${PATH_TO_BUILD} + mkdir -p ${PATH_TO_BUILD} + fi +} + +function convert_string() { + local _input=$1 + _output=$(echo $_input | sed 's/::/;/g') + echo "${_output}" +} + +function set_env() { + CONVERT_TILING_KEY="$(convert_string ${TILING_KEY})" + + CONVERT_OPS_COMPILE_OPTIONS="$(convert_string ${OPS_COMPILE_OPTIONS})" + + CONVERT_ASCEND_COMPUTE_UNIT="$(convert_string ${ASCEND_COMPUTE_UNIT})" +} + +function build() { + cd ${PATH_TO_BUILD} + cmake ${PATH_TO_SOURCE} \ + -DBUILD_OPEN_PROJECT=${BUILD_OPEN_PROJECT} \ + -DPREPARE_BUILD=ON \ + -DCUSTOM_ASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} \ + -DASCEND_AUTOGEN_DIR=${ASCEND_AUTOGEN_DIR} \ + -DASCEND_BINARY_OUT_DIR=${ASCEND_BINARY_OUT_DIR} \ + -DASCEND_IMPL_OUT_DIR=${ASCEND_IMPL_OUT_DIR} \ + -DOP_BUILD_TOOL=${OP_BUILD_TOOL} \ + -DASCEND_CMAKE_DIR=${ASCEND_CMAKE_DIR} \ + -DCHECK_COMPATIBLE=${CHECK_COMPATIBLE} \ + -DTILING_KEY="${CONVERT_TILING_KEY}" \ + -DOPS_COMPILE_OPTIONS="${CONVERT_OPS_COMPILE_OPTIONS}" \ + -DASCEND_COMPUTE_UNIT=${CONVERT_ASCEND_COMPUTE_UNIT} \ + -DOP_DEBUG_CONFIG=${OP_DEBUG_CONFIG} \ + -DASCEND_OP_NAME=${ASCEND_OP_NAME} + + make ${JOB_NUM} prepare_build +} + +function main() { + clean + set_env + build +} + +main diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt new file mode 100644 index 00000000000..44fecd39d85 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME GroupedMatmulSwigluQuantWeightNzTensorList + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnExc PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp +) + +target_sources(opapi PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + ) +endif () + +target_sources(optiling PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000000..d5992610a80 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include "aclnn_kernels/contiguous.h" +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "opdev/common_types.h" +#include "opdev/data_type_utils.h" +#include "opdev/format_utils.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/platform.h" +#include "opdev/shape_utils.h" +#include "opdev/tensor_view_utils.h" +#include "opdev/make_op_executor.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" +#include "aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" + +using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +static constexpr int64_t SPLIT = 2; +static constexpr int64_t K_LIMIT = 65536; +static constexpr int64_t N_LIMIT = 4096; +static constexpr int64_t NZ_DIM_3 = 32; +static constexpr int64_t NZ_DIM_2 = 16; +static constexpr int64_t OUTPUT_IDX_0 = 0; +static constexpr int64_t OUTPUT_IDX_1 = 1; +static constexpr size_t X_DIM_LIMIT = 2; +static constexpr size_t WEIGHT_ND_DIM_LIMIT = 2; +static constexpr size_t WEIGHT_NZ_DIM_LIMIT = 4; +static constexpr size_t WEIGHT_SCALE_DIM_LIMIT = 1; +static constexpr size_t TOKEN_SCALE_DIM_LIMIT = 1; +static constexpr size_t GROUP_LIST_DIM_LIMIT = 1; +static constexpr size_t QUANTOUT_DIM_LIMIT = 2; +static constexpr size_t QUANTSCALEOUT_DIM_LIMIT = 1; + +static const std::initializer_list X_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list WEIGHT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list WEIGHT_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16}; +static const std::initializer_list X_SCALE_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT, DataType::DT_FLOAT16, DataType::DT_BF16}; +static const std::initializer_list GROUP_LIST_DTYPE_SUPPORT_LIST = {DataType::DT_INT64}; +static const std::initializer_list QUANTOUT_DTYPE_SUPPORT_LIST = {DataType::DT_INT8}; +static const std::initializer_list QUANTSCALEOUT_DTYPE_SUPPORT_LIST = {DataType::DT_FLOAT}; + +static bool CheckNotNull(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset, + const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) +{ + OP_CHECK_NULL(x, return false); + OP_CHECK_NULL(weight, return false); + OP_CHECK_NULL(weightScale, return false); + OP_CHECK_NULL(xScale, return false); + OP_CHECK_NULL(groupList, return false); + OP_CHECK_NULL(output, return false); + OP_CHECK_NULL(outputScale, return false); + if (bias != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where bias is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + if (offset != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where offset is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + if (outputOffset != nullptr) { + OP_LOGW("aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario where outputOffset is not 0. " + "Features and accuracy are not guaranteed if inputting bias with values other than 0s."); + } + return true; +} + +static bool CheckInputOutDims(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + OP_CHECK_WRONG_DIMENSION(x, X_DIM_LIMIT, return false); + op::Format weightViewFormat = (*weight)[0]->GetViewFormat(); + if (IsPrivateFormat(weightViewFormat)){ + OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_NZ_DIM_LIMIT, return false); + } else { + OP_CHECK_WRONG_DIMENSION((*weight)[0], WEIGHT_ND_DIM_LIMIT, return false); + } + OP_CHECK_WRONG_DIMENSION((*weightScale)[0], WEIGHT_SCALE_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(xScale, TOKEN_SCALE_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(groupList, GROUP_LIST_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(output, QUANTOUT_DIM_LIMIT, return false); + OP_CHECK_WRONG_DIMENSION(outputScale, QUANTSCALEOUT_DIM_LIMIT, return false); + return true; +} + +static bool CheckInputOutShape(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + int64_t m = x->GetViewShape().GetDim(0); + int64_t k = x->GetViewShape().GetDim(1); + int64_t n = (*weightScale)[0]->GetViewShape().GetDim(0); + int64_t e = weight->Size(); + if (n % SPLIT != 0){ + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, N is %ld , not an even number.", n); + return false; + } + int64_t nAfterHalve = static_cast(n / SPLIT); + // x shape is expected to be [M, K] + op::Shape xExpectShape = {m, k}; + // The ND shape of each weight in TensorList is expected to be [K, N] + op::Shape weightNDExpectShape = {k, n}; + // The NZ shape of each weight in TensorList is expected to be [N // 32, K // 16, 16, 32] + op::Shape weightNZExpectShape = {static_cast(n / NZ_DIM_3), + static_cast(k / NZ_DIM_2), + NZ_DIM_2, NZ_DIM_3}; + // weightScale shape is expected to be [N] + op::Shape weightScaleExpectShape = {n}; + // xScale shape is expected to be [E, N] + op::Shape xScaleExpectShape = {m}; + // output shape is expected to be [M, N] + op::Shape outputExpectShape = {m, nAfterHalve}; + // outputScale shape is expected to be [M] + op::Shape outputScaleExpectShape = {m}; + for (size_t i = 0; i < weight->Size(); ++i) { + op::Format weightViewFormat = (*weight)[i]->GetViewFormat(); + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(x, xExpectShape, return false); + if (IsPrivateFormat(weightViewFormat)){ + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNZExpectShape, return false); + } else { + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weight)[i], weightNDExpectShape, return false); + } + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE((*weightScale)[i], weightScaleExpectShape, return false); + } + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(xScale, xScaleExpectShape, return false); + + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(output, outputExpectShape, return false); + OP_CHECK_SHAPE_NOT_EQUAL_WITH_EXPECTED_SIZE(outputScale, outputScaleExpectShape, return false); + // The length of groupList should be less than or equal to the number of experts in weight + int64_t groupListLen = groupList->GetViewShape().GetDim(0); + if(groupListLen > e) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, Length of 'groupList' out of range (expected to be in range of [1, %ld], but got %ld)", + e, groupListLen); + return false; + } + if(nAfterHalve > N_LIMIT) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + where N after halve is %ld greater than %ld.", + nAfterHalve, N_LIMIT); + return false; + } + if(k >= K_LIMIT) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + The tail axis dimension of input0(x) is %ld, which need lower than %ld.", + k, K_LIMIT); + return false; + } + return true; +} + +static bool CheckDtypeValid(const aclTensor* x, const aclTensorList* weight, const aclTensorList* weightScale, + const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale) +{ + OP_CHECK_DTYPE_NOT_SUPPORT(x, X_DTYPE_SUPPORT_LIST, return false); + for (size_t i = 0; i < weight->Size(); ++i) { + OP_CHECK_DTYPE_NOT_SUPPORT((*weight)[i], WEIGHT_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT((*weightScale)[i], WEIGHT_SCALE_DTYPE_SUPPORT_LIST, return false); + } + OP_CHECK_DTYPE_NOT_SUPPORT(xScale, X_SCALE_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(groupList, GROUP_LIST_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(output, QUANTOUT_DTYPE_SUPPORT_LIST, return false); + OP_CHECK_DTYPE_NOT_SUPPORT(outputScale, QUANTSCALEOUT_DTYPE_SUPPORT_LIST, return false); + return true; +} + +static bool CheckFormat(const aclTensor* x, const aclTensorList* weight, const aclTensor* output) +{ + bool isNZ = (*weight)[0]->GetStorageFormat() == op::Format::FORMAT_FRACTAL_NZ; + if (!isNZ) { + // fp16 in fp32 out that is split k template, not precision-advanced now + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + weight Format expect is FRACTAL_NZ, but got [%s].", op::ToString((*weight)[0]->GetStorageFormat()).GetString()); + return false; + } + if (IsPrivateFormat(x->GetStorageFormat())) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + x Format Not support Private Format."); + return false; + } + if (IsPrivateFormat(output->GetStorageFormat())) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "aclnnGroupedMatmulSwigluQuantWeightNzTensorList, The current version does not support the scenario.\ + output Format Not support Private Format."); + return false; + } + return true; +} + +static aclnnStatus CheckParams(const aclTensor* x, const aclTensorList* weight, const aclTensor* bias, const aclTensor* offset, + const aclTensorList* weightScale, const aclTensor* xScale, const aclTensor* groupList, + const aclTensor* output, const aclTensor* outputScale, const aclTensor* outputOffset) { + // 1. Check if parameters are null pointers + CHECK_RET(CheckNotNull(x, weight, bias, offset, weightScale, xScale, + groupList, output, outputScale, outputOffset), ACLNN_ERR_PARAM_NULLPTR); + + // 2. Verify input and output parameter dimensions + CHECK_RET(CheckInputOutDims(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 3. Verify input and output shape parameters + CHECK_RET(CheckInputOutShape(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 4. Check if the input data types are within the supported data type range + CHECK_RET(CheckDtypeValid(x, weight, weightScale, xScale, + groupList, output, outputScale), ACLNN_ERR_PARAM_INVALID); + + // 5. Check if data format is supported + CHECK_RET(CheckFormat(x, weight, output), ACLNN_ERR_PARAM_INVALID); + + return ACLNN_SUCCESS; +} + +static aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(const aclTensor *x, const aclTensorList *weight, + const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, + const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, + aclTensor *outputOffset, uint64_t *workspaceSize, + aclOpExecutor **executor){ + // Fixed pattern, create OpExecutor + auto uniqueExecutor = CREATE_EXECUTOR(); + CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + // Fixed pattern, parameter check + auto ret = CheckParams(x, weight, bias, offset, weightScale, xScale, + groupList, output, outputScale, outputOffset); + CHECK_RET(ret == ACLNN_SUCCESS, ret); + // Empty tensor scenario + if (output->IsEmpty() || groupList->IsEmpty() || outputScale->IsEmpty()) { + *workspaceSize = 0; + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; + } + // Convert to contiguous + x = l0op::Contiguous(x, uniqueExecutor.get()); + CHECK_RET(x != nullptr, ACLNN_ERR_INNER_NULLPTR); + for (size_t i = 0; i < weight->Size(); ++i) { + (*weight)[i]->SetOriginalShape((*weight)[i]->GetViewShape()); + } + xScale = l0op::Contiguous(xScale, uniqueExecutor.get()); + CHECK_RET(xScale != nullptr, ACLNN_ERR_INNER_NULLPTR); + groupList = l0op::Contiguous(groupList, uniqueExecutor.get()); + CHECK_RET(groupList != nullptr, ACLNN_ERR_INNER_NULLPTR); + // Call L0 operator capability + auto ret_0 = l0op::GroupedMatmulSwigluQuantWeightNzTensorList(x, weight, weightScale, xScale, groupList, uniqueExecutor.get()); + CHECK_RET(ret_0 != std::tuple(nullptr, nullptr), ACLNN_ERR_INNER_NULLPTR); + auto out0 = std::get(ret_0); + auto ret_1 = l0op::ViewCopy(out0, output, uniqueExecutor.get()); + CHECK_RET(ret_1 != nullptr, ACLNN_ERR_INNER_NULLPTR); + auto out1 = std::get(ret_0); + auto ret_2 = l0op::ViewCopy(out1, outputScale, uniqueExecutor.get()); + CHECK_RET(ret_2 != nullptr, ACLNN_ERR_INNER_NULLPTR); + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; +} + +aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize(const aclTensor *x, const aclTensorList *weight, + const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, + const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, + aclTensor *outputOffset, uint64_t *workspaceSize, + aclOpExecutor **executor) { + OP_CHECK_COMM_INPUT(workspaceSize, executor); + L2_DFX_PHASE_1(aclnnGroupedMatmulSwigluQuantWeightNzTensorList, + DFX_IN(x, weight, bias, offset, weightScale, xScale, groupList), + DFX_OUT(output, outputScale, outputOffset)); + // weight is forcibly bound to StorageFormat and ViewFormat as NZ in this scenario + CHECK_RET(weight != nullptr, ACLNN_ERR_PARAM_NULLPTR); + for (size_t i = 0; i < weight->Size(); ++i) { + auto storgeShape = (*weight)[i]->GetStorageShape(); + auto viewShape = (*weight)[i]->GetViewShape(); + aclTensor* weightNZ = const_cast((*weight)[i]); + CHECK_COND((storgeShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT), + ACLNN_ERR_PARAM_INVALID, + "aclnnGroupedMatmulSwigluQuantWeightNZTensorList, The dimnum of storageShape for second input (weight) \ + must be 4. \n But StorageShape got %s , and dimNum is %lu.", + op::ToString(storgeShape).GetString(), storgeShape.GetDimNum()); + // The StorageFormat of weight is unconditionally regarded as NZ + weightNZ->SetStorageFormat(op::Format::FORMAT_FRACTAL_NZ); + if (viewShape.GetDimNum() == WEIGHT_NZ_DIM_LIMIT){ + // If the viewShape of weight is 4-dimensional, it is regarded as NZ + weightNZ->SetViewFormat(op::Format::FORMAT_FRACTAL_NZ); + } else if (viewShape.GetDimNum() == WEIGHT_ND_DIM_LIMIT){ + // If the viewShape of weight is 2-dimensional, it is regarded as ND + weightNZ->SetViewFormat(op::Format::FORMAT_ND); + } + } + // Call the common interface + return aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSizeCommon(x, weight, bias, offset, weightScale, xScale, groupList, + output, outputScale, outputOffset, workspaceSize, executor); +} + +aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) { + L2_DFX_PHASE_2(aclnnGroupedMatmulSwigluQuantWeightNzTensorList); + CHECK_COND(CommonOpExecutorRun(workspace, workspaceSize, executor, stream) == ACLNN_SUCCESS, ACLNN_ERR_INNER, + "This is an error in GroupedMatmulSwigluQuantWeightNzTensorList launch aicore"); + return ACLNN_SUCCESS; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000000..407f27f480e --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/aclnn_grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define OP_API_INC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief The first interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, which calculates the workspace size according to the specific calculation process. + * @domain aclnn_ops_infer + * + * @param [in] x: Represents x in the formula. The data type supports INT8, and the data format supports ND. + * @param [in] weight: + * Represents weight in the formula. The data type supports INT8, and the data format supports NZ. + * @param [in] weightScale: Represents quantization parameters. The data type supports FLOAT16, BFLOAT16, and FLOAT32. The data format supports ND, with a maximum length of 128. + * Represents per Channel parameters. The data type supports FLOAT16 and BFLOAT16. The data format supports ND. + * @param [in] xScale: + * Represents per Token quantization parameters. The data type supports FLOAT32, and the data format supports ND. + * @param [in] groupList: Required parameter, representing the index situation on the input and output grouping axes. The data type supports INT64. + * @param [out] quantOutput: Represents out in the formula. The data type supports INT8, and the data format supports ND. + * @param [out] quantScaleOutput: Represents outQuantScale in the formula. The data type supports Float32. + * @param [out] workspaceSize: Returns the workspace size that users need to apply for on the npu device side. + * @param [out] executor: Returns the op executor, containing the operator calculation process. + * @return aclnnStatus: Returns the status code. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize( + const aclTensor *x, const aclTensorList *weight, const aclTensor *bias, const aclTensor *offset, + const aclTensorList *weightScale, const aclTensor *xScale, const aclTensor *groupList, + aclTensor *output, aclTensor *outputScale, aclTensor *outputOffset, uint64_t *workspaceSize, aclOpExecutor **executor); + +/** + * @brief The second interface of aclnnGroupedMatmulSwigluQuantWeightNzTensorList, used to execute calculations. + * @param [in] workspace: The starting address of the workspace memory applied for on the npu device side. + * @param [in] workspaceSize: The workspace size applied for on the npu device side, obtained from the first interface aclnnGroupedMatmulSwigluQuantWeightNzTensorListGetWorkspaceSize. + * @param [in] stream: acl stream. + * @param [in] executor: op executor, containing the operator calculation process. + * @return aclnnStatus: Returns the status code. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnGroupedMatmulSwigluQuantWeightNzTensorList(void* workspace, + uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000000..9181552b1c4 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "opdev/op_log.h" +#include "opdev/op_dfx.h" +#include "opdev/make_op_executor.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" + +using namespace op; + +namespace l0op { +OP_TYPE_REGISTER(GroupedMatmulSwigluQuantWeightNzTensorList); + +const std::tuple GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x, + const aclTensorList *weight, + const aclTensorList *perChannelScale, + const aclTensor *perTokenScale, + const aclTensor *groupList, + aclOpExecutor *executor) { + L0_DFX(GroupedMatmulSwigluQuantWeightNzTensorList, x, weight, perChannelScale, perTokenScale, groupList); + if (x == nullptr) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "x is nullptr."); + return std::tuple(nullptr, nullptr); + } + int64_t m = perTokenScale->GetViewShape().GetDim(0); + int64_t n = (*perChannelScale)[0]->GetViewShape().GetDim(0); + int64_t nAfterHalve = static_cast(n / 2); + gert::Shape outShape({m, nAfterHalve}); + gert::Shape scaleOutShape({m}); + auto out = executor->AllocTensor(outShape, DataType::DT_INT8, ge::FORMAT_ND); + auto scaleOut = executor->AllocTensor(scaleOutShape, DataType::DT_FLOAT, ge::FORMAT_ND); + auto ret = INFER_SHAPE(GroupedMatmulSwigluQuantWeightNzTensorList, + OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList), + OP_OUTPUT(out, scaleOut)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "InferShape failed."); + return std::tuple(nullptr, nullptr); + } + ret = ADD_TO_LAUNCHER_LIST_AICORE(GroupedMatmulSwigluQuantWeightNzTensorList, + OP_INPUT(x, weight, perChannelScale, perTokenScale, groupList), + OP_OUTPUT(out, scaleOut)); + if (ret != ACLNN_SUCCESS) { + OP_LOGE(ACLNN_ERR_PARAM_INVALID, "ADD_TO_LAUNCHER_LIST_AICORE failed."); + return std::tuple(nullptr, nullptr); + } + return std::tie(out, scaleOut); +} + +} // namespace l0op diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000000..f47ad8a88f7 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H +#define OP_API_INC_LEVEL0_OP_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_OP_H + +#include "opdev/op_executor.h" + +namespace l0op { +const std::tuple GroupedMatmulSwigluQuantWeightNzTensorList(const aclTensor *x, + const aclTensorList *weight, + const aclTensorList *perChannelScale, + const aclTensor *perTokenScale, + const aclTensor *groupList, + aclOpExecutor *executor); +} + +#endif diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp new file mode 100644 index 00000000000..bd7a80b1d96 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_def.cpp + * \brief + */ + +#include +#include "register/op_def_registry.h" +namespace ops { +class GroupedMatmulSwigluQuantWeightNzTensorList : public OpDef { +public: + explicit GroupedMatmulSwigluQuantWeightNzTensorList(const char* name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("weight") + .ParamType(DYNAMIC) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("weight_scale") + .ParamType(DYNAMIC) + .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("x_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Input("group_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64,ge::DT_INT64,ge::DT_INT64}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8,ge::DT_INT8,ge::DT_INT8}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + this->Output("y_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT,ge::DT_FLOAT,ge::DT_FLOAT}) + .Format({ge::FORMAT_ND,ge::FORMAT_ND,ge::FORMAT_ND}); + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true); + + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; + +OP_ADD(GroupedMatmulSwigluQuantWeightNzTensorList); +} diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp new file mode 100644 index 00000000000..5e3d44320a5 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_proto.cpp + * \brief + */ +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "platform/platform_info.h" + +using namespace ge; +namespace ops { +const int64_t X_INDEX = 0; +const int64_t WEIGHTSCALE_INDEX = 2; +const int64_t M_DIM_INDEX = 0; +const int64_t N_DIM_INDEX = 0; +static ge::graphStatus InferShape4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferShapeContext* context) { + const gert::Shape* xShape = context->GetInputShape(X_INDEX); + const gert::Shape* weightScaleShape = context->GetDynamicInputShape(WEIGHTSCALE_INDEX, 0); + int64_t m = xShape->GetDim(M_DIM_INDEX); + int64_t n = static_cast(weightScaleShape->GetDim(N_DIM_INDEX) / 2); + auto outShape = context->GetOutputShape(0); + outShape->SetDimNum(2); + outShape->SetDim(0, m); + outShape->SetDim(1, n); + auto outScaleShape = context->GetOutputShape(1); + outScaleShape->SetDimNum(1); + outScaleShape->SetDim(0, m); + return GRAPH_SUCCESS; +} + +static graphStatus InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList(gert::InferDataTypeContext* context) { + context->SetOutputDataType(0, DataType::DT_INT8); + context->SetOutputDataType(1, DataType::DT_FLOAT); + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(GroupedMatmulSwigluQuantWeightNzTensorList) + .InferShape(InferShape4GroupedMatmulSwigluQuantWeightNzTensorList) + .InferDataType(InferDataType4GroupedMatmulSwigluQuantWeightNzTensorList); +} // namespace ops diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp new file mode 100644 index 00000000000..f34959398ba --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.cpp + * \brief + */ +#include +#include +#include "register/op_impl_registry.h" +#include "log/ops_log.h" +#include "error/ops_error.h" +#include "tiling/tiling_base.h" +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h" +using namespace ge; +using namespace AscendC; +using namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling; + +template +static T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +namespace optiling { + +struct GMMSwigluCompileInfo { + uint64_t ubSize_ = 0; + uint32_t aicNum_ = 0; + uint32_t baseM_ = 128; + uint32_t baseN_ = 256; +}; + +static uint64_t CalcMaxTmpSize(const uint32_t row, const uint64_t n) { + std::vector shape_vec = {static_cast(row * n)}; + Shape shape(shape_vec); + uint32_t max; + uint32_t min; + GetSwiGLUMaxMinTmpSize(shape, 4, max, min, false); + uint32_t averageTmp = (max + min) >> 1; + GetAscendQuantMaxMinTmpSize(shape, 4, max, min); + uint32_t average = (max + min) >> 1; + average = average > averageTmp ? average : averageTmp; + GetAscendDequantMaxMinTmpSize(shape, 4, max, min); + averageTmp = (max + min) >> 1; + return average > averageTmp ? average : averageTmp; +} + +static uint64_t CalRows(const uint64_t ubSize, const uint64_t n) { + uint64_t tokenSize = n << 2; + uint64_t expectSize = ubSize - tokenSize; + uint64_t rows = expectSize / (8 + tokenSize); + uint64_t realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n); + while (expectSize < realSize) { + rows -= CeilDiv(realSize - expectSize, (8 + tokenSize) << 2); + realSize = (8 + tokenSize) * rows + CalcMaxTmpSize(rows, n); + } + return rows; +} + +static void SetTilingKey(gert::TilingContext* context, bool isSplitWorkSpace) { + if(isSplitWorkSpace){ + context->SetTilingKey(1); + context->SetScheduleMode(BATCH_MODE_SCHEDULE); + } else { + context->SetTilingKey(0); + context->SetScheduleMode(BATCH_MODE_SCHEDULE); + } +} + +static bool IsPreFill(GMMSwigluQuantTilingData &tilingData) { + int64_t k = tilingData.gmmSwigluBaseParams.get_K(); + int64_t n = tilingData.gmmSwigluBaseParams.get_N(); + int64_t m = tilingData.gmmSwigluBaseParams.get_M(); + int64_t groupNum = tilingData.gmmSwigluBaseParams.get_groupNum(); + if (groupNum == 128 && m >= PREFILL_M_MIN_SIZE) { // 128:prefiling groupNum + std::array kNList = {k, n}; // 2: kNList size + if (PREFILL_WHITE_LIST.count(kNList)) { + return true; + } + } + return false; +} + +ASCENDC_EXTERN_C graphStatus TilingGMMSwigluQuant(gert::TilingContext* context) { + // set info + OPS_LOG_I(context->GetNodeName(), "Begin Run GMM Swiglu Tiling ."); + + auto compileInfoPtr = context->GetCompileInfo(); + auto xTensor = context->GetInputTensor(X_INDEX); + OPS_LOG_E_IF_NULL(context, xTensor, return GRAPH_FAILED); + const int64_t m = xTensor->GetStorageShape().GetDim(0); + const int64_t k = xTensor->GetStorageShape().GetDim(1); + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + OPS_LOG_E_IF_NULL(context, wTensor, return GRAPH_FAILED); + const int64_t n = wTensor->GetStorageShape().GetDim(0) * wTensor->GetStorageShape().GetDim(3); + auto groupListTensor = context->GetDynamicInputTensor(GROUPLIST_INDEX, 0); + OPS_LOG_E_IF_NULL(context, groupListTensor, return GRAPH_FAILED); + const int64_t groupNum = groupListTensor->GetStorageShape().GetDim(0); + GMMSwigluQuantTilingData tilingData; + const int64_t row = CalRows(compileInfoPtr->ubSize_, n); + tilingData.gmmSwigluBaseParams.set_groupNum(groupNum); + tilingData.gmmSwigluBaseParams.set_coreNum(compileInfoPtr->aicNum_); + tilingData.gmmSwigluBaseParams.set_K(k); + tilingData.gmmSwigluBaseParams.set_N(n); + tilingData.gmmSwigluBaseParams.set_M(m); + tilingData.gmmSwiglu.set_maxProcessRowNum(row); + tilingData.gmmSwiglu.set_groupListLen(groupNum); + tilingData.gmmSwiglu.set_tokenLen(n); + + OPS_LOG_D(context->GetNodeName(),"grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling."); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.groupNum: %ld", groupNum); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.coreNum: %u ", compileInfoPtr->aicNum_); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.M: %ld", m); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.K: %ld", k); + OPS_LOG_D(context->GetNodeName(),"gmmSwigluBaseParams.N: %ld", n); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.maxProcessRowNum: %ld", row); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.groupListLen: %ld", groupNum); + OPS_LOG_D(context->GetNodeName(),"gmmSwiglu.tokenLen: %ld", n); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + using namespace matmul_tiling; + MatmulApiTiling tiling(ascendcPlatform); + tiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT8); + tiling.SetBType(TPosition::GM, CubeFormat::NZ, matmul_tiling::DataType::DT_INT8); + tiling.SetCType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_INT32); + tiling.SetBias(false); + tiling.SetShape(compileInfoPtr->baseM_, compileInfoPtr->baseN_, k); + tiling.SetOrgShape(m, n, k); + tiling.SetBufferSpace(-1, -1, -1); + OPS_ERR_IF(tiling.GetTiling(tilingData.mmTilingData) == -1, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling, get tiling failed"), + return GRAPH_FAILED); + auto workspaceSizes = context->GetWorkspaceSizes(1); + bool isPreFill = IsPreFill(tilingData); + tilingData.gmmSwigluBaseParams.set_isPreFill(isPreFill); + int64_t usrWorkspaceLimut = isPreFill ? PREFILL_USER_WORKSPACE_LIMIT : USER_WORKSPACE_LIMIT; + int64_t mLimit = ((usrWorkspaceLimut / DOUBLE_WORKSPACE_SPLIT) / INT32_DTYPE_SIZE) / n; + OPS_ERR_IF(mLimit <= 0, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(),"mLimit is %ld must over then 0.", mLimit), + return GRAPH_FAILED); + tilingData.gmmSwigluBaseParams.set_mLimit(mLimit); + workspaceSizes[0] = SYS_WORKSPACE_SIZE + ((mLimit * DOUBLE_WORKSPACE_SPLIT > m \ + ? m \ + : mLimit * DOUBLE_WORKSPACE_SPLIT) * n * sizeof(int32_t)); + bool isSplitWorkSpace = m > mLimit * DOUBLE_WORKSPACE_SPLIT; + OPS_LOG_D(context->GetNodeName(), "USER_WORKSPACE_LIMIT: %ld", usrWorkspaceLimut); + OPS_LOG_D(context->GetNodeName(), "mLimit: %ld", mLimit); + OPS_LOG_D(context->GetNodeName(), "workspaceSizes: %lu", workspaceSizes[0]); + OPS_LOG_D(context->GetNodeName(), "isSplitWorkSpace: %s", isSplitWorkSpace ? "true" : "false"); + OPS_LOG_D(context->GetNodeName(), "isPreFill: %s", isPreFill ? "true" : "false"); + SetTilingKey(context, isSplitWorkSpace); + tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->SetBlockDim(compileInfoPtr->aicNum_); // block dim is the number of aicube + context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + + OPS_LOG_D(context->GetNodeName(), "End Run GMM Swiglu Tiling."); + return GRAPH_SUCCESS; +} + +ASCENDC_EXTERN_C graphStatus TilingPrepareForGMMSwigluQuant(gert::TilingParseContext* context) { + // get info + fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); + OPS_LOG_E_IF_NULL(context, platformInfoPtr, return GRAPH_FAILED); + auto compileInfoPtr = context->GetCompiledInfo(); + OPS_LOG_E_IF_NULL(context, compileInfoPtr, return GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); + compileInfoPtr->aicNum_ = ascendcPlatform.GetCoreNumAic(); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize_); + OPS_LOG_D(context->GetNodeName(), "ubSize is %lu, aicNum is %u.", compileInfoPtr->ubSize_, compileInfoPtr->aicNum_); + return GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(GroupedMatmulSwigluQuantWeightNzTensorList) +.Tiling(TilingGMMSwigluQuant) +.TilingParse(TilingPrepareForGMMSwigluQuant); +} // namespace optiling diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h new file mode 100644 index 00000000000..ccc0d459cb8 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_host/grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_tiling.h + * \brief + */ +#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H + +#include +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GMMSwigluBaseParams) + TILING_DATA_FIELD_DEF(uint32_t, groupNum); + TILING_DATA_FIELD_DEF(uint32_t, coreNum); + TILING_DATA_FIELD_DEF(uint32_t, K); + TILING_DATA_FIELD_DEF(uint32_t, N); + TILING_DATA_FIELD_DEF(uint32_t, M); + TILING_DATA_FIELD_DEF(uint32_t, mLimit); + TILING_DATA_FIELD_DEF(uint64_t, isPreFill); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMSwigluBaseParamsOp, GMMSwigluBaseParams) + +BEGIN_TILING_DATA_DEF(GMMSwiglu) + TILING_DATA_FIELD_DEF(uint32_t, maxProcessRowNum); + TILING_DATA_FIELD_DEF(uint32_t, groupListLen); + TILING_DATA_FIELD_DEF(uint32_t, tokenLen); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(GMMSwigluOp, GMMSwiglu) + +BEGIN_TILING_DATA_DEF(GMMSwigluQuantTilingData) + TILING_DATA_FIELD_DEF_STRUCT(GMMSwigluBaseParams, gmmSwigluBaseParams); + TILING_DATA_FIELD_DEF_STRUCT(GMMSwiglu, gmmSwiglu); + TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, mmTilingData); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GroupedMatmulSwigluQuantWeightNzTensorList, GMMSwigluQuantTilingData) +} + +namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling { +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t WEIGHT_INDEX = 1; +constexpr uint32_t GROUPLIST_INDEX = 4; +constexpr uint32_t BATCH_MODE_SCHEDULE = 1; +constexpr uint32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; +constexpr int64_t USER_WORKSPACE_LIMIT = 256 * 1024 * 1024; +constexpr int64_t PREFILL_USER_WORKSPACE_LIMIT = 64 * 1024 * 1024; +constexpr int64_t DOUBLE_WORKSPACE_SPLIT = 2; +constexpr int64_t INT32_DTYPE_SIZE = 4; +constexpr uint32_t PREFILL_M_MIN_SIZE = 16 * 1024; + +const std::set> PREFILL_WHITE_LIST = { // used for preFill case + {{2048, 1536}}, + {{4096, 3072}} +}; +} // namespace GroupedMatmulSwigluQuantWeightNzTensorListTiling + +#endif // AIR_CXX_RUNTIME_V2_OP_IMPL_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp new file mode 100644 index 00000000000..ba60d80ba98 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.cpp + * \brief + */ +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list.h" +#include +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h" +using namespace AscendC; +using namespace matmul; +using namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST; +using MM_DTYPE_Y = int32_t; + +template +using xType = MatmulType; + +template +using weightType = MatmulType; + +using yType = MatmulType; + +#define GMM_CV_SPLIT_IMP(computeClass, dtypeC, transA, transB, sync, cfg, aType, bType, cType) \ + do { \ + using matmulType = MMImplType, bType, cType, cType, cfg>; \ + matmulType::MT mm; \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwigluBaseParams, gmmSwigluBaseParams_, tiling); \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, mmTilingData, mmTilingData_, tiling); \ + GET_TILING_DATA_MEMBER(GMMSwigluQuantTilingData, gmmSwiglu, gmmSwiglu_, tiling); \ + if ASCEND_IS_AIC { \ + mm.SetSubBlockIdx(0); \ + mm.Init(&mmTilingData_, &tPipe); \ + } \ + computeClass computeOp(mm); \ + computeOp.Init(x, weight, perChannelScale, perTokenScale, groupList, quantOutput, quantScaleOutput, \ + user1, &gmmSwigluBaseParams_, &mmTilingData_, &gmmSwiglu_, &tPipe); \ + computeOp.Process(); \ + } while (0) + +extern "C" __global__ __aicore__ void grouped_matmul_swiglu_quant_weight_nz_tensor_list(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, GM_ADDR tiling) { + TPipe tPipe; + AscendCUtils::SetOverflow(1); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + GM_ADDR user1 = GetUserWorkspace(workspace); + if (TILING_KEY_IS(0)) { // antiquant msd + KERNEL_TASK_TYPE(0, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP( + GMMSwigluCompute, // computeClass + DTYPE_WEIGHT_SCALE, + false, // transA + false, // transB + false, // sync + NZ_CFG_MDL, // cfg + xType, // aType + weightType, // bType + yType); // cType + } else if(TILING_KEY_IS(1)){ + KERNEL_TASK_TYPE(1, KERNEL_TYPE_MIX_AIC_1_2); + GMM_CV_SPLIT_IMP( + GMMSwigluSplitWorkSpaceCompute, // computeClass + DTYPE_WEIGHT_SCALE, + false, // transA + false, // transB + false, // sync + NZ_CFG_MDL, // cfg + xType, // aType + weightType, // bType + yType); // cType + } +} diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h new file mode 100644 index 00000000000..45d95488f52 --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list.h @@ -0,0 +1,498 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_H + +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h" +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +/** @brief intenal computation class +*/ +template +class GMMSwigluCompute{ + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = int8_t; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMSwigluCompute(typename mmType::MT& mm_): mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmBaseParamsIN, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN); + __aicore__ inline void Process(); + private: + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); + + __aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k); + + __aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN); + template + __aicore__ inline void UpdateChannelScale(uint32_t loopidx); + __aicore__ inline void VectorCompute(uint32_t loopidx); + template + __aicore__ inline void PreLoadTokenAndChannel(LocalTensor& channelScaleLocal); + __aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig); + __aicore__ inline void customDataCopyIn(uint32_t outLoopIdx); + __aicore__ inline void customDataCopyOut(); + __aicore__ inline void Dequant(uint32_t loopidx); + __aicore__ inline void Quant(uint32_t loopidx); + __aicore__ inline void Swiglu(uint32_t loopidx); + private: + typename mmType::MT& mm; + const GMMSwigluBaseParams* __restrict gmmBaseParams; + const GMMSwiglu* __restrict gmmSwiglu; + const TCubeTiling* __restrict mmTilingData; + uint32_t blockIdx; + VecConfig vecConfig; + TPipe* pipe; + GlobalTensor xGM, weightGM; + GlobalTensor perChannelScaleGM; + GlobalTensor perTokenScaleGM; + GlobalTensor groupListGM; + GlobalTensor quantOutputGM; + GlobalTensor quantScaleOutputGM; + GlobalTensor mmOutGM; + // define the que + TQue mmOutQueue; + TQue perChannelScaleInQueue; + TQue quantOutQueue; + TQue quantScaleOutQueue; + TBuf reduceWorkspace; + TBuf castWorkspace; + bool sequentialWrite = true; + uint32_t cubeNum; // Matmul completions on the kernel + uint32_t groupNum; // Matmul completions on the kernel + int32_t preOffset; + int64_t aicCoreNum; + int64_t aivCoreNum; + GM_ADDR xTensorPtr; + GM_ADDR weightTensorPtr; + GM_ADDR perChannelScalePtr; +}; + +template +__aicore__ inline void GMMSwigluCompute::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN) +{ + aicCoreNum = GetBlockNum(); + aivCoreNum = aicCoreNum * 2; + blockIdx = GetBlockIdx(); + mmTilingData = mmTilingDataIN; + gmmBaseParams = gmmSwigluBaseParamsIn; + gmmSwiglu = gmmSwigluIN; + pipe = tPipeIN; + xTensorPtr = x; + weightTensorPtr = weight; + perChannelScalePtr = perChannelScale; + groupNum = gmmSwiglu->groupListLen; + if ASCEND_IS_AIC { + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen); + } + if ASCEND_IS_AIV { + mmOutGM.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->M * gmmSwiglu->tokenLen); + perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, gmmSwiglu->groupListLen * gmmSwiglu->tokenLen); + perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmSwiglu->maxProcessRowNum); + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2); + quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmSwiglu->maxProcessRowNum); + } +} + +template +__aicore__ inline void GMMSwigluCompute::Process() { + MNConfig mnConfig; + if ASCEND_IS_AIC { + preOffset = 0; + int32_t prevSplitValue = 0; + for (uint32_t groupIdx = 0, count = 0; groupIdx < gmmSwiglu->groupListLen; ++groupIdx) { + UpdateMnConfig(mnConfig); + int32_t currSplitValue = static_cast(groupListGM.GetValue(groupIdx)); + int32_t splitValue = currSplitValue - prevSplitValue; + prevSplitValue = currSplitValue; + SetMNConfig(splitValue, groupIdx, mnConfig); + if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + + uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN; + + while (curBlock < curCount) { + MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + MMCompute(groupIdx, mnConfig, blockIdx); + curBlock += aicCoreNum; + } + count = curCount % gmmBaseParams->coreNum; + } + SyncAll(); + } + + if ASCEND_IS_AIV { + UpdateVecConfig(blockIdx, vecConfig); + if (blockIdx < vecConfig.usedCoreNum) { + LocalTensor channelScaleLocal = perChannelScaleInQueue.AllocTensor(); + LocalTensor mmLocal = mmOutQueue.AllocTensor(); + LocalTensor quantLocal = quantOutQueue.AllocTensor(); + LocalTensor quantScaleLocal = quantScaleOutQueue.AllocTensor(); + mmOutQueue.EnQue(mmLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); + quantOutQueue.EnQue(quantLocal); + PreLoadTokenAndChannel(channelScaleLocal); + } + SyncAll(); + if (blockIdx < vecConfig.usedCoreNum) { + for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) { + vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1) + ? vecConfig.tailLoopNum + : gmmSwiglu->maxProcessRowNum; + customDataCopyIn(outLoopIdx); + for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) { + UpdateChannelScale(innerLoopIdx); + VectorCompute(innerLoopIdx); + } + customDataCopyOut(); + } + + LocalTensor channelScaleLocal = perChannelScaleInQueue.DeQue(); + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + perChannelScaleInQueue.FreeTensor(channelScaleLocal); + mmOutQueue.FreeTensor(mmLocal); + quantScaleOutQueue.FreeTensor(quantScaleLocal); + quantOutQueue.FreeTensor(quantLocal); + } else { + return; + } + } +} + +template +template +__aicore__ inline void GMMSwigluCompute::PreLoadTokenAndChannel(LocalTensor& channelScaleLocal) +{ + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + DataCopyExtParams copyChannelParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = channelScaleLocal.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams); + PipeBarrier(); + Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams); + } + perChannelScaleInQueue.EnQue(channelScaleLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx) +{ + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset); + weightGM.SetGlobalBuffer(GetTensorAddr(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k)); + if (mnConfig.blockDimM == 1){ + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset; + mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(xGM[xOffset], transposeX); + mm.SetTensorB(weightGM, transposeW); + mm.template IterateAll(mmOutGM[mnConfig.workSpaceOffset], 0); +} + +template +__aicore__ inline void GMMSwigluCompute::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMSwigluCompute::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) { + SetMKN(splitValue, groupIdx, mnConfig); + mnConfig.baseM = BASIC_M; + mnConfig.baseN = BASIC_N; + mnConfig.singleM = SINGLE_CORE_M; + mnConfig.singleN = SINGLE_CORE_N; +} + +template +__aicore__ inline void GMMSwigluCompute::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) +{ + mnConfig.m = static_cast(splitValue); + mnConfig.k = gmmBaseParams->K; // tilingData + mnConfig.n = gmmBaseParams->N; // tilingData +} + +template +__aicore__ inline uint64_t GMMSwigluCompute::GetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline void GMMSwigluCompute::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; +} + +template +__aicore__ inline void GMMSwigluCompute::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig) +{ + // Step 1: Read grouplist reduceSum to calculate total data count + int64_t prevM = 0; + for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + vecConfig.M += tempM; + } + // Step 2: Calculate core allocation + uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum; + vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M; + uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum; + vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1; + vecConfig.startIdx = blockIdx < tailCoreIdx + ? eachCoreTaskNum * blockIdx + :((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx); + vecConfig.curIdx = vecConfig.startIdx; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + vecConfig.curOffset = vecConfig.startOffset; + int64_t curStartIdx = vecConfig.startIdx; + prevM = 0; + for (uint32_t groupIdx = 0; groupIdx < gmmSwiglu->groupListLen; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + if (curStartIdx >= 0 && curStartIdx - tempM < 0) { + vecConfig.curGroupIdx = groupIdx; + vecConfig.nextUpadteInterVal = tempM - curStartIdx; + } + curStartIdx -= tempM; + } + // Step 3: Calculate total data volume + vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum; + vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + ? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + : gmmSwiglu->maxProcessRowNum; + pipe->Reset(); + // Step 4: Allocate space + pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t)); + pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float)); + pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)); + pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float)); + pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float)); + pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t)); +} + +template +__aicore__ inline void GMMSwigluCompute::customDataCopyIn(uint32_t outLoopIdx) +{ + LocalTensor _inMMLocal_0 = mmOutQueue.DeQue(); + DataCopyExtParams copyParams_0{1, static_cast(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams_0{false, 0 ,0, 0}; + DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0); + + mmOutQueue.EnQue(_inMMLocal_0); + + LocalTensor _inMMLocal_1 = mmOutQueue.DeQue(); + + Cast(_inMMLocal_1.ReinterpretCast(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen); + + mmOutQueue.EnQue(_inMMLocal_1); + LocalTensor _inMMLocal_2 = mmOutQueue.DeQue(); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){ + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + float scale = perTokenScaleGM.GetValue(vecConfig.curIdx); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curIdx++; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen; + mmOutQueue.EnQue(_inMMLocal_2); +} + +template +template +__aicore__ inline void GMMSwigluCompute::UpdateChannelScale(uint32_t loopIdx){ + // Update perChannel + if (unlikely(vecConfig.nextUpadteInterVal == 0)) { + int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx; + while (loop--) { + int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + vecConfig.curGroupIdx++; + int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + if(nextTemp != curTemp){ + vecConfig.nextUpadteInterVal = nextTemp - curTemp; + break; + } + } + LocalTensor _inChannel = perChannelScaleInQueue.DeQue(); + DataCopyExtParams copyParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = _inChannel.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams); + PipeBarrier(); + Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams); + } + PipeBarrier(); + perChannelScaleInQueue.EnQue(_inChannel); + } +} + +template +__aicore__ inline void GMMSwigluCompute::VectorCompute(uint32_t loopIdx) { + Dequant(loopIdx); + Swiglu(loopIdx); + Quant(loopIdx); +} + +template +__aicore__ inline void GMMSwigluCompute::Dequant(uint32_t loopIdx) { + // perChanelScale * perTokenScale + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor perChannelLocal = perChannelScaleInQueue.DeQue(); + Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen); + vecConfig.nextUpadteInterVal--; + mmOutQueue.EnQue(mmLocal); + perChannelScaleInQueue.EnQue(perChannelLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::Swiglu(uint32_t loopIdx) { + // High-level API swiglu + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + float beta = 1.0f; + LocalTensor workspaceLocal= reduceWorkspace.Get(); + LocalTensor src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2]; + LocalTensor src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen]; + SwiGLU(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2); + PipeBarrier(); + DataCopyParams repeatParams{1, static_cast((gmmSwiglu->tokenLen / 2) / 8), 0, 0}; + DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams); + mmOutQueue.EnQue(_inMMLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::Quant(uint32_t loopIdx) { + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT], + _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + gmmSwiglu->tokenLen / BISECT); + LocalTensor workspaceLocal= reduceWorkspace.Get(); + PipeBarrier(); + ReduceMaxTemplate(workspaceLocal, + _inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8; + PipeBarrier(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + PipeBarrier(); + quantScaleLocal.SetValue(loopIdx, quantScale); + PipeBarrier(); + quantScale = 1 / quantScale; + PipeBarrier(); + Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + quantScale, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + int32_t dstTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen / BISECT); + int32_t srcTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen); + int32_t tempCount = static_cast(gmmSwiglu->tokenLen / BISECT); + LocalTensor castSpace = castWorkspace.Get(); + CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount); + mmOutQueue.EnQue(_inMMLocal); + quantOutQueue.EnQue(quantLocal); +} + +template +__aicore__ inline void GMMSwigluCompute::customDataCopyOut() { + // perChanelScale * perTokenScale + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantScaleOutputGM[vecConfig.startIdx], quantScaleLocal, copyParams_0); + LocalTensor quantLocal = quantOutQueue.DeQue(); + DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantOutputGM[vecConfig.startIdx * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1); + PipeBarrier(); + vecConfig.startIdx += vecConfig.innerLoopNum; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + quantOutQueue.EnQue(quantLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); +} + +} // namespace GROUPED_MATMUL +#endif // ASCENDC_GROUPED_MATMUL_QUANT_MIXCORE_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h new file mode 100644 index 00000000000..a5c8571d95d --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_split_ws.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H + +#include "grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h" +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +/** @brief internal computation class +*/ + +template +class GMMSwigluSplitWorkSpaceCompute{ + public: + using AT = typename mmType::AT::T; + using BT = typename mmType::BT::T; + using B = typename mmType::BT; + using CT = typename mmType::CT::T; + using BiasT = typename mmType::BiasT::T; + using WT = int8_t; + constexpr static bool transposeX = mmType::AT::isTrans; + constexpr static bool transposeW = mmType::BT::isTrans; + + /** @brief constructor */ + __aicore__ inline GMMSwigluSplitWorkSpaceCompute(typename mmType::MT& mm_): mm(mm_) {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, GM_ADDR quantScaleOutput, + GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmBaseParamsIN, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN); + __aicore__ inline void Process(); + + private: + __aicore__ inline void MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor &mmOutGM); + + __aicore__ inline void UpdateMnConfig(MNConfig &mnConfig); + + __aicore__ inline void SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline void SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig); + + __aicore__ inline uint64_t GetWOffset(uint32_t tailN, uint32_t k); + + __aicore__ inline void MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN); + + template + __aicore__ inline void UpdateChannelScale(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void VectorCompute(uint32_t loopidx, VecConfig& vecConfig); + + template + __aicore__ inline void PreLoadTokenAndChannel(LocalTensor& channelScaleLocal, VecConfig& vecConfig); + + __aicore__ inline void UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig); + + __aicore__ inline void UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx); + + __aicore__ inline void InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig); + + __aicore__ inline void customDataCopyIn(uint32_t outLoopIdx, GlobalTensor &mmOutGM, VecConfig& vecConfig); + + __aicore__ inline void customDataCopyOut(VecConfig& vecConfig); + + __aicore__ inline void Dequant(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void Quant(uint32_t loopidx, VecConfig& vecConfig); + + __aicore__ inline void Swiglu(uint32_t loopidx, VecConfig& vecConfig); + + private: + typename mmType::MT& mm; + const GMMSwigluBaseParams* __restrict gmmBaseParams; + const GMMSwiglu* __restrict gmmSwiglu; + const TCubeTiling* __restrict mmTilingData; + uint32_t blockIdx; + WorkSpaceSplitConfig workspaceSplitConfig; + TPipe* pipe; + GlobalTensor xGM; + GlobalTensor weightGM; + GlobalTensor perChannelScaleGM; + GlobalTensor perTokenScaleGM; + GlobalTensor groupListGM; + GlobalTensor quantOutputGM; + GlobalTensor quantScaleOutputGM; + GlobalTensor mmOutGM1; + GlobalTensor mmOutGM2; + // define the que + TQue mmOutQueue; + TQue perChannelScaleInQueue; + TQue quantOutQueue; + TQue quantScaleOutQueue; + TBuf reduceWorkspace; + TBuf castWorkspace; + bool sequentialWrite = true; + uint32_t cubeNum; // Matmul completions on the kernel + uint32_t groupNum; // Matmul completions on the kernel + int64_t aicCoreNum; + int64_t aivCoreNum; + GM_ADDR xTensorPtr; + GM_ADDR weightTensorPtr; + GM_ADDR perChannelScalePtr; +}; + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Init(GM_ADDR x, GM_ADDR weight, + GM_ADDR perChannelScale, GM_ADDR perTokenScale, + GM_ADDR groupList, GM_ADDR quantOutput, + GM_ADDR quantScaleOutput, GM_ADDR workspace, + const GMMSwigluBaseParams* __restrict gmmSwigluBaseParamsIn, + const TCubeTiling* __restrict mmTilingDataIN, + const GMMSwiglu* __restrict gmmSwigluIN, TPipe* tPipeIN) +{ + aicCoreNum = GetBlockNum(); + aivCoreNum = aicCoreNum * 2; + blockIdx = GetBlockIdx(); + pipe = tPipeIN; + xTensorPtr = x; + weightTensorPtr = weight; + perChannelScalePtr = perChannelScale; + mmTilingData = mmTilingDataIN; + gmmBaseParams = gmmSwigluBaseParamsIn; + gmmSwiglu = gmmSwigluIN; + groupNum = gmmSwiglu->groupListLen; + if ASCEND_IS_AIC { + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen, + gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + } + if ASCEND_IS_AIV { + mmOutGM1.SetGlobalBuffer((__gm__ int32_t *)workspace, gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + mmOutGM2.SetGlobalBuffer((__gm__ int32_t *)workspace + gmmBaseParams->mLimit * gmmSwiglu->tokenLen, + gmmBaseParams->mLimit * gmmSwiglu->tokenLen); + perChannelScaleGM.SetGlobalBuffer((__gm__ CHANNELDTYPE *)perChannelScale, + gmmSwiglu->groupListLen * gmmSwiglu->tokenLen); + perTokenScaleGM.SetGlobalBuffer((__gm__ float *)perTokenScale, gmmBaseParams->M); + groupListGM.SetGlobalBuffer((__gm__ int64_t *)groupList, gmmSwiglu->groupListLen); + quantOutputGM.SetGlobalBuffer((__gm__ int8_t *)quantOutput, gmmBaseParams->M * gmmSwiglu->tokenLen / 2); + quantScaleOutputGM.SetGlobalBuffer((__gm__ float *)quantScaleOutput, gmmBaseParams->M); + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::InitWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig) +{ + workspaceSplitConfig.M = groupListGM.GetValue(gmmSwiglu->groupListLen - 1); + workspaceSplitConfig.loopCount = Ceil(workspaceSplitConfig.M, gmmBaseParams->mLimit); + workspaceSplitConfig.notLastTaskSize = gmmBaseParams->mLimit; + workspaceSplitConfig.lastLoopTaskSize = workspaceSplitConfig.M - (workspaceSplitConfig.loopCount - 1) * gmmBaseParams->mLimit; + workspaceSplitConfig.leftMatrixStartIndex = 0; + workspaceSplitConfig.rightMatrixExpertStartIndex = 0; + workspaceSplitConfig.rightMatrixExpertNextStartIndex = 0; + workspaceSplitConfig.isLastLoop = false; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateWorkSpaceSplitConfig(WorkSpaceSplitConfig &workspaceSplitConfig, int32_t workspaceSplitLoopIdx) +{ + workspaceSplitConfig.leftMatrixStartIndex = workspaceSplitLoopIdx * gmmBaseParams->mLimit; + workspaceSplitConfig.rightMatrixExpertStartIndex = workspaceSplitConfig.rightMatrixExpertNextStartIndex; + workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertStartIndex; + // Calculate the right expert matrix end index (rightMatrixExpertEndIndex) and the next start index (rightMatrixExpertNextStartIndex) + int32_t curTaskNum = 0; + int32_t nextTaskNum = 0; + while(workspaceSplitConfig.rightMatrixExpertEndIndex < gmmSwiglu->groupListLen) + { + curTaskNum = groupListGM.GetValue(workspaceSplitConfig.rightMatrixExpertEndIndex) - workspaceSplitConfig.leftMatrixStartIndex; + int32_t nextTaskIdx = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen - 1 \ + ? gmmSwiglu->groupListLen - 1 \ + : workspaceSplitConfig.rightMatrixExpertEndIndex + 1; + nextTaskNum = groupListGM.GetValue(nextTaskIdx) - workspaceSplitConfig.leftMatrixStartIndex; + if (curTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex; + break; + } else if (curTaskNum == gmmBaseParams->mLimit && nextTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex + 1; + break; + } else if (nextTaskNum > gmmBaseParams->mLimit){ + workspaceSplitConfig.rightMatrixExpertEndIndex++; + workspaceSplitConfig.rightMatrixExpertNextStartIndex = workspaceSplitConfig.rightMatrixExpertEndIndex; + break; + } + workspaceSplitConfig.rightMatrixExpertEndIndex++; + } + workspaceSplitConfig.isLastLoop = workspaceSplitLoopIdx == workspaceSplitConfig.loopCount - 1 ? true : false; + + if (workspaceSplitConfig.isLastLoop) { + workspaceSplitConfig.rightMatrixExpertEndIndex = workspaceSplitConfig.rightMatrixExpertEndIndex >= gmmSwiglu->groupListLen \ + ? gmmSwiglu->groupListLen - 1 \ + : workspaceSplitConfig.rightMatrixExpertEndIndex; + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Process() { + InitWorkSpaceSplitConfig(workspaceSplitConfig); + int32_t parallelNum = gmmBaseParams->isPreFill ? 2 : 1; // 2: double workspace buffer + for (int32_t workspaceSplitLoopIdx = 0; workspaceSplitLoopIdx < workspaceSplitConfig.loopCount; workspaceSplitLoopIdx++) { + UpdateWorkSpaceSplitConfig(workspaceSplitConfig, workspaceSplitLoopIdx); + GlobalTensor mmOutGM = (workspaceSplitLoopIdx % 2 == 0 ) ? mmOutGM1 : mmOutGM2; + + if ASCEND_IS_AIC { + if (workspaceSplitLoopIdx >= parallelNum){ // first parallelNum core no need to wait + SyncAll(); + } + MNConfig mnConfig; + int32_t prevSplitValue = workspaceSplitConfig.leftMatrixStartIndex; + for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex, count = 0; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; ++groupIdx) { + UpdateMnConfig(mnConfig); + int32_t currSplitValue = static_cast(groupListGM.GetValue(groupIdx)); + currSplitValue = currSplitValue > (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \ + ? (workspaceSplitLoopIdx + 1) * gmmBaseParams->mLimit \ + : currSplitValue; + int32_t splitValue = currSplitValue - prevSplitValue; + prevSplitValue = currSplitValue; + SetMNConfig(splitValue, groupIdx, mnConfig); + if (mnConfig.m <= 0 || mnConfig.k <= 0 || mnConfig.n <= 0) { + continue; + } + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + mnConfig.blockDimN = Ceil(mnConfig.n, mnConfig.singleN); + + uint32_t curCount = count + mnConfig.blockDimM * mnConfig.blockDimN; + uint32_t curBlock = blockIdx >= count ? blockIdx : blockIdx + gmmBaseParams->coreNum; + uint32_t thresholdM_dimN = THRESHOLD_BLOCK_NUM * mnConfig.blockDimN; + + while (curBlock < curCount) { + MNBlockIdxCompute(mnConfig, curBlock, count, thresholdM_dimN); + MMCompute(groupIdx, mnConfig, blockIdx, mmOutGM); + curBlock += aicCoreNum; + } + count = curCount % gmmBaseParams->coreNum; + } + SyncAll(); + } + + if ASCEND_IS_AIV { + VecConfig vecConfig; + UpdateVecConfig(blockIdx, vecConfig); + if (blockIdx < vecConfig.usedCoreNum) { + LocalTensor channelScaleLocal = perChannelScaleInQueue.AllocTensor(); + LocalTensor mmLocal = mmOutQueue.AllocTensor(); + LocalTensor quantLocal = quantOutQueue.AllocTensor(); + LocalTensor quantScaleLocal = quantScaleOutQueue.AllocTensor(); + mmOutQueue.EnQue(mmLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); + quantOutQueue.EnQue(quantLocal); + PreLoadTokenAndChannel(channelScaleLocal, vecConfig); + } + SyncAll(); + if (blockIdx < vecConfig.usedCoreNum) { + for (uint32_t outLoopIdx = 0; outLoopIdx < vecConfig.outLoopNum; outLoopIdx++) { + vecConfig.innerLoopNum = outLoopIdx == (vecConfig.outLoopNum - 1) + ? vecConfig.tailLoopNum + : gmmSwiglu->maxProcessRowNum; + PipeBarrier(); + customDataCopyIn(outLoopIdx, mmOutGM, vecConfig); + PipeBarrier(); + for (uint32_t innerLoopIdx = 0; innerLoopIdx < vecConfig.innerLoopNum; innerLoopIdx++) { + UpdateChannelScale(innerLoopIdx, vecConfig); + VectorCompute(innerLoopIdx, vecConfig); + } + PipeBarrier(); + customDataCopyOut(vecConfig); + PipeBarrier(); + } + + LocalTensor channelScaleLocal = perChannelScaleInQueue.DeQue(); + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + perChannelScaleInQueue.FreeTensor(channelScaleLocal); + mmOutQueue.FreeTensor(mmLocal); + quantScaleOutQueue.FreeTensor(quantScaleLocal); + quantOutQueue.FreeTensor(quantLocal); + } + if (workspaceSplitLoopIdx < workspaceSplitConfig.loopCount - parallelNum){ + SyncAll(); + } + } + } +} + +template +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::PreLoadTokenAndChannel(LocalTensor& channelScaleLocal, VecConfig& vecConfig) +{ + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + DataCopyExtParams copyChannelParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = channelScaleLocal.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyChannelParams, padParams); + PipeBarrier(); + Cast(channelScaleLocal, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(channelScaleLocal, perChannelScaleTensor, copyChannelParams, padParams); + } + perChannelScaleInQueue.EnQue(channelScaleLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::MMCompute(uint32_t groupIdx, MNConfig& mnConfig, uint32_t coreIdx, GlobalTensor &mmOutGM) +{ + uint32_t tailN = mnConfig.nIdx * mnConfig.singleN; + uint32_t curSingleN = mnConfig.nIdx < mnConfig.blockDimN - 1 ? mnConfig.singleN : mnConfig.n - tailN; + uint32_t curSingleM = mnConfig.mIdx < mnConfig.blockDimM - 1 ? mnConfig.singleM + : mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + uint64_t xOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.k; + if constexpr (transposeX) { + xOffset = mnConfig.mIdx * mnConfig.singleM; + } + uint64_t outOffset = mnConfig.mIdx * mnConfig.singleM * mnConfig.n + tailN; + xGM.SetGlobalBuffer((__gm__ int8_t *)xTensorPtr + mnConfig.xBaseOffset + workspaceSplitConfig.leftMatrixStartIndex * mnConfig.k); + weightGM.SetGlobalBuffer(GetTensorAddr(groupIdx, weightTensorPtr) + GetWOffset(tailN, mnConfig.k)); + if (mnConfig.blockDimM == 1){ + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } else { + weightGM.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); + } + mnConfig.workSpaceOffset = outOffset + mnConfig.yBaseOffset; + mm.SetOrgShape(mnConfig.m, mnConfig.n, mnConfig.k); + mm.SetSingleShape(curSingleM, curSingleN, mnConfig.k); + mm.SetTensorA(xGM[xOffset], transposeX); + mm.SetTensorB(weightGM, transposeW); + mm.template IterateAll(mmOutGM[mnConfig.workSpaceOffset], 0); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateMnConfig(MNConfig &mnConfig) { + if constexpr (B::format == CubeFormat::NZ) { + mnConfig.wBaseOffset += AlignUp<16>(mnConfig.k) * AlignUp<32>(mnConfig.n); // 16: nz format last two dim size + } else { + mnConfig.wBaseOffset += mnConfig.k * mnConfig.n; + } + mnConfig.nAxisBaseOffset += mnConfig.n; + mnConfig.mAxisBaseOffset += mnConfig.m; + mnConfig.xBaseOffset += mnConfig.m * mnConfig.k; + mnConfig.yBaseOffset += mnConfig.m * mnConfig.n; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::SetMNConfig(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) { + SetMKN(splitValue, groupIdx, mnConfig); + mnConfig.baseM = BASIC_M; + mnConfig.baseN = BASIC_N; + mnConfig.singleM = SINGLE_CORE_M; + mnConfig.singleN = SINGLE_CORE_N; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::SetMKN(const int32_t splitValue, const uint32_t groupIdx, MNConfig &mnConfig) +{ + mnConfig.m = static_cast(splitValue); + mnConfig.k = gmmBaseParams->K; // tilingData + mnConfig.n = gmmBaseParams->N; // tilingData +} + +template +__aicore__ inline uint64_t GMMSwigluSplitWorkSpaceCompute::GetWOffset(uint32_t tailN, uint32_t k) { + uint64_t wOffset = 0; + if constexpr (mmType::BT::format == CubeFormat::NZ) { + wOffset = tailN * AlignUp<16>(k); // 16: nz format last two dim size + } else { + wOffset = tailN; + } + return wOffset; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::MNBlockIdxCompute(MNConfig &mnConfig, const uint32_t curBlock, + const uint32_t count, const uint32_t thresholdM_dimN) { + mnConfig.mIdx = (curBlock - count) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - count) % mnConfig.blockDimN; +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateVecConfig(uint32_t blockIdx, VecConfig& vecConfig) +{ + // Step 1: Read grouplist reduceSum to calculate total data count + vecConfig.M = workspaceSplitConfig.isLastLoop \ + ? workspaceSplitConfig.lastLoopTaskSize\ + : workspaceSplitConfig.notLastTaskSize; + // Step 2: Calculate core allocation + uint32_t eachCoreTaskNum = (vecConfig.M + aivCoreNum - 1) / aivCoreNum; + vecConfig.usedCoreNum = vecConfig.M >= aivCoreNum ? aivCoreNum : vecConfig.M; + uint32_t tailCoreIdx = vecConfig.M - (eachCoreTaskNum - 1) * vecConfig.usedCoreNum; + vecConfig.taskNum = blockIdx < tailCoreIdx ? eachCoreTaskNum : eachCoreTaskNum - 1; + vecConfig.startIdx = blockIdx < tailCoreIdx + ? eachCoreTaskNum * blockIdx + :((eachCoreTaskNum - 1) * blockIdx + tailCoreIdx); + vecConfig.curIdx = vecConfig.startIdx; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + vecConfig.curOffset = vecConfig.startOffset; + int64_t curStartIdx = vecConfig.startIdx; + int64_t prevM = workspaceSplitConfig.leftMatrixStartIndex; + for (uint32_t groupIdx = workspaceSplitConfig.rightMatrixExpertStartIndex; groupIdx <= workspaceSplitConfig.rightMatrixExpertEndIndex; groupIdx++){ + int64_t currM = groupListGM.GetValue(groupIdx); + int64_t tempM = currM - prevM; + prevM = currM; + if (curStartIdx >= 0 && curStartIdx - tempM < 0) { + vecConfig.curGroupIdx = groupIdx; + vecConfig.nextUpadteInterVal = tempM - curStartIdx; + } + curStartIdx -= tempM; + } + // Step 3: Calculate total data volume + vecConfig.outLoopNum = (vecConfig.taskNum + gmmSwiglu->maxProcessRowNum - 1) / gmmSwiglu->maxProcessRowNum; + vecConfig.tailLoopNum = vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + ? vecConfig.taskNum % gmmSwiglu->maxProcessRowNum + : gmmSwiglu->maxProcessRowNum; + pipe->Reset(); + // Step 4: Allocate space + pipe->InitBuffer(mmOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen * sizeof(int32_t)); + pipe->InitBuffer(perChannelScaleInQueue, 1, gmmSwiglu->tokenLen * sizeof(float)); + pipe->InitBuffer(quantOutQueue, 1, gmmSwiglu->maxProcessRowNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)); + pipe->InitBuffer(quantScaleOutQueue, 1, AlignUp(gmmSwiglu->maxProcessRowNum, 8) * sizeof(float)); + pipe->InitBuffer(reduceWorkspace, 1024 * sizeof(float)); + pipe->InitBuffer(castWorkspace, 32 * sizeof(int8_t)); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::customDataCopyIn(uint32_t outLoopIdx, GlobalTensor &mmOutGM, VecConfig& vecConfig) +{ + LocalTensor _inMMLocal_0 = mmOutQueue.DeQue(); + DataCopyExtParams copyParams_0{1, static_cast(vecConfig.innerLoopNum * gmmSwiglu->tokenLen * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams_0{false, 0 ,0, 0}; + PipeBarrier(); + DataCopyPad(_inMMLocal_0, mmOutGM[vecConfig.curOffset], copyParams_0, padParams_0); + mmOutQueue.EnQue(_inMMLocal_0); + + LocalTensor _inMMLocal_1 = mmOutQueue.DeQue(); + + Cast(_inMMLocal_1.ReinterpretCast(), _inMMLocal_1, RoundMode::CAST_NONE, vecConfig.innerLoopNum * gmmSwiglu->tokenLen); + + mmOutQueue.EnQue(_inMMLocal_1); + LocalTensor _inMMLocal_2 = mmOutQueue.DeQue(); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + for (uint32_t i = 0; i < vecConfig.innerLoopNum; i++){ + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + float scale = perTokenScaleGM.GetValue(vecConfig.curIdx + workspaceSplitConfig.leftMatrixStartIndex); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + Muls(_inMMLocal_2[i * gmmSwiglu->tokenLen], _inMMLocal_2[i * gmmSwiglu->tokenLen], scale, gmmSwiglu->tokenLen); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curIdx++; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + vecConfig.curOffset = vecConfig.curIdx * gmmSwiglu->tokenLen; + mmOutQueue.EnQue(_inMMLocal_2); +} + +template +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::UpdateChannelScale(uint32_t loopIdx, VecConfig& vecConfig){ + // Update perChannel + if (unlikely(vecConfig.nextUpadteInterVal == 0)) { + int64_t loop = gmmSwiglu->groupListLen - vecConfig.curGroupIdx; + while (loop--) { + int64_t curTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + vecConfig.curGroupIdx++; + int64_t nextTemp = groupListGM.GetValue(vecConfig.curGroupIdx); + if(nextTemp != curTemp){ + vecConfig.nextUpadteInterVal = nextTemp - curTemp; + break; + } + } + LocalTensor _inChannel = perChannelScaleInQueue.DeQue(); + DataCopyExtParams copyParams{1, static_cast(gmmSwiglu->tokenLen * sizeof(DTYPE_CS)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0 ,0, 0}; + + GlobalTensor perChannelScaleTensor; + perChannelScaleTensor.SetGlobalBuffer(GetTensorAddr(vecConfig.curGroupIdx, perChannelScalePtr)); + + if constexpr(!IsSameType::value) { + LocalTensor dstLocalT = _inChannel.template ReinterpretCast(); + DataCopyPad(dstLocalT[gmmSwiglu->tokenLen], perChannelScaleTensor, copyParams, padParams); + PipeBarrier(); + Cast(_inChannel, dstLocalT[gmmSwiglu->tokenLen], RoundMode::CAST_NONE, gmmSwiglu->tokenLen); + } else { + DataCopyPad(_inChannel, perChannelScaleTensor, copyParams, padParams); + } + PipeBarrier(); + + perChannelScaleInQueue.EnQue(_inChannel); + } +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::VectorCompute(uint32_t loopIdx, VecConfig& vecConfig) { + Dequant(loopIdx, vecConfig); + Swiglu(loopIdx, vecConfig); + Quant(loopIdx, vecConfig); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Dequant(uint32_t loopIdx, VecConfig& vecConfig) { + // perChanelScale * perTokenScale + LocalTensor mmLocal = mmOutQueue.DeQue(); + LocalTensor perChannelLocal = perChannelScaleInQueue.DeQue(); + Mul(mmLocal[loopIdx * gmmSwiglu->tokenLen], mmLocal[loopIdx * gmmSwiglu->tokenLen], perChannelLocal, gmmSwiglu->tokenLen); + vecConfig.nextUpadteInterVal--; + mmOutQueue.EnQue(mmLocal); + perChannelScaleInQueue.EnQue(perChannelLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Swiglu(uint32_t loopIdx, VecConfig& vecConfig) { + // High-level API swiglu + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + float beta = 1.0f; + LocalTensor workspaceLocal= reduceWorkspace.Get(); + LocalTensor src0Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / 2]; + LocalTensor src1Local = _inMMLocal[loopIdx * gmmSwiglu->tokenLen]; + SwiGLU(workspaceLocal, src0Local, src1Local, beta, gmmSwiglu->tokenLen / 2); + PipeBarrier(); + DataCopyParams repeatParams{1, static_cast((gmmSwiglu->tokenLen / 2) / 8), 0, 0}; + DataCopy(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], workspaceLocal, repeatParams); + mmOutQueue.EnQue(_inMMLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::Quant(uint32_t loopIdx, VecConfig& vecConfig) { + LocalTensor _inMMLocal = mmOutQueue.DeQue(); + Abs(_inMMLocal[loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT], + _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + gmmSwiglu->tokenLen / BISECT); + LocalTensor workspaceLocal= reduceWorkspace.Get(); + PipeBarrier(); + ReduceMaxTemplate(workspaceLocal, + _inMMLocal, loopIdx * gmmSwiglu->tokenLen + gmmSwiglu->tokenLen / BISECT, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + float quantScale = workspaceLocal.GetValue(0) / QUANT_SCALE_INT8; + PipeBarrier(); + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + PipeBarrier(); + quantScaleLocal.SetValue(loopIdx, quantScale); + PipeBarrier(); + quantScale = 1 / quantScale; + PipeBarrier(); + Muls(_inMMLocal[loopIdx * gmmSwiglu->tokenLen], _inMMLocal[loopIdx * gmmSwiglu->tokenLen], + quantScale, gmmSwiglu->tokenLen / BISECT); + PipeBarrier(); + LocalTensor quantLocal = quantOutQueue.DeQue(); + int32_t dstTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen / BISECT); + int32_t srcTempOffset = static_cast(loopIdx * gmmSwiglu->tokenLen); + int32_t tempCount = static_cast(gmmSwiglu->tokenLen / BISECT); + LocalTensor castSpace = castWorkspace.Get(); + CastFp32ToInt8Template(quantLocal, _inMMLocal, castSpace, dstTempOffset, srcTempOffset, tempCount); + mmOutQueue.EnQue(_inMMLocal); + quantOutQueue.EnQue(quantLocal); +} + +template +__aicore__ inline void GMMSwigluSplitWorkSpaceCompute::customDataCopyOut(VecConfig& vecConfig) { + LocalTensor quantScaleLocal = quantScaleOutQueue.DeQue(); + DataCopyParams copyParams_0{1, (uint16_t)(vecConfig.innerLoopNum * sizeof(float)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantScaleOutputGM[workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx], quantScaleLocal, copyParams_0); + LocalTensor quantLocal = quantOutQueue.DeQue(); + DataCopyParams copyParams_1{1, (uint16_t)(vecConfig.innerLoopNum * gmmSwiglu->tokenLen / 2 * sizeof(int8_t)), 0, 0}; + PipeBarrier(); + DataCopyPad(quantOutputGM[(workspaceSplitConfig.leftMatrixStartIndex + vecConfig.startIdx) * gmmSwiglu->tokenLen / 2], quantLocal, copyParams_1); + PipeBarrier(); + vecConfig.startIdx += vecConfig.innerLoopNum; + vecConfig.startOffset = vecConfig.startIdx * gmmSwiglu->tokenLen; + quantOutQueue.EnQue(quantLocal); + quantScaleOutQueue.EnQue(quantScaleLocal); +} + +} // namespace GROUPED_MATMUL +#endif // ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_SPLIT_WS_H diff --git a/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h new file mode 100644 index 00000000000..37ddd845dfd --- /dev/null +++ b/csrc/grouped_matmul_swiglu_quant_weight_nz_tensor_list/op_kernel/grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file grouped_matmul_swiglu_quant_weight_nz_tensor_list_utils.h + * \brief + */ +#ifndef ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H +#define ASCENDC_GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST_UTILS_H + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" +#include "lib/matmul_intf.h" + +namespace GROUPED_MATMUL_SWIGLU_QUANT_WEIGHT_NZ_TENSOR_LIST { +using namespace AscendC; +constexpr uint32_t INT8_BITS = 8; // a int8 number has 8 bits +constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32; // 32: a block has 32 bytes data +constexpr uint32_t THRESHOLD_BLOCK_NUM = 8; +constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64; // 64: a block has 64 bytes data +constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2; // 2: a float16 data has two bytes +constexpr uint32_t SINGLE_CORE_M = 128; +constexpr uint32_t SINGLE_CORE_N = 512; +constexpr uint32_t SINGLE_CORE_K = 7168; +constexpr uint32_t BASIC_M = 128; +constexpr uint32_t BASIC_N = 256; +constexpr uint32_t BASIC_K = 128; +constexpr uint32_t STEP_M = 1; +constexpr uint32_t STEP_N = 1; +constexpr uint32_t STEP_Ka = 4; +constexpr uint32_t STEP_Kb = 4; +constexpr uint32_t DEPTH_A1 = 8; +constexpr uint32_t DEPTH_B1 = 8; +constexpr uint32_t VEC_LEN_ONCE_REPEAT_ELE = 64; +constexpr uint32_t VEC_LEN_ONCE_REPEAT_BLOCK = 8; +constexpr uint32_t BISECT = 2; +constexpr uint32_t MOD_32_MASK = 0x1F; +constexpr uint32_t MOD_16_MASK = 0x0F; +constexpr uint32_t ALIGN_8_ELE = 8; +constexpr uint32_t ALIGN_16_ELE = 16; +constexpr float QUANT_SCALE_INT8 = 127.0f; +constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, true); +constexpr MatmulConfig GetMMCFG() { + MatmulConfig MM_CFG = NZ_CFG_MDL; + MM_CFG.singleCoreM = SINGLE_CORE_M; + MM_CFG.singleCoreN= SINGLE_CORE_N; + MM_CFG.singleCoreK= SINGLE_CORE_K; + MM_CFG.basicM= BASIC_M; + MM_CFG.basicN= BASIC_N; + MM_CFG.basicK= BASIC_K; + return MM_CFG; +} + +constexpr static MatmulApiStaticTiling GetMMTiling(const MatmulApiStaticTiling& mmTiling) +{ + MatmulApiStaticTiling tiling = mmTiling; + tiling.stepM = STEP_M; + tiling.stepN = STEP_N; + tiling.stepKa = STEP_Ka; + tiling.stepKb = STEP_Kb; + tiling.depthA1 = DEPTH_A1; + tiling.depthB1 = DEPTH_B1; + return tiling; +} +template +struct MMImplType { + using AT = AT_; + using BT = BT_; + using CT = CT_; + using BiasT = BiasT_; + static constexpr MatmulConfig cfg = GetMMCFG(); + static constexpr MatmulApiStaticTiling mdl = GetMMTiling(GetMatmulApiTiling(cfg)); + using MT = matmul::MatmulImpl; +}; + +struct MNConfig { + int64_t m = 0; + int64_t k = 0; + int64_t n = 0; + int64_t baseM = 0; + int64_t baseN = 0; + int64_t mIdx = 0; + int64_t nIdx = 0; + int64_t blockDimM = 0; + int64_t blockDimN = 0; + int64_t singleM = 0; + int64_t singleN = 0; + int64_t wBaseOffset = 0; + int64_t nAxisBaseOffset = 0; + int64_t mAxisBaseOffset = 0; + int64_t xBaseOffset = 0; + int64_t yBaseOffset = 0; + int64_t wOutOffset = 0; + int64_t workSpaceOffset = 0; +}; + +struct VecConfig { + int64_t M = 0; + int64_t usedCoreNum = 0; + int64_t startOffset = 0; + int64_t curOffset = 0; + int64_t startIdx = 0; + int64_t curIdx = 0; + int64_t taskNum = 0; + int64_t curGroupIdx = 0; + int64_t outLoopNum = 0; + int64_t innerLoopNum = 0; + int64_t tailLoopNum = 0; + int64_t nextUpadteInterVal = 0; +}; + +struct WorkSpaceSplitConfig { + int64_t M = 0; + int64_t loopCount = 0; + int64_t leftMatrixStartIndex = 0; + int64_t rightMatrixExpertStartIndex = 0; + int64_t rightMatrixExpertNextStartIndex = 0; + int64_t rightMatrixExpertEndIndex = 0; + int64_t notLastTaskSize = 0; + int64_t lastLoopTaskSize = 0; + bool isLastLoop = false; +}; + +template +__aicore__ inline T AlignUp(T a) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignUp(T a, T base) { + return (a + base - 1) / base * base; +} + +template +__aicore__ inline T AlignDown(T a, T base) { + if (unlikely(base == 0)) { + return a; + } + return a / base * base; +} + +template <> +__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) { + // to be Multiple of 4, result should be in a format of b(xxxx,x100). + // This means last two bits should be zero, requiring that + // result = num & b(1111,1100) = num & (~3). + // &(~3) operator may reduces num into the range [num, num - 3]. + // As the result should be no less than a (result >= a), it means num - 3 >= a in the worst case. + // In this case, num >= a+3. On the other hand, num should also be less then a+4, otherwise, + // the result will not be least multiple of 4 for 3. In other cases like [num, num - 2], + // num = a + 3 also satisfies the goal condition. + return (a + 3) & ~3; // & ~3: set last two bits of (a+3) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 7) & ~7; // & ~7: set last four bits of (a+7) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) { + // In general, if we want to get the least multiple of b (b is the power of 2) for a, + // it comes to a conclusion from the above comment: result = (a + (b - 1)) & (~b) + return (a + 15) & ~15; // & ~15: set last four bits of (a+15) to be zero +} + +template <> +__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) { + // refer to the above comments. + return (a + 31) & ~31; // & ~31: set last five bits of (a+31) to be zero} +} + +__aicore__ inline void ReduceMaxTemplate(LocalTensor& dstLocal, LocalTensor& srcLocal, + uint32_t srcOffset, uint32_t count) +{ + if (likely(count > VEC_LEN_ONCE_REPEAT_ELE && count % VEC_LEN_ONCE_REPEAT_ELE == 0)){ + WholeReduceMax(dstLocal, + srcLocal[srcOffset], VEC_LEN_ONCE_REPEAT_ELE, + count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, + VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + PipeBarrier(); + WholeReduceMax(dstLocal, dstLocal, + count / VEC_LEN_ONCE_REPEAT_ELE, 1, 1, 1, + VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + } else if (count <= VEC_LEN_ONCE_REPEAT_ELE) { + WholeReduceMax(dstLocal, + srcLocal[srcOffset], + count, 1, 1, 1, VEC_LEN_ONCE_REPEAT_BLOCK, ReduceOrder::ORDER_ONLY_VALUE); + } else { + ReduceMax(dstLocal, srcLocal[srcOffset], dstLocal, count, false); + } +} + +__aicore__ inline void CastFp32ToInt8Template(LocalTensor& dstLocal, LocalTensor& srcLocal, + LocalTensor& oneBlockWorkspace, + int32_t dstOffset, int32_t srcOffset, int32_t count) +{ + Cast(srcLocal[srcOffset].ReinterpretCast(), srcLocal[srcOffset], RoundMode::CAST_RINT, count); + PipeBarrier(); + if ((dstOffset & MOD_32_MASK) == 0) { + Cast(dstLocal[dstOffset], + srcLocal[srcOffset].ReinterpretCast(), + RoundMode::CAST_RINT, count); + } else if ((dstOffset & MOD_16_MASK) == 0) { + Cast(dstLocal[dstOffset + ALIGN_16_ELE], + srcLocal[srcOffset + ALIGN_8_ELE].ReinterpretCast(), + RoundMode::CAST_RINT, count - ALIGN_16_ELE); + PipeBarrier(); + Cast(oneBlockWorkspace, srcLocal[srcOffset].ReinterpretCast(), + RoundMode::CAST_RINT, ALIGN_16_ELE); + PipeBarrier(); + for (int32_t i = 0; i < ALIGN_16_ELE; i++) { + int8_t temp = oneBlockWorkspace.GetValue(i); + dstLocal.SetValue(dstOffset + i, temp); + } + PipeBarrier(); + } +} + +template +__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address. + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} + +} // namespace GROUPED_MATMUL + +#endif // ASCENDC_GROUPED_MATMUL_UTILS_H diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 90e7f03afac..06338e4f475 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -27,12 +27,14 @@ #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include "aclnn_torch_adapter/op_api_common.h" #include #include #include namespace vllm_ascend { +const int64_t INT4_NUMS_IN_INT32 = 8; void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping, aclrtStream stream) { torch::Device src_device = src.device(); @@ -520,6 +522,71 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic cmd.Run(); return y_out; } + +std::tuple grouped_matmul_swiglu_quant( + const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, + const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) +{ + int m = x.sizes()[0]; + int n = weight.sizes()[2]; + bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; + if (is_a8w4) { + n *= INT4_NUMS_IN_INT32; + } + + at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNZ, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + return std::tuple(output, output_scale, output_offset); +} + +std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list( + const at::Tensor & x, + const at::TensorList & weight, + const at::TensorList & weight_scale, + const at::Tensor & x_scale, + const at::Tensor & group_list, + const c10::optional & bias, + const c10::optional & offset) +{ + auto x_size = x.sizes(); + int n = weight[0].sizes()[1]; + int m = x_size[0]; + int k = x_size[1]; + + at::Tensor output = at::zeros({m, n/2}, x.options().dtype(at::kChar)); + at::Tensor output_scale = at::zeros({m}, x.options().dtype(at::kFloat)); + at::Tensor output_offset = at::zeros({m}, x.options().dtype(at::kFloat)); + + EXEC_NPU_CMD( + aclnnGroupedMatmulSwigluQuantWeightNzTensorList, + x, + weight, + bias, + offset, + weight_scale, + x_scale, + group_list, + output, + output_scale, + output_offset); + + return std::tuple(output, output_scale, output_offset); +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -576,4 +643,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()"); ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks); + + ops.def( + "grouped_matmul_swiglu_quant(Tensor x, Tensor weight, Tensor weight_scale, Tensor x_scale," + " Tensor group_list, *, Tensor? bias=None," + " Tensor? offset=None) -> (Tensor output, Tensor output_scale, Tensor output_offset)"); + ops.impl("grouped_matmul_swiglu_quant", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant); + + ops.def( + "grouped_matmul_swiglu_quant_weight_nz_tensor_list(Tensor x, Tensor[] weight, Tensor[] weight_scale, Tensor x_scale," + " Tensor group_list, *," + " Tensor? bias=None, Tensor? offset=None) ->" + " (Tensor output, Tensor output_scale, Tensor output_offset)" + ); + ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index dbb056be89c..26b3d66de03 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -35,7 +35,7 @@ namespace vllm_ascend { namespace meta { - +const int64_t INT4_NUMS_IN_INT32 = 8; std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, @@ -114,14 +114,50 @@ std::tuple mla_preproces return {q_out0, kv_cache_out0, q_out1, kv_cache_out1}; } +std::tuple grouped_matmul_swiglu_quant( + const at::Tensor &x, const at::Tensor &weight, const at::Tensor &weight_scale, const at::Tensor &x_scale, + const at::Tensor &group_list, const c10::optional &bias, const c10::optional &offset) +{ + int m = x.sizes()[0]; + int n = weight.sizes()[2]; + bool is_a8w4 = x.dtype() == at::kChar && weight.dtype() == at::kInt; + if (is_a8w4) { + n *= INT4_NUMS_IN_INT32; + } + at::Tensor output = at::empty({m, n/2}, x.options().dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::empty({m}, x.options().dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::empty({}, x.options().dtype(c10::ScalarType::Float)); + return {output, output_scale, output_offset}; +} + +std::tuple grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta( + const at::Tensor & x, + const at::TensorList & weight, + const at::TensorList & weight_scale, + const at::Tensor & x_scale, + const at::Tensor & group_list, + const c10::optional & bias, + const c10::optional & offset) +{ + auto x_size = x.sizes(); + int n = weight[0].sizes()[1]; + int m = x_size[0]; + int k = x_size[1]; + + at::Tensor output = at::zeros({m, n/2}, c10::dtype(c10::ScalarType::Char)); + at::Tensor output_scale = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); + at::Tensor output_offset = at::zeros({m}, c10::dtype(c10::ScalarType::Float)); + + return std::tuple(output, output_scale, output_offset); +} } // namespace meta } // namespace vllm_ascend namespace { - // Register the meta implementations of the custom kernels for symbolic tracing, this will also - // the custom kernel been captured into aclgraph - TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { +// Register the meta implementations of the custom kernels for symbolic tracing, this will also +// the custom kernel been captured into aclgraph +TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -132,5 +168,9 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); + // grouped_matmul_swiglu_quant meta implementation + ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); + // Grouped matmul swiglu quant weight nz tensor list + ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); } } diff --git a/csrc/utils/CMakeLists.txt b/csrc/utils/CMakeLists.txt new file mode 100644 index 00000000000..db468cb2bc5 --- /dev/null +++ b/csrc/utils/CMakeLists.txt @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_library(ops_utils_tiling_headers INTERFACE) + +target_include_directories(ops_utils_tiling_headers INTERFACE + $ + $<$:$> + $<$:$> + $<$:$> + $<$:$> + $ +) + +target_compile_definitions(ops_utils_tiling_headers INTERFACE + OPS_UTILS_LOG_SUB_MOD_NAME="OP_TILING" + OPS_UTILS_LOG_PACKAGE_TYPE=$,"[Custom]",""> +) + +add_library(ops_utils_proto_headers INTERFACE) + +target_include_directories(ops_utils_proto_headers INTERFACE + $ + $<$:$> + $<$:$> + $<$:$> + $ +) + +target_compile_definitions(ops_utils_proto_headers INTERFACE + OPS_UTILS_LOG_SUB_MOD_NAME="OP_PROTO" + OPS_UTILS_LOG_PACKAGE_TYPE=$,"[Custom]",""> +) + +if(NOT BUILD_OPEN_PROJECT) + install_package( + PACKAGE ops_adv + TARGETS ops_utils_tiling_headers ops_utils_proto_headers + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/inc/ + DESTINATION include/ops_adv/utils + ) +endif() diff --git a/csrc/utils/inc/aclnn_util.h b/csrc/utils/inc/aclnn_util.h new file mode 100644 index 00000000000..472ea4dbb63 --- /dev/null +++ b/csrc/utils/inc/aclnn_util.h @@ -0,0 +1,14 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef OP_API_INC_ACLNN_UTIL_H +#define OP_API_INC_ACLNN_UTIL_H + +#define ACLNN_API __attribute__((visibility("default"))) +#endif // OP_API_INC_ACLNN_UTIL_H \ No newline at end of file diff --git a/csrc/utils/inc/error/ops_error.h b/csrc/utils/inc/error/ops_error.h new file mode 100644 index 00000000000..fbb5c295cc9 --- /dev/null +++ b/csrc/utils/inc/error/ops_error.h @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ops_error.h + * \brief + */ + +#pragma once + +#include "log/ops_log.h" + +/* 基础报错 */ +#define OPS_REPORT_VECTOR_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E89999", OPS_DESC, __VA_ARGS__) +#define OPS_REPORT_CUBE_INNER_ERR(OPS_DESC, ...) OPS_INNER_ERR_STUB("E69999", OPS_DESC, __VA_ARGS__) + +/* 条件报错 */ +#define OPS_ERR_IF(COND, LOG_FUNC, EXPR) OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) diff --git a/csrc/utils/inc/fallback.h b/csrc/utils/inc/fallback.h new file mode 100644 index 00000000000..eb19050d3ee --- /dev/null +++ b/csrc/utils/inc/fallback.h @@ -0,0 +1,497 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback.h + * \brief + */ + +#ifndef ACLNNFALLBACK_OPAPI_H_ +#define ACLNNFALLBACK_OPAPI_H_ + +#include + +#include +#include +#include +#include + +#include "aclnn/aclnn_base.h" +#include "fallback_comm.h" +#include "error/ops_error.h" +#include "runtime/base.h" + +namespace fallback { +using namespace std; +using namespace gert; +using namespace ge; +using namespace std; + +namespace std_utils { + template + struct index_sequence {}; + + template + struct make_index_sequence_helper : make_index_sequence_helper {}; + + template + struct make_index_sequence_helper<0, Is...> { + using type = index_sequence; + }; + + template + using make_index_sequence = typename make_index_sequence_helper::type; +} + +using aclOpExecutor = struct aclOpExecutor; +using aclTensor = struct aclTensor; +using aclScalar = struct aclScalar; +using aclIntArray = struct aclIntArray; +using aclFloatArray = struct aclFloatArray; +using aclBoolArray = struct aclBoolArray; +using aclTensorList = struct aclTensorList; + +using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t* stride, int64_t offset, aclFormat format, + const int64_t* storage_dims, uint64_t storage_dims_num, void* tensor_data); + +using _aclCreateScalar = aclScalar* (*)(void* value, aclDataType data_type); +using _aclCreateIntArray = aclIntArray* (*)(const int64_t* value, uint64_t size); +using _aclCreateFloatArray = aclFloatArray* (*)(const float* value, uint64_t size); +using _aclCreateBoolArray = aclBoolArray* (*)(const bool* value, uint64_t size); +using _aclCreateTensorList = aclTensorList* (*)(const aclTensor* const* value, uint64_t size); + +using _aclDestroyTensor = int (*)(const aclTensor* tensor); +using _aclDestroyScalar = int (*)(const aclScalar* scalar); +using _aclDestroyIntArray = int (*)(const aclIntArray* array); +using _aclDestroyFloatArray = int (*)(const aclFloatArray* array); +using _aclDestroyBoolArray = int (*)(const aclBoolArray* array); +using _aclDestroyTensorList = int (*)(const aclTensorList* array); + +#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +inline const char* GetOpApiLibName(void) { + return "libopapi.so"; +} + +inline const char* GetCustOpApiLibName(void) { + return "libcust_opapi.so"; +} + +inline void* GetOpApiFuncAddrInLib(void* handler, const char* libName, const char* apiName) { + auto funcAddr = dlsym(handler, apiName); + if (funcAddr == nullptr) { + OPS_LOG_W("aclnnfallback", "dlsym %s from %s failed, error:%s.", apiName, libName, dlerror()); + } + return funcAddr; +} + +inline void* GetOpApiLibHandler(const char* libName) { + auto handler = dlopen(libName, RTLD_LAZY); + if (handler == nullptr) { + OPS_LOG_W("aclnnfallback", "dlopen %s failed, error:%s.", libName, dlerror()); + } + return handler; +} + +inline void* GetAclnnArrdByApiName(const char *apiName) { + vector libs = {"libaclnn_ops_infer.so", "libaclnn_ops_train.so", "libaclnn_math.so", + "libaclnn_rand.so", "libaclnn_sparse.so", "libaclnn_fft.so"}; + for (const auto &libName : libs) { + static auto libHandler = GetOpApiLibHandler(libName.c_str()); + if (libHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(libHandler, libName.c_str(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + } + OPS_LOG_E("aclnnfallback", "api %s can't find in any aclnn lib.", apiName); + return nullptr; +} + +inline void* GetOpApiFuncAddr(const char* apiName) { + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + OPS_LOG_D("aclnnfallback", "opapi lib is not exist,will use aclnn lib."); + return GetAclnnArrdByApiName(apiName); +} + +inline aclTensor* ConvertType(aclTensor* ge_tensor) { + return ge_tensor; +} + +inline aclIntArray* ConvertType(const std::vector &arr) { + if (arr.empty()) { + return nullptr; + } + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + auto array = aclCreateIntArray(arr.data(), arr.size()); + return array; +} + +inline aclDataType GetConvertType(const gert::Tensor* ge_tensor) { + // convert data type + auto dataType_ge = ge_tensor->GetDataType(); + auto dataType = aclDataType::ACL_FLOAT16; + if (dataType_ge == DT_FLOAT) { + dataType = aclDataType::ACL_FLOAT; + } else if (dataType_ge == DT_BF16) { + dataType = aclDataType::ACL_BF16; + } else if (dataType_ge == DT_BOOL) { + dataType = aclDataType::ACL_BOOL; + } else if (dataType_ge == DT_INT64) { + dataType = aclDataType::ACL_INT64; + } else if (dataType_ge == DT_INT32) { + dataType = aclDataType::ACL_INT32; + } else if (dataType_ge == DT_UINT64) { + dataType = aclDataType::ACL_UINT64; + } else if (dataType_ge == DT_UINT32) { + dataType = aclDataType::ACL_UINT32; + } else if (dataType_ge == DT_INT8) { + dataType = aclDataType::ACL_INT8; + } else if (dataType_ge == DT_UINT8) { + dataType = aclDataType::ACL_UINT8; + } else if (dataType_ge == DT_INT4) { + dataType = aclDataType::ACL_INT4; + } else { + dataType = aclDataType::ACL_FLOAT16; + } + + return dataType; +} + +inline aclTensor* ConvertType(const gert::Tensor* ge_tensor) { + if (ge_tensor == nullptr) { + return nullptr; + } + + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr); + + void* device_addr = nullptr; + auto tensor_place = ge_tensor->GetPlacement(); + device_addr = const_cast(ge_tensor->GetAddr()); + + auto dataType = GetConvertType(ge_tensor); + + OPS_LOG_D("aclnnfallback", "aclCreateTensor: tensor type is %d", dataType); + + // convert shape + auto gert_shape = ge_tensor->GetStorageShape(); + std::vector shape; + for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) { + shape.push_back(gert_shape.GetDim(i)); + } + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + aclTensor* out = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), + 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), device_addr); + + OPS_ERR_IF(out == nullptr, + OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr); + + return out; +} + +inline aclTensorList* ConvertType(std::vector& ge_tenserList) { + OPS_ERR_IF(ge_tenserList.size() == 0, + OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr); + + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + OPS_ERR_IF(aclCreateTensorList == nullptr, + OPS_LOG_E("aclnnfallback", "ge_tenserList size 0"), return nullptr); + + std::vector tmp; + for (size_t i = 0; i < ge_tenserList.size(); i++) { + auto t_acl = ConvertType(ge_tenserList[i]); + tmp.push_back(t_acl); + } + + aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size()); + return tensorList; +} + +template +inline aclScalar* ConvertScalarType(T value) { + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + OPS_ERR_IF(aclCreateScalar == nullptr, + OPS_LOG_E("aclnnfallback", "aclCreateScalar nullptr"), return nullptr); + if (typeid(value) == typeid(float)) { + return aclCreateScalar(&value, aclDataType::ACL_FLOAT); + } + return nullptr; +} + +template +T ConvertType(T value) { + return value; +} + +inline aclTensor* ConvertMmType(const gert::Tensor* ge_tensor, bool transpose, bool enable_NZ=false) { + if (ge_tensor == nullptr) { + return nullptr; + } + auto gert_shape = ge_tensor->GetStorageShape(); + if (gert_shape.GetDimNum() <= 1) { + return ConvertType(ge_tensor); + } + + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + OPS_ERR_IF(aclCreateTensor == nullptr, OPS_LOG_E("aclnnfallback", "aclCreateTensor nullptr"), return nullptr); + + void* device_addr = const_cast(ge_tensor->GetAddr()); + // convert data type + auto dataType_ge = ge_tensor->GetDataType(); + auto dataType = ToAclDataType(dataType_ge); + // convert shape + std::vector shape; + for (size_t i = 0; i < gert_shape.GetDimNum(); ++i) { + shape.push_back(gert_shape.GetDim(i)); + } + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + auto viewShape = shape; + // 对于transpose后的tensor对后两维度进行strides, viewShape转换 + if (transpose) { + // dimM 为倒数第二维, dimN 为倒数第一维度 + auto dimM = shape.size() - 2; + auto dimN = shape.size() - 1; + auto swap = strides[dimN]; + strides[dimN] = strides[dimM]; + strides[dimM] = swap; + // 修改viewShape + viewShape[dimN] = shape[dimM]; + viewShape[dimM] = shape[dimN]; + } + auto acl_format = aclFormat::ACL_FORMAT_ND; + if (enable_NZ && GetPrimaryFormat(ge_tensor->GetStorageFormat()) == ge::Format::FORMAT_FRACTAL_NZ) { + acl_format = aclFormat::ACL_FORMAT_FRACTAL_NZ; + } + aclTensor* out = aclCreateTensor(viewShape.data(), shape.size(), dataType, strides.data(), + 0, acl_format, shape.data(), shape.size(), device_addr); + OPS_ERR_IF(out == nullptr, OPS_LOG_E("aclnnfallback", "out nullptr"), return nullptr); + + return out; +} + +inline void Release(aclTensor* p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + OPS_ERR_IF(aclDestroyTensor == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyTensor is null"), return); + aclDestroyTensor(p); +} + +inline void Release(aclScalar* p) { + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + OPS_ERR_IF(aclDestroyScalar == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyScalar is null"), return); + aclDestroyScalar(p); +} + +inline void Release(aclIntArray* p) { + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + OPS_ERR_IF(aclDestroyIntArray == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyIntArray is null"), return); + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray* p) { + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + OPS_ERR_IF(aclDestroyBoolArray == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyBoolArray is null"), return); + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList* p) { + static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList); + OPS_ERR_IF(aclDestroyTensorList == nullptr, + OPS_LOG_E("aclnnfallback", "aclDestroyTensorList is null"), return); + aclDestroyTensorList(p); +} + +template +void Release(T value) { + (void)value; +} + +template +void CallRelease(Tuple t, std_utils::index_sequence) { + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple& t) { + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std_utils::make_index_sequence{}); +} + +template +auto ConvertTypes(Ts&... args) -> decltype(std::make_tuple(ConvertType(args)...)) { + auto tp = std::make_tuple(ConvertType(args)...); + return tp; +} + +template +auto call(Function f, Tuple t, std_utils::index_sequence) -> int { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) -> int { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std_utils::make_index_sequence{}); +} + +template +auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr, std_utils::index_sequence) + -> int (*)(typename std::decay(params))>::type...) { + using OpApiFunc = int (*)(typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple& params, void* opApiAddr) + -> typename std::enable_if::value != 0, + decltype(ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence::value>{}))>::type { + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, std_utils::make_index_sequence{}); +} + +template +class ConvertedParams { + public: + ConvertedParams(Tuple&& convertedParams) : convertedParams_(std::move(convertedParams)){}; + ConvertedParams(ConvertedParams&& other) : convertedParams_(std::move(other.convertedParams_)) { + other.validParams_ = false; + }; + ConvertedParams& operator=(ConvertedParams&& other) { + if (this == &other) { + return *this; + } + + convertedParams_ = std::move(other.convertedParams_); + validParams_ = true; + other.validParams_ = false; + return *this; + } + + ConvertedParams() = delete; + ConvertedParams(const ConvertedParams& other) = delete; + ConvertedParams& operator=(const ConvertedParams& other) = delete; + + ~ConvertedParams() { + if (validParams_) { + ReleaseConvertTypes(convertedParams_); + } + } + + const Tuple& GetConvertedParams() const { + return convertedParams_; + } + + private: + Tuple convertedParams_; + bool validParams_{true}; +}; + +using InitHugeMemThreadLocal = int (*)(void*, bool); +using UnInitHugeMemThreadLocal = void (*)(void*, bool); +using ReleaseHugeMem = void (*)(void*, bool); +using PTAGetExecCache = aclOpExecutor* (*)(uint64_t, uint64_t*); +using InitPTACacheThreadLocal = void (*)(); +using SetPTAHashKey = void (*)(uint64_t); +using CanUsePTACache = bool (*)(const char*); + +using ResetCacheThreadLocal = void (*)(); + +#define EXEC_OPAPI_CMD(aclnn_api, ...) \ + ({ \ + static auto ret = GRAPH_SUCCESS; \ + do { \ + static const auto ResetCacheThreadLocalAddr = GetOpApiFuncAddr("ResetCacheThreadLocal"); \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr || ResetCacheThreadLocalAddr == nullptr) { \ + OPS_LOG_E("aclnnfallback", "%s or %s not in %s or %s or ResetCacheThreadLocal not found.", \ + #aclnn_api "GetWorkspaceSize", #aclnn_api, GetOpApiLibName(), GetOpApiLibName()); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + auto ResetCacheThreadLocalFunc = reinterpret_cast(ResetCacheThreadLocalAddr); \ + ResetCacheThreadLocalFunc(); \ + uint64_t workspace_size = 0; \ + uint64_t* workspace_size_addr = &workspace_size; \ + aclOpExecutor* executor = nullptr; \ + aclOpExecutor** executor_addr = &executor; \ + auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + if (workspace_status != 0) { \ + OPS_LOG_E("aclnnfallback", "call %s failed:", #aclnn_api); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + void* workspace_addr = nullptr; \ + if (workspace_size > 0) { \ + workspace_addr = host_api_ctx->MallocWorkspace(workspace_size); \ + if (workspace_addr == nullptr) { \ + OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed", #aclnn_api); \ + ret = GRAPH_FAILED; \ + break; \ + } \ + } \ + auto acl_stream = host_api_ctx->GetStream(); \ + auto acl_call = [converted_params, workspace_addr, workspace_size, host_api_ctx, acl_stream, \ + executor]() -> int { \ + using OpApiFunc = int (*)(void*, uint64_t, aclOpExecutor*, const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + ReleaseConvertTypes(converted_params); \ + host_api_ctx->FreeWorkspace(); \ + if (api_ret != 0) { \ + OPS_LOG_E("aclnnfallback", "call %s allocate workspace failed api_ret: %d", #aclnn_api, api_ret); \ + return GRAPH_FAILED; \ + } \ + return api_ret; \ + }; \ + \ + ret = acl_call(); \ + } while (false); \ + (ret); \ + }) + +} // namespace fallback + +#endif // ACLNNFALLBACK_OPAPI_H_ diff --git a/csrc/utils/inc/fallback_comm.h b/csrc/utils/inc/fallback_comm.h new file mode 100644 index 00000000000..a2dd5cfd135 --- /dev/null +++ b/csrc/utils/inc/fallback_comm.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback_comm.h + * \brief + */ + +#ifndef INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ +#define INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ + +#include "aclnn/aclnn_base.h" +#include "exe_graph/runtime/op_execute_context.h" +#include "exe_graph/runtime/tensor.h" +#include "register/op_impl_registry.h" +#include "runtime/base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace fallback { + +aclDataType ToAclDataType(ge::DataType dtype); +} // namespace fallback + +#ifdef __cplusplus +} +#endif + +#endif // INC_EXTERNAL_GRAPH_FALLBACK_COMMON_H_ diff --git a/csrc/utils/inc/kernel/dropmask.h b/csrc/utils/inc/kernel/dropmask.h new file mode 100644 index 00000000000..13ed9c3501b --- /dev/null +++ b/csrc/utils/inc/kernel/dropmask.h @@ -0,0 +1,121 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dropmask.h + * \brief + */ + +#ifndef DROPMASK_H +#define DROPMASK_H + +#include "util.h" + +using AscendC::DROPOUT_MODE_BIT_MISALIGN; +using AscendC::DropOutShapeInfo; +using AscendC::DropOut; + +struct DropMaskInfo { + // for compute dropout mask offset + // 参数按B N G S1 S2全部切分设置进行偏移计算,没有切分的轴对应的参数设置为合适的0或者原始值 + int64_t n2G; // n2 * g + int64_t gSize; // g + int64_t s1Size; // s1 + int64_t s2Size; // s2 + int64_t gOutIdx; // g out index + int64_t bSSOffset; // boidx * s1 * s2 ===bSSOffset + int64_t n2OutIdx; // n out index + int64_t s1OutIdx; // s1 out index ===s1oIdx + int64_t s1InnerIdx; // s1 inner index, 配比 ===loopIdx + int64_t s1BaseSize; // S1基本块大小 + int64_t splitS1BaseSize; // s1 split size ===vec1S1BaseSize + int64_t s2StartIdx; // s2 start index + int64_t s2Idx; // s2 index =====s2LoopCount + int64_t s2BaseNratioSize; // s2的配比长度: s2BaseSize(S2基本块大小) * nRatio + + // for copy in dropout mask + uint32_t s1CopySize; + uint32_t s2CopySize; + int64_t s2TotalSize; + + // for compute dropout mask + uint32_t firstAxis; + uint32_t lstAxis; + uint32_t maskLstAxis; + int64_t vecCoreOffset = 0; + float keepProb; + + bool boolMode; +}; + +template +__aicore__ inline int64_t ComputeDropOffset(DropMaskInfo &dropMaskInfo) +{ + if constexpr (hasDrop == true) { + // boidx * n2 * g* s1 * s2 + int64_t bOffset = dropMaskInfo.bSSOffset * dropMaskInfo.n2G; + // n2oIdx * g * s1 *s2 + int64_t n2Offset = dropMaskInfo.n2OutIdx * dropMaskInfo.gSize * dropMaskInfo.s1Size * dropMaskInfo.s2Size; + // goIdx * s1 * s2 + int64_t gOffset = dropMaskInfo.gOutIdx * dropMaskInfo.s1Size * dropMaskInfo.s2Size; + // s1oIdx * s1BaseSize * s2Size + s1innerindex * vec1S1BaseSize * s2Size + int64_t s1Offset = (dropMaskInfo.s1OutIdx * dropMaskInfo.s1BaseSize + dropMaskInfo.vecCoreOffset + + dropMaskInfo.s1InnerIdx * dropMaskInfo.splitS1BaseSize) * dropMaskInfo.s2Size; + // s2StartIdx + s2index * s2BaseNratioSize + int64_t s2Offset = dropMaskInfo.s2StartIdx + dropMaskInfo.s2Idx * dropMaskInfo.s2BaseNratioSize; + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template +__aicore__ inline void CopyInDropMask(LocalTensor&dstTensor, GlobalTensor& srcBoolTensor, + GlobalTensor& srcByteTensor, DropMaskInfo &dropMaskInfo, int64_t alignedSize = blockBytes) +{ + if constexpr (hasDrop == true) { + int64_t dropMaskOffset = ComputeDropOffset(dropMaskInfo); + if (unlikely(dropMaskInfo.boolMode)) { + BoolCopyIn(dstTensor, srcBoolTensor, dropMaskOffset, + dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); + } else { + Bit2Int8CopyIn(dstTensor, srcByteTensor, dropMaskOffset, 1, + dropMaskInfo.s1CopySize, dropMaskInfo.s2CopySize, dropMaskInfo.s2TotalSize, alignedSize); + } + return; + } +} + +template +__aicore__ inline void ComputeDropMask(LocalTensor& dstTensor, LocalTensor& srcTensor, + LocalTensor& dropoutBuffer, LocalTensor& tmpDropBuffer, DropMaskInfo &dropMaskInfo) +{ + if constexpr (hasDrop == true) { + DropOutShapeInfo dropOutShapeInfo; + dropOutShapeInfo.firstAxis = dropMaskInfo.firstAxis; + dropOutShapeInfo.srcLastAxis = dropMaskInfo.lstAxis; + + if (unlikely(dropMaskInfo.boolMode)) { + dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis, blockBytes) * blockBytes; + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); + } else { + dropOutShapeInfo.maskLastAxis = CeilDiv(dropMaskInfo.maskLstAxis / byteBitRatio, blockBytes) * blockBytes; + if (likely(dropMaskInfo.lstAxis / byteBitRatio % blockBytes == 0)) { + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, dropMaskInfo.keepProb, dropOutShapeInfo); + } else { + DropOut(dstTensor, srcTensor, dropoutBuffer, tmpDropBuffer, + dropMaskInfo.keepProb, dropOutShapeInfo); + } + } + return; + } +} + +#endif // DROPMASK_H diff --git a/csrc/utils/inc/kernel/pse.h b/csrc/utils/inc/kernel/pse.h new file mode 100644 index 00000000000..e6cd8e7b1f0 --- /dev/null +++ b/csrc/utils/inc/kernel/pse.h @@ -0,0 +1,483 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file pse.h + * \brief + */ + +#ifndef FLASH_ATTENTION_SCORE_PSE_H +#define FLASH_ATTENTION_SCORE_PSE_H + +#include "kernel_operator.h" +#include "util.h" + +constexpr static int64_t pseS1S2 = 0; +constexpr static int64_t pse1S2 = 1; +constexpr static int64_t pseSlopeBn = 2; +constexpr static int64_t pseSlopeN = 3; + +constexpr static uint8_t pseEncodeALibiS2Full = 0x11; + +enum class PseTypeEnum { + PSE_OUTER_MUL_ADD_TYPE = 0, // default + PSE_OUTER_ADD_MUL_TYPE, + PSE_INNER_MUL_ADD_TYPE, + PSE_INNER_MUL_ADD_SQRT_TYPE, + PSE_INVALID_TYPE +}; + +struct PseInfo { + int64_t blockCount; + int64_t bSSOffset; // boidx * s1 * s2 + int64_t boIdx; + int64_t gSize; + int64_t goIdx; + int64_t loopIdx; + int64_t n2G; + int64_t n2oIdx; + int64_t pseBSize; + int64_t pseS1Size; // for alibi + int64_t pseS2ComputeSize; // for alibi, do not need assignment + int64_t pseS2Size; // for alibi + uint32_t pseShapeType; + int64_t readS2Size; // for alibi, do not need assignment + int64_t s1BaseSize; + int64_t s1Size; + int64_t s1oIdx; + int64_t s2AlignedSize; + int64_t s2BaseNratioSize; + int64_t s2LoopCount; + int64_t s2RealSize; + int64_t s2Size; + int64_t s2SizeAcc; // accumulated sum of s2 size + int64_t s2StartIdx; + int64_t vec1S1BaseSize; + int64_t vec1S1RealSize; + uint32_t pseEncodeType; // for distinguish alibi + uint32_t pseType; // 0: outer, mul-add 1:outer, add-mul 2:inner, mul-add 3:inner, mul-add-sqrt + int64_t pseAlibiBaseS1; + int64_t pseAlibiBaseS2; + int64_t qStartIdx; + int64_t kvStartIdx; + int64_t vecCoreOffset = 0; + bool needCast; + bool align8 = false; + bool pseEndogenous = false; +}; + +template +__aicore__ inline void DataCopyInCommon(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int32_t dtypeSize, + int32_t alignedS2Size) +{ + if constexpr (hasPse == true) { + uint32_t shapeArray[] = {static_cast(s1Size), static_cast(alignedS2Size)}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(s1Size * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = s1Size; + dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B + dataCopyParams.dstStride = alignedS2Size * dtypeSize / blockBytes - dataCopyParams.blockLen; // gap + if (actualS2Len * dtypeSize % blockBytes == 0) { + dataCopyParams.srcStride = + (actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) / blockBytes; // srcGap + DataCopy(dstTensor, srcTensor[offset], dataCopyParams); + } else { + dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte + dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen); + dataCopyParams.dstStride = (alignedS2Size - s2Size) * dtypeSize / blockBytes; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = false; + DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams); + } + } +} + +template +__aicore__ inline void DataCopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + int32_t dtypeSize = sizeof(INPUT_T); + int32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize; + DataCopyInCommon(dstTensor, srcTensor, offset, s1Size, s2Size, + actualS2Len, dtypeSize, alignedS2Size); + } +} + +template +__aicore__ inline void DataCopyInAlign8(LocalTensor &dstTensor, GlobalTensor &srcTensor, int64_t offset, + int64_t s1Size, int64_t s2Size, int64_t actualS2Len) +{ + if constexpr (hasPse == true) { + int32_t dtypeSize = sizeof(INPUT_T); + if (dtypeSize == 0){ + return; + } + int32_t alignedS2Size = CeilDiv(s2Size, 32 / dtypeSize) * (32 / dtypeSize); + DataCopyInCommon(dstTensor, srcTensor, offset, s1Size, s2Size, + actualS2Len, dtypeSize, alignedS2Size); + } +} + +/* +dst = BroadcastAdd(src0, src1) +src0 shape: (s1, s2) +src1 shape: (1, s2) +dst shape: (s1, s2) +*/ +template +__aicore__ inline void BroadcastAdd(const LocalTensor &src0Tensor, const LocalTensor &src1Tensor, + int64_t src0Offset, int32_t src1Size, int32_t repeatTimes) +{ + if constexpr (hasPse == true) { + /* Total data number of single step should be smaller than 256bytes. + * If larger, we need to do add multiple times. */ + int32_t innerLoop = src1Size / repeatMaxSize; // s2轴整块计算次数 + int32_t innerRemain = src1Size % repeatMaxSize; // s2轴尾块计算量 + BinaryRepeatParams binaryRepeatParams; + binaryRepeatParams.src0BlkStride = 1; + binaryRepeatParams.src0RepStride = src1Size / blockSize; + binaryRepeatParams.src1BlkStride = 1; + binaryRepeatParams.src1RepStride = 0; + binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride; + binaryRepeatParams.blockNumber = binaryRepeatParams.src0RepStride; + + for (int32_t j = 0; j < innerLoop; j++) { + auto innerOffset = j * repeatMaxSize; + auto ubOffset = src0Offset + innerOffset; + Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], repeatMaxSize, repeatTimes, + binaryRepeatParams); + } + if (innerRemain > 0) { + auto innerOffset = innerLoop * repeatMaxSize; + auto ubOffset = src0Offset + innerOffset; + Add(src0Tensor[ubOffset], src0Tensor[ubOffset], src1Tensor[innerOffset], innerRemain, repeatTimes, + binaryRepeatParams); + } + } +} + +template +__aicore__ inline void PseBroadcastAdd(int32_t s1Size, int32_t s2Size, int32_t computeSize, const LocalTensor &pseUb, + const LocalTensor &dstTensor, uint32_t pseShapeType) +{ + if constexpr (hasPse == true) { + if (pseShapeType == pseS1S2 || pseShapeType == pseSlopeBn || pseShapeType == pseSlopeN) { + Add(dstTensor, dstTensor, pseUb, computeSize); + } else { + /* Total repeated times should be <= repeatMaxTimes. If larger, + * we need to do multiple inner loops. */ + int32_t s1OuterLoop = s1Size / repeatMaxTimes; + int32_t s1OuterRemain = s1Size % repeatMaxTimes; + for (int32_t s1OuterIdx = 0; s1OuterIdx < s1OuterLoop; s1OuterIdx++) { + int32_t s1OuterOffset = s1OuterIdx * repeatMaxTimes * s2Size; + BroadcastAdd(dstTensor, pseUb, s1OuterOffset, s2Size, repeatMaxTimes); + } + if (s1OuterRemain > 0) { + int32_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * s2Size; + BroadcastAdd(dstTensor, pseUb, s1OuterOffset, s2Size, s1OuterRemain); + } + } + } +} +template __aicore__ inline int64_t PseComputeOffset(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = 0; + int64_t s1Offset = 0; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + int64_t gOffset = 0; + if (pseInfo.pseShapeType == pseS1S2) { + // b, n2, g, s1, s2 + bOffset = pseInfo.bSSOffset * pseInfo.n2G; + n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s1Size * pseInfo.s2Size; + gOffset = pseInfo.goIdx * pseInfo.s1Size * pseInfo.s2Size; + s1Offset = (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize) * pseInfo.s2Size; + } else if (pseInfo.pseShapeType == pse1S2) { + // b, n2, g, 1, s2 + bOffset = pseInfo.s2SizeAcc * pseInfo.n2G; + n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.s2Size; + gOffset = pseInfo.goIdx * pseInfo.s2Size; + } + if (pseInfo.pseBSize == 1) { + bOffset = 0; + } + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template __aicore__ inline int64_t PseAlibiComputeOffset(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + int64_t bOffset = (pseInfo.boIdx % pseInfo.pseBSize) * pseInfo.n2G * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t gOffset = pseInfo.goIdx * pseInfo.pseS2Size * pseInfo.pseS1Size; + int64_t row = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t column = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + int64_t m = 0; + int64_t k = 0; + if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) { + int64_t threshold = pseInfo.s1Size - pseInfo.pseS1Size; + if (row >= threshold) { + m = row - threshold; + k = column; + } else { + m = row % pseInfo.pseS1Size; + k = pseInfo.pseS2Size - (row - column) - (pseInfo.pseS1Size - m); + } + } else { + int64_t threshold = pseInfo.pseS2Size - pseInfo.pseS1Size; + int64_t posVal = row - column - threshold; + if (threshold >= 0) { + if (posVal >= 0) { + m = posVal; + k = 0; + } else { + m = 0; + k = -posVal; + } + } else { + m = posVal; + k = 0; + } + } + int64_t s1Offset = m * pseInfo.pseS2Size; + int64_t s2Offset = k; + pseInfo.readS2Size = Min(pseInfo.s2AlignedSize, pseInfo.pseS2Size - k); + pseInfo.pseS2ComputeSize = Align(pseInfo.readS2Size); + + return bOffset + n2Offset + gOffset + s1Offset + s2Offset; + } else { + return 0; + } +} + +template __aicore__ inline bool NeedPseAlibiCompute(PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + // Alibi编码只计算下三角 + if (pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + (pseInfo.loopIdx + 1) * pseInfo.vec1S1BaseSize <= + pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize) { + return false; + } + return true; + } else { + return false; + } +} + +template +__aicore__ inline void PseAlibiCopyIn(LocalTensor &dstTensor, LocalTensor &tmpTensor, + GlobalTensor &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + if (!NeedPseAlibiCompute(pseInfo)) { + return; + } + int64_t offset = PseAlibiComputeOffset(pseInfo); + if constexpr (IsSameType::value) { + if (!pseInfo.align8){ + DataCopyIn(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size, + pseInfo.pseS2Size, alignedSize); + } else { + DataCopyInAlign8(dstTensor, srcTensor, offset, pseInfo.vec1S1RealSize, + pseInfo.readS2Size, pseInfo.pseS2Size); + } + return; + } + + DataCopyIn(tmpTensor, srcTensor, offset, pseInfo.vec1S1RealSize, pseInfo.readS2Size, + pseInfo.pseS2Size, alignedSize); + if (pseInfo.needCast) { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize); + } + return; + } +} + +template +__aicore__ inline void PseSlopeCopyIn(LocalTensor &dstTensor, LocalTensor &helpTensor, + __gm__ uint8_t *pseSlope, GlobalTensor &alibiGm, PseInfo &pseInfo, + int64_t alignedSize = 16) { + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize; + int64_t gOffset = pseInfo.goIdx; + + if (pseInfo.pseShapeType == pseSlopeBn) { + bOffset = pseInfo.boIdx * pseInfo.n2G; + } + int64_t offset = bOffset + n2Offset + gOffset; + + DataCopyIn(helpTensor, alibiGm, 0, pseInfo.vec1S1RealSize, + pseInfo.s2RealSize, pseInfo.pseAlibiBaseS2, alignedSize); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + if (pseInfo.needCast) { + int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize; + Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize); + pipe_barrier(PIPE_V); + + int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + + float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx); + + Adds(dstTensor, dstTensor, posShift, computeSize); + pipe_barrier(PIPE_V); + Abs(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + float slopes = ((__gm__ T *)pseSlope)[offset] * -1; + if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + Sqrt(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + } + Muls(dstTensor, dstTensor, slopes, computeSize); + pipe_barrier(PIPE_V); + } + } +} + +template +__aicore__ inline void PseSlopeCast(LocalTensor &dstTensor, LocalTensor &helpTensor, + __gm__ uint8_t *pseSlope, PseInfo &pseInfo) { + if constexpr (hasPse == true) { + int64_t bOffset = 0; + int64_t n2Offset = pseInfo.n2oIdx * pseInfo.gSize; + int64_t gOffset = pseInfo.goIdx; + + if (pseInfo.pseShapeType == pseSlopeBn) { + bOffset = pseInfo.boIdx * pseInfo.n2G; + } + int64_t offset = bOffset + n2Offset + gOffset; + int64_t computeSize = pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize; + Cast(dstTensor, helpTensor, RoundMode::CAST_NONE, computeSize); + pipe_barrier(PIPE_V); + + int64_t s1Offset = pseInfo.s1oIdx * pseInfo.s1BaseSize + pseInfo.vecCoreOffset + + pseInfo.loopIdx * pseInfo.vec1S1BaseSize; + int64_t s2Offset = pseInfo.s2StartIdx + pseInfo.s2LoopCount * pseInfo.s2BaseNratioSize; + + float posShift = float(s2Offset + pseInfo.kvStartIdx - s1Offset - pseInfo.qStartIdx); + + Adds(dstTensor, dstTensor, posShift, computeSize); + pipe_barrier(PIPE_V); + Abs(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + float slopes = ((__gm__ T *)pseSlope)[offset] * -1; + if (pseInfo.pseType == (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + Sqrt(dstTensor, dstTensor, computeSize); + pipe_barrier(PIPE_V); + } + Muls(dstTensor, dstTensor, slopes, computeSize); + pipe_barrier(PIPE_V); + } +} + +template +__aicore__ inline void PseCopyIn(LocalTensor &dstTensor, LocalTensor &tmpTensor, + GlobalTensor &srcTensor, PseInfo &pseInfo, int64_t alignedSize = 16) +{ + if constexpr (hasPse == true) { + if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) { + return PseAlibiCopyIn(dstTensor, tmpTensor, srcTensor, pseInfo, alignedSize); + } + int64_t offset = PseComputeOffset(pseInfo); + int64_t s1Size = pseInfo.pseShapeType == pse1S2 ? (pseInfo.blockCount == 0 ? 1 : pseInfo.blockCount) : + pseInfo.vec1S1RealSize; + + if constexpr (IsSameType::value) { + if (!pseInfo.align8){ + DataCopyIn(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, + pseInfo.s2Size, alignedSize); + } else { + DataCopyInAlign8(dstTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size); + } + return; + } + DataCopyIn(tmpTensor, srcTensor, offset, s1Size, pseInfo.s2RealSize, pseInfo.s2Size, + alignedSize); + if (pseInfo.needCast) { + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + Cast(dstTensor, tmpTensor, RoundMode::CAST_NONE, s1Size * pseInfo.s2AlignedSize); + } + return; + } +} + +template +__aicore__ inline void PseAlibiCompute(LocalTensor &dstTensor, LocalTensor &pseTensor, PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + if (!NeedPseAlibiCompute(pseInfo)) { + return; + } + Add(dstTensor, dstTensor, pseTensor, pseInfo.vec1S1RealSize * pseInfo.pseS2ComputeSize); + return; + } +} + +template +__aicore__ inline void PseCompute(LocalTensor &dstTensor, LocalTensor &pseTensor, PseInfo &pseInfo) +{ + if constexpr (hasPse == true) { + if (pseInfo.pseEncodeType == pseEncodeALibiS2Full) { + return PseAlibiCompute(dstTensor, pseTensor, pseInfo); + } + int64_t computeSize = (pseInfo.pseShapeType == pseS1S2 || pseInfo.pseShapeType == pseSlopeBn || + pseInfo.pseShapeType == pseSlopeN) + ? pseInfo.vec1S1RealSize * pseInfo.s2AlignedSize + : pseInfo.s2AlignedSize; + PseBroadcastAdd(pseInfo.vec1S1RealSize, pseInfo.s2AlignedSize, computeSize, pseTensor, + dstTensor, pseInfo.pseShapeType); + return; + } +} + +template +__aicore__ inline void PseInnerAlibiCreate(GlobalTensor &dstTensor, LocalTensor &helpTensor, PseInfo &pseInfo) { + if constexpr (hasPse == true) { + if (pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_TYPE && pseInfo.pseType != (uint32_t)PseTypeEnum::PSE_INNER_MUL_ADD_SQRT_TYPE) { + return; + } + event_t eventIdMte3ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V)); + event_t eventIdMte3ToS = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_S)); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + float tmpValue = -1.0; + + for (int64_t i = 0; i < pseInfo.pseAlibiBaseS1; i++) { + CreateVecIndex(helpTensor, (half)(i * tmpValue), pseInfo.pseAlibiBaseS2); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopy(dstTensor[i * pseInfo.pseAlibiBaseS2], helpTensor, pseInfo.pseAlibiBaseS2); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); + SetFlag(eventIdMte3ToS); + WaitFlag(eventIdMte3ToS); + } + } +} +#endif diff --git a/csrc/utils/inc/kernel/util.h b/csrc/utils/inc/kernel/util.h new file mode 100644 index 00000000000..2c7d2089323 --- /dev/null +++ b/csrc/utils/inc/kernel/util.h @@ -0,0 +1,144 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file util.h + * \brief + */ + +#ifndef FLASH_ATTENTION_UTIL_H +#define FLASH_ATTENTION_UTIL_H + +constexpr int32_t blockBytes = 32; +constexpr int32_t byteBitRatio = 8; +constexpr int64_t prefixAttenMaskDownHeight = 1024; +constexpr static int32_t blockSize = blockBytes / 4; // 4 means sizeof(T) +constexpr static int32_t repeatMaxBytes = 256; +constexpr static int32_t repeatMaxTimes = 255; +constexpr static int32_t repeatMaxSize = repeatMaxBytes / 4; // 4 means sizeof(T) + +using AscendC::LocalTensor; +using AscendC::GlobalTensor; +using AscendC::DataFormat; +using AscendC::ShapeInfo; +using AscendC::DataCopyParams; +using AscendC::DataCopyPadParams; +using AscendC::BinaryRepeatParams; +using AscendC::IsSameType; +using AscendC::HardEvent; +using AscendC::SetFlag; +using AscendC::WaitFlag; + +enum class LayOutTypeEnum { None = 0, LAYOUT_BSH = 1, LAYOUT_SBH = 2, LAYOUT_BNSD = 3, LAYOUT_TND = 4, LAYOUT_NTD_TND = 5}; + +namespace math { +template __aicore__ inline T Ceil(T a, T b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +template __aicore__ inline T Align(T a, T b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; +} +} + +template +__aicore__ inline T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +template +__aicore__ inline T1 Max(T1 a, T2 b) +{ + return (a > b) ? (a) : (b); +} + +template +__aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +__aicore__ inline void BoolCopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, + int64_t srcOffset, uint32_t s1Size, uint32_t s2Size, int64_t totalS2Size, int64_t alignedSize = blockBytes) +{ + uint32_t alignedS2Size = CeilDiv(s2Size, alignedSize) * alignedSize; + uint32_t shapeArray[] = {s1Size, alignedS2Size}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(s1Size * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = s1Size; + dataCopyParams.dstStride = 0; + if (totalS2Size == blockBytes && alignedSize == 64) { // totalS2Size < 64 && totalS2Size % blockBytes == 0 + dataCopyParams.dstStride = 1; + alignedSize = blockBytes; + alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes; + } + if (totalS2Size % alignedSize == 0) { + dataCopyParams.blockLen = alignedS2Size / blockBytes; + dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes; + DataCopy(dstTensor, srcTensor[srcOffset], dataCopyParams); + } else { + dataCopyParams.blockLen = s2Size; + dataCopyParams.srcStride = totalS2Size - s2Size; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = true; + dataCopyPadParams.rightPadding = Min(alignedS2Size - s2Size, blockBytes); + dataCopyPadParams.paddingValue = 1; + DataCopyPad(dstTensor, srcTensor[srcOffset], dataCopyParams, dataCopyPadParams); + } +} + +__aicore__ inline void Bit2Int8CopyIn(LocalTensor &dstTensor, GlobalTensor &srcTensor, + int64_t srcOffset, uint32_t batchSize, uint32_t s1BaseSize, uint32_t s2BaseSize, int64_t s2TotalSize, + int64_t alignedSize = blockBytes) +{ + uint32_t alignedS2Size = CeilDiv(s2BaseSize / byteBitRatio, alignedSize) * alignedSize; + uint32_t shapeArray[] = {batchSize * s1BaseSize, alignedS2Size}; + dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND)); + dstTensor.SetSize(batchSize * s1BaseSize * alignedS2Size); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = batchSize * s1BaseSize; + dataCopyParams.blockLen = CeilDiv(s2BaseSize / byteBitRatio, blockBytes); + dataCopyParams.dstStride = 0; + if (s2TotalSize / byteBitRatio % alignedSize == 0 && s2BaseSize / byteBitRatio % alignedSize == 0) { + dataCopyParams.srcStride = + (s2TotalSize / byteBitRatio - dataCopyParams.blockLen * blockBytes) / blockBytes; + DataCopy(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams); + } else { + dataCopyParams.blockLen = CeilDiv(s2BaseSize , byteBitRatio); + dataCopyParams.srcStride = (s2TotalSize - s2BaseSize) / byteBitRatio; + DataCopyPadParams dataCopyPadParams; + dataCopyPadParams.isPad = true; + dataCopyPadParams.rightPadding = 0; + dataCopyPadParams.paddingValue = 0; + DataCopyPad(dstTensor, srcTensor[srcOffset / byteBitRatio], dataCopyParams, dataCopyPadParams); + } +} + +__aicore__ inline int32_t Align(int32_t shape) +{ + int32_t alignFactor = 16; + int32_t alignedSize = CeilDiv(shape, alignFactor) * alignFactor; + return alignedSize; +} + +#endif // FLASH_ATTENTION_UTIL_H diff --git a/csrc/utils/inc/log/inner/dfx_base.h b/csrc/utils/inc/log/inner/dfx_base.h new file mode 100644 index 00000000000..0fd1edb4ead --- /dev/null +++ b/csrc/utils/inc/log/inner/dfx_base.h @@ -0,0 +1,190 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_base.h + * \brief 外部模块不应直接引用本头文件 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ops { +namespace utils { + +class LogBase { +public: + static constexpr const int MAX_LOG_LEN = 16000; + static constexpr const int MSG_HDR_LEN = 200; + + static inline uint64_t GetTid() + { + return static_cast(syscall(__NR_gettid)); + } + + static inline const char *GetStr(const std::string &str) + { + return str.c_str(); + } + + static inline const char *GetStr(const char *str) + { + return str; + } + + static inline const std::string &GetOpInfo(const std::string &str) + { + return str; + } + + static inline const char *GetOpInfo(const char *str) + { + return str; + } + + static inline std::string GetOpInfo(const gert::TilingContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::TilingParseContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::InferShapeContext *context) + { + return GetOpInfoFromContext(context); + } + + static inline std::string GetOpInfo(const gert::InferDataTypeContext *context) + { + return GetOpInfoFromContext(context); + } + +private: + template static inline std::string GetOpInfoFromContext(T context) + { + if (context == nullptr) { + return "nil:nil"; + } + std::string opInfo = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil"; + opInfo += ":"; + opInfo += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil"; + return opInfo; + } +}; + +} // namespace utils + +template +std::string Shape2String(const T& shape) { + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} +} // namespace ops + +// 使用本宏前需预定义标识子模块名称的 OPS_UTILS_LOG_SUB_MOD_NAME +// 如: #define OPS_UTILS_LOG_SUB_MOD_NAME "OP_TILING" 或通过 CMake 传递预定义宏 +#define OPS_LOG_STUB(MOD_ID, LOG_LEVEL, OPS_DESC, FMT, ...) \ + do { \ + if (AlogCheckDebugLevel(static_cast(MOD_ID), (LOG_LEVEL)) == 1) { \ + AlogRecord(static_cast(MOD_ID), DLOG_TYPE_DEBUG, (LOG_LEVEL), \ + "[%s:%d][%s]%s[%s][%lu] OpName:[%s] " #FMT, \ + __FILE__, __LINE__, (OPS_UTILS_LOG_SUB_MOD_NAME), \ + (OPS_UTILS_LOG_PACKAGE_TYPE), __FUNCTION__, ops::utils::LogBase::GetTid(), \ + ops::utils::LogBase::GetStr(ops::utils::LogBase::GetOpInfo(OPS_DESC)), ##__VA_ARGS__); \ + } \ + }while (0) + +#define OPS_LOG_STUB_IF(COND, LOG_FUNC, EXPR) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (__builtin_expect((COND), 0)) { \ + LOG_FUNC; \ + EXPR; \ + } \ + } while (0) + +#define OPS_INNER_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \ + do { \ + OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \ + REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \ + } while (0) + +#define OPS_CALL_ERR_STUB(ERR_CODE_STR, OPS_DESC, FMT, ...) \ + do { \ + OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__); \ + REPORT_INNER_ERR_MSG(ERR_CODE_STR, FMT, ##__VA_ARGS__); \ + } while (0) + +#define OPS_LOG_STUB_D(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_DEBUG, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_I(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_INFO, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_W(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_WARN, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_E(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_ERROR, OPS_DESC, FMT, ##__VA_ARGS__) +#define OPS_LOG_STUB_EVENT(OPS_DESC, FMT, ...) OPS_LOG_STUB(OP, DLOG_EVENT, OPS_DESC, FMT, ##__VA_ARGS__) + +#define OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, FMT, ...) \ + do { \ + if (0 == AlogCheckDebugLevel(OP, (LEVEL))) { \ + break; \ + } \ + char msgbufxyz[ops::utils::LogBase::MAX_LOG_LEN]; \ + size_t msgmaxlen = (MSG_LENGTH - ops::utils::LogBase::MSG_HDR_LEN); \ + int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, FMT, ##__VA_ARGS__); \ + if (rettmp == -1) { \ + msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \ + } \ + size_t msglength = std::strlen(msgbufxyz); \ + if (msglength < msgmaxlen) { \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgbufxyz); \ + break; \ + } \ + char *msgchunkbegin = msgbufxyz; \ + char *msgchunkend = nullptr; \ + while (msgchunkbegin < msgbufxyz + msglength) { \ + if (msgchunkbegin[0] == '\n') { \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), ""); \ + msgchunkbegin += 1; \ + continue; \ + } \ + msgchunkend = std::strchr(msgchunkbegin, '\n'); \ + if (msgchunkend == nullptr) { \ + msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \ + } \ + while (msgchunkend > msgchunkbegin) { \ + std::string msgchunk(msgchunkbegin, \ + std::min(msgmaxlen, static_cast(msgchunkend - msgchunkbegin))); \ + OPS_LOG_STUB(OP, (LEVEL), (OPS_DESC), "%s", msgchunk.c_str()); \ + msgchunkbegin += msgchunk.size(); \ + } \ + msgchunkbegin += 1; \ + } \ + } while (0) diff --git a/csrc/utils/inc/log/ops_log.h b/csrc/utils/inc/log/ops_log.h new file mode 100644 index 00000000000..e7653a89bc6 --- /dev/null +++ b/csrc/utils/inc/log/ops_log.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file ops_log.h + * \brief + */ + +#pragma once + +#include "log/inner/dfx_base.h" + +/* 基础日志 */ +#define OPS_LOG_D(OPS_DESC, ...) OPS_LOG_STUB_D(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_I(OPS_DESC, ...) OPS_LOG_STUB_I(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_W(OPS_DESC, ...) OPS_LOG_STUB_W(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_E(OPS_DESC, ...) OPS_INNER_ERR_STUB("EZ9999", OPS_DESC, __VA_ARGS__) +#define OPS_LOG_E_WITHOUT_REPORT(OPS_DESC, ...) OPS_LOG_STUB_E(OPS_DESC, __VA_ARGS__) +#define OPS_LOG_EVENT(OPS_DESC, ...) OPS_LOG_STUB_EVENT(OPS_DESC, __VA_ARGS__) + +/* 全量日志 + * 输出超长日志, 若日志超长, 则会被分为多行输出 */ +#define OPS_LOG_FULL(LEVEL, OPS_DESC, ...) OPS_LOG_STUB_FULL(LEVEL, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_D_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_DEBUG, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_I_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_INFO, OPS_DESC, __VA_ARGS__) +#define OPS_LOG_W_FULL(OPS_DESC, ...) OPS_LOG_STUB_FULL(DLOG_WARN, OPS_DESC, __VA_ARGS__) + +/* 条件日志 */ +#define OPS_LOG_D_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_D(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_I_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_I(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_W_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_W(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_E_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_E(OP_DESC, __VA_ARGS__), EXPR) +#define OPS_LOG_EVENT_IF(COND, OP_DESC, EXPR, ...) OPS_LOG_STUB_IF(COND, OPS_LOG_EVENT(OP_DESC, __VA_ARGS__), EXPR) + +#define OPS_LOG_E_IF_NULL(OPS_DESC, PTR, EXPR) \ + if (__builtin_expect((PTR) == nullptr, 0)) { \ + OPS_LOG_STUB_E(OPS_DESC, "%s is nullptr!", #PTR); \ + OPS_CALL_ERR_STUB("EZ9999", OPS_DESC, "%s is nullptr!", #PTR); \ + EXPR; \ + } + +#define OPS_CHECK(COND, LOG_FUNC, EXPR) \ + if (COND) { \ + LOG_FUNC; \ + EXPR; \ + } + +#define OP_CHECK(COND, LOG_FUNC, EXPR) \ + if (COND) { \ + LOG_FUNC; \ + EXPR; \ + } diff --git a/csrc/utils/inc/tiling/data_copy_transpose_tiling.h b/csrc/utils/inc/tiling/data_copy_transpose_tiling.h new file mode 100644 index 00000000000..7e8d15d7f42 --- /dev/null +++ b/csrc/utils/inc/tiling/data_copy_transpose_tiling.h @@ -0,0 +1,47 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling.h + * \brief + */ + +#pragma once + +#include +#include +#include "data_copy_transpose_tiling_def.h" + +namespace optiling { + +inline void GetDataCopyTransposeTiling(const ge::Shape &dstShape, const ge::Shape &srcShape, const uint32_t typeSize, + optiling::CopyTransposeTiling &tiling) +{ + std::vector dstShapeInfo = dstShape.GetDims(); + std::vector srcShapeInfo = srcShape.GetDims(); + + tiling.set_dstShapeB(dstShapeInfo[0]); + tiling.set_dstShapeN(dstShapeInfo[1]); + tiling.set_dstShapeS(dstShapeInfo[2]); + tiling.set_dstShapeH(dstShapeInfo[3]); + tiling.set_dstShapeHN(tiling.get_dstShapeH() / tiling.get_dstShapeN()); + + tiling.set_srcShapeB(srcShapeInfo[0]); + tiling.set_srcShapeN(srcShapeInfo[1]); + tiling.set_srcShapeS(srcShapeInfo[2]); + tiling.set_srcShapeHN(srcShapeInfo[3]); + tiling.set_originalShapeNLen(tiling.get_srcShapeHN() * typeSize); + tiling.set_shapeSHValue(tiling.get_dstShapeS() * tiling.get_dstShapeH()); + tiling.set_shapeNsValue(tiling.get_dstShapeN() * tiling.get_dstShapeS()); + tiling.set_shapeNsnValue(tiling.get_dstShapeN() * tiling.get_srcShapeS() * tiling.get_srcShapeN()); + tiling.set_shapeBHValue(tiling.get_dstShapeB() * tiling.get_dstShapeH()); +} + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h b/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h new file mode 100644 index 00000000000..510b5cdadb3 --- /dev/null +++ b/csrc/utils/inc/tiling/data_copy_transpose_tiling_def.h @@ -0,0 +1,43 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file data_copy_transpose_tiling_def.h + * \brief + */ + +#pragma once + +#include +#include + +namespace optiling { + +BEGIN_TILING_DATA_DEF(CopyTransposeTiling) +TILING_DATA_FIELD_DEF(uint32_t, dstShapeB); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeS); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, dstShapeH); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeB); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeN); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeS); +TILING_DATA_FIELD_DEF(uint32_t, srcShapeHN); +TILING_DATA_FIELD_DEF(uint32_t, originalShapeNLen); +TILING_DATA_FIELD_DEF(uint32_t, shapeSHValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsValue); +TILING_DATA_FIELD_DEF(uint32_t, shapeNsnValue); +TILING_DATA_FIELD_DEF(uint32_t, invalidParamCopyTransposeTiling); +TILING_DATA_FIELD_DEF(uint32_t, shapeBHValue); +TILING_DATA_FIELD_DEF(uint32_t, paramsAlign); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(CopyTransposeTilingOp, CopyTransposeTiling) + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_base.h b/csrc/utils/inc/tiling/tiling_base.h new file mode 100644 index 00000000000..9776d90c874 --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_base.h @@ -0,0 +1,225 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_base.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include +#include "log/ops_log.h" + +#ifdef ASCENDC_OP_TEST +#define ASCENDC_EXTERN_C extern "C" +#else +#define ASCENDC_EXTERN_C +#endif + +namespace optiling { + +struct AiCoreParams { + uint64_t ubSize; + uint64_t blockDim; + uint64_t aicNum; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; +}; + +struct FlashAttentionScoreGradCompileInfo { + uint32_t aivNum; + uint32_t aicNum; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + uint64_t l2CacheSize; + int64_t coreNum; +}; + +class TilingBaseClass { +public: + TilingBaseClass() = default; + + explicit TilingBaseClass(gert::TilingContext *context) : context_(context) + { + } + + virtual ~TilingBaseClass() = default; + + // Tiling执行框架 + // 1、GRAPH_SUCCESS: 成功,并且不需要继续执行后续Tiling类的实现 + // 2、GRAPH_FAILED: 失败,中止整个Tiling流程 + // 3、GRAPH_PARAM_INVALID: 本类不支持,需要继续往下执行其他Tiling类的实现 + ge::graphStatus DoTiling() + { + auto ret = GetShapeAttrsInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetPlatformInfo(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + if (!IsCapable()) { + return ge::GRAPH_PARAM_INVALID; + } + ret = DoOpTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = DoLibApiTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = GetWorkspaceSize(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + ret = PostTiling(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + context_->SetTilingKey(GetTilingKey()); + DumpTilingInfo(); + return ge::GRAPH_SUCCESS; + } + + // 更新 context + virtual void Reset(gert::TilingContext *context) + { + context_ = context; + } + +protected: + virtual bool IsCapable() = 0; + // 1、获取平台信息比如CoreNum、UB/L1/L0C资源大小 + virtual ge::graphStatus GetPlatformInfo() = 0; + // 2、获取INPUT/OUTPUT/ATTR信息 + virtual ge::graphStatus GetShapeAttrsInfo() = 0; + // 3、计算数据切分TilingData + virtual ge::graphStatus DoOpTiling() = 0; + // 4、计算高阶API的TilingData + virtual ge::graphStatus DoLibApiTiling() = 0; + // 5、计算TilingKey + [[nodiscard]] virtual uint64_t GetTilingKey() const = 0; + // 6、计算Workspace 大小 + virtual ge::graphStatus GetWorkspaceSize() = 0; + // 7、保存Tiling数据 + virtual ge::graphStatus PostTiling() = 0; + // 8、Dump Tiling数据 + virtual void DumpTilingInfo() + { + int32_t enable = AlogCheckDebugLevel(static_cast(OP), DLOG_DEBUG); + if (enable != 1) { + return; + } + auto buf = (uint32_t *)context_->GetRawTilingData()->GetData(); + auto bufLen = context_->GetRawTilingData()->GetDataSize(); + std::ostringstream oss; + oss << "Start to dump tiling info. tilingkey:" << GetTilingKey() << ", tiling data size:" << bufLen + << ", content:"; + for (size_t i = 0; i < bufLen / sizeof(uint32_t); i++) { + oss << *(buf + i) << ","; + if (oss.str().length() > 640) { // Split according to 640 to avoid truncation + OPS_LOG_D(context_, "%s", oss.str().c_str()); + oss.str(""); + } + } + OPS_LOG_D(context_, "%s", oss.str().c_str()); + } + + static uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) + { + uint32_t ration; + if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) { + return sliceNum; + } + ration = aivCoreNum / aicCoreNum; + return (sliceNum + (ration - 1)) / ration; + } + + template [[nodiscard]] std::string GetShapeDebugStr(const T &shape) const + { + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); + } + + [[nodiscard]] std::string GetTensorDebugStr(const gert::StorageShape *shape, + const gert::CompileTimeTensorDesc *tensor) + { + if (shape == nullptr || tensor == nullptr) { + return "nil "; + } + std::ostringstream oss; + oss << "(dtype: " << ge::TypeUtils::DataTypeToSerialString(tensor->GetDataType()) << "),"; + oss << "(shape:" << GetShapeDebugStr(shape->GetStorageShape()) << "),"; + oss << "(ori_shape:" << GetShapeDebugStr(shape->GetOriginShape()) << "),"; + oss << "(format: " + << ge::TypeUtils::FormatToSerialString( + static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) + << "),"; + oss << "(ori_format: " << ge::TypeUtils::FormatToSerialString(tensor->GetOriginFormat()) << ") "; + return oss.str(); + } + + [[nodiscard]] std::string GetTilingContextDebugStr() + { + std::ostringstream oss; + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetInputsNum(); ++i) { + oss << "input" << i << ": "; + oss << GetTensorDebugStr(context_->GetInputShape(i), context_->GetInputDesc(i)); + } + + for (size_t i = 0; i < context_->GetComputeNodeInfo()->GetOutputsNum(); ++i) { + oss << "output" << i << ": "; + oss << GetTensorDebugStr(context_->GetOutputShape(i), context_->GetOutputDesc(i)); + } + return oss.str(); + } + + [[nodiscard]] std::string GetTilingDataDebugStr() const + { + auto rawTilingData = context_->GetRawTilingData(); + auto rawTilingDataSize = rawTilingData->GetDataSize(); + auto data = reinterpret_cast(rawTilingData->GetData()); + size_t len = rawTilingDataSize / sizeof(int32_t); + std::ostringstream oss; + for (size_t i = 0; i < len; i++) { + oss << data[i] << ", "; + } + return oss.str(); + } + +protected: + gert::TilingContext *context_ = nullptr; + std::unique_ptr ascendcPlatform_{nullptr}; + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0}; +}; + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_templates_registry.h b/csrc/utils/inc/tiling/tiling_templates_registry.h new file mode 100644 index 00000000000..53fc590aaf0 --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_templates_registry.h @@ -0,0 +1,162 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_templates_registry.h + * \brief + */ + +#pragma once + +#include +#include +#include +#include +#include "tiling/tiling_base.h" +#include "log/ops_log.h" +#include "error/ops_error.h" + +namespace optiling { + +template std::unique_ptr TILING_CLASS(gert::TilingContext *context) +{ + return std::unique_ptr(new (std::nothrow) T(context)); +} + +using TilingClassCase = std::unique_ptr (*)(gert::TilingContext *); + +class TilingCases { +public: + explicit TilingCases(std::string op_type) : op_type_(std::move(op_type)) + { + } + + template void AddTiling(int32_t priority) + { + OPS_ERR_IF(cases_.find(priority) != cases_.end(), + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "There are duplicate registrations."), return); + cases_[priority] = TILING_CLASS; + OPS_ERR_IF( + cases_[priority] == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling func failed, please check the class name."), + return); + } + + const std::map &GetTilingCases() + { + return cases_; + } + +private: + std::map cases_; + const std::string op_type_; +}; + +class TilingRegistry { +public: + TilingRegistry() = default; + +#ifdef ASCENDC_OP_TEST + static TilingRegistry &GetInstance(); +#else + static TilingRegistry &GetInstance() + { + static TilingRegistry registry_impl_; + return registry_impl_; + } +#endif + + std::shared_ptr RegisterOp(const std::string &op_type) + { + if (registry_map_.find(op_type) == registry_map_.end()) { + registry_map_[op_type] = std::shared_ptr(new (std::nothrow) TilingCases(op_type)); + } + OPS_ERR_IF(registry_map_[op_type] == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Register tiling func failed, please check the class name."), + return nullptr); + return registry_map_[op_type]; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext *context) + { + const char *op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto it = tilingTemplateRegistryMap.begin(); it != tilingTemplateRegistryMap.end(); ++it) { + auto tilingTemplate = it->second(context); + if (tilingTemplate != nullptr) { + ge::graphStatus status = tilingTemplate->DoTiling(); + if (status != ge::GRAPH_PARAM_INVALID) { + OPS_LOG_D(context, "Do general op tiling success priority=%d", it->first); + return status; + } + OPS_LOG_D(context, "Ignore general op tiling priority=%d", it->first); + } + } + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Do op tiling failed, no valid template is found."); + return ge::GRAPH_FAILED; + } + + ge::graphStatus DoTilingImpl(gert::TilingContext *context, const std::vector &priorities) + { + const char *op_type = context->GetNodeType(); + auto tilingTemplateRegistryMap = GetTilingTemplates(op_type); + for (auto priorityId : priorities) { + auto templateFunc = tilingTemplateRegistryMap[priorityId](context); + if (templateFunc != nullptr) { + ge::graphStatus status = templateFunc->DoTiling(); + if (status == ge::GRAPH_SUCCESS) { + OPS_LOG_D(context, "Do general op tiling success priority=%d", priorityId); + return status; + } + OPS_LOG_D(context, "Ignore general op tiling priority=%d", priorityId); + } + } + return ge::GRAPH_FAILED; + } + + const std::map &GetTilingTemplates(const std::string &op_type) + { + OPS_ERR_IF(registry_map_.find(op_type) == registry_map_.end(), + OPS_REPORT_VECTOR_INNER_ERR(op_type, "Get op tiling func failed, please check the op name."), + return empty_tiling_case_); + return registry_map_[op_type]->GetTilingCases(); + } + +private: + std::map> registry_map_; + const std::map empty_tiling_case_ {}; +}; + +class Register { +public: + explicit Register(std::string op_type) : op_type_(std::move(op_type)) + { + } + + template Register &tiling(int32_t priority) + { + auto tilingCases = TilingRegistry::GetInstance().RegisterOp(op_type_); + OPS_ERR_IF(tilingCases == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(op_type_, "Register op tiling failed, please the op name."), + return *this); + tilingCases->AddTiling(priority); + return *this; + } + +private: + const std::string op_type_; +}; + +// op_type: 算子名称, class_name: 注册的 tiling 类, +// priority: tiling 类的优先级, 越小表示优先级越高, 即被选中的概率越大 +#define REGISTER_TILING_TEMPLATE(op_type, class_name, priority) \ + static Register VAR_UNUSED##op_type_##class_name##priority_register = Register(op_type).tiling(priority) + +} // namespace optiling diff --git a/csrc/utils/inc/tiling/tiling_type.h b/csrc/utils/inc/tiling/tiling_type.h new file mode 100644 index 00000000000..d417b0b6b48 --- /dev/null +++ b/csrc/utils/inc/tiling/tiling_type.h @@ -0,0 +1,136 @@ +/** + * Copyright (c) 2023-2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file tiling_type.h + * \brief + */ + +#pragma once + +#include + +namespace optiling { + +enum class AxisEnum { + B = 0, + N2 = 1, + G = 2, + S1 = 3, + S2 = 4, + D = 5, + NONE = 9, +}; + +enum class DtypeEnum { + FLOAT16 = 0, + FLOAT32 = 1, + BFLOAT16 = 2, + FLOAT16_PRECISION = 3, +}; + +enum class PerformanceOrientedEnum { + BIG_BUFFER = 1, + BIG_DOUBLE_BUFFER = 2, +}; + +enum class MatmulConfig { + NULL_CONFIG = 0, + NORMAL_CONFIG = 1, + MDL_CONFIG = 2 +}; + +enum class PseConfig { + NO_PSE = 0, + EXIST_PSE = 1 +}; + +enum class AttenMaskConfig { + NO_ATTEN_MASK = 0, + EXIST_ATTEN_MASK = 1 +}; + +enum class DropOutConfig { + NO_DROP_OUT = 0, + EXIST_DROP_OUT = 1 +}; + +enum class CubeFormatEnum { + ND = 0, + NZ = 1 +}; +enum class LayoutEnum { + BSND = 0, + SBND = 1, + BNSD = 2, + TND = 3 +}; + +enum class CubeInputSourceEnum { + GM = 0, + L1 = 1 +}; + +enum class OptionEnum { + DISABLE = 0, + ENABLE = 1 +}; + +enum class SparseEnum { + ALL = 0, + NONE = 1, + ANY = 2, + CAUSAL = 3, + BAND = 4, + PREFIX = 5, + BAND_COMPRESS = 6, + RIGHT_DOWN_CAUSAL = 7, + RIGHT_DOWN_CAUSAL_BAND = 8, + BAND_LEFT_UP_CAUSAL = 9 +}; + +constexpr uint64_t RecursiveSum() +{ + return 0; +} + +template constexpr uint64_t RecursiveSum(T templateId, Args... templateIds) +{ + return static_cast(templateId) + 10 * RecursiveSum(templateIds...); +} + +// TilingKey 的生成规则: +// FlashAttentionScore/FlashAttentionScoreGrad 十进制位组装tiling key,包含以下关键参数,从低位到高位依次是:Ub0, Ub1, +// Block, DataType, Format, Sparse, 特化模板 Ub0、Ub1: +// 表示Ub核内切分的轴,使用枚举AxisEnum表示,因为我们允许最多切分两根轴,所以存在UB0和UB1,如果没有UB核内切分, +// 那么填AXIS_NONE。UB0和UB1各占一个十进制位; +// Block: 表示UB用来分核的轴,使用枚举AxisEnum表示,占一个十进制位; +// DataType: 表示当前tiling key支持的输入输出的数据类型,使用枚举SupportedDtype来表示,占一个十进制位 +// Format: 表示当前tiling key支持的Format, 使用枚举InputLayout表示,占一个十进制位 +// Sparse: 表示当前tiling key是否支持Sparse,使用枚举SparseCapability表示,占一个十进制位 +// 其余特化场景,定义自己的位域和值 +// usage: get tilingKey from inputed types +// uint64_t tilingKey = GET_FLASHATTENTION_TILINGKEY(AxisEnum::AXIS_S1, AxisEnum::AXIS_S2, AxisEnum::AXIS_N2, +// SupportedDtype::FLOAT32, InputLayout::BSH, SparseCapability::SUPPORT_ALL) + +constexpr uint64_t TILINGKEYOFFSET = uint64_t(10000000000000000000UL); // 10^19 +template constexpr uint64_t GET_TILINGKEY(Args... templateIds) +{ + return TILINGKEYOFFSET + RecursiveSum(templateIds...); +} + +// usage: get tilingKey from inputed types +// uint64_t tilingKey = TILINGKEY(S2, S1, N2, FLOAT32, BSND, ALL) + +#define TILINGKEY(ub2, ub1, block, dtype, layout, sparse) \ + (GET_TILINGKEY(AxisEnum::ub2, AxisEnum::ub1, AxisEnum::block, DtypeEnum::dtype, LayoutEnum::layout, \ + SparseEnum::sparse)) + +} // namespace optiling diff --git a/csrc/utils/src/fallback_comm.cpp b/csrc/utils/src/fallback_comm.cpp new file mode 100644 index 00000000000..949cb728934 --- /dev/null +++ b/csrc/utils/src/fallback_comm.cpp @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file fallback_comm.cpp + * \brief + */ + +#include "fallback_comm.h" + +#include +#include +#include +#include + +#include "aclnn/aclnn_base.h" +#include "runtime/base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace fallback { +using namespace std; +using namespace gert; +using namespace ge; + +aclDataType ToAclDataType(ge::DataType dtype) { + static const std::vector CANN_CONVERT_TO_ACL_DataType_LIST = { + ge::DataType::DT_FLOAT, ge::DataType::DT_FLOAT16, ge::DataType::DT_INT8, ge::DataType::DT_INT32, + ge::DataType::DT_UINT8, ge::DataType::DT_INT16, ge::DataType::DT_UINT16, ge::DataType::DT_UINT32, + ge::DataType::DT_INT64, ge::DataType::DT_DOUBLE, ge::DataType::DT_BOOL, ge::DataType::DT_STRING, + ge::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX128, ge::DataType::DT_BF16, ge::DataType::DT_UINT64, + ge::DataType::DT_INT4}; + auto iter = std::find(CANN_CONVERT_TO_ACL_DataType_LIST.begin(), CANN_CONVERT_TO_ACL_DataType_LIST.end(), dtype); + if (iter == CANN_CONVERT_TO_ACL_DataType_LIST.end()) { + return aclDataType::ACL_DT_UNDEFINED; + } + return static_cast(dtype); +} + +} // namespace fallback + +#ifdef __cplusplus +} +#endif diff --git a/docs/source/developer_guide/contribution/multi_node_test.md b/docs/source/developer_guide/contribution/multi_node_test.md index 1fdcc3c5907..a57a19c62a8 100644 --- a/docs/source/developer_guide/contribution/multi_node_test.md +++ b/docs/source/developer_guide/contribution/multi_node_test.md @@ -90,7 +90,7 @@ currently, the multi-node test workflow defined in the [vllm_ascend_test_nightly uses: ./.github/workflows/_e2e_nightly_multi_node.yaml with: soc_version: a3 - image: m.daocloud.io/quay.io/ascend/cann:8.3.rc1-a3-ubuntu22.04-py3.11 + image: m.daocloud.io/quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 replicas: 1 size: ${{ matrix.test_config.size }} config_file_path: ${{ matrix.test_config.config_file_path }} diff --git a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md index 27986aabbbd..04bde6fe9a5 100644 --- a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md +++ b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md @@ -6,7 +6,7 @@ MTP boosts inference performance by parallelizing the prediction of multiple tok ## How to Use MTP To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service: -`--speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}` +--speculative_config ' {"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} ' - `num_speculative_tokens`: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required. - `disable_padded_drafter_batch`: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False. diff --git a/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md b/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md new file mode 100644 index 00000000000..79a923a0b2e --- /dev/null +++ b/docs/source/developer_guide/feature_guide/add_custom_aclnn_op.md @@ -0,0 +1,25 @@ +# Adding a custom aclnn operation + +This document describes how to add a custom aclnn operation to vllm-ascend. + +## How custom aclnn operation works in vllm-ascend? + +Custom aclnn operations are built and installed into `vllm_ascend/cann_ops_custom` directory during the build process of vllm-ascend. Then the aclnn operators are bound to `torch.ops._C_ascend` module, enabling users to invoke them in vllm-ascend python code. + +To enable custom operations, use the following code: + +```python +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() +``` + +## How to add a custom aclnn operation? + +- Create a new operation folder under `csrc` directory +- Create `op_host` and `op_kernel` directories for host and kernel source code +- Add build options in `csrc/build_aclnn.sh` for supported SOC. Note that multiple ops should be separated with `;`, i.e. `CUSTOM_OPS=op1;op2;op3` +- Bind aclnn operators to torch.ops._C_ascend module in `csrc/torch_binding.cpp` +- Write a meta implementation in `csrc/torch_binding_meta.cpp` for op being captured into aclgraph + +After a successful build of vllm-ascend, the custom aclnn operation can be invoked in python code. diff --git a/docs/source/developer_guide/feature_guide/index.md b/docs/source/developer_guide/feature_guide/index.md index 91f6badb4ba..592850e664a 100644 --- a/docs/source/developer_guide/feature_guide/index.md +++ b/docs/source/developer_guide/feature_guide/index.md @@ -12,4 +12,5 @@ eplb_swift_balancer.md Multi_Token_Prediction ACL_Graph KV_Cache_Pool_Guide +add_custom_aclnn_op ::: diff --git a/docs/source/installation.md b/docs/source/installation.md index 757752be35f..3f80341372a 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -80,19 +80,19 @@ source vllm-ascend-env/bin/activate pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple attrs 'numpy<2.0.0' decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions # Download and install the CANN package. -wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC1/Ascend-cann-toolkit_8.3.RC1_linux-"$(uname -i)".run -chmod +x ./Ascend-cann-toolkit_8.3.RC1_linux-"$(uname -i)".run -./Ascend-cann-toolkit_8.3.RC1_linux-"$(uname -i)".run --full -# https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C22B800TP052/Ascend-cann-kernels-910b_8.3.rc1_linux-aarch64.run +wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC2/Ascend-cann-toolkit_8.3.RC2_linux-"$(uname -i)".run +chmod +x ./Ascend-cann-toolkit_8.3.RC2_linux-"$(uname -i)".run +./Ascend-cann-toolkit_8.3.RC2_linux-"$(uname -i)".run --full +# https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C22B800TP052/Ascend-cann-kernels-910b_8.3.rc2_linux-aarch64.run source /usr/local/Ascend/ascend-toolkit/set_env.sh -wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC1/Ascend-cann-kernels-910b_8.3.RC1_linux-"$(uname -i)".run -chmod +x ./Ascend-cann-kernels-910b_8.3.RC1_linux-"$(uname -i)".run -./Ascend-cann-kernels-910b_8.3.RC1_linux-"$(uname -i)".run --install +wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC2/Ascend-cann-kernels-910b_8.3.RC2_linux-"$(uname -i)".run +chmod +x ./Ascend-cann-kernels-910b_8.3.RC2_linux-"$(uname -i)".run +./Ascend-cann-kernels-910b_8.3.RC2_linux-"$(uname -i)".run --install -wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC1/Ascend-cann-nnal_8.3.RC1_linux-"$(uname -i)".run -chmod +x ./Ascend-cann-nnal_8.3.RC1_linux-"$(uname -i)".run -./Ascend-cann-nnal_8.3.RC1_linux-"$(uname -i)".run --install +wget --header="Referer: https://www.hiascend.com/" https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.3.RC2/Ascend-cann-nnal_8.3.RC2_linux-"$(uname -i)".run +chmod +x ./Ascend-cann-nnal_8.3.RC2_linux-"$(uname -i)".run +./Ascend-cann-nnal_8.3.RC2_linux-"$(uname -i)".run --install source /usr/local/Ascend/nnal/atb/set_env.sh ``` diff --git a/docs/source/tutorials/DeepSeek-R1.md b/docs/source/tutorials/DeepSeek-R1.md new file mode 100644 index 00000000000..432ae40f7b5 --- /dev/null +++ b/docs/source/tutorials/DeepSeek-R1.md @@ -0,0 +1,293 @@ +# DeepSeek-R1 + +## Introduction + +DeepSeek-R1 is a high-performance Mixture-of-Experts (MoE) large language model developed by DeepSeek Company. It excels in complex logical reasoning, mathematical problem-solving, and code generation. By dynamically activating its expert networks, it delivers exceptional performance while maintaining computational efficiency. Building upon R1, DeepSeek-R1-W8A8 is a fully quantized version of the model. It employs 8-bit integer (INT8) quantization for both weights and activations, which significantly reduces the model's memory footprint and computational requirements, enabling more efficient deployment and application in resource-constrained environments. +This article takes the deepseek- R1-w8a8 version as an example to introduce the deployment of the R1 series models. + +## Supported Features + +Refer to [supported features](../user_guide/support_matrix/supported_models.md) to get the model's supported feature matrix. + +Refer to [feature guide](../user_guide/feature_guide/index.md) to get the feature's configuration. + +## Environment Preparation + +### Model Weight + +- `DeepSeek-R1-w8a8`(Quantized version): require 1 Atlas 800 A3 (64G × 16) nodes or 2 Atlas 800 A2 (64G × 8) nodes. [Download model weight](https://www.modelscope.cn/models/vllm-ascend/DeepSeek-R1-W8A8) + +It is recommended to download the model weight to the shared directory of multiple nodes. + +### Verify Multi-node Communication(Optional) + +If you want to deploy multi-node environment, you need to verify multi-node communication according to [verify multi-node communication environment](../installation.md#verify-multi-node-communication). + +### Installation + +You can using our official docker image and install extra operator for supporting `DeepSeek-R1-w8a8`. + +:::{note} +Only AArch64 architecture are supported currently due to extra operator's installation limitations. +::: + +:::::{tab-set} +:sync-group: install + +::::{tab-item} A3 series +:sync: A3 + +1. Start the docker image on your node, refer to [using docker](../installation.md#set-up-using-docker). + + +In addition, if you don't want to use the docker image as above, you can also build all from source: + +- Install `vllm-ascend` from source, refer to [installation](../installation.md). + +- Install extra operator for supporting `DeepSeek-R1-w8a8`, refer to the above tab. + +If you want to deploy multi-node environment, you need to set up environment on each node. + +## Deployment +### Service-oriented Deployment + +- `DeepSeek-R1-w8a8`: require 1 Atlas 800 A3 (64G × 16) nodes or 2 Atlas 800 A2 (64G × 8). + +:::::{tab-set} +:sync-group: install + +::::{tab-item} DeepSeek-R1-w8a8 A3 series +:sync: A3 + +```shell +#!/bin/sh + +# this obtained through ifconfig +# nic_name is the network interface name corresponding to local_ip of the current node +nic_name="xxxx" +local_ip="xxxx" + +# AIV +export HCCL_OP_EXPANSION_MODE="AIV" + +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export HCCL_BUFFSIZE=200 +export VLLM_ASCEND_ENABLE_MLAPO=1 +export VLLM_RPC_TIMEOUT=3600000 +export VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS=3600000 +export VLLM_TORCH_PROFILER_DIR="PATH/profile" +export VLLM_ASCEND_ENABLE_FLASHCOMM1=0 +export DISABLE_L2_CACHE=1 + +vllm serve path/DeepSeek-R1-W8A8 \ + --host 0.0.0.0 \ + --port 8000 \ + --data-parallel-size 4 \ + --tensor-parallel-size 4 \ + --quantization ascend \ + --seed 1024 \ + --served-model-name deepseek_r1 \ + --enable-expert-parallel \ + --max-num-seqs 16 \ + --max-model-len 8192 \ + --max-num-batched-tokens 2048 \ + --trust-remote-code \ + --no-enable-prefix-caching \ + --gpu-memory-utilization 0.92 \ + --speculative-config '{"num_speculative_tokens":1,"method":"deepseek_mtp"}' \ + --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' \ + --additional-config '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false}}' +``` + +:::: +::::{tab-item} DeepSeek-R1-w8a8 A2 series +:sync: A2 + +Run the following scripts on two nodes respectively. + +**Node 0** + +```shell +#!/bin/sh + +# this obtained through ifconfig +# nic_name is the network interface name corresponding to local_ip of the current node +nic_name="xxxx" +local_ip="xxxx" + +# AIV +export HCCL_OP_EXPANSION_MODE="AIV" + +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export HCCL_BUFFSIZE=200 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_INTRA_PCIE_ENABLE=1 +export HCCL_INTRA_ROCE_ENABLE=0 + +vllm serve path/DeepSeek-R1-W8A8 \ + --host 0.0.0.0 \ + --port 8000 \ + --data-parallel-size 4 \ + --data-parallel-size_local 2 \ + --data-parallel-address $local_ip \ + --data-parallel-rpc-port 13389 \ + --tensor-parallel-size 4 \ + --quantization ascend \ + --seed 1024 \ + --served-model-name deepseek_r1 \ + --enable-expert-parallel \ + --max-num-seqs 20 \ + --max-model-len 8192 \ + --max-num-batched-tokens 4096 \ + --trust-remote-code \ + --no-enable-prefix-caching \ + --gpu-memory-utilization 0.92 \ + --speculative-config '{"num_speculative_tokens":1,"method":"deepseek_mtp"}' \ + --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' \ + --additional-config '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false}}' +``` + +**Node 1** + +```shell +#!/bin/sh + +# this obtained through ifconfig +# nic_name is the network interface name corresponding to local_ip of the current node +nic_name="xxxx" +local_ip="xxxx" +node0_ip="xxxx" # same as the local_IP address in node 0 + +# AIV +export HCCL_OP_EXPANSION_MODE="AIV" + +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +export HCCL_BUFFSIZE=200 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_INTRA_PCIE_ENABLE=1 +export HCCL_INTRA_ROCE_ENABLE=0 + +vllm serve path/DeepSeek-R1-W8A8 \ + --host 0.0.0.0 \ + --port 8000 \ + --headless \ + --data-parallel-size 4 \ + --data-parallel-size-local 2 \ + --data-parallel-start-rank 2 \ + --data-parallel-address $node0_ip \ + --data-parallel-rpc-port 13389 \ + --tensor-parallel-size 4 \ + --quantization ascend \ + --seed 1024 \ + --served-model-name deepseek_r1 \ + --enable-expert-parallel \ + --max-num-seqs 20 \ + --max-model-len 8192 \ + --max-num-batched-tokens 4096 \ + --trust-remote-code \ + --no-enable-prefix-caching \ + --gpu-memory-utilization 0.94 \ + --speculative-config '{"num_speculative_tokens":1,"method":"deepseek_mtp"}' \ + --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}' \ + --additional-config '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false}}' +``` + +:::: +::::: + +### Prefill-Decode Disaggregation + +Not supported yet. + +## Functional Verification + +Once your server is started, you can query the model with input prompts: + +```shell +curl http://:/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "deepseek_r1", + "prompt": "The future of AI is", + "max_tokens": 50, + "temperature": 0 + }' +``` + +## Accuracy Evaluation + +Here are two accuracy evaluation methods. + +### Using AISBench + +1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details. + +2. After execution, you can get the result, here is the result of `DeepSeek-R1-w8a8` in `vllm-ascend:0.11.0rc2` for reference only. + +| dataset | version | metric | mode | vllm-api-general-chat | +|----- | ----- | ----- | ----- | -----| +| aime2024dataset | - | accuracy | gen | 80.00 | +| gpqadataset | - | accuracy | gen | 72.22 | + +### Using Language Model Evaluation Harness + +As an example, take the `gsm8k` dataset as a test dataset, and run accuracy evaluation of `DeepSeek-R1-w8a8` in online mode. + +1. Refer to [Using lm_eval](../developer_guide/evaluation/using_lm_eval.md) for `lm_eval` installation. + +2. Run `lm_eval` to execute the accuracy evaluation. + +```shell +lm_eval \ + --model local-completions \ + --model_args model=path/DeepSeek-R1-w8a8,base_url=http://:/v1/completions,tokenized_requests=False,trust_remote_code=True \ + --tasks gsm8k \ + --output_path ./ +``` + +3. After execution, you can get the result. + +## Performance +### Using AISBench + +Refer to [Using AISBench for performance evaluation](../developer_guide/evaluation/using_ais_bench.md#execute-performance-evaluation) for details. + +### Using vLLM Benchmark + +Run performance evaluation of `DeepSeek-R1-w8a8` as an example. + +Refer to [vllm benchmark](https://docs.vllm.ai/en/latest/contributing/benchmarks.html) for more details. + +There are three `vllm bench` subcommand: +- `latency`: Benchmark the latency of a single batch of requests. +- `serve`: Benchmark the online serving throughput. +- `throughput`: Benchmark offline inference throughput. + +Take the `serve` as an example. Run the code as follows. + +```shell +export VLLM_USE_MODELSCOPE=true +vllm bench serve --model path/DeepSeek-R1-w8a8 --dataset-name random --random-input 200 --num-prompt 200 --request-rate 1 --save-result --result-dir ./ +``` + +After about several minutes, you can get the performance evaluation result. diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 321ec22d9cc..892a7774b98 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -17,10 +17,10 @@ multi_npu_qwen3_moe multi_npu_quantization single_node_300i DeepSeek-V3.2-Exp.md +DeepSeek-R1.md multi_node multi_node_kimi multi_node_qwen3vl -multi_node_pd_disaggregation_llmdatadist multi_node_pd_disaggregation_mooncake multi_node_ray ::: diff --git a/docs/source/tutorials/multi_node_pd_disaggregation_llmdatadist.md b/docs/source/tutorials/multi_node_pd_disaggregation_llmdatadist.md deleted file mode 100644 index 3bd06daa53f..00000000000 --- a/docs/source/tutorials/multi_node_pd_disaggregation_llmdatadist.md +++ /dev/null @@ -1,241 +0,0 @@ -# Prefill-Decode Disaggregation Llmdatadist Verification (Qwen) - -## Getting Start - -vLLM-Ascend now supports prefill-decode (PD) disaggregation with Expert Parallel (EP) options. This guide takes one-by-one steps to verify these features with constrained resources. - -Using the Qwen3-30B-A3B model as an example, use vllm-ascend v0.10.1rc1 (with vLLM v0.10.1.1) on 3 Atlas 800T A2 servers to deploy the "1P2D" architecture. Assume the IP address of the prefiller server is 192.0.0.1, and the decoder servers are 192.0.0.2 (decoder 1) and 192.0.0.3 (decoder 2). On each server, use 2 NPUs to deploy one service instance. - -## Verify Multi-Node Communication Environment - -### Physical Layer Requirements - -- The physical machines must be located on the same WLAN, with network connectivity. -- All NPUs must be interconnected. Intra-node connectivity is via HCCS, and inter-node connectivity is via RDMA. - -### Verification Process - -1. Single Node Verification: - -Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: - -```bash -# Check the remote switch ports -for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done -# Get the link status of the Ethernet ports (UP or DOWN) -for i in {0..7}; do hccn_tool -i $i -link -g ; done -# Check the network health status -for i in {0..7}; do hccn_tool -i $i -net_health -g ; done -# View the network detected IP configuration -for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done -# View gateway configuration -for i in {0..7}; do hccn_tool -i $i -gateway -g ; done -# View NPU network configuration -cat /etc/hccn.conf -``` - -2. Get NPU IP Addresses - -```bash -for i in {0..7}; do hccn_tool -i $i -ip -g;done -``` - -3. Cross-Node PING Test - -```bash -# Execute on the target node (replace 'x.x.x.x' with actual npu ip address) -for i in {0..7}; do hccn_tool -i $i -ping -g address x.x.x.x;done -``` - -## Generate Ranktable - -The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. For more details, please refer to the [vllm-ascend examples](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/README.md). Execute the following commands for reference. - -```shell -cd vllm-ascend/examples/disaggregate_prefill_v1/ -bash gen_ranktable.sh --ips \ - --npus-per-node --network-card-name --prefill-device-cnt --decode-device-cnt \ - [--local-device-ids ,,...] -``` - -Assume that we use devices 0 and 1 on the prefiller server node and devices 6 and 7 on both of the decoder server nodes. The following commands are for reference. (`--local-device-ids` is necessary if you specify certain NPU devices on the local server.) - -```shell -# On the prefiller node -cd vllm-ascend/examples/disaggregate_prefill_v1/ -bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ - --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 0,1 - -# On the decoder 1 -cd vllm-ascend/examples/disaggregate_prefill_v1/ -bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ - --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 6,7 - -# On the decoder 2 -cd vllm-ascend/examples/disaggregate_prefill_v1/ -bash gen_ranktable.sh --ips 192.0.0.1 192.0.0.2 192.0.0.3 \ - --npus-per-node 2 --network-card-name eth0 --prefill-device-cnt 2 --decode-device-cnt 4 --local-device-ids 6,7 -``` - -The rank table will be generated at /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json - -|Parameter | Meaning | -| --- | --- | -| --ips | Each node's local IP address (prefiller nodes should be in front of decoder nodes) | -| --npus-per-node | Each node's NPU clips | -| --network-card-name | The physical machines' NIC | -|--prefill-device-cnt | NPU clips used for prefill | -|--decode-device-cnt |NPU clips used for decode | -|--local-device-ids |Optional. No need if using all devices on the local node. | - -## Prefiller/Decoder Deployment - -We can run the following scripts to launch a server on the prefiller/decoder node, respectively. - -:::::{tab-set} - -::::{tab-item} Prefiller node - -```shell -export HCCL_IF_IP=192.0.0.1 # node ip -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" -export OMP_PROC_BIND=false -export OMP_NUM_THREADS=10 - -vllm serve /model/Qwen3-30B-A3B \ - --host 0.0.0.0 \ - --port 13700 \ - --tensor-parallel-size 2 \ - --no-enable-prefix-caching \ - --seed 1024 \ - --served-model-name qwen3-moe \ - --max-model-len 6144 \ - --max-num-batched-tokens 6144 \ - --trust-remote-code \ - --gpu-memory-utilization 0.9 \ - --enable-expert-parallel \ - --kv-transfer-config \ - '{"kv_connector": "LLMDataDistCMgrConnector", - "kv_buffer_device": "npu", - "kv_role": "kv_producer", - "kv_parallel_size": 1, - "kv_port": "20001", - "engine_id": "0", - "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' \ - --enforce-eager -``` - -:::: - -::::{tab-item} Decoder node 1 - -```shell -export HCCL_IF_IP=192.0.0.2 # node ip -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" -export OMP_PROC_BIND=false -export OMP_NUM_THREADS=10 - -vllm serve /model/Qwen3-30B-A3B \ - --host 0.0.0.0 \ - --port 13700 \ - --no-enable-prefix-caching \ - --tensor-parallel-size 2 \ - --seed 1024 \ - --served-model-name qwen3-moe \ - --max-model-len 6144 \ - --max-num-batched-tokens 6144 \ - --trust-remote-code \ - --gpu-memory-utilization 0.9 \ - --enable-expert-parallel \ - --kv-transfer-config \ - '{"kv_connector": "LLMDataDistCMgrConnector", - "kv_buffer_device": "npu", - "kv_role": "kv_consumer", - "kv_parallel_size": 1, - "kv_port": "20001", - "engine_id": "0", - "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' -``` - -:::: - -::::{tab-item} Decoder node 2 - -```shell -export HCCL_IF_IP=192.0.0.3 # node ip -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="/path/to/your/generated/ranktable.json" -export OMP_PROC_BIND=false -export OMP_NUM_THREADS=10 - -vllm serve /model/Qwen3-30B-A3B \ - --host 0.0.0.0 \ - --port 13700 \ - --no-enable-prefix-caching \ - --tensor-parallel-size 2 \ - --seed 1024 \ - --served-model-name qwen3-moe \ - --max-model-len 6144 \ - --max-num-batched-tokens 6144 \ - --trust-remote-code \ - --gpu-memory-utilization 0.9 \ - --enable-expert-parallel \ - --kv-transfer-config \ - '{"kv_connector": "LLMDataDistCMgrConnector", - "kv_buffer_device": "npu", - "kv_role": "kv_consumer", - "kv_parallel_size": 1, - "kv_port": "20001", - "engine_id": "0", - "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" - }' \ - --additional-config \ - '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' -``` - -:::: - -::::: - -## Example Proxy for Deployment - -Run a proxy server on the same node with the prefiller service instance. You can get the proxy program in the repository's examples: [load\_balance\_proxy\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py) - -```shell -python load_balance_proxy_server_example.py \ - --host 192.0.0.1 \ - --port 8080 \ - --prefiller-hosts 192.0.0.1 \ - --prefiller-port 13700 \ - --decoder-hosts 192.0.0.2 192.0.0.3 \ - --decoder-ports 13700 13700 -``` - -## Verification - -Check service health using the proxy server endpoint. - -```shell -curl http://192.0.0.1:8080/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "qwen3-moe", - "prompt": "Who are you?", - "max_tokens": 100, - "temperature": 0 - }' -``` diff --git a/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md b/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md index 4614860928e..d11ea137f0b 100644 --- a/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md +++ b/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md @@ -1,10 +1,10 @@ -# Prefill-Decode Disaggregation Mooncake Verification (Qwen) +# Prefill-Decode Disaggregation Mooncake Verification (Deepseek) ## Getting Start vLLM-Ascend now supports prefill-decode (PD) disaggregation with EP (Expert Parallel) options. This guide take one-by-one steps to verify these features with constrained resources. -Take the Qwen3-235B model as an example, use 4 Atlas 800T A3 servers to deploy the "2P1D" architecture. Assume the ip of the prefiller server is 192.0.0.1 (prefill 1) and 192.0.0.2 (prefill 2), and the decoder servers are 192.0.0.3 (decoder 1) and 192.0.0.4 (decoder 2). On each server, use 8 NPUs 16 chips to deploy one service instance. +Take the Deepseek-r1-w8a8 model as an example, use 4 Atlas 800T A3 servers to deploy the "2P1D" architecture. Assume the ip of the prefiller server is 192.0.0.1 (prefill 1) and 192.0.0.2 (prefill 2), and the decoder servers are 192.0.0.3 (decoder 1) and 192.0.0.4 (decoder 2). On each server, use 8 NPUs 16 chips to deploy one service instance. ## Verify Multi-Node Communication Environment @@ -15,6 +15,11 @@ Take the Qwen3-235B model as an example, use 4 Atlas 800T A3 servers to deploy t ### Verification Process +Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: + +:::::{tab-set} +::::{tab-item} A3 + 1. Single Node Verification: Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: @@ -34,8 +39,9 @@ for i in {0..15}; do hccn_tool -i $i -gateway -g ; done 2. Check NPU network configuration: +Ensure that the hccn.conf file exists in the environment. If using Docker, mount it into the container. + ```bash -# Ensure that the hccn.conf file exists in the environment. If using Docker, mount it into the container. cat /etc/hccn.conf ``` @@ -52,6 +58,97 @@ for i in {0..15}; do hccn_tool -i $i -ip -g | grep ipaddr; done for i in {0..15}; do hccn_tool -i $i -ping -g address x.x.x.x;done ``` +:::: + +::::{tab-item} A2 + +1. Single Node Verification: + +Execute the following commands on each node in sequence. The results must all be `success` and the status must be `UP`: + +```bash +# Check the remote switch ports +for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done +# Get the link status of the Ethernet ports (UP or DOWN) +for i in {0..7}; do hccn_tool -i $i -link -g ; done +# Check the network health status +for i in {0..7}; do hccn_tool -i $i -net_health -g ; done +# View the network detected IP configuration +for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done +# View gateway configuration +for i in {0..7}; do hccn_tool -i $i -gateway -g ; done +``` + +2. Check NPU network configuration: + +Ensure that the hccn.conf file exists in the environment. If using Docker, mount it into the container. + +```bash +cat /etc/hccn.conf +``` + +3. Get NPU IP Addresses + +```bash +for i in {0..7}; do hccn_tool -i $i -ip -g;done +``` + +4. Cross-Node PING Test + +```bash +# Execute on the target node (replace 'x.x.x.x' with actual npu ip address) +for i in {0..7}; do hccn_tool -i $i -ping -g address x.x.x.x;done +``` + +:::: + +::::: + +## Run with Docker +Start a Docker container on each node. + +```{code-block} bash + :substitutions: +# Update the vllm-ascend image +export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:|vllm_ascend_version| +export NAME=vllm-ascend + +# Run the container using the defined variables +# Note: If you are running bridge network with docker, please expose available ports for multiple nodes communication in advance +docker run --rm \ +--name $NAME \ +--net=host \ +--shm-size=1g \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci4 \ +--device /dev/davinci5 \ +--device /dev/davinci6 \ +--device /dev/davinci7 \ +--device /dev/davinci8 \ +--device /dev/davinci9 \ +--device /dev/davinci10 \ +--device /dev/davinci11 \ +--device /dev/davinci12 \ +--device /dev/davinci13 \ +--device /dev/davinci14 \ +--device /dev/davinci15 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /etc/hccn.conf:/etc/hccn.conf \ +-v /mnt/sfs_turbo/.cache:/root/.cache \ +-it $IMAGE bash +``` + ## Install Mooncake Mooncake is the serving platform for Kimi, a leading LLM service provided by Moonshot AI. First, we need to obtain the Mooncake project. Refer to the following command: @@ -93,7 +190,15 @@ make install We can run the following scripts to launch a server on the prefiller/decoder node, respectively. Please note that each P/D node will occupy ports ranging from kv_port to kv_port + num_chips to initialize socket listeners. To avoid any issues, port conflicts should be prevented. Additionally, ensure that each node's engine_id is uniquely assigned to avoid conflicts. -### Layerwise +### launch_online_dp.py +Use `launch_online_dp.py` to launch external dp vllm servers. +[launch\_online\_dp.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/external_online_dp/launch_online_dp.py) + +### run_dp_template.sh +Modify `run_dp_template.py` on each node. +[run\_dp\_template.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/external_online_dp/run_dp_template.sh) + +#### Layerwise :::::{tab-set} :sync-group: nodes @@ -102,36 +207,42 @@ We can run the following scripts to launch a server on the prefiller/decoder nod :sync: prefill node1 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.1 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=1024 +nic_name="eth0" # network card name +local_ip="192.0.0.1" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_BUFFSIZE=256 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 2 \ - --data-parallel-size-local 2 \ - --data-parallel-address 192.0.0.1 \ - --data-parallel-rpc-port 13389 \ - --tensor-parallel-size 8 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 16384 \ + --max-num-seqs 8 \ --enforce-eager \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 32768 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ + --gpu-memory-utilization 0.9 \ + --quantization ascend \ + --no-enable-prefix-caching \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeLayerwiseConnector", "kv_role": "kv_producer", @@ -157,36 +268,42 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ :sync: prefill node2 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.2 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=1024 +nic_name="eth0" # network card name +local_ip="192.0.0.2" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_BUFFSIZE=256 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 2 \ - --data-parallel-size-local 2 \ - --data-parallel-address 192.0.0.2 \ - --data-parallel-rpc-port 13389 \ - --tensor-parallel-size 8 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 16384 \ + --max-num-seqs 8 \ --enforce-eager \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 32768 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ + --gpu-memory-utilization 0.9 \ + --quantization ascend \ + --no-enable-prefix-caching \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeLayerwiseConnector", "kv_role": "kv_producer", @@ -208,42 +325,47 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ :::: -::::{tab-item} Decoder node 1 (master Node) +::::{tab-item} Decoder node 1 :sync: decoder node1 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.3 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=2048 +nic_name="eth0" # network card name +local_ip="192.0.0.3" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_BUFFSIZE=600 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 32 \ - --data-parallel-size-local 16 \ - --data-parallel-address 192.0.0.3 \ - --data-parallel-rpc-port 5964 \ - --tensor-parallel-size 1 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 512 \ - --max-num_seqs 16 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 256 \ + --max-num-seqs 40 \ --trust-remote-code \ + --gpu-memory-utilization 0.94 \ + --quantization ascend \ --no-enable-prefix-caching \ - --gpu-memory-utilization 0.9 \ - --compilation-config '{"cudagraph_capture_sizes":[16]}' \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": true,"lm_head_tensor_parallel_size":16}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeLayerwiseConnector", "kv_role": "kv_consumer", @@ -261,47 +383,50 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ } } }' -``` :::: -::::{tab-item} Decoder node 2 (primary node) +::::{tab-item} Decoder node 2 :sync: decoder node2 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.4 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=2048 +nic_name="eth0" # network card name +local_ip="192.0.0.4" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_BUFFSIZE=600 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --headless \ - --data-parallel-size 32 \ - --data-parallel-size-local 16 \ - --data-parallel-start-rank 16 \ - --data-parallel-address 192.0.0.3 \ - --data-parallel-rpc-port 5964 \ - --tensor-parallel-size 1 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 512 \ - --max-num_seqs 16 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 256 \ + --max-num-seqs 40 \ --trust-remote-code \ + --gpu-memory-utilization 0.94 \ + --quantization ascend \ --no-enable-prefix-caching \ - --gpu-memory-utilization 0.9 \ - --compilation-config '{"cudagraph_capture_sizes":[16]}' \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": true,"lm_head_tensor_parallel_size":16}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeLayerwiseConnector", "kv_role": "kv_consumer", @@ -309,6 +434,7 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ "engine_id": "2", "kv_connector_module_path": "vllm_ascend.distributed.mooncake_layerwise_connector", "kv_connector_extra_config": { + "prefill": { "dp_size": 2, "tp_size": 8 @@ -325,7 +451,7 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ ::::: -### Non-layerwise +#### Non-layerwise :::::{tab-set} :sync-group: nodes @@ -334,36 +460,42 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ :sync: prefill node1 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.1 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=1024 +nic_name="eth0" # network card name +local_ip="192.0.0.1" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_BUFFSIZE=256 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 2 \ - --data-parallel-size-local 2 \ - --data-parallel-address 192.0.0.1 \ - --data-parallel-rpc-port 13389 \ - --tensor-parallel-size 8 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 16384 \ + --max-num-seqs 8 \ --enforce-eager \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 32768 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ + --gpu-memory-utilization 0.9 \ + --quantization ascend \ + --no-enable-prefix-caching \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeConnector", "kv_role": "kv_producer", @@ -389,36 +521,42 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ :sync: prefill node2 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.2 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=1024 +nic_name="eth0" # network card name +local_ip="192.0.0.2" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_BUFFSIZE=256 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 2 \ - --data-parallel-size-local 2 \ - --data-parallel-address 192.0.0.2 \ - --data-parallel-rpc-port 13389 \ - --tensor-parallel-size 8 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 16384 \ + --max-num-seqs 8 \ --enforce-eager \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 32768 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ + --gpu-memory-utilization 0.9 \ + --quantization ascend \ + --no-enable-prefix-caching \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeConnector", "kv_role": "kv_producer", @@ -440,42 +578,47 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ :::: -::::{tab-item} Decoder node 1 (master node) +::::{tab-item} Decoder node 1 :sync: decoder node1 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.3 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=2048 +nic_name="eth0" # network card name +local_ip="192.0.0.3" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_BUFFSIZE=600 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --api-server-count 1 \ - --data-parallel-size 32 \ - --data-parallel-size-local 16 \ - --data-parallel-address 192.0.0.3 \ - --data-parallel-rpc-port 5964 \ - --tensor-parallel-size 1 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 512 \ - --max-num_seqs 16 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 256 \ + --max-num-seqs 40 \ --trust-remote-code \ + --gpu-memory-utilization 0.94 \ + --quantization ascend \ --no-enable-prefix-caching \ - --gpu-memory-utilization 0.9 \ - --compilation-config '{"cudagraph_capture_sizes":[16]}' \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": true,"lm_head_tensor_parallel_size":16}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeConnector", "kv_role": "kv_consumer", @@ -493,47 +636,50 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ } } }' -``` :::: -::::{tab-item} Decoder node 2 (primary Node) +::::{tab-item} Decoder node 2 :sync: decoder node2 ```shell -unset ftp_proxy -unset https_proxy -unset http_proxy -export HCCL_IF_IP=192.0.0.4 -export GLOO_SOCKET_IFNAME="eth0" # network card name -export TP_SOCKET_IFNAME="eth0" -export HCCL_SOCKET_IFNAME="eth0" -export HCCL_BUFFSIZE=2048 +nic_name="eth0" # network card name +local_ip="192.0.0.4" +export HCCL_IF_IP=$local_ip +export GLOO_SOCKET_IFNAME=$nic_name +export TP_SOCKET_IFNAME=$nic_name +export HCCL_SOCKET_IFNAME=$nic_name export OMP_PROC_BIND=false export OMP_NUM_THREADS=10 -export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH - -vllm serve /model/Qwen3-235B-A22B-W8A8 \ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export VLLM_ASCEND_ENABLE_MLAPO=1 +export HCCL_BUFFSIZE=600 +export TASK_QUEUE_ENABLE=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=$1 +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:$LD_LIBRARY_PATH +vllm serve /path_to_weight/DeepSeek-r1_w8a8_mtp \ --host 0.0.0.0 \ - --port 8004 \ - --headless \ - --data-parallel-size 32 \ - --data-parallel-size-local 16 \ - --data-parallel-start-rank 16 \ - --data-parallel-address 192.0.0.3 \ - --data-parallel-rpc-port 5964 \ - --tensor-parallel-size 1 \ + --port $2 \ + --data-parallel-size $3 \ + --data-parallel-rank $4 \ + --data-parallel-address $5 \ + --data-parallel-rpc-port $6 \ + --tensor-parallel-size $7 \ --enable-expert-parallel \ --seed 1024 \ - --distributed-executor-backend mp \ - --served-model-name qwen3-moe \ - --max-model-len 32768 \ - --max-num-batched-tokens 512 \ - --max-num_seqs 16 \ + --served-model-name ds_r1 \ + --max-model-len 40000 \ + --max-num-batched-tokens 256 \ + --max-num-seqs 40 \ --trust-remote-code \ + --gpu-memory-utilization 0.94 \ + --quantization ascend \ --no-enable-prefix-caching \ - --gpu-memory-utilization 0.9 \ - --compilation-config '{"cudagraph_capture_sizes":[16]}' \ + --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ + --additional-config '{"recompute_scheduler_enable":true,"multistream_overlap_shared_expert": true,"lm_head_tensor_parallel_size":16}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeConnector", "kv_role": "kv_consumer", @@ -557,9 +703,28 @@ vllm serve /model/Qwen3-235B-A22B-W8A8 \ ::::: +### Start the service + +```bash +# on 190.0.0.1 +python launch_online_dp.py --dp-size 2 --tp-size 8 --dp-size-local 2 --dp-rank-start 0 --dp-address 190.0.0.1 --dp-rpc-port 12321 --vllm-start-port 7100 +# on 190.0.0.2 +python launch_online_dp.py --dp-size 2 --tp-size 8 --dp-size-local 2 --dp-rank-start 0 --dp-address 190.0.0.2 --dp-rpc-port 12321 --vllm-start-port 7100 +# on 190.0.0.3 +python launch_online_dp.py --dp-size 32 --tp-size 1 --dp-size-local 16 --dp-rank-start 0 --dp-address 190.0.0.3 --dp-rpc-port 12321 --vllm-start-port 7100 +# on 190.0.0.4 +python launch_online_dp.py --dp-size 32 --tp-size 1 --dp-size-local 16 --dp-rank-start 16 --dp-address 190.0.0.3 --dp-rpc-port 12321 --vllm-start-port 7100 +``` + ## Example Proxy for Deployment -Run a proxy server on the same node with the prefiller service instance. You can get the proxy program in the repository's examples: [load\_balance\_proxy\_layerwise\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py) or [load\_balance\_proxy\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py) +Run a proxy server on the same node where your prefiller service instance is deployed. You can find the proxy implementation in the repository's examples directory. + +We provide two different proxy implementations with distinct request routing behaviors: + +- **`load_balance_proxy_layerwise_server_example.py`**: Requests are first routed to the D nodes, which then forward to the P nodes as needed.This proxy is designed for use with the MooncakeLayerwiseConnector.[load\_balance\_proxy\_layerwise\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py) + +- **`load_balance_proxy_server_example.py`**: Requests are first routed to the P nodes, which then forward to the D nodes for subsequent processing.This proxy is designed for use with the MooncakeConnector.[load\_balance\_proxy\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py) :::::{tab-set} @@ -567,12 +732,51 @@ Run a proxy server on the same node with the prefiller service instance. You can ```shell python load_balance_proxy_layerwise_server_example.py \ - --host 192.0.0.1 \ - --port 8080 \ - --prefiller-hosts 192.0.0.1 192.0.0.2\ - --prefiller-port 8004 8004\ - --decoder-hosts 192.0.0.3\ - --decoder-ports 8004 + --port 1999 \ + --host 192.0.0.1 \ + --prefiller-hosts \ + 192.0.0.1 \ + 192.0.0.1 \ + 192.0.0.2 \ + 192.0.0.2 \ + --prefiller-ports \ + 7100 7101 7100 7101 \ + --decoder-hosts \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + --decoder-ports \ + 7100 7101 7102 7103 7104 7105 7106 7107 7108 7109 7110 7111 7112 7113 7114 7115\ + 7100 7101 7102 7103 7104 7105 7106 7107 7108 7109 7110 7111 7112 7113 7114 7115\ ``` :::: @@ -581,18 +785,127 @@ python load_balance_proxy_layerwise_server_example.py \ ```shell python load_balance_proxy_server_example.py \ - --host 192.0.0.1 \ - --port 8080 \ - --prefiller-hosts 192.0.0.1 192.0.0.2\ - --prefiller-port 8004 8004\ - --decoder-hosts 192.0.0.3\ - --decoder-ports 8004 + --port 1999 \ + --host 192.0.0.1 \ + --prefiller-hosts \ + 192.0.0.1 \ + 192.0.0.1 \ + 192.0.0.2 \ + 192.0.0.2 \ + --prefiller-ports \ + 7100 7101 7100 7101 \ + --decoder-hosts \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.3 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + 192.0.0.4 \ + --decoder-ports \ + 7100 7101 7102 7103 7104 7105 7106 7107 7108 7109 7110 7111 7112 7113 7114 7115\ + 7100 7101 7102 7103 7104 7105 7106 7107 7108 7109 7110 7111 7112 7113 7114 7115\ ``` :::: ::::: +|Parameter | meaning | +| --- | --- | +| --port | Proxy service Port | +| --host | Proxy service Host IP| +| --prefiller-hosts | Hosts of prefiller nodes | +| --prefiller-ports | Ports of prefiller nodes | +| --decoder-hosts | Hosts of decoder nodes | +| --decoder-ports | Ports of decoder nodes | + +You can get the proxy program in the repository's examples, [load\_balance\_proxy\_server\_example.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py) + +## Benchmark + +We recommend use aisbench tool to assess performance. [aisbench](https://gitee.com/aisbench/benchmark) Execute the following commands to install aisbench + +```shell +git clone https://gitee.com/aisbench/benchmark.git +cd benchmark/ +pip3 install -e ./ +``` + +You need to canncel the http proxy before assessing performance, as following + +```shell +# unset proxy +unset http_proxy +unset https_proxy +``` + +- You can place your datasets in the dir: `benchmark/ais_bench/datasets` +- You can change the configurationin the dir :`benchmark/ais_bench/benchmark/configs/models/vllm_api` Take the ``vllm_api_stream_chat.py`` for examples + +```python +models = [ + dict( + attr="service", + type=VLLMCustomAPIChatStream, + abbr='vllm-api-stream-chat', + path="/root/.cache/ds_r1", + model="dsr1", + request_rate = 14, + retry = 2, + host_ip = "192.0.0.1", # Proxy service host IP + host_port = 8000, # Proxy service Port + max_out_len = 10, + batch_size=768, + trust_remote_code=True, + generation_kwargs = dict( + temperature = 0, + seed = 1024, + ignore_eos=False, + ) + ) +] +``` + +- Take gsm8k dataset for example, execute the following commands to assess performance. + +```shell +ais_bench --models vllm_api_stream_chat --datasets gsm8k_gen_0_shot_cot_str_perf --debug --mode perf +``` + +- For more details for commands and parameters for aisbench, refer to [aisbench](https://gitee.com/aisbench/benchmark) + +## FAQ + +### 1. Prefiller nodes need to warmup + +Since the computation of some NPU operators requires several rounds of warm-up to achieve best performance, we recommend preheating the service with some requests before conducting performance tests to achieve the best end-to-end throughput. + ## Verification Check service health using the proxy server endpoint. diff --git a/docs/source/tutorials/multi_npu_qwen3_next.md b/docs/source/tutorials/multi_npu_qwen3_next.md index 637fb4a61ca..325745ac3d4 100644 --- a/docs/source/tutorials/multi_npu_qwen3_next.md +++ b/docs/source/tutorials/multi_npu_qwen3_next.md @@ -49,17 +49,14 @@ The [Triton Ascend](https://gitee.com/ascend/triton-ascend) is required when you Install the Ascend BiSheng toolkit: ```bash -wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/Ascend-BiSheng-toolkit_aarch64.run -chmod a+x Ascend-BiSheng-toolkit_aarch64.run -./Ascend-BiSheng-toolkit_aarch64.run --install -source /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh +source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh ``` Install Triton Ascend: ```bash -wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl -pip install triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl +wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl +pip install triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl ``` :::: @@ -76,7 +73,7 @@ Coming soon ... Please make sure you have already executed the command: ```bash -source /usr/local/Ascend/8.3.RC1/bisheng_toolkit/set_env.sh +source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh ``` :::::{tab-set} diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index a8732b177c9..beb82b3ebc0 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -7,12 +7,13 @@ This section provides a detailed usage guide of vLLM Ascend features. :maxdepth: 1 graph_mode quantization +quantization-llm-compressor sleep_mode structured_output lora eplb_swift_balancer netloader dynamic_batch -kv_pool_mooncake +kv_pool external_dp ::: diff --git a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md b/docs/source/user_guide/feature_guide/kv_pool.md similarity index 84% rename from docs/source/user_guide/feature_guide/kv_pool_mooncake.md rename to docs/source/user_guide/feature_guide/kv_pool.md index 9188d7d1354..4b5ec13f8ac 100644 --- a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md +++ b/docs/source/user_guide/feature_guide/kv_pool.md @@ -1,4 +1,4 @@ -# Mooncacke Store Deployment Guide +# Ascend Store Deployment Guide ## Environmental Dependencies @@ -8,27 +8,30 @@ * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM:main branch * vLLM-Ascend:main branch - * Mooncake:main branch - - Installation and Compilation Guide:https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries - - Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine. - - An example command for compiling ADXL: - - `rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install` - - Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation ### KV Pooling Parameter Description **kv_connector_extra_config**: Additional Configurable Parameters for Pooling. -**mooncake_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration. +**lookup_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration. **load_async**: Whether to Enable Asynchronous Loading. The default value is false. -**register_buffer**: Whether to Register Video Memory with the Backend. Registration is Not Required When Used with MooncakeConnectorV1; It is Required in All Other Cases. The Default Value is false. +**backend**: Set the storage backend for kvpool, with the default being mooncake. + +## Example of using Mooncake as a KVCache pooling backend +* Software: + * Mooncake:main branch + + Installation and Compilation Guide:https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries -## Run Mooncake Master + Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine. -### 1.Configure mooncake.json + An example command for compiling ADXL: + + `rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install` + + Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation + +### run mooncake master + +#### 1.Configure mooncake.json The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located. @@ -54,7 +57,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path **master_server_address**: Configured with the IP and port of the master service. **global_segment_size**: Expands the kvcache size registered by the PD node to the master. -### 2. Start mooncake_master +#### 2. Start mooncake_master Under the mooncake folder: @@ -64,9 +67,9 @@ mooncake_master --port 50088 --eviction_high_watermark_ratio 0.95 --eviction_rat `eviction_high_watermark_ratio` determines the watermark where Mooncake Store will perform eviction,and `eviction_ratio` determines the portion of stored objects that would be evicted. -## Pooling and Prefill Decode Disaggregate Scenario +### Pooling and Prefill Decode Disaggregate Scenario -### 1.Run `prefill` Node and `decode` Node +#### 1.Run `prefill` Node and `decode` Node Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache. @@ -123,9 +126,10 @@ python3 -m vllm.entrypoints.openai.api_server \ } }, { - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_producer", - "mooncake_rpc_port":"0" + "lookup_rpc_port":"0", + "backend": "mooncake" } ] } @@ -185,16 +189,17 @@ python3 -m vllm.entrypoints.openai.api_server \ } }, { - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_consumer", - "mooncake_rpc_port":"1" + "lookup_rpc_port":"1", + "backend": "mooncake" } ] } }' > d.log 2>&1 ``` -### 2、Start proxy_server. +#### 2、Start proxy_server. ``` bash proxy.sh @@ -212,7 +217,7 @@ python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_e --decoder-ports 8200 \ ``` -### 3. Run Inference +#### 3. Run Inference Configure the localhost, port, and model weight path in the command to your own settings. @@ -228,9 +233,9 @@ Long question: curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' ``` -## Pooling and Mixed Deployment Scenario +### Pooling and Mixed Deployment Scenario -### 1、Run Mixed Department Script +#### 1、Run Mixed Department Script The mixed script is essentially a pure pooling scenario for the P node. @@ -263,19 +268,17 @@ python3 -m vllm.entrypoints.openai.api_server \ --max-num-batched-tokens 4096 \ --kv-transfer-config \ '{ - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_both", "kv_connector_extra_config": { - "register_buffer": true, "use_layerwise": false, - "mooncake_rpc_port":"0" + "lookup_rpc_port":"1", + "backend": "mooncake" } }' > mix.log 2>&1 ``` -`register_buffer` is set to `false` by default and need to be set to `true` only in PD-mixed scenario. - -### 2. Run Inference +#### 2. Run Inference Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy. diff --git a/docs/source/user_guide/feature_guide/quantization-llm-compressor.md b/docs/source/user_guide/feature_guide/quantization-llm-compressor.md new file mode 100644 index 00000000000..a97b4de2940 --- /dev/null +++ b/docs/source/user_guide/feature_guide/quantization-llm-compressor.md @@ -0,0 +1,65 @@ +# llm-compressor Quantization Guide + +Model quantization is a technique that reduces the size and computational requirements of a model by lowering the data precision of the weights and activation values in the model, thereby saving the memory and improving the inference speed. + +## Supported llm-compressor Quantization Types + +Support CompressedTensorsW8A8 static weight + +weight: per-channel, int8, symmetric; activation: per-tensor, int8, symmetric. + +Support CompressedTensorsW8A8Dynamic weight + +weight: per-channel, int8, symmetric; activation: per-token, int8, symmetric, dynamic. + +## Install llm-compressor + +To quantize a model, you should install [llm-compressor](https://github.com/vllm-project/llm-compressor/blob/main/README.md). It is a unified library for creating compressed models for faster inference with vLLM. + +Install llm-compressor + +```bash +pip install llmcompressor +``` + +### Generate the W8A8 weights + +```bash +cd examples/quantization/llm-compressor + +python3 w8a8_int8_dynamic.py +``` + +for more details, see the [Official Sample](https://github.com/vllm-project/llm-compressor/tree/main/examples). + +## Run the model + +Now, you can run the quantized model with vLLM Ascend. Examples for online and offline inference are provided as follows: + +### Offline inference + +```python +import torch + +from vllm import LLM, SamplingParams + +prompts = [ + "Hello, my name is", + "The future of AI is", +] +sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40) + +llm = LLM(model="{quantized_model_save_path}", + max_model_len=2048, + trust_remote_code=True) + +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + +### Online inference + +Start the quantized model using vLLM Ascend; no modifications to the startup command are required. diff --git a/examples/quantization/llm-compressor/w8a8_int8.py b/examples/quantization/llm-compressor/w8a8_int8.py new file mode 100644 index 00000000000..9a6cb392f0c --- /dev/null +++ b/examples/quantization/llm-compressor/w8a8_int8.py @@ -0,0 +1,160 @@ +import os +import torch + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, \ + AutoTokenizer, AutoProcessor, AutoConfig, AutoImageProcessor + +from llmcompressor import oneshot +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationType, QuantizationStrategy + +W8A8_W_cha_A_ten_static_symmetric = { + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False + ), + ), +} + +# supported modifiers +MODIFIER_DICT = { + "PTQ": QuantizationModifier, + "AWQ": AWQModifier, + "GPTQ": GPTQModifier, +} + +# supported schemes +SCHEMES_DICT = { + "W8A8_W_cha_A_ten_static_symmetric": W8A8_W_cha_A_ten_static_symmetric, +} + +MODEL_DICT = { + "qwen3": AutoModelForCausalLM, +} + +TOKENIZER_DICT = { + "qwen3": AutoTokenizer, +} + + +def load_environment_variables(): + env_vars = { + 'model_path': "Qwen/Qwen3-32B", + 'export_path': "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric", + 'modifier': "GPTQ", + 'schemes': "W8A8_W_cha_A_ten_static_symmetric", + 'calib_prompt_path': "HuggingFaceH4/ultrachat_200k" + } + + # verify export model path + if env_vars['export_path'] is None: + env_vars['export_path'] = env_vars['model_path'].rstrip("/") + "-" + env_vars['modifier'] + if env_vars['schemes'] is not None: + env_vars['export_path'] += "-" + env_vars['schemes'] + os.makedirs(env_vars['export_path'], exist_ok=True) + + return env_vars + + +def load_calibration_text_dataset(calib_prompt_path, tokenizer): + # Load dataset + for f in os.listdir(calib_prompt_path): + print(f) + if any(f.lower().endswith('.jsonl') for f in os.listdir(calib_prompt_path)): + ds = load_dataset('json', data_dir=calib_prompt_path, split='validation') + elif any(f.lower().endswith('.parquet') for f in os.listdir(calib_prompt_path)): + ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]") + else: + raise ValueError("Unsupported calibration file format: {}".format( + calib_prompt_path.split('.')[-1])) + + # Preprocess dataset + def preprocess(example): + if tokenizer.chat_template is not None: + return {"text": tokenizer.apply_chat_template( + example["messages"], tokenize=False)} + else: + return {"text": example["messages"]} + + # Tokenize inputs + def tokenize(sample): + return tokenizer( + sample["text"], + add_special_tokens=False, + ) + + ds = ds.map(preprocess) + ds = ds.map(tokenize, remove_columns=ds.column_names) + return ds + + +# Define a oneshot data collator for multimodal inputs. +def data_collator(batch): + assert len(batch) == 1 + return { + key: torch.tensor(value, dtype=torch.bfloat16 if key == "pixel_values" else torch.long) + for key, value in batch[0].items() + } + + +def quantize_model(model, env_vars, dataset_dict=None): + # since the MoE gate layers are sensitive to quantization, we add them to the ignore + # list so they remain at full precision + ignore = ["lm_head", "re:.*mlp.down_proj"] + + # define a llmcompressor recipe + recipe = [ + MODIFIER_DICT[env_vars['modifier']]( + config_groups=SCHEMES_DICT[env_vars['schemes']], + ignore=ignore, + ), + ] + + # quantize the model + oneshot( + model=model, + dataset=dataset_dict, + recipe=recipe, + trust_remote_code_model=True, + ) + + +def save_quantized_model(model, tokenizer, save_path, save_compressed=False): + model.save_pretrained(save_path, save_compressed=save_compressed) + tokenizer.save_pretrained(save_path) + + +if __name__ == '__main__': + # get environment variables + env_vars = load_environment_variables() + + # support model type list + config = AutoConfig.from_pretrained(env_vars['model_path'], trust_remote_code=True) + model_type = config.model_type + + model = MODEL_DICT[model_type].from_pretrained( + env_vars['model_path'], torch_dtype="auto", trust_remote_code=True + ) + tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars['model_path'], trust_remote_code=True) + + ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer) + + # Quantize the model + quantize_model(model, env_vars, ds) + + # save the quantized model + save_quantized_model(model, tokenizer, env_vars['export_path'], True) \ No newline at end of file diff --git a/examples/quantization/llm-compressor/w8a8_int8_dynamic.py b/examples/quantization/llm-compressor/w8a8_int8_dynamic.py new file mode 100644 index 00000000000..1cc9d21c663 --- /dev/null +++ b/examples/quantization/llm-compressor/w8a8_int8_dynamic.py @@ -0,0 +1,83 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure algorithms. In this case, we: +# * apply SmoothQuant to make the activations easier to quantize +# * quantize the weights to int8 with GPTQ (static per channel) +# * quantize the activations to int8 (dynamic per token) +recipe = [ + SmoothQuantModifier(smoothing_strength=0.8), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), +] + +# Apply algorithms and save to output_dir +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("npu") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-Dynamic-Per-Token" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 3888115886b..7778a6f1bde 100644 --- a/mypy.ini +++ b/mypy.ini @@ -15,6 +15,15 @@ ignore_missing_imports = True [mypy-lm_eval.*] ignore_missing_imports = True +[mypy-compressed_tensors.*] +ignore_missing_imports = True + +[mypy-datasets.*] +ignore_missing_imports = True + +[mypy-llmcompressor.*] +ignore_missing_imports = True + [mypy-msprobe.*] ignore_missing_imports = True -allow_untyped_imports = True \ No newline at end of file +allow_untyped_imports = True diff --git a/pyproject.toml b/pyproject.toml index 1fa9d15fc0e..a10ff9a834d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,11 @@ [build-system] # Should be mirrored in requirements.txt requires = [ + "attrs", "cmake>=3.26", "decorator", "einops", + "googleapis-common-protos", "numpy<2.0.0", "packaging", "pip", @@ -12,6 +14,7 @@ requires = [ "scipy", "pandas", "pandas-stubs", + "psutil", "setuptools>=64", "setuptools-scm>=8", "transformers<=4.57.1", @@ -23,6 +26,7 @@ requires = [ "quart", "numba", "opencv-python-headless<=4.11.0.86", # Required to avoid numpy version conflict with vllm + "compressed_tensors>=0.11.0" ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 41a143902f7..2a176f84727 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ torchvision wheel pandas-stubs opencv-python-headless<=4.11.0.86 # Required to avoid numpy version conflict with vllm +compressed_tensors>=0.11.0 # requirements for disaggregated prefill msgpack diff --git a/setup.py b/setup.py index 0cee690e618..1bf800813e4 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ from sysconfig import get_paths from typing import Dict, List -from setuptools import Extension, find_packages, setup +from setuptools import Command, Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py from setuptools.command.develop import develop @@ -199,6 +199,27 @@ def run(self): super().run() +class build_and_install_aclnn(Command): + description = "Build and install AclNN by running build_aclnn.sh" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + try: + print("Running bash build_aclnn.sh ...") + subprocess.check_call( + ["bash", "csrc/build_aclnn.sh", ROOT_DIR, envs.SOC_VERSION]) + print("buid_aclnn.sh executed successfully!") + except subprocess.CalledProcessError as e: + print(f"Error running build_aclnn.sh: {e}") + raise SystemExit(e.returncode) + + class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. did_config: Dict[str, bool] = {} @@ -385,8 +406,22 @@ def target_name(s: str) -> str: shutil.copy(src_path, dst_path) print(f"Copy: {src_path} -> {dst_path}") + # copy back _cann_ops_custom directory + src_cann_ops_custom = os.path.join(ROOT_DIR, "vllm_ascend", + "_cann_ops_custom") + dst_cann_ops_custom = os.path.join(self.build_lib, "vllm_ascend", + "_cann_ops_custom") + if os.path.exists(src_cann_ops_custom): + import shutil + if os.path.exists(dst_cann_ops_custom): + shutil.rmtree(dst_cann_ops_custom) + shutil.copytree(src_cann_ops_custom, dst_cann_ops_custom) + print(f"Copy: {src_cann_ops_custom} -> {dst_cann_ops_custom}") + def run(self): - # First, run the standard build_ext command to compile the extensions + # First, ensure ACLNN custom-ops is built and installed. + self.run_command("build_aclnn") + # Then, run the standard build_ext command to compile the extensions super().run() @@ -450,6 +485,7 @@ def _read_requirements(filename: str) -> List[str]: cmdclass = { "develop": custom_develop, "build_py": custom_build_info, + "build_aclnn": build_and_install_aclnn, "build_ext": cmake_build_ext, "install": custom_install } diff --git a/tests/e2e/multicard/test_chunk_gated_delta_rule.py b/tests/e2e/multicard/test_chunk_gated_delta_rule.py new file mode 100644 index 00000000000..a0e4b6ef9df --- /dev/null +++ b/tests/e2e/multicard/test_chunk_gated_delta_rule.py @@ -0,0 +1,33 @@ +import torch + +from tests.ut.base import PytestBase +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule + + +class TestChunkGatedDeltaRule(PytestBase): + + def test_triton_fusion_ops(self, mock_moe_env): + q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu() + k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu() + v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu() + g = torch.randn(1, 17, 8, dtype=torch.float32).npu() + beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu() + initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu() + q_start_loc = torch.range(0, 3, dtype=torch.int).npu() + + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule(q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=q_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True) + + assert core_attn_out_non_spec.shape == (1, 17, 8, 128) + assert last_recurrent_state.shape == (3, 8, 128, 128) diff --git a/tests/e2e/multicard/test_prefix_caching.py b/tests/e2e/multicard/test_prefix_caching.py index e5660c4d331..e29916623ba 100644 --- a/tests/e2e/multicard/test_prefix_caching.py +++ b/tests/e2e/multicard/test_prefix_caching.py @@ -58,6 +58,7 @@ ] +@pytest.mark.skip(reason="Fix me, the accuracy is not correct") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [50]) def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None: diff --git a/tests/e2e/multicard/test_quantization.py b/tests/e2e/multicard/test_quantization.py new file mode 100644 index 00000000000..67c57daf09e --- /dev/null +++ b/tests/e2e/multicard/test_quantization.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/e2e/multicard/test_quantization.py`. +""" +from modelscope import snapshot_download # type: ignore + +from tests.e2e.conftest import VllmRunner + + +def test_models_distributed_quantized_W8A8(): + example_prompts = [ + "The president of the United States is", + ] + max_tokens = 5 + with VllmRunner(snapshot_download("neuralmagic/Qwen2.5-3B-quantized.w8a8"), + tensor_parallel_size=2, + max_model_len=4096, + gpu_memory_utilization=0.8, + enforce_eager=False) as vllm_model: + vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + golden_results = [ + 'The president of the United States is the head of state and', + ] + + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index a162191c0fe..e51748ea1e2 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -24,7 +24,6 @@ import os from unittest.mock import patch -import pytest from modelscope import snapshot_download # type: ignore from tests.e2e.conftest import VllmRunner @@ -64,7 +63,6 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): del vllm_model -@pytest.mark.skip(reason="Fix me, the accuracy is not correct") def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): example_prompts = [ "Hello, my name is", @@ -74,11 +72,14 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): ] max_tokens = 20 - with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", - tensor_parallel_size=4, - max_model_len=4096, - gpu_memory_utilization=0.8, - distributed_executor_backend="mp") as vllm_model: + with VllmRunner( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + enforce_eager=True, + ) as vllm_model: ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model @@ -87,6 +88,7 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): max_model_len=4096, gpu_memory_utilization=0.8, distributed_executor_backend="mp", + enforce_eager=True, additional_config={ "ascend_scheduler_config": { "enabled": True, diff --git a/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py b/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py new file mode 100644 index 00000000000..7e87e6d41e1 --- /dev/null +++ b/tests/e2e/nightly/ops/test_gmm_swiglu_quant_weight_nz_tensor_list.py @@ -0,0 +1,148 @@ +import gc + +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +# enable internal format +torch_npu.npu.config.allow_internal_format = True +# enable vllm-ascend custom ops +enable_custom_op() + + +def gmm_swiglu_quant(x: torch.Tensor, weight: torch.Tensor, + perChannelScale: torch.Tensor, + perTokenScale: torch.Tensor, m: int): + """ + Perform quantized GMM (Grouped Matrix Multiplication) operation with SwiGLU activation function. + + Parameters: + x (torch.Tensor): Input tensor with shape (m, k). + weight (torch.Tensor): Weight tensor with shape (k, n). + perChannelScale (torch.Tensor): Per-channel scaling factor with shape (n,). + perTokenScale (torch.Tensor): Per-token scaling factor with shape (m,). + m (int): Number of tokens (rows of x). + + Returns: + quantOutput (torch.Tensor): Quantized output tensor with shape (m, k // 2). + quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (m,). + """ + # Perform matrix multiplication with int32 precision + c_temp1 = torch.matmul(x.to(torch.int32), weight.to(torch.int32)) + c_temp1 = c_temp1.to(torch.float32) # Convert back to float32 for scaling + + # Apply per-channel and per-token scaling + c_temp2 = torch.mul(c_temp1, perChannelScale) + c_temp3 = torch.mul(c_temp2, perTokenScale.reshape(m, 1)) + + # Split the result into two parts to apply SwiGLU activation function + c_temp4, gate = c_temp3.chunk(2, dim=-1) + c_temp5 = c_temp4 * torch.sigmoid(c_temp4) # SwiGLU activation + c_temp6 = c_temp5 * gate # Element-wise multiplication with gating values + + # Quantize the output + max = torch.max( + torch.abs(c_temp6), + -1).values # Find maximum absolute value to calculate scaling factor + quantScaleOutput = 127 / max # Calculate quantization scaling factor + quantOutput = torch.round(c_temp6 * quantScaleOutput.reshape(m, 1)).to( + torch.int8) # Quantize to int8 + quantScaleOutput = 1 / quantScaleOutput # Inverse quantization scaling factor for subsequent dequantization + + return quantOutput, quantScaleOutput + + +def process_groups(x: torch.Tensor, weight: torch.Tensor, + perChannelScale: torch.Tensor, perTokenScale: torch.Tensor, + groupList: torch.Tensor): + """ + Process input data by groups and call GMM_Swiglu_quant function for quantized computation. + + Parameters: + x (torch.Tensor): Input tensor with shape (M, K). + weight (torch.Tensor): List of weight tensors, each with shape (E, K, N). + perChannelScale (torch.Tensor): List of per-channel scaling factors, each with shape (E, N). + perTokenScale (torch.Tensor): Per-token scaling factor with shape (M,). + groupList (list): List defining the number of tokens in each group. + + Returns: + quantOutput (torch.Tensor): Quantized output tensor with shape (M, N // 2). + quantScaleOutput (torch.Tensor): Quantization scaling factor with shape (M,). + """ + M, N = x.shape[0], weight.shape[2] # Get the shape of the input tensor + quantOutput = torch.zeros(M, N // 2).to( + torch.int8) # Initialize quantized output tensor + quantScaleOutput = torch.zeros(M).to( + torch.float32) # Initialize quantization scaling factor tensor + + start_idx = 0 # Starting index + preV = 0 # Number of tokens in the previous group + groupList = groupList.tolist() + # Iterate through groupList to process data by groups + for i, v in enumerate(groupList): + currV = v + tempV = currV - preV # Calculate number of tokens in the current group + preV = currV # Update number of tokens in the previous group + if tempV > 0: + # Call GMM_Swiglu_quant to process the current group + quantOutput[start_idx:start_idx + tempV], quantScaleOutput[start_idx:start_idx + tempV] = \ + gmm_swiglu_quant(x[start_idx:start_idx + tempV], + weight[i], + perChannelScale[i], + perTokenScale[start_idx:start_idx + tempV], + tempV) + + start_idx += tempV # Update starting index to process the next group + return quantOutput, quantScaleOutput + + +@torch.inference_mode() +def test_gmm_swiglu_quant_weight_nz_tensor_list(): + M, K, E, N = 8192, 7168, 4, 4096 + + # x (M, K) - int8 + x = torch.randint(-128, 127, (M, K), dtype=torch.int8) + + # weight (E, N, K) - int8 + weight = torch.randint(-128, 127, size=(E, K, N), dtype=torch.int8) + + # weight_scale (E, N) - float32 + weight_scale = torch.rand(E, N) * 0.9 + 0.1 # uniform(0.1, 1.0) + weight_scale = weight_scale.to(torch.float32) + + weight_nz_npu = [] + weight_scale_npu = [] + for i in range(E): + weight_nz_npu.append(torch_npu.npu_format_cast(weight[i].npu(), 29)) + weight_scale_npu.append(weight_scale[i].npu()) + + # x_scale (M,) - float32 + x_scale = torch.rand(M) * 0.9 + 0.1 # uniform(0.1, 1.0) + x_scale = x_scale.to(torch.float32) + + group_list = torch.tensor([2048, 4096, 6144, 8192], dtype=torch.int64) + + output_cpu, output_scale_cpu = process_groups(x, weight, weight_scale, + x_scale, group_list) + output_npu, output_scale_npu, _ = \ + torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(x.npu(), + weight_nz_npu, + weight_scale_npu, + x_scale.npu(), + group_list.npu()) + output_npu_valid = output_npu[:group_list[-1], :] + output_scale_npu_valid = output_scale_npu[:group_list[-1]] + + torch.testing.assert_close(output_npu_valid.cpu(), + output_cpu, + atol=1, + rtol=2**-13) + torch.testing.assert_close(output_scale_npu_valid.cpu(), + output_scale_cpu, + atol=1e-9, + rtol=1e-6) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py b/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py new file mode 100644 index 00000000000..28e724bbb93 --- /dev/null +++ b/tests/e2e/nightly/ops/test_grouped_matmul_swiglu_quant.py @@ -0,0 +1,175 @@ +import gc + +import numpy as np +import torch +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +def x_int8_to_x_int4(x: torch.Tensor): + m, k = x.shape + x_high_4bit = torch.floor(x.to(torch.float16) // 16).to(torch.int8) + x_low_4bit = ( + torch.bitwise_and(x.view(torch.int16), 0x0f0f).view(torch.int8) - 8) + x_int4 = torch.empty((2 * m, k), dtype=torch.int8) + x_int4[::2, :] = x_high_4bit + x_int4[1::2, :] = x_low_4bit + return x_int4 + + +def custom_mm(x: torch.Tensor, weight: torch.Tensor, + weight_scale: torch.Tensor, m: int): + """ + Performing Quantized GMM (General Matrix Multiplication) Operation + Parameters: + x (torch.Tensor): Input tensor with shape (m, k). + weight (torch.Tensor): Weight tensor with shape (k, n). + weight_scale (torch.Tensor): Scaling factor for each channel. + - In perGroup scenario: Shape is (k_group_num, n). Note: When k_group_num == 1, it is a perChannel scenario. + - In perChannel scenario: Shape is (n). + m (int): Number of tokens (number of rows in x). + Returns: + mm_out(fp16): Result of MatMul + perGroup or perChannel dequantization. + """ + # Perform matrix multiplication with int32 precision + k, n = weight.shape + mm_out = torch.zeros((m, n), dtype=torch.float16) + # perGroup scenario + if len(weight_scale.shape) == 2 and weight_scale.shape[0] != 1: + k_group = weight_scale.shape[0] + per_group_ele = k // k_group + x_grouped = x.view(-1, k_group, per_group_ele).transpose(0, 1) + weight_grouped = weight.view(k_group, per_group_ele, n) + + c_temp = torch.bmm(x_grouped.to(torch.int32), + weight_grouped.to(torch.int32)).to(torch.float16) + for k_idx in range(k_group): + mm_out += (c_temp[k_idx] * + weight_scale[k_idx].view(1, -1).to(torch.float16)).to( + torch.float16) + # perChannel scenario + elif len(weight_scale.shape) == 1 or (len(weight_scale.shape) == 2 + and weight_scale.shape[0] == 1): + c_temp = torch.matmul(x.to(torch.int32), + weight.to(torch.int32)).to(torch.float32) + mm_out = c_temp * weight_scale.view(1, -1).to(torch.float16) + return mm_out.to(torch.float32) + + +def gmm_swiglu_quant_golden_a8_w4(x: torch.Tensor, weight: torch.Tensor, + weight_scale: torch.Tensor, + per_token_scale: torch.Tensor, + bias: torch.Tensor, + group_list: torch.Tensor): + """ + Process the input data by group and call the GMM_Swiglu_quant function for quantization computation. + Parameters: + x (torch.Tensor): Input tensor with shape (M, K), type INT8. + weight (torch.Tensor): List of weight tensors, each with shape (E, K, N), data type INT8 but data range INT4, representing INT4 values. + weight_scale (torch.Tensor): Scaling factor for each channel. + - In perGroup scenario: shape (E, k_group_num, N). + - In perChannel scenario: shape (E, N). + per_token_scale (torch.Tensor): Scaling factor for each token, shape (M, ). + bias: torch.Tensor, + group_list (list): List defining the number of tokens in each group. + Returns: + quant_output (torch.Tensor): Quantized output tensor with shape (M, N // 2). + quant_scale_output (torch.Tensor): Quantization scaling factor, shape (M, ). + """ + M, N = x.shape[0], weight.shape[2] + quant_output = torch.zeros(M, N // 2).to(torch.int8) + quant_scale_output = torch.zeros(M).to(torch.float32) + # Preprocessing X_INT8 -> X_INT4 + x_int4 = x_int8_to_x_int4(x) + start_idx = 0 + # Number of tokens in the previous group + pre_v = 0 + group_list = group_list.tolist() + # Traverse group_list and process data by group + for i, v in enumerate(group_list): + curr_v = v + # Calculate the number of tokens in the current group " * 2 " because 1 row of Int8--> 2 rows of Int4 + temp_v = int((curr_v - pre_v) * 2) + # Update the number of tokens in the previous group + pre_v = curr_v + if (temp_v > 0): + mm_out = custom_mm(x_int4[int(start_idx):int(start_idx + temp_v)], + weight[i], weight_scale[i], temp_v) + mm_num_concat = ((mm_out[::2] * 16 + mm_out[1::2]) + + bias[i].view(1, -1)) + per_token_quant = mm_num_concat * per_token_scale[start_idx // 2:( + start_idx + temp_v) // 2].view(-1, 1) + swiglu, gate = per_token_quant.chunk(2, dim=-1) + temp = swiglu * torch.sigmoid(swiglu) + temp = temp * gate + max_value = torch.max(torch.abs(temp), dim=-1).values + quant_scale_output_temp = 127 / max_value + quant_output[start_idx // 2:(start_idx + temp_v) // + 2] = torch.round(temp * + quant_scale_output_temp.reshape( + temp_v // 2, 1)).to(torch.int8) + quant_scale_output[start_idx // 2:(start_idx + temp_v) // + 2] = 1 / quant_scale_output_temp + start_idx += temp_v + return quant_output, quant_scale_output + + +def generate_non_decreasing_sequence(length, upper_limit): + # Generate random increasing sequence + random_increments = torch.randint(0, 128, (length, )) + sequence = torch.cumsum(random_increments, dim=0) + + # Make sure the last value is less than the upper limit + if sequence[-1] >= upper_limit: + scale_factor = upper_limit / sequence[-1] + sequence = (sequence * scale_factor).to(torch.int64) + return sequence + + +@torch.inference_mode() +def test_grouped_matmul_swiglu_quant_kernel(): + E = 16 + M = 512 + K = 7168 + N = 4096 + torch.npu.config.allow_internal_format = True + x = torch.randint(-5, 5, (M, K), dtype=torch.int8).npu() + weight_ori = torch.randint(-5, 5, (E, K, N), dtype=torch.int8) + weight_nz = torch_npu.npu_format_cast(weight_ori.npu().to(torch.float32), + 29) + pack_weight = torch_npu.npu_quantize(weight_nz, + torch.tensor([1.], device='npu'), + None, torch.quint4x2, -1, False) + + weight_scale = torch.randn(E, 1, N) + scale_np = weight_scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() + pertoken_scale = torch.randn(M).to(torch.float32).npu() + group_list = generate_non_decreasing_sequence(E, M).npu() + bias = torch.zeros((E, N), dtype=torch.float32, + device="npu").uniform_(-5, 5) + + output_golden, output_scale_golden = gmm_swiglu_quant_golden_a8_w4( + x.cpu(), weight_ori, weight_scale, pertoken_scale.cpu(), bias.cpu(), + group_list.cpu()) + + output, output_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant( + x=x, + weight=pack_weight, + bias=bias, + group_list=group_list, + weight_scale=scale_uint64_tensor, + x_scale=pertoken_scale) + torch.testing.assert_close(output_golden, output.cpu(), atol=1, rtol=0.005) + torch.testing.assert_close(output_scale_golden, + output_scale.cpu(), + atol=1, + rtol=0.005) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index 17d1f4a4dfd..60cb3c16fa2 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -137,7 +137,7 @@ def test_models_with_aclgraph_full_decode_only( vllm_aclgraph_qwen_answers = [ ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle formed by two random points on a square's perimeter is", - 'i$.\n\nLet $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$ and' + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' ] vllm_aclgraph_ds_answers = [ diff --git a/tests/e2e/vllm_interface/vllm_test.cfg b/tests/e2e/vllm_interface/vllm_test.cfg index 9723d49cad7..dfd540384bc 100644 --- a/tests/e2e/vllm_interface/vllm_test.cfg +++ b/tests/e2e/vllm_interface/vllm_test.cfg @@ -1,2 +1,2 @@ # Base docker image used to build the vllm-ascend e2e test image, which is built in the vLLM repository -BASE_IMAGE_NAME="quay.io/ascend/cann:8.3.rc1-910b-ubuntu22.04-py3.11" +BASE_IMAGE_NAME="quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11" diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py index c8139b7167b..9bd4cd0e304 100644 --- a/tests/ut/attention/test_attention_mask.py +++ b/tests/ut/attention/test_attention_mask.py @@ -74,10 +74,11 @@ def test_get_attn_mask(self): attn_mask = attention_mask_builder.get_attn_mask( max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) self.assertEqual(attn_mask.shape, (2048, 2048)) - self.assertEqual(attn_mask[0][-1], torch.tensor(True)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 2048) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) + (2048, 2048)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py index 4408b41a825..bd8d07930f4 100644 --- a/tests/ut/distributed/mooncake/test_config_data.py +++ b/tests/ut/distributed/mooncake/test_config_data.py @@ -1,6 +1,13 @@ +import sys +import types import unittest +from unittest.mock import MagicMock -from vllm_ascend.distributed.mooncake.config_data import ( +fake_engine = types.ModuleType("mooncake.engine") +fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] +sys.modules["mooncake.engine"] = fake_engine + +from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402 _convert_to_bytes, _parse_global_segment_size) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 8d21a02b9e3..a0edff8e3f3 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1051,7 +1051,7 @@ def setUp(self): 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', mock_string_to_int64_hash), patch( - 'vllm_ascend.distributed.mooncake.transfer_engine.TransferEngine', + 'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine', return_value=self.mock_transfer_engine), patch( 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py deleted file mode 100644 index 7111aaed6c8..00000000000 --- a/tests/ut/models/test_qwen2_5_vl.py +++ /dev/null @@ -1,488 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from pytest_mock import MockerFixture - -from tests.ut.base import PytestBase -from vllm_ascend.models.qwen2_5_vl import ( - AscendQwen2_5_VisionAttention, AscendQwen2_5_VisionBlock, - AscendQwen2_5_VisionPatchEmbed, AscendQwen2_5_VisionRotaryEmbedding, - AscendQwen2_5_VisionTransformer, AscendQwen2_5_VLForConditionalGeneration) - - -class TestAscendQwen2_5_VisionAttention(PytestBase): - - def init_attention( - self, - mocker, - embed_dim=1000, - num_heads=10, - projection_size=100, - quant_config=None, - prefix="", - ): - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionAttention.__init__") - - attention = AscendQwen2_5_VisionAttention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - args, kwargs = mocker_attn.call_args - assert args == (embed_dim, num_heads, projection_size, None, "") - assert not kwargs - attention.num_attention_heads_per_partition = num_heads - return attention - - def test_attn_init_should_normal(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 10 - projection_size = 100 - quant_config = None - prefix = "" - vit = self.init_attention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - mocker=mocker, - ) - assert vit.embed_dim == 1000 - assert vit.hidden_size_per_attention_head == 10 - - def test_attn_init_should_raise_error(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 7 - projection_size = 100 - quant_config = None - prefix = "" - with pytest.raises(AssertionError): - # projection_size should divided by num heads - self.init_attention( - mocker=mocker, - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - - def test_split_qkv(self, mocker: MockerFixture): - attention = self.init_attention(mocker=mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - q, k, v = attention.split_qkv(torch.rand((100, 10, 300))) - assert q.shape == (100, 10, 10, 10) - assert k.shape == (100, 10, 10, 10) - assert v.shape == (100, 10, 10, 10) - - def test_attn_forward(self, mocker: MockerFixture): - attention = self.init_attention(mocker=mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - - qkv = lambda x: (x, 0) # noqa - split_qkv = lambda x: [ #noqa - torch.rand((100, 3, 10, 128)) for i in range(3) - ] # noqa - npu_rotary_mul = lambda q, cos, sin: q # noqa - _npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa - proj = lambda x: (x, 0) # noqa - - mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv) - mocker_split_qkv = mocker.patch.object( - attention, - "split_qkv", - side_effect=split_qkv, - ) - mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul", - side_effect=npu_rotary_mul) - mocker_npu_flash_attention_unpad = mocker.patch( - "torch_npu._npu_flash_attention_unpad", - side_effect=_npu_flash_attention_unpad, - ) - mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj) - attention.__dict__["qkv"] = mocker_qkv - attention.__dict__["split_qkv"] = mocker_split_qkv - attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul - attention.__dict__["_npu_flash_attention_unpad"] = ( - mocker_npu_flash_attention_unpad) - attention.__dict__["proj"] = mocker_proj - - output = attention.forward( - x=x, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - qkv_args, qkv_kwargs = mocker_qkv.call_args - assert qkv_args == (x, ) - assert not qkv_kwargs - - split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args - assert split_qkv_args == (x, ) - assert not split_qkv_kwargs - - npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args - assert npu_rotary_mul_args[1:] == (cos, sin) - assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128]) - assert not npu_rotary_mul_kwargs - - assert output.shape == torch.Size([100, 3, 1280]) - - -class TestAscendQwen2_5_VisionBlock(PytestBase): - - def init_vision_block( - self, - mocker, - dim=100, - num_heads=10, - mlp_hidden_dim=100, - ): - mocker_vit = mocker.patch( - "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__", - return_value=None, - ) - - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionAttention.__init__", - return_value=None, - ) - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - vision_block = AscendQwen2_5_VisionBlock( - dim=dim, - num_heads=num_heads, - mlp_hidden_dim=mlp_hidden_dim, - ) - args, kwargs = mocker_vit.call_args - assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "") - assert not kwargs - - args1, kwargs1 = mocker_attn.call_args - assert not args1 - assert kwargs1 == { - "embed_dim": dim, - "num_heads": num_heads, - "projection_size": dim, - "quant_config": None, - "prefix": ".attn", - } - return vision_block - - def test_init_vision_block_should_normal( - self, - mocker: MockerFixture, - ): - vision_block = self.init_vision_block(mocker) - assert isinstance(vision_block, AscendQwen2_5_VisionBlock) - - def test_vision_block_forward(self, mocker: MockerFixture): - x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - vision_block = self.init_vision_block(mocker) - mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x) - mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x) - vision_block.__dict__["attn"] = mocker_attn - vision_block.__dict__["mlp"] = mocker_mlp - - output = vision_block.forward(x.clone(), cu_seqlens, cos, sin) - - _, attn_kwargs = mocker_attn.call_args - assert attn_kwargs == { - "cu_seqlens": cu_seqlens, - "cos": cos, - "sin": sin, - } - - assert torch.all(x * 3 == output) - - -class TestAscendQwen2_5_VisionPatchEmbed(PytestBase): - - def test_forward(self): - patch_embed = AscendQwen2_5_VisionPatchEmbed() - - ret = patch_embed(torch.rand((120, 1176))) - assert ret.shape == (120, 1152) - - -class TestAscendQwen2_5_VisionRotaryEmbedding(PytestBase): - - def init_rotary_embedding( - self, - mocker, - dim=128, - ): - mocker_ebed = mocker.patch( - "vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding.__init__", - return_value=None, - ) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - rotary_embedding = AscendQwen2_5_VisionRotaryEmbedding(dim=dim, ) - args, kwargs = mocker_ebed.call_args - assert args == (dim, 10000.0) - assert not kwargs - return rotary_embedding - - def test_init_rotary_embedding_should_normal(self, mocker: MockerFixture): - rotary_embedding = self.init_rotary_embedding(mocker) - assert isinstance(rotary_embedding, - AscendQwen2_5_VisionRotaryEmbedding) - - -class TestAscendQwen2_5_VisionTransformer(PytestBase): - - input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) - - def init_vision_transformer( - self, - mocker, - ): - norm_eps = 1e-6 - vision_config = mocker.MagicMock() - vision_config.patch_size = 16 - vision_config.temporal_patch_size = 2 - vision_config.in_channels = 3 - vision_config.hidden_act = "gelu" - vision_config.depth = 0 - vision_config.num_heads = 10 - vision_config.hidden_size = 300 - - mocker.patch( - "vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_rank", - return_value=0, - ) - mocker.patch("vllm.distributed.utils.divide", return_value=100) - mocker.patch( - "vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", - return_value=2, - ) - mocker.patch( - "vllm.model_executor.layers.linear.divide", - return_value=2, - ) - mocker.patch( - "vllm.model_executor.layers.linear.get_tensor_model_parallel_rank", - return_value=0) - mocker.patch( - "vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size", - return_value=2, - ) - mocker.patch( - "vllm_ascend.ops.linear.divide", - return_value=2, - ) - - mock_group = mocker.MagicMock() - mock_group.rank_in_group = 0 - mock_group.world_size = 2 - mocker.patch( - "vllm_ascend.ops.linear_op.get_tp_group", - return_value=mock_group, - ) - mocker.patch( - "vllm.distributed.parallel_state.get_tp_group", - return_value=mock_group, - ) - - vision_transformer = AscendQwen2_5_VisionTransformer( - vision_config, - norm_eps, - ) - - assert not vision_transformer.interleaved - return vision_transformer - - def test_init_vision_transformer(self, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - assert isinstance(vision_transformer, AscendQwen2_5_VisionTransformer) - - @pytest.mark.parametrize( - "interleaved, expected", - [ - ( - False, - torch.tensor([ - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - ]), - ), - ( - True, - torch.tensor([ - input_data[0, 0].cos(), - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[0, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - input_data[1, 1].cos(), - ]), - ), - ], - ) - def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - vision_transformer.__dict__["interleaved"] = interleaved - vision_transformer.__dict__["hidden_size_per_attention_head"] = 2 - vision_transformer.hidden_size_per_attention_head = 4 - cos_new, _ = vision_transformer.cal_cos_sin(self.input_data) - assert cos_new.shape == (1, 32, 1, 2) - - def test_pad_qkv_bias(self, mocker: MockerFixture): - attention = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - res = attention.pad_qkv_bias(torch.rand((300))) - assert res.shape[0] == 384 - - def test_pad_qkv_weight(self, mocker: MockerFixture): - attention = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "torch_npu.npu_format_cast", - return_value=torch.rand((384, 300)), - ) - res = attention.pad_qkv_weight(torch.rand((300, 300))) - assert res.shape == (384, 300) - - def test_pad_proj_weight(self, mocker: MockerFixture): - attention = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "torch_npu.npu_format_cast", - return_value=torch.rand((300, 384)), - ) - res = attention.pad_proj_weight(torch.rand((300, 300))) - assert res.shape == (300, 384) - - def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture): - attention = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1))) - assert res.shape == (384, 1) - - def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture): - attention = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300))) - assert res.shape[0] == 384 - - def test_forward(self, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - x = torch.randn(1, 3, 224, 224) - grid_thw = torch.tensor([[1, 4, 4]]) - mocker_patch_embed = mocker.patch.object( - vision_transformer, - "patch_embed", - side_effect=lambda _: torch.randn(16, 512), # noqa - ) - mocker_rot_pos_emb = mocker.patch.object( - vision_transformer, - "rot_pos_emb", - side_effect=lambda _: torch.randn(16, 64), # noqa - ) - mocker_get_window_index = mocker.patch.object( - vision_transformer, - "get_window_index", - side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa - ) - mocker_cal_cos_sin = mocker.patch.object( - vision_transformer, - "cal_cos_sin", - side_effect=lambda _: - (torch.randn(16, 32), torch.randn(16, 32)), # noqa - ) - mocker_merger = mocker.patch.object( - vision_transformer, - "merger", - side_effect=lambda _: torch.randn(16, 256), # noqa - ) - vision_transformer.__dict__["vision_blocks"] = [ - lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa - ] - vision_transformer.__dict__["patch_embed"] = mocker_patch_embed - vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb - vision_transformer.__dict__[ - "get_window_index"] = mocker_get_window_index - vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin - vision_transformer.__dict__["merger"] = mocker_merger - vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2] - vision_transformer.__dict__["spatial_merge_unit"] = 2 - ret = vision_transformer.forward(x, grid_thw) - assert ret.shape == (8, 256) - mocker_patch_embed.assert_called_with(x) - mocker_rot_pos_emb.assert_called_with(grid_thw) - mocker_get_window_index.assert_called_with(grid_thw) - mocker_cal_cos_sin.assert_called_once() - mocker_merger.assert_called_once() - - -class TestAscendQwen2_5_VLForConditionalGeneration(PytestBase): - - def test_init_vl_for_conditional_generation(self, mocker: MockerFixture): - vllm_config = mocker.MagicMock() - vllm_config.vision_config = "vision_config" - vllm_config.rms_norm_eps = 1e-5 - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker_vl = mocker.patch( - "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__", - return_value=None, - ) - mocker_vit = mocker.patch( - "vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionTransformer.__init__", - return_value=None, - ) - - vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration( - vllm_config=vllm_config) - args, kwargs = mocker_vl.call_args - assert not args - assert kwargs == {"vllm_config": vllm_config, "prefix": ""} - mocker_vit.assert_called_once() - assert isinstance( - vl_for_conditional_generation, - AscendQwen2_5_VLForConditionalGeneration, - ) diff --git a/tests/ut/models/test_qwen2_5_vl_without_padding.py b/tests/ut/models/test_qwen2_5_vl_without_padding.py deleted file mode 100644 index 00caf810e61..00000000000 --- a/tests/ut/models/test_qwen2_5_vl_without_padding.py +++ /dev/null @@ -1,422 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from pytest_mock import MockerFixture -from vllm.model_executor.models.qwen2_5_vl import \ - Qwen2_5_VLForConditionalGeneration - -from tests.ut.base import PytestBase -from vllm_ascend.models.qwen2_5_vl_without_padding import ( - AscendQwen2_5_VisionAttention_Without_Padding, - AscendQwen2_5_VisionBlock_Without_Padding, - AscendQwen2_5_VisionPatchEmbed_Without_Padding, - AscendQwen2_5_VisionTransformer_Without_Padding, - AscendQwen2_5_VLForConditionalGeneration_Without_Padding) - - -class TestAscendQwen2_5_VisionAttention_Without_Padding(PytestBase): - - def init_attention( - self, - mocker, - embed_dim=1000, - num_heads=10, - projection_size=100, - quant_config=None, - prefix="", - ): - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.Qwen2_5_VisionAttention.__init__" - ) - - attention = AscendQwen2_5_VisionAttention_Without_Padding( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - args, kwargs = mocker_attn.call_args - assert args == (embed_dim, num_heads, projection_size, None, "") - assert not kwargs - attention.num_attention_heads_per_partition = num_heads - return attention - - def test_vit_init_should_normal(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 10 - projection_size = 100 - quant_config = None - prefix = "" - vit = self.init_attention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - mocker=mocker, - ) - assert vit.embed_dim == 1000 - assert vit.hidden_size_per_attention_head == 10 - - def test_vit_init_should_raise_error(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 7 - projection_size = 100 - quant_config = None - prefix = "" - with pytest.raises(AssertionError): - # projection_size should divided by num heads - self.init_attention( - mocker=mocker, - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - - def test_vit_forward(self, mocker: MockerFixture): - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - attention = self.init_attention(mocker=mocker) - x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - - qkv = lambda x: (x, 0) # noqa - split_qkv = lambda x: [ #noqa - torch.rand((100, 3, 10, 128)) for i in range(3) - ] # noqa - npu_rotary_mul = lambda q, cos, sin: q # noqa - _npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa - proj = lambda x: (x, 0) # noqa - - mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv) - mocker_split_qkv = mocker.patch.object( - attention, - "split_qkv", - side_effect=split_qkv, - ) - mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul", - side_effect=npu_rotary_mul) - mocker_npu_flash_attention_unpad = mocker.patch( - "torch_npu._npu_flash_attention_unpad", - side_effect=_npu_flash_attention_unpad, - ) - mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj) - attention.__dict__["qkv"] = mocker_qkv - attention.__dict__["split_qkv"] = mocker_split_qkv - attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul - attention.__dict__["_npu_flash_attention_unpad"] = ( - mocker_npu_flash_attention_unpad) - attention.__dict__["proj"] = mocker_proj - - output = attention.forward( - x=x, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - qkv_args, qkv_kwargs = mocker_qkv.call_args - assert qkv_args == (x, ) - assert not qkv_kwargs - - split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args - assert split_qkv_args == (x, ) - assert not split_qkv_kwargs - - npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args - assert npu_rotary_mul_args[1:] == (cos, sin) - assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128]) - assert not npu_rotary_mul_kwargs - - assert output.shape == torch.Size([100, 3, 1280]) - - -class TestAscendQwen2_5_VisionBlock_Without_Padding(PytestBase): - - def init_vision_block( - self, - mocker, - dim=100, - num_heads=10, - mlp_hidden_dim=100, - ): - mocker_vit = mocker.patch( - "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__", - return_value=None, - ) - - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionAttention_Without_Padding.__init__", - return_value=None, - ) - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - vision_block = AscendQwen2_5_VisionBlock_Without_Padding( - dim=dim, - num_heads=num_heads, - mlp_hidden_dim=mlp_hidden_dim, - ) - args, kwargs = mocker_vit.call_args - assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "") - assert not kwargs - - args1, kwargs1 = mocker_attn.call_args - assert not args1 - assert kwargs1 == { - "embed_dim": dim, - "num_heads": num_heads, - "projection_size": dim, - "quant_config": None, - "prefix": ".attn", - } - return vision_block - - def test_init_vision_block_should_normal( - self, - mocker: MockerFixture, - ): - vision_block = self.init_vision_block(mocker) - assert isinstance(vision_block, - AscendQwen2_5_VisionBlock_Without_Padding) - - def test_vision_block_forward(self, mocker: MockerFixture): - x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - vision_block = self.init_vision_block(mocker) - mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x) - mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x) - vision_block.__dict__["attn"] = mocker_attn - vision_block.__dict__["mlp"] = mocker_mlp - - output = vision_block.forward(x.clone(), cu_seqlens, cos, sin) - - _, attn_kwargs = mocker_attn.call_args - assert attn_kwargs == { - "cu_seqlens": cu_seqlens, - "cos": cos, - "sin": sin, - } - - assert torch.all(x * 3 == output) - - -class TestAscendQwen2_5_VisionPatchEmbed_Without_Padding(PytestBase): - - def test_forward(self): - patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding() - - ret = patch_embed(torch.rand((120, 1176))) - assert ret.shape == (120, 1152) - - -class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase): - - input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) - - def init_vision_transformer( - self, - mocker, - ): - norm_eps = 1e-6 - vision_config = mocker.MagicMock() - vision_config.patch_size = 16 - vision_config.temporal_patch_size = 2 - vision_config.in_channels = 3 - vision_config.hidden_act = "gelu" - vision_config.depth = 0 - vision_config.hidden_size = 1280 - vision_config.num_heads = 16 - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker_vit = mocker.patch( - "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__", - return_value=None, - ) - mocker_vision_rotary_embedding = mocker.patch( - "vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__", - return_value=None, - ) - mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__", - return_value=None, - ) - mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionPatchEmbed_Without_Padding.__init__", - return_value=None, - ) - mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_world_size", - return_value=1, - ) - mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_rank", - return_value=0, - ) - mocker.patch("vllm.distributed.utils.divide", return_value=100) - - vision_transformer = AscendQwen2_5_VisionTransformer_Without_Padding( - vision_config, - norm_eps, - ) - args, kwargs = mocker_vit.call_args - assert args == (vision_config, norm_eps, None, "") - assert not kwargs - mocker_vision_rotary_embedding.assert_called_once() - return vision_transformer - - def test_init_vision_transformer(self, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - assert isinstance(vision_transformer, - AscendQwen2_5_VisionTransformer_Without_Padding) - - @pytest.mark.parametrize( - "interleaved, expected", - [ - ( - False, - torch.tensor([ - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - ]), - ), - ( - True, - torch.tensor([ - input_data[0, 0].cos(), - input_data[0, 0].cos(), - input_data[0, 1].cos(), - input_data[0, 1].cos(), - input_data[1, 0].cos(), - input_data[1, 0].cos(), - input_data[1, 1].cos(), - input_data[1, 1].cos(), - ]), - ), - ], - ) - def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - vision_transformer.__dict__["interleaved"] = interleaved - vision_transformer.__dict__["hidden_size_per_attention_head"] = 2 - vision_transformer.hidden_size_per_attention_head = 4 - cos_new, _ = vision_transformer.cal_cos_sin(self.input_data) - assert cos_new.shape == (1, 4, 1, 2) - assert torch.allclose(cos_new.view(-1), expected) - - def test_forward(self, mocker: MockerFixture): - vision_transformer = self.init_vision_transformer(mocker) - x = torch.randn(1, 3, 224, 224) - grid_thw = torch.tensor([[1, 4, 4]]) - mocker_patch_embed = mocker.patch.object( - vision_transformer, - "patch_embed", - side_effect=lambda _: torch.randn(16, 512), # noqa - ) - mocker_rot_pos_emb = mocker.patch.object( - vision_transformer, - "rot_pos_emb", - side_effect=lambda _: torch.randn(16, 64), # noqa - ) - mocker_get_window_index = mocker.patch.object( - vision_transformer, - "get_window_index", - side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa - ) - mocker_cal_cos_sin = mocker.patch.object( - vision_transformer, - "cal_cos_sin", - side_effect=lambda _: - (torch.randn(16, 32), torch.randn(16, 32)), # noqa - ) - mocker_merger = mocker.patch.object( - vision_transformer, - "merger", - side_effect=lambda _: torch.randn(16, 256), # noqa - ) - vision_transformer.__dict__["vision_blocks"] = [ - lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa - ] - vision_transformer.__dict__["patch_embed"] = mocker_patch_embed - vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb - vision_transformer.__dict__[ - "get_window_index"] = mocker_get_window_index - vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin - vision_transformer.__dict__["merger"] = mocker_merger - vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2] - vision_transformer.__dict__["spatial_merge_unit"] = 2 - ret = vision_transformer.forward(x, grid_thw) - assert ret.shape == (8, 256) - mocker_patch_embed.assert_called_with(x) - mocker_rot_pos_emb.assert_called_with(grid_thw) - mocker_get_window_index.assert_called_with(grid_thw) - mocker_cal_cos_sin.assert_called_once() - mocker_merger.assert_called_once() - - -class TestAscendQwen2_5_VLForConditionalGeneration_Without_Padding(PytestBase): - - def test_init_vl_for_conditional_generation(self, mocker: MockerFixture): - vllm_config = mocker.MagicMock() - vllm_config.vision_config = "vision_config" - vllm_config.rms_norm_eps = 1e-5 - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker_vl = mocker.patch( - "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__", - return_value=None, - ) - mocker_vit = mocker.patch( - "vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionTransformer_Without_Padding.__init__", - return_value=None, - ) - - vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration_Without_Padding( - vllm_config=vllm_config) - args, kwargs = mocker_vl.call_args - assert not args - assert kwargs == {"vllm_config": vllm_config, "prefix": ""} - mocker_vit.assert_called_once() - assert isinstance( - vl_for_conditional_generation, - AscendQwen2_5_VLForConditionalGeneration_Without_Padding, - ) - - def test_overridden_methods(self): - self.assert_method_overridden( - AscendQwen2_5_VLForConditionalGeneration_Without_Padding, - Qwen2_5_VLForConditionalGeneration, - "_process_image_input", - ) - - self.assert_method_overridden( - AscendQwen2_5_VLForConditionalGeneration_Without_Padding, - Qwen2_5_VLForConditionalGeneration, - "_process_video_input", - ) - - @staticmethod - def assert_method_overridden(subclass, parent, method_name: str): - """assert subclass override parent method""" - parent_func = parent.__dict__.get(method_name) - child_func = subclass.__dict__.get(method_name) - - assert child_func is not None, f"{subclass.__name__} should defined {method_name}" - assert child_func is not parent_func, f"{method_name} should override in {subclass.__name__}" diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 4622692dd00..b667767ba79 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -65,7 +65,7 @@ def test_override_quantization_method(self, mock_is_available): # Test when NPU is available mock_is_available.return_value = True result = AscendQuantConfig.override_quantization_method(None, None) - self.assertEqual(result, ASCEND_QUANTIZATION_METHOD) + self.assertIsNone(result) # Test when NPU is not available mock_is_available.return_value = False @@ -93,7 +93,7 @@ def test_get_quant_method_for_linear(self): self.assertIs(method, mock_ascend_linear.return_value) mock_ascend_linear.assert_called_once_with( self.ascend_config, ".attn", - self.ascend_config.packed_modules_mapping) + self.ascend_config.packed_modules_mapping, linear_layer) def test_get_quant_method_for_attention(self): attention_layer = MagicMock(spec=Attention) diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py new file mode 100644 index 00000000000..bb2409da5de --- /dev/null +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -0,0 +1,314 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import torch +from vllm.config import CacheConfig, CompilationMode, VllmConfig + +from tests.ut.base import TestBase +from vllm_ascend.spec_decode.eagle_proposer import EagleProposer +from vllm_ascend.spec_decode.interface import SpecDcodeType + + +class TestEagleProposerInitialization(TestBase): + + def setUp(self): + self.vllm_config = MagicMock(spec=VllmConfig) + self.vllm_config.speculative_config = MagicMock() + self.vllm_config.cache_config = MagicMock(spec=CacheConfig) + self.vllm_config.scheduler_config = MagicMock() + self.vllm_config.model_config = MagicMock() + self.device = torch.device("cpu") + self.runner = MagicMock() + + self.vllm_config.cache_config.block_size = 16 + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 + self.vllm_config.scheduler_config.max_num_seqs = 32 + self.vllm_config.model_config.dtype = torch.float16 + self.vllm_config.model_config.max_model_len = 2048 + + def test_initialization_eagle(self): + self.vllm_config.speculative_config.method = "eagle" + self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 + self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE + self.vllm_config.model_config.enforce_eager = False + + proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + + self.assertEqual(proposer.name, SpecDcodeType.EAGLE) + self.assertEqual(proposer.block_size, 16) + self.assertEqual(proposer.hidden_size, 4096) + self.assertTrue(proposer.use_cuda_graph) + + self.assertEqual(proposer.input_ids.shape, (1024, )) + self.assertEqual(proposer.positions.shape, (1024, )) + self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) + self.assertEqual(proposer.arange.shape, (33, )) + + def test_initialization_eagle3(self): + self.vllm_config.speculative_config.method = "eagle3" + self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 2048 + self.vllm_config.compilation_config.mode = CompilationMode.NONE + self.vllm_config.model_config.enforce_eager = True + + proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + + self.assertEqual(proposer.name, SpecDcodeType.EAGLE3) + self.assertEqual(proposer.hidden_size, 2048) + self.assertFalse(proposer.use_cuda_graph) + self.assertEqual(proposer.hidden_states.shape, (1024, 2048)) + + +class TestEagleProposerLoadModel(TestBase): + + def setUp(self): + self.vllm_config = MagicMock(spec=VllmConfig) + self.vllm_config.speculative_config = MagicMock() + self.vllm_config.speculative_config.method = "eagle" + self.device = torch.device("cpu") + self.runner = MagicMock() + + self.vllm_config.cache_config.block_size = 16 + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 + self.vllm_config.scheduler_config.max_num_seqs = 32 + self.vllm_config.model_config.dtype = torch.float16 + self.vllm_config.model_config.max_model_len = 2048 + + self.proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + + @patch( + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") + def test_load_model_pp1(self, mock_pp_group, mock_get_model, + mock_get_layers): + mock_pp_group.return_value.world_size = 1 + mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()} + mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()} + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + + mock_model = MagicMock() + mock_model.model.embed_tokens = MagicMock() + mock_model.lm_head = MagicMock() + mock_get_model.return_value = MagicMock() + self.proposer.name = SpecDcodeType.EAGLE + + self.proposer.load_model(mock_model) + mock_get_model.assert_called_once() + self.assertEqual(self.proposer.attn_layer_name, "layer3") + self.assertIs(self.proposer.model.model.embed_tokens, + mock_model.model.embed_tokens) + + @patch( + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") + def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, + mock_get_layers): + mock_pp_group.return_value.world_size = 2 + mock_target_layers = {"layer1": MagicMock()} + mock_draft_layers = {"layer2": MagicMock()} + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + + mock_model = MagicMock() + original_embed = MagicMock() + mock_get_model.return_value = MagicMock(model=MagicMock( + embed_tokens=original_embed)) + + self.proposer.load_model(mock_model) + + self.assertIsNot(self.proposer.model.model.embed_tokens, + mock_model.model.embed_tokens) + self.assertEqual(self.proposer.attn_layer_name, "layer2") + + @patch( + "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") + @patch("vllm_ascend.spec_decode.eagle_proposer.get_pp_group") + @patch("vllm_ascend.spec_decode.eagle_proposer.supports_multimodal") + def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, + mock_get_model, mock_get_layers): + mock_model = MagicMock() + mock_model.get_language_model.return_value.lm_head = MagicMock() + mock_supports_multi.return_value = True + original_embed = MagicMock() + mock_get_model.return_value = MagicMock(model=MagicMock( + embed_tokens=original_embed)) + + mock_target_layers = {"layer1": MagicMock()} + mock_draft_layers = {"layer2": MagicMock()} + mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + mock_pp_group.return_value.world_size = 2 + + self.proposer.model = MagicMock() + self.proposer.name = SpecDcodeType.EAGLE + + self.proposer.load_model(mock_model) + mock_model.get_language_model.assert_called_once() + self.assertIs(self.proposer.model.lm_head, + mock_model.get_language_model.return_value.lm_head) + + +class TestEagleProposerDummyRun(TestBase): + + def setUp(self): + self.vllm_config = MagicMock(spec=VllmConfig) + self.vllm_config.speculative_config = MagicMock() + self.device = torch.device("cpu") + self.runner = MagicMock() + self.runner._select_moe_comm_method.return_value = "alltoall" + + self.vllm_config.cache_config.block_size = 16 + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 + self.vllm_config.scheduler_config.max_num_seqs = 32 + self.vllm_config.model_config.dtype = torch.float16 + self.vllm_config.model_config.max_model_len = 2048 + + self.proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + self.proposer.model = MagicMock() + + @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") + def test_dummy_run_basic(self, mock_context): + num_tokens = 32 + with_prefill = False + + self.proposer.dummy_run(num_tokens=num_tokens, + with_prefill=with_prefill) + + mock_context.assert_called_once() + + @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") + def test_dummy_run_with_prefill(self, mock_context): + mock_context.return_value.__enter__.return_value = None + self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) + + self.runner._select_moe_comm_method.assert_called_with(64) + self.proposer.model.assert_called_once() + + +class TestEagleProposerGenerateTokenIds(TestBase): + + def setUp(self): + self.vllm_config = MagicMock(spec=VllmConfig) + self.vllm_config.speculative_config = MagicMock() + self.vllm_config.speculative_config.method = "eagle" + self.device = torch.device("cpu") + self.runner = MagicMock() + self.runner.input_batch = MagicMock() + self.runner.input_batch.req_ids = [0, 1, 2] + self.runner.requests = { + 0: MagicMock(get_token_id=lambda x: 100), + 1: MagicMock(get_token_id=lambda x: 101), + 2: MagicMock(get_token_id=lambda x: 102), + } + + self.vllm_config.cache_config.block_size = 16 + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 + self.vllm_config.scheduler_config.max_num_seqs = 32 + self.vllm_config.model_config.dtype = torch.float16 + self.vllm_config.model_config.max_model_len = 2048 + + self.proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + self.proposer.attn_layer_name = "layer_0" + self.proposer._propose = MagicMock( + return_value=torch.tensor([[1, 2], [3, 4], [5, 6]])) + + def test_generate_token_ids_without_metadata(self): + valid_sampled = [[20, 30, 40]] + valid_sampled = [np.array(sublist) for sublist in valid_sampled] + scheduler_output = MagicMock() + scheduler_output.num_scheduled_tokens = [2, 1, 3] + positions = torch.tensor([0, 1, 2, 3, 4, 5]) + hidden_states = torch.randn(6, 4096) + num_scheduled = 6 + + mock_attn_metadata = MagicMock() + mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) + mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6]) + mock_attn_metadata.block_tables = MagicMock() + self.proposer._get_eagle_atten_dict = MagicMock( + return_value={"layer_0": mock_attn_metadata}) + + result = self.proposer.generate_token_ids( + valid_sampled_token_ids=valid_sampled, + scheduler_output=scheduler_output, + positions=positions, + num_scheduled_tokens=num_scheduled, + hidden_states=hidden_states, + ) + + self.proposer._propose.assert_called_once() + self.assertEqual(result, [[1, 2], [3, 4], [5, 6]]) + + def test_generate_token_ids_with_metadata(self): + valid_sampled = [[5], [6, 7], [8, 9, 10]] + valid_sampled = [np.array(sublist) for sublist in valid_sampled] + spec_metadata = MagicMock() + spec_metadata.num_draft_tokens = [2, 3, 4] + + mock_attn_metadata = MagicMock() + mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) + mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6]) + mock_attn_metadata.block_tables = MagicMock() + self.proposer._get_eagle_atten_dict = MagicMock( + return_value={"layer_0": mock_attn_metadata}) + self.proposer._prepare_inputs = MagicMock( + return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5]))) + + result = self.proposer.generate_token_ids( + valid_sampled_token_ids=valid_sampled, + spec_decode_metadata=spec_metadata, + positions=torch.randn(6, 1), + hidden_states=torch.randn(6, 4096), + ) + + self.proposer._prepare_inputs.assert_called_once() + self.assertEqual(self.proposer._propose.call_count, 1) + self.assertEqual(len(result), 3) + + +class TestEagleProposerHelperMethods(TestBase): + + def setUp(self): + self.vllm_config = MagicMock(spec=VllmConfig) + self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3) + self.device = torch.device("cpu") + self.runner = MagicMock() + self.runner.input_batch = MagicMock() + self.runner.input_batch.req_ids = [0, 1, 2] + self.runner.arange_np = np.arange(10) + self.runner.input_batch.num_reqs = 3 + + self.vllm_config.cache_config.block_size = 16 + self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 + self.vllm_config.scheduler_config.max_num_seqs = 32 + self.vllm_config.model_config.dtype = torch.float16 + self.vllm_config.model_config.max_model_len = 2048 + + self.proposer = EagleProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self.runner) + + def test_prepare_inputs(self): + self.proposer.token_arange_np = np.arange(10) + mock_attn = MagicMock() + mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) + num_rejected = torch.tensor([1, 0, 1], device=self.device) + + with patch.object(self.proposer, + '_prepare_inputs', + return_value=(torch.tensor([0, 2, 5]), + torch.tensor([1, 2, 4]))): + cu_num_tokens, indices = self.proposer._prepare_inputs( + mock_attn, num_rejected) + self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5]) + self.assertEqual(indices.tolist(), [1, 2, 4]) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 798cf14af8e..5fe5cde3e80 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -9,7 +9,8 @@ from tests.ut.base import TestBase from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, AscendDeviceType +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, + COMPRESSED_TENSORS_METHOD, AscendDeviceType) class TestNPUPlatform(TestBase): @@ -47,8 +48,9 @@ def test_class_variables(self): self.assertEqual(NPUPlatform.device_control_env_var, "ASCEND_RT_VISIBLE_DEVICES") self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1") - self.assertEqual(NPUPlatform.supported_quantization, - [ASCEND_QUANTIZATION_METHOD]) + self.assertEqual( + NPUPlatform.supported_quantization, + [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]) def test_is_sleep_mode_available(self): self.assertTrue(self.platform.is_sleep_mode_available()) diff --git a/tests/ut/torchair/test_torchair_worker.py b/tests/ut/torchair/test_torchair_worker.py index 32d5a92e655..0397aee17c7 100644 --- a/tests/ut/torchair/test_torchair_worker.py +++ b/tests/ut/torchair/test_torchair_worker.py @@ -59,6 +59,7 @@ def test_init_device(self, mock_platform, mock_init_dist_env): worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() @@ -93,6 +94,7 @@ def test_init_device_torchair_worker(self, mock_platform, worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index fbc7fdc4299..5a12981a370 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -329,6 +329,8 @@ def test_init_device(self, mock_platform, mock_init_dist_env): worker.model_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 + worker.model_config.seed = 42 # Test _init_device diff --git a/vllm_ascend/_cann_ops_custom/.gitkeep b/vllm_ascend/_cann_ops_custom/.gitkeep new file mode 100644 index 00000000000..df36e2ec719 --- /dev/null +++ b/vllm_ascend/_cann_ops_custom/.gitkeep @@ -0,0 +1,3 @@ +# This folder is reserved for the installation of custom aclnn operators tailored for vLLM-Ascend. +# Source code of the operators can be found in the `src` folder. +# The operators are compiled into a custom CANN software package and installed to this folder automatically. diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 3514984d826..2c963b5ce28 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -67,8 +67,6 @@ def get_mask_scale_factor(dtype: torch.dtype = torch.float16): def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): - if max_seq_len == 2048: - return self.chunked_prefill_attn_mask.to(torch.bool) self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 0915b38a519..04195d1cc5b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -31,8 +31,13 @@ def register_connector(): KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", - "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", - "MooncakeConnectorV1") + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") + + KVConnectorFactory.register_connector( + "AscendStoreConnector", + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") KVConnectorFactory.register_connector( "MooncakeLayerwiseConnector", diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py index 2e91f715232..c6983b69e23 100644 --- a/vllm_ascend/distributed/cpu_offload_connector.py +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -29,6 +29,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -58,7 +59,10 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata): class CPUOffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): if not vllm_config.cache_config.enable_prefix_caching: self.connector_scheduler: Optional[ CPUOffloadingConnectorScheduler] = None diff --git a/vllm_ascend/distributed/kvpool/__init__.py b/vllm_ascend/distributed/kvpool/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py new file mode 100644 index 00000000000..9f4833555db --- /dev/null +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -0,0 +1,194 @@ +import threading +from typing import Any, Optional + +import torch +import zmq +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import ForwardContext +from vllm.utils import logger +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder + +from vllm_ascend.distributed.kvpool.pool_scheduler import ( + KVPoolScheduler, get_zmq_rpc_path_lookup) +from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker + + +class AscendStoreConnector(KVConnectorBase_V1): + + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): + super().__init__(vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) + + connector_name = vllm_config.kv_transfer_config.kv_connector + if connector_name == "MooncakeConnectorStoreV1": + logger.warning( + "It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future." + ) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.sended_but_unfinished_reqs: set[str] = set() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = KVPoolScheduler(vllm_config, + self.use_layerwise) + else: + self.connector_worker = KVPoolWorker( + vllm_config, + self.use_layerwise, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0: + self.lookup_server = LookupKeyServer(self.connector_worker, + vllm_config, + self.use_layerwise) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving + + +class LookupKeyServer: + + def __init__( + self, + pool_worker: KVPoolWorker, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder() + self.decoder_tensor = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_lookup(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.pool_worker = pool_worker + self.running = True + self.use_layerwise = use_layerwise + + def process_request(): + while self.running: + all_frames = self.socket.recv_multipart(copy=False) + token_len = int.from_bytes(all_frames[0], byteorder="big") + hash_frames = all_frames[1:] + hashes_str = self.decoder.decode(hash_frames) + result = self.pool_worker.lookup_scheduler( + token_len, hashes_str, self.use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! diff --git a/vllm_ascend/distributed/kvpool/backend/__init__.py b/vllm_ascend/distributed/kvpool/backend/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/distributed/kvpool/backend/backend.py b/vllm_ascend/distributed/kvpool/backend/backend.py new file mode 100644 index 00000000000..3aeccbf352c --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/backend.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from vllm.config import ParallelConfig + + +class Backend(ABC): + + def __init__(self, parallel_config: ParallelConfig): + pass + + def set_device(self): + pass + + def register_buffer(self, ptrs: list[int], lengths: list[int]): + pass + + @abstractmethod + def exists(self, keys: list[str]) -> list[int]: + pass + + @abstractmethod + def put(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + pass + + @abstractmethod + def get(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + pass diff --git a/vllm_ascend/distributed/kvpool/backend/memcache_backend.py b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py new file mode 100644 index 00000000000..0da6d092c4f --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py @@ -0,0 +1,74 @@ +# Standard +from enum import Enum + +import torch +from vllm.config import ParallelConfig +from vllm.utils import logger + +from vllm_ascend.distributed.kvpool.backend.backend import Backend + + +class MmcDirect(Enum): + COPY_L2G = 0 + COPY_G2L = 1 + COPY_G2H = 2 + COPY_H2G = 3 + + +class MemcacheBackend(Backend): + + def __init__(self, parallel_config: ParallelConfig): + try: + from memcache import DistributedObjectStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install memcache by following the instructions at " + "https://gitee.com/ascend/memfabric_hybrid " # noqa: E501 + "to run vLLM with MemcacheConnector.") from e + try: + self.rank = parallel_config.rank + self.store = DistributedObjectStore() + res = self.store.init(self.rank) + assert res == 0 + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def set_device(self): + device = torch.device(f"npu:{self.rank}") + torch.npu.set_device(device) + + def register_buffer(self, ptrs: list[int], sizes: list[int]): + for ptr, size in zip(ptrs, sizes): + ret_value = self.store.register_buffer(ptr, size) + if ret_value != 0: + raise RuntimeError("Memcache memory registration failed.") + + def exists(self, keys: list[str]) -> list[int]: + return self.store.batch_is_exist(keys) + + def get(self, key: list[str], addr: list[list[int]], + size: list[list[int]]): + try: + res = self.store.batch_get_into_layers(key, addr, size, + MmcDirect.COPY_G2L.value) + for value in res: + if value != 0: + logger.error(f"Failed to get key {key},res:{res}") + except Exception as e: + logger.error(f"Failed to get key {key}. {e}") + + def put(self, key: list[str], addr: list[list[int]], + size: list[list[int]]): + try: + res = self.store.batch_put_from_layers(key, addr, size, + MmcDirect.COPY_L2G.value) + for value in res: + if value != 0: + logger.error(f"Failed to get key {key},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {key},error:{e}") diff --git a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py new file mode 100644 index 00000000000..314c4dcc9b4 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py @@ -0,0 +1,188 @@ +# Standard +import json +import os +import re +from dataclasses import dataclass +from typing import Union + +# Third Party +from vllm.config import ParallelConfig +from vllm.utils import logger +from vllm.utils.network_utils import get_ip + +from vllm_ascend.distributed.kvpool.backend.backend import Backend +from vllm_ascend.distributed.mooncake_transfer_engine import global_te + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + + +class MooncakeBackend(Backend): + + def __init__(self, parallel_config: ParallelConfig): + try: + from mooncake.store import MooncakeDistributedStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + self.config = MooncakeStoreConfig.load_from_env() + self.store = MooncakeDistributedStore() + if self.config.protocol == "ascend": + local_hostname = get_ip() + transfer_engine = global_te.get_transfer_engine(local_hostname, + device_name=None) + self.local_seg = local_hostname + ":" + str( + transfer_engine.get_rpc_port()) + ret = self.store.setup(self.local_seg, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine()) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def register_buffer(self, ptrs: list[int], lengths: list[int]): + global_te.register_buffer(ptrs, lengths) + + def exists(self, keys: list[str]) -> list[int]: + return self.store.batch_is_exist(keys) + + def put(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + try: + res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes) + for value in res: + if value < 0: + logger.error(f"Failed to put key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {keys},error:{e}") + + def get(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + try: + res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes) + for value in res: + if value < 0: + logger.error(f"Failed to get key {keys}, res:{res}") + except Exception as e: + logger.error(f"Failed to get key {keys}, error:{e}") + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: Union[int, str] + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + use_ascend_direct: bool + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=_parse_global_segment_size( + config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE)), + local_buffer_size=(config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE)), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + use_ascend_direct=config.get("use_ascend_direct", False)) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) + + +def _parse_global_segment_size(value) -> int: + """ + Parse storage size strings with support for units: GB, MB, KB, B + + Args: + value: Input value (int, str, or other convertible types) + + Returns: + int: Size in bytes + + Raises: + ValueError: For invalid format, missing number, or negative values + TypeError: For unsupported input types + """ + + if isinstance(value, int): + return value + elif not isinstance(value, str): + try: + return int(value) + except (TypeError, ValueError) as e: + raise TypeError( + f"Unsupported type for global_segment_size: {type(value)}" + ) from e + + cleaned_input = value.strip().lower() + if not cleaned_input: + raise ValueError("global segment size cannot be empty.") + + UNIT_MULTIPLIERS = { + 'gb': 1024**3, # 1 GB = 1024^3 bytes + 'mb': 1024**2, # 1 MB = 1024^2 bytes + 'kb': 1024, # 1 KB = 1024 bytes + 'b': 1 # 1 B = 1 byte + } + pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' + match = re.match(pattern, cleaned_input) + + if not match: + raise ValueError(f"Invalid format: '{value}'") + + number_str = match.group(1) + unit = match.group(2) or 'b' + + multiplier = UNIT_MULTIPLIERS[unit] + return _convert_to_bytes(number_str, multiplier, value) + + +def _convert_to_bytes(number_str: str, multiplier: int, + original_input: str) -> int: + """ + Convert numeric string to byte count + + Args: + number_str: Numeric portion of input + multiplier: Unit conversion factor + original_input: Original input string (for error messages) + + Returns: + int: Byte count + + Raises: + ValueError: For invalid numbers or negative results + """ + try: + numeric_value = float(number_str) + except ValueError: + raise ValueError( + f"Invalid numeric value '{number_str}' in: '{original_input}'") + # Calculate byte count + try: + byte_count = int(numeric_value * multiplier) + except OverflowError: + raise ValueError(f"Storage size too large: '{original_input}'") + return byte_count diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py new file mode 100644 index 00000000000..e3b0873d686 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -0,0 +1,364 @@ +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import logger +from vllm.utils.math_utils import cdiv +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import NewRequestData + + +#Parameters related to the key +@dataclass +class KeyMetadata: + """name of the LLM model""" + + model_name: str + """ worker id when running under a distributed setting """ + head_or_tp_rank: int + + +@dataclass(order=True) +class PoolKey: + key_metadata: KeyMetadata + chunk_hash: str + + def __hash__(self): + return hash(( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.chunk_hash, + )) + + def to_string(self): + return ( + f"{self.key_metadata.model_name}" + f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" + ) + + def split_layers(self, num_layers: int) -> List["LayerPoolKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerPoolKey( + self.key_metadata, + self.chunk_hash, + layer_id, + )) + return keys + + +@dataclass(order=True) +class LayerPoolKey(PoolKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash(( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.chunk_hash, + self.layer_id, + )) + + def to_string(self): + return ( + f"{self.key_metadata.model_name}" + f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" + ) + + +class ChunkedTokenDatabase(): + + def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool): + self.metadata = metadata + self.block_size = block_size + self.use_mla = use_mla + self.kv_caches_base_addr: list[int] = [] + self.block_len: list[int] = [] + + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): + assert self.metadata is not None + return PoolKey( + self.metadata, + chunk_hash, + ) + + def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]): + self.kv_caches_base_addr = kv_caches_base_addr + + def set_block_len(self, block_len: list[int]): + self.block_len = block_len + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] + if self.use_mla: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] + else: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] + return addr_list, size_list + + def process_tokens( + self, + token_len: int, + block_hashes: Union[list[BlockHash], list[str]], + mask_num: int = 0, + ) -> Iterable[Tuple[int, int, PoolKey]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if not block_hashes: + return + if not isinstance(block_hashes[0], str): + block_hashes = [ + h.hex() # type: ignore[union-attr] + for h in block_hashes + ] + start_idx = 0 + for chunk_id, hash_val in enumerate(block_hashes): + start_idx = chunk_id * self.block_size + if start_idx >= token_len: + break + end_idx = min(start_idx + self.block_size, token_len) + if start_idx < mask_num: + continue + else: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + + +#Parameters related to the connector metadata +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in kvpool + kvpool_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_len: int + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_len = self.token_len + len(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_len_chunk: int + + block_ids: list[int] + + block_hashes: list[BlockHash] + + can_save: Optional[bool] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: Optional[bool] = False, + block_hashes: list[BlockHash] = [], + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_len = tracker.token_len + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + + skip_save = skip_save or num_tokens_to_save < chunk_boundary + if skip_save and load_spec is None: + return None + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.kvpool_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + logger.debug( + f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}" + ) + return ReqMeta( + req_id=tracker.req_id, + token_len_chunk=num_tokens_to_save, + block_ids=tracker.allocated_block_ids, + can_save=not skip_save, + load_spec=load_spec, + block_hashes=block_hashes, + is_last_chunk=is_last_chunk, + ) + + +class AscendConnectorMetadata(KVConnectorMetadata): + + def __init__(self, unfinished_request_ids): + self.requests = [] + self.unfinished_request_ids = unfinished_request_ids + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerPoolKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + is_last_chunk: bool = True diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py new file mode 100644 index 00000000000..b30158ae8c2 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -0,0 +1,246 @@ +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +import torch +from vllm.utils import logger +from vllm.v1.core.kv_cache_utils import BlockHash + +from vllm_ascend.distributed.kvpool.backend.backend import Backend + +# isort: off +from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase, + LasyerMultiBlockReqMeta + ) +# isort: on + + +class KVTransferThread(threading.Thread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event, name: str): + super().__init__(daemon=True, name=name) + self.m_store = m_store + self.ready_event = ready_event + self.tp_rank = tp_rank + self.token_database = token_database + self.done_task_lock = threading.Lock() + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def add_request( + self, + req_id: str, + token_len: int, + block_ids: list[int], + block_hashes: list[BlockHash], + mask_num: int = 0, + is_last_chunk: Optional[bool] = None, + ) -> torch.Tensor: + req = ({ + "req_id": req_id, + "token_len": token_len, + "block_ids": block_ids, + "block_hashes": block_hashes, + "mask_num": mask_num, + "is_last_chunk": is_last_chunk, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.m_store.set_device() + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, put_step: int, ready_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheSendingThread") + self.put_step = put_step + + def _handle_request(self, req_meta: dict[str, Any]): + token_len = req_meta["token_len"] + mask_num = req_meta["mask_num"] + block_ids = req_meta["block_ids"] + block_hashes = req_meta["block_hashes"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + token_len = req_meta["token_len"] + mask_num = req_meta["mask_num"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + block_hashes = req_meta["block_hashes"] + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % + len(key_list):] + key_list[:self.tp_rank % + len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, put_step: int, ready_event: threading.Event, + num_layers: int): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + self.put_step = put_step + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + addr_list = [] + size_list = [] + key_list = [] + for index, key in enumerate(req_meta.keys): + addr, size = self.token_database.prepare_value_layer( + req_meta.starts[index], req_meta.ends[index], + req_meta.block_ids, req_meta.layer_id) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + addr_list = [] + size_list = [] + key_list = [] + for index, key in enumerate(req_meta.keys): + addr, size = self.token_database.prepare_value_layer( + req_meta.starts[index], req_meta.ends[index], + req_meta.block_ids, req_meta.layer_id) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % + len(key_list):] + key_list[:self.tp_rank % + len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) + + self.request_queue.task_done() + self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py similarity index 52% rename from vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py rename to vllm_ascend/distributed/kvpool/pool_scheduler.py index aad4dc6e9c3..06041b5a6e5 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -1,174 +1,33 @@ -import threading from typing import Any, Optional -import torch import vllm.envs as envs import zmq -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.forward_context import ForwardContext +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata from vllm.utils import logger from vllm.utils.network_utils import make_zmq_socket from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.serial_utils import MsgpackEncoder -from vllm_ascend.distributed.mooncake.config_data import ( - LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) -from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine +from vllm_ascend.distributed.kvpool.config_data import ( + AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker) -class MooncakeConnectorV1(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self.kv_role = vllm_config.kv_transfer_config.kv_role - - self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "use_layerwise", False) - - self.kv_caches: dict[str, torch.Tensor] = {} - - self._block_size = vllm_config.cache_config.block_size - - self.sended_but_unfinished_reqs: set[str] = set() - - if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( - vllm_config, self.use_layerwise) - else: - self.connector_worker = MooncakeEngine( - vllm_config, - self.use_layerwise, - ) - - assert self.connector_worker is not None - if vllm_config.parallel_config.rank == 0: - self.lookup_server = MooncakeLookupServer( - self.connector_worker, vllm_config, self.use_layerwise) - - ############################################################ - # Scheduler Side Methods - ############################################################ - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) - - def request_finished( - self, - request: "Request", - block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: - assert self.connector_scheduler is not None - return self.connector_scheduler.request_finished(request, block_ids) - - ############################################################ - # Worker Side Methods - ############################################################ - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - assert self.connector_worker is not None - self.connector_worker.register_kv_caches(kv_caches) - - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - assert self.connector_worker is not None - assert isinstance(self._get_connector_metadata(), - MooncakeConnectorMetadata) - self.connector_worker.start_load_kv(self._get_connector_metadata()) - - def wait_for_layer_load(self, layer_name: str) -> None: - """MooncakeStoreConnector does not do layerwise saving.""" - if not self.use_layerwise: - return - self.connector_worker.wait_for_layer_load() - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """MooncakeStoreConnector does not save explicitly.""" - if not self.use_layerwise: - return - - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - self.connector_worker.save_kv_layer(self._get_connector_metadata()) - - def wait_for_save(self): - """MooncakeStoreConnector does not save explicitly.""" - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - - if self.use_layerwise: - self.connector_worker.wait_layer_transfer_finish() - return - - self.connector_worker.wait_for_save(self._get_connector_metadata()) - - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - assert self.connector_worker is not None - meta = self._get_connector_metadata() - done_sending, done_recving = self.connector_worker.get_finished() - sended_and_finished: set[str] = set() - for item in list(self.sended_but_unfinished_reqs): - if item not in meta.unfinished_request_ids: - sended_and_finished.add(item) - self.sended_but_unfinished_reqs.remove(item) - for item in done_sending: - if item in meta.unfinished_request_ids: - self.sended_but_unfinished_reqs.add(item) - else: - sended_and_finished.add(item) - - return sended_and_finished, done_recving - - -def get_zmq_rpc_path_mooncake( - vllm_config: Optional["VllmConfig"] = None, ) -> str: - base_url = envs.VLLM_RPC_BASE_PATH - # Default to 0 if not configured - rpc_port = 0 - if vllm_config is not None: - rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( - "mooncake_rpc_port", 0) - logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) - return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" - - -class MooncakeStoreConnectorV1Scheduler: +class KVPoolScheduler: def __init__(self, vllm_config: "VllmConfig", use_layerwise): - self.client = MooncakeLookupClient(vllm_config) + self.client = LookupKeyClient(vllm_config) self.use_layerwise = use_layerwise self.kv_role = vllm_config.kv_transfer_config.kv_role self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "consumer_is_to_load", False) self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) - # request_id -> (vllm cached tokes, mooncake cached tokens) + # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} self._block_size = vllm_config.cache_config.block_size # request_id -> full_token_ids @@ -201,14 +60,13 @@ def get_num_new_matched_tokens( return 0, False if self._discard_partial_chunks: - token_block_end = len(request.prompt_token_ids - ) // self._block_size * self._block_size - token_ids = torch.tensor( - request.prompt_token_ids[:token_block_end]) + token_len = len(request.prompt_token_ids + ) // self._block_size * self._block_size else: - token_ids = torch.tensor(request.prompt_token_ids) + token_len = len(request.prompt_token_ids) - num_external_hit_tokens = self.client.lookup(token_ids) + num_external_hit_tokens = self.client.lookup(token_len, + request.block_hashes) if num_external_hit_tokens == request.num_tokens: num_external_hit_tokens -= 1 @@ -216,7 +74,7 @@ def get_num_new_matched_tokens( need_to_allocate = num_external_hit_tokens - num_computed_tokens logger.info( - "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + "Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d", request.request_id, request.num_tokens, num_external_hit_tokens, @@ -228,11 +86,11 @@ def get_num_new_matched_tokens( self.load_specs[request.request_id] = LoadSpec( vllm_cached_tokens=num_computed_tokens, - mooncake_cached_tokens=num_external_hit_tokens, + kvpool_cached_tokens=num_external_hit_tokens, can_load=False, ) - return need_to_allocate, self.load_async + return need_to_allocate, self.load_async and not self.use_layerwise def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", @@ -261,10 +119,10 @@ def update_state_after_alloc(self, request: "Request", assert ( num_external_tokens > 0 and num_external_tokens - == self.load_specs[request.request_id].mooncake_cached_tokens - + == self.load_specs[request.request_id].kvpool_cached_tokens - self.load_specs[request.request_id].vllm_cached_tokens ), (f"Mismatch in number of tokens: {num_external_tokens} vs " - f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].kvpool_cached_tokens} - " f"{self.load_specs[request.request_id].vllm_cached_tokens}" f" for request {request.request_id}") @@ -289,7 +147,7 @@ def build_connector_meta( self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) - meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + meta = AscendConnectorMetadata(self._unfinished_request_ids) for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests @@ -304,12 +162,15 @@ def build_connector_meta( self._block_size * self._block_size) if self._discard_partial_chunks else len( request.prompt_token_ids)) + request_tuple = self._unfinished_requests.get(request.req_id) + request_real = request_tuple[0] # type: ignore[index] req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=load_spec, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) + block_hashes=request_real.block_hashes, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) @@ -317,33 +178,14 @@ def build_connector_meta( meta.add_request(req_meta) cached_reqs = scheduler_output.scheduled_cached_reqs - if isinstance(cached_reqs, list) and not force_skip_save: - for i, req in enumerate(cached_reqs): - request_tracker = self._request_trackers[req.req_id] - request_tracker.update(req.new_token_ids, req.new_block_ids) - last_chunk_tokens_num = ((len(req.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(req.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=None, - skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - elif not force_skip_save: + if not force_skip_save: for i, req_id in enumerate(cached_reqs.req_ids): request_tracker = self._request_trackers[req_id] num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] req_tuple = self._unfinished_requests.get(req_id) if req_tuple: request = req_tuple[0] - num_current_tokens = len(request_tracker.token_ids) + num_current_tokens = request_tracker.token_len new_token_ids = request.all_token_ids[ num_current_tokens:num_current_tokens + num_new_tokens] else: @@ -355,8 +197,7 @@ def build_connector_meta( continue request_tracker.update(new_token_ids, new_block_ids) # decode not save - if len(request_tracker.token_ids) > len( - request.prompt_token_ids): + if request_tracker.token_len > len(request.prompt_token_ids): continue last_chunk_tokens_num = ((len(request.prompt_token_ids) // @@ -368,7 +209,8 @@ def build_connector_meta( self._block_size, load_spec=None, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) + block_hashes=request.block_hashes, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) @@ -384,15 +226,14 @@ def build_connector_meta( load_spec = self.load_specs.pop(request_id, None) if not load_spec: continue - num_tokens_to_compute = load_spec.mooncake_cached_tokens + num_tokens_to_compute = load_spec.kvpool_cached_tokens if (num_tokens_to_compute % self._block_size != 0) and (num_tokens_to_compute == len(request.prompt_token_ids) - 1): num_tokens_to_compute = num_tokens_to_compute + 1 request_tracker = RequestTracker( req_id=request_id, - token_ids=request.prompt_token_ids[:num_tokens_to_compute]. - copy(), + token_len=num_tokens_to_compute, allocated_block_ids=block_ids, num_saved_tokens=0, ) @@ -404,6 +245,7 @@ def build_connector_meta( self._block_size, load_spec=load_spec, skip_save=None, + block_hashes=request.block_hashes, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: @@ -431,12 +273,12 @@ def request_finished( return delay_free_blocks, None -class MooncakeLookupClient: +class LookupKeyClient: def __init__(self, vllm_config: "VllmConfig"): self.encoder = MsgpackEncoder() self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) + socket_path = get_zmq_rpc_path_lookup(vllm_config) self.socket = make_zmq_socket( self.ctx, socket_path, @@ -444,9 +286,12 @@ def __init__(self, vllm_config: "VllmConfig"): bind=False, ) - def lookup(self, token_ids: torch.Tensor) -> int: - request = self.encoder.encode(token_ids) - self.socket.send_multipart(request, copy=False) + def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int: + hash_strs = [h.hex() for h in block_hashes] + hash_frames = self.encoder.encode(hash_strs) + token_len_bytes = token_len.to_bytes(4, byteorder="big") + all_frames = [token_len_bytes] + list(hash_frames) + self.socket.send_multipart(all_frames, copy=False) resp = self.socket.recv() result = int.from_bytes(resp, "big") return result @@ -455,39 +300,19 @@ def close(self): self.socket.close(linger=0) -class MooncakeLookupServer: - - def __init__( - self, - mooncake_engine: MooncakeEngine, - vllm_config: "VllmConfig", - use_layerwise: bool, - ): - self.decoder = MsgpackDecoder(torch.Tensor) - self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) - self.socket = make_zmq_socket( - self.ctx, - socket_path, - zmq.REP, # type: ignore[attr-defined] - bind=True, - ) - - self.mooncake_engine = mooncake_engine - self.running = True - - def process_request(): - while self.running: - frames = self.socket.recv_multipart(copy=False) - token_ids = self.decoder.decode(frames) - result = self.mooncake_engine.lookup_scheduler( - token_ids, use_layerwise) - response = result.to_bytes(4, "big") - self.socket.send(response) - - self.thread = threading.Thread(target=process_request, daemon=True) - self.thread.start() - - def close(self): - self.socket.close(linger=0) - # TODO: close the thread! +def get_zmq_rpc_path_lookup( + vllm_config: Optional["VllmConfig"] = None, ) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config + if "lookup_rpc_port" in extra_config: + rpc_port = extra_config["lookup_rpc_port"] + elif "mooncake_rpc_port" in extra_config: + rpc_port = extra_config["mooncake_rpc_port"] + logger.warning( + "It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future." + ) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}" diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/kvpool/pool_worker.py similarity index 57% rename from vllm_ascend/distributed/mooncake/mooncake_engine.py rename to vllm_ascend/distributed/kvpool/pool_worker.py index 143d2c91cad..b03d2808928 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -1,25 +1,33 @@ # Standard import math import threading -import time -from typing import Generator, List, Optional, Union +from typing import Dict, Generator, Optional, Type # Third Party import torch from vllm.config import VllmConfig from vllm.utils import logger -from vllm.utils.torch_utils import get_kv_cache_torch_dtype - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, - MooncakeEngineMetadata) -from vllm_ascend.distributed.mooncake.kv_transfer import ( +from vllm.v1.core.kv_cache_utils import BlockHash + +from vllm_ascend.distributed.kvpool.backend.backend import Backend +from vllm_ascend.distributed.kvpool.backend.memcache_backend import \ + MemcacheBackend +from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \ + MooncakeBackend +from vllm_ascend.distributed.kvpool.config_data import ( + AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, + LasyerMultiBlockReqMeta) +from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + +backend_map: Dict[str, Type[Backend]] = { + "mooncake": MooncakeBackend, + "memcache": MemcacheBackend, +} -class MooncakeEngine: +class KVPoolWorker: #The main class for the cache engine. def __init__( @@ -29,6 +37,7 @@ def __init__( ): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config + self.dp_rank = parallel_config.data_parallel_rank self.use_mla = False if (hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) @@ -40,37 +49,37 @@ def __init__( self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) - self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "register_buffer", False) + self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size self.current_layer = 0 - # self.use_mla = first_kv_cache_tuple[0].size( - # -1) != first_kv_cache_tuple[1].size(-1) self.num_layers = model_config.get_num_layers(parallel_config) self.block_size = vllm_config.cache_config.block_size - num_kv_head = model_config.get_num_kv_heads(parallel_config) - head_size = model_config.get_head_size() - kv_dtype = get_kv_cache_torch_dtype( - vllm_config.cache_config.cache_dtype, model_config.dtype) - self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: - kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + self.num_kv_head = 1 + else: + self.num_kv_head = model_config.get_total_num_kv_heads() + + if self.num_kv_head < self.tp_size: + self.put_step = self.tp_size // self.num_kv_head + self.head_or_tp_rank = self.tp_rank // self.put_step else: - kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, - head_size) - self.metadata = MooncakeEngineMetadata( + self.head_or_tp_rank = self.tp_rank + self.put_step = 1 + + self.metadata = KeyMetadata( model_config.model, - parallel_config.world_size, - parallel_config.rank, - kv_dtype, - kv_shape, - self.block_size, - self.use_mla, + self.head_or_tp_rank, ) - self.token_database = ChunkedTokenDatabase(self.metadata) + self.token_database = ChunkedTokenDatabase(self.metadata, + self.block_size, + self.use_mla) - self.m_store = Mooncakestore(parallel_config) + real_backend = backend_map.get(self.backend.lower()) + self.m_store = real_backend( # type: ignore[misc] + parallel_config) self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None @@ -108,94 +117,83 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches self.kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[i % 2] - self._register(base_addr, region_len) + region_len = self.num_blocks * self.block_len[i % 2] + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [cache_or_caches ] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[0] - self._register(base_addr, region_len) + region_len = self.num_blocks * self.block_len[0] + ptrs.append(base_addr) + lengths.append(region_len) + self.m_store.register_buffer(ptrs, lengths) + self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) + self.token_database.set_block_len(self.block_len) if self.use_layerwise: self.get_event = threading.Event() if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending, - self.num_layers) + self.m_store, self.token_database, self.tp_rank, + self.put_step, ready_event_sending, self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, - self.block_size, ready_event, self.get_event) + self.m_store, self.token_database, self.tp_rank, ready_event, + self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending) + self.m_store, self.token_database, self.tp_rank, + self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event) + self.m_store, self.token_database, self.tp_rank, + ready_event) self.kv_recv_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - try: - self.m_store.register_buffer(ptr, length) - except Exception as e: - raise RuntimeError( - f"Mooncake memory registration failed. Error is: {e}") - - def start_load_kv(self, metadata: MooncakeConnectorMetadata): + def start_load_kv(self, metadata: AscendConnectorMetadata): self.current_layer = 0 self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec if load_spec is None or not load_spec.can_load: #load =0 continue - tokens = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - if (load_spec.mooncake_cached_tokens % self.block_size - != 0) and (load_spec.mooncake_cached_tokens - == tokens.shape[0] - 1): - tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] + if (load_spec.kvpool_cached_tokens % self.block_size + != 0) and (load_spec.kvpool_cached_tokens + == token_len - 1): + token_len = request.load_spec.kvpool_cached_tokens + 1 else: - tokens = tokens[:request.load_spec.mooncake_cached_tokens] - masked_token_count = (request.load_spec.vllm_cached_tokens // - self.block_size * self.block_size) - token_mask = torch.ones_like(tokens, dtype=torch.bool) - token_mask[:masked_token_count] = False + token_len = request.load_spec.kvpool_cached_tokens + mask_num = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) if self.use_layerwise: layerwise_retriever = self.retrieve_layer( req_id, - tokens, + token_len, request.block_ids, - token_mask, + request.block_hashes, + mask_num, ) next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) @@ -203,102 +201,84 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): if self.load_async: self.kv_recv_thread.add_request( # type: ignore[union-attr] req_id, - tokens, + token_len, request.block_ids, - token_mask, + request.block_hashes, + mask_num, ) else: - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, block_id = self.prepare_value( - start, end, request.block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, - blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, _ = self.prepare_value( - start, end, request.block_ids) - self.m_store.get(key, addr, size) - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, request.block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, request.block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % len( + key_list):] + key_list[:self.tp_rank % len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list + ):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list + ):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) def wait_for_layer_load(self) -> None: - """MooncakeConnector does not do layerwise saving.""" for layerwise_retriever in self.layerwise_retrievers: ret_token_mask = next(layerwise_retriever) if self.current_layer == self.num_layers - 1: assert ret_token_mask is not None num_retrieved_tokens = ret_token_mask.sum().item() - logger.info(f"Retrieved {num_retrieved_tokens} tokens") + logger.debug(f"Retrieved {num_retrieved_tokens} tokens") def save_kv_layer(self, - connector_metadata: MooncakeConnectorMetadata) -> None: - """MooncakeConnector does not save explicitly.""" + connector_metadata: AscendConnectorMetadata) -> None: if self.current_layer == 0: self.layerwise_storers = [] for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: + can_save = request.can_save + if can_save is None or not can_save: continue - token_ids = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu # TODO: whether need to remov saveThread # no lookup, skipmask - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): + skip_leading_tokens = self.lookup(token_len, + request.block_hashes, + self.use_layerwise) + if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) + mask_num = (skip_leading_tokens // self.block_size * + self.block_size) - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), + token_len - skip_leading_tokens, + token_len, skip_leading_tokens, request.req_id, ) layerwise_storer = self.store_layer( req_id, - token_ids, - mask=store_mask, + token_len, + block_hashes=request.block_hashes, + mask_num=mask_num, block_ids=request.block_ids, + is_last_chunk=request.is_last_chunk, ) self.layerwise_storers.append(layerwise_storer) for layerwise_storer in self.layerwise_storers: @@ -306,59 +286,53 @@ def save_kv_layer(self, next(layerwise_storer) except Exception: raise - self.current_layer = self.current_layer + 1 + self.current_layer = self.current_layer + 1 - def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): - """MooncakeConnector does not save explicitly.""" + def wait_for_save(self, connector_metadata: AscendConnectorMetadata): for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: + can_save = request.can_save + if can_save is None or not can_save: continue - token_ids = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): + skip_leading_tokens = self.lookup(token_len, request.block_hashes, + self.use_layerwise) + if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) - - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False + mask_num = (skip_leading_tokens // self.block_size * + self.block_size) logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), + token_len - skip_leading_tokens, + token_len, skip_leading_tokens, request.req_id, ) self.kv_send_thread.add_request( # type: ignore[union-attr] req_id, - token_ids, + token_len, request.block_ids, - store_mask, + request.block_hashes, + mask_num, request.is_last_chunk, ) def retrieve_layer( self, req_id: str, - tokens: torch.Tensor, + token_len: int, block_ids: list[int], - mask: Optional[torch.Tensor] = None, + block_hashes: list[BlockHash], + mask_num: int = 0, ) -> Generator[Optional[torch.Tensor], None, None]: """ Retrieve the KV cache in a layerwise manner. @@ -376,20 +350,16 @@ def retrieve_layer( be the boolean mask indicating which tokens are retrieved and will only be returned in the last iteration. """ + num_required_tokens = token_len - mask_num - if mask is not None: - num_required_tokens = torch.sum(mask).item() - else: - num_required_tokens = len(tokens) - - ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") starts = [] ends = [] keys = [] first_flag = True for start, end, key in self.token_database.process_tokens( - tokens, mask): + token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -421,16 +391,18 @@ def retrieve_layer( retrieved_tokens = torch.sum(ret_mask) logger.debug(f"Retrieved {retrieved_tokens} " f"out of {num_required_tokens} " - f"out of total {len(tokens)} tokens") + f"out of total {token_len} tokens") yield ret_mask def store_layer( self, req_id: str, - tokens: torch.Tensor, + token_len: int, block_ids: list[int], - mask: Optional[torch.Tensor] = None, + block_hashes: list[BlockHash], + is_last_chunk: bool, + mask_num: int = 0, ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. @@ -452,17 +424,13 @@ def store_layer( storage backends. In the last iteration, it puts the memory objects of the last layer to the storage backends. """ - - if mask is not None: - num_stored_tokens = torch.sum(mask).item() - else: - num_stored_tokens = len(tokens) + num_stored_tokens = token_len - mask_num starts = [] ends = [] keys = [] for start, end, key in self.token_database.process_tokens( - tokens, mask): + token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -473,7 +441,7 @@ def store_layer( for layer_id, keys_multi_chunk in enumerate(keys): req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, - layer_id) + layer_id, is_last_chunk) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] yield @@ -481,7 +449,7 @@ def store_layer( for layer_id in range(self.num_layers): yield logger.debug( - f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + f"Stored {num_stored_tokens} out of total {token_len} tokens") def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( @@ -500,13 +468,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.tp_rank) return done_sending, done_recving - def wait_layer_transfer_finish(self): - time.sleep(10) - pass - def lookup( self, - tokens: Union[torch.Tensor, List[int]], + token_len: int, + block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ @@ -517,34 +482,24 @@ def lookup( end = 0 keys = [] try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): + starts = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes): + if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): + else: keys.append(key.to_string()) - starts.append(start) - res = self.m_store.batch_exists( - keys) # type: ignore[assignment] - for index, value in enumerate(res): # type: ignore[arg-type] - if value != 1: - return starts[index] + starts.append(start) + + res = self.m_store.exists(keys) # type: ignore[assignment] + + if use_layerwise: + res = self.check_all_layers_exists(res, self.num_layers) + for index, value in enumerate(res): # type: ignore[arg-type] + if value != 1: + return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") @@ -553,7 +508,8 @@ def lookup( def lookup_scheduler( self, - tokens: Union[torch.Tensor, List[int]], + token_len: int, + block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ @@ -564,59 +520,59 @@ def lookup_scheduler( end = 0 keys = [] try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): + starts = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes): + if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): + else: keys.append(key.to_string()) - starts.append(start) - multi_tp_keys = keys[:] - for i in range(1, self.tp_size): - for item in keys: - new_str = item.replace( # type: ignore[attr-defined] - "@0", f"@{i}", 1) - multi_tp_keys.append(new_str) - res = self.m_store.batch_exists( - multi_tp_keys) # type: ignore[assignment] - num_block = len(keys) - multi_tp_values = [ - res[i * num_block:(i + 1) * - num_block] # type: ignore[index] - for i in range(self.tp_size) - ] - index = self.find_min_first_non_one_index(multi_tp_values) - if index != -1: - return starts[index] + starts.append(start) + + multi_tp_keys = keys[:] + for i in range(1, min(self.tp_size, self.num_kv_head)): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) + multi_tp_keys.append(new_str) + + res = self.m_store.exists( + multi_tp_keys) # type: ignore[assignment] + num_block = len(keys) + if use_layerwise: + res = self.check_all_layers_exists(res, self.num_layers) + num_block = len(keys) // self.num_layers + multi_tp_values = [ + res[i * num_block:(i + 1) * num_block] # type: ignore[index] + for i in range(min(self.tp_size, self.num_kv_head)) + ] + index = self.find_min_first_non_one_index(multi_tp_values) + if index != -1: + return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end + def check_all_layers_exists(self, res: list[int], + num_layers: int) -> list[int]: + total_chunks = len(res) // num_layers + result = [] + + for chunk_idx in range(total_chunks): + start = chunk_idx * num_layers + end = start + num_layers + chunk = res[start:end] + result.append(1 if all(x == 1 for x in chunk) else 0) + + return result + def find_min_first_non_one_index(self, arr): try: return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1 - - def close(self) -> None: - """Close the cache engine and free all the resources""" - self.m_store.close() diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index c0bd06d4b89..5c5a0a5bef3 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -28,6 +28,7 @@ from vllm.utils import logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus import vllm_ascend.envs as envs_ascend @@ -100,7 +101,10 @@ def add_new_req(self, request_id: str, local_block_ids: list[int], class LLMDataDistCMgrConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py deleted file mode 100644 index 2434b4dbc05..00000000000 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ /dev/null @@ -1,534 +0,0 @@ -import array -import hashlib -import json -import os -import re -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union - -import torch -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata -from vllm.utils import logger -from vllm.utils.math_utils import cdiv -from vllm.v1.core.sched.output import NewRequestData - -DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB - - -@dataclass -class MooncakeEngineMetadata: - """name of the LLM model""" - - model_name: str - """ world size when running under a distributed setting """ - world_size: int - """ worker id when running under a distributed setting """ - worker_id: int - """ the format of kv tensors """ - kv_dtype: torch.dtype - """ the shape of kv tensors """ - """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ - kv_shape: tuple[int, int, int, int, int] - block_size: int = 128 - """ whether use MLA""" - use_mla: bool = False - - -@dataclass(order=True) -class MooncakeEngineKey: - model_name: str - world_size: int - worker_id: int - chunk_hash: str - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}") - - def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: - """Split the key into multiple keys for each layer""" - keys = [] - for layer_id in range(num_layers): - keys.append( - LayerMooncakeEngineKey( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - layer_id, - )) - return keys - - def to_dict(self): - # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. - return { - "__type__": "CacheEngineKey", - "model_name": self.model_name, - "world_size": self.world_size, - "worker_id": self.worker_id, - "chunk_hash": self.chunk_hash, - } - - @staticmethod - def from_dict(d): - return MooncakeEngineKey( - model_name=d["model_name"], - world_size=d["world_size"], - worker_id=d["worker_id"], - chunk_hash=d["chunk_hash"], - ) - - -@dataclass(order=True) -class LayerMooncakeEngineKey(MooncakeEngineKey): - """A key for the layer cache engine""" - - layer_id: int - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - self.layer_id, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") - - -class ChunkedTokenDatabase(): - - def __init__( - self, - metadata: MooncakeEngineMetadata, - ): - self.metadata = metadata - - def _make_key_by_hash(self, - chunk_hash: str, - layer_id: Optional[int] = None): - assert self.metadata is not None - return MooncakeEngineKey( - self.metadata.model_name, - self.metadata.world_size, - self.metadata.worker_id, - chunk_hash, - ) - - def _hash( - self, - tokens: Union[torch.Tensor, List[int]], - prefix_hash: str, - ) -> str: - # TODO: change it to a more efficient hash function - if isinstance(tokens, torch.Tensor): - tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() - elif isinstance(tokens, list): - tokens_bytes = array.array("I", tokens).tobytes() - return hashlib.sha256(prefix_hash.encode("ascii") + - tokens_bytes).hexdigest() - - def _chunk_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - ) -> Iterable[Union[torch.Tensor, List[int]]]: - """ - Chunk the tokens into chunks of size self.metadata.block_size. - - :param tokens: the input tokens, with shape [seq_len] - device: the target device after chunking - - :return: a generator of chunks of tokens, each with - shape [metadata.block_size] - """ - for i in range(0, len(tokens), self.metadata.block_size): - yield tokens[i:i + self.metadata.block_size] - - def _prefix_hash( - self, - token_chunks: Iterable[Union[torch.Tensor, List[int]]], - ) -> Iterable[str]: - prefix_hash = '' - for token_chunk in token_chunks: - prefix_hash = self._hash(token_chunk, prefix_hash) - yield prefix_hash - - def process_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - mask: Optional[torch.Tensor] = None, - ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: - """Process the tokens and return the corresponding cache engine keys. - - :param Union[torch.Tensor, List[int]] tokens: The tokens to process. - - :param Optional[torch.Tensor] mask: The mask for the tokens. Should - have the same length as tokens. And the mask should ALWAYS be like - FFFFFTTTTTTT, where True means the tokens needs to be matched, - and the Falses will ALWAYS be at the PREFIX of the tensor. - - :param bool make_key: Whether to make the cache engine key or not. - If False, the hash value will be returned instead. - - :returns: A iterable of tuples with three elements. The first element - is the start index of the tokens for the key. The second element - is the end index of the tokens for the key. The third element is - the cache engine key (or hash) for the tokens. - - :raises: ValueError if the number of Falses in the mask is not a - multiple of the chunk size. - """ - if mask is not None: - num_falses = mask.numel() - mask.long().sum().item() - else: - num_falses = 0 - - if num_falses % self.metadata.block_size != 0: - raise ValueError( - "The number of Falses in the mask is not a multiple of the chunk size." - ) - total_len = len(tokens) - - token_chunks = self._chunk_tokens(tokens) - prefix_hashes = self._prefix_hash(token_chunks) - - start_idx = 0 - for chunk_id, hash_val in enumerate(prefix_hashes): - start_idx = chunk_id * self.metadata.block_size - end_idx = min(start_idx + self.metadata.block_size, total_len) - if start_idx < num_falses: - continue - else: - yield start_idx, end_idx, self._make_key_by_hash(hash_val) - - -@dataclass -class LoadSpec: - # Number of tokens cached in vLLM - vllm_cached_tokens: int - # Number of tokens that are cached in mooncake - mooncake_cached_tokens: int - # Whether the scheduler allow us to load the tokens - can_load: bool - - -@dataclass -class SaveSpec: - # Skip already saved tokens - skip_leading_tokens: int - # Whether the scheduler allow us to save the tokens - can_save: bool - - -@dataclass -class RequestTracker: - # Request id - req_id: str - - # The token ids that has been scheduled so far - token_ids: list[int] - - # The block ids that has been allocated so far - # NOTE: allocated blocks could be more than the number of tokens - # FIXME: need to check whether the block ids will be changed after - # preemption - allocated_block_ids: list[int] - - # The number of tokens that has been savd - num_saved_tokens: int = 0 - - @staticmethod - def from_new_request( - new_request: "NewRequestData", - num_tokens_to_compute: int, - ) -> "RequestTracker": - """Create the request tracker from a new request. - - Args: - new_request (NewRequestData): the new request data. - num_tokens_to_compute (int): the number of tokens that will - be 'computed', including the `num_computed_tokens` (vLLM's - local cache hit) and new tokens that will be scheduled. - - """ - # vLLM 0.9.0 update: request.block_ids changed from list[int] to - # list[list[int]] - # Need to check the type of request.block_ids - - unfolded_block_ids = [] - - if not isinstance(new_request.block_ids[0], list): - unfolded_block_ids = new_request.block_ids.copy() - else: - unfolded_block_ids = new_request.block_ids[0].copy() - - return RequestTracker( - req_id=new_request.req_id, - token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. - copy(), - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) - - def update( - self, - new_token_ids: list[int], - new_block_ids: Union[tuple[list[int], ...], list[int]], - ) -> None: - """Update the request tracker when a running request is - scheduled again - """ - - self.token_ids.extend(new_token_ids) - - if len(new_block_ids) == 0: - new_block_ids = [] - elif isinstance(new_block_ids, tuple): - new_block_ids = new_block_ids[0] - elif isinstance(new_block_ids, list): - pass - else: - raise ValueError( - f"Unsupported new_block_ids type {type(new_block_ids)}") - self.allocated_block_ids.extend(new_block_ids) - - -@dataclass -class ReqMeta: - # Request id - req_id: str - # Request tokens - token_ids: torch.Tensor - - block_ids: list[int] - # # Slot mapping if exchange for block_id - # slot_mapping: torch.Tensor - # Skip save or not - save_spec: Optional[SaveSpec] = None - # load_spec - load_spec: Optional[LoadSpec] = None - - is_last_chunk: Optional[bool] = None - - @staticmethod - def from_request_tracker( - tracker: RequestTracker, - block_size: int, - load_spec: Optional[LoadSpec] = None, - skip_save: Optional[bool] = False, - is_last_chunk: Optional[bool] = None, - discard_partial_chunks: bool = True, - ) -> Optional["ReqMeta"]: - """Create the request metadata from a request tracker. - - Args: - tracker (RequestTracker): the request tracker. - block_size (int): the block size in vLLM. - load_spec (Optional[LoadSpec]): the load spec for KV cache loading. - skip_save (bool): whether to skip the save operation. - discard_partial_chunks (bool): whether to discard partial chunks. - - Returns: - the request metadata if we need to perform load/save - operations, None otherwise. - """ - input_token_ids = tracker.token_ids - input_token_len = len(input_token_ids) - - # For save operation: do not save if the following condition is met - # 1. has already been saved before (num_saved_tokens > 0) - # 2. number of unsaved tokens is not reached the chunk boundary - skip_leading_tokens = tracker.num_saved_tokens - chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * - block_size if discard_partial_chunks else 0) - # Calculate number of tokens to save based on discard_partial_chunks - # setting - num_tokens_to_save = ((input_token_len // block_size * block_size) - if discard_partial_chunks else input_token_len) - - skip_save = skip_save or num_tokens_to_save < chunk_boundary - if skip_save and load_spec is None: - return None - - # If we need to save, update the number of saved tokens - if not skip_save: - tracker.num_saved_tokens = num_tokens_to_save - save_spec = SaveSpec(skip_leading_tokens, not skip_save) - - # Calculate the token ids and slot mappings for load and save - # OPTIMIZATION: pre-allocate the buffer for token ids and block ids - token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] - - # # For load operation: check whether the request is scheduled to load - if load_spec is not None and load_spec.can_load: - logger.debug( - "Scheduled to load %d tokens for request %s", - load_spec.mooncake_cached_tokens, - tracker.req_id, - ) - else: - # Do not load if not in `can_load` state - load_spec = None - logger.debug( - f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" - ) - return ReqMeta( - req_id=tracker.req_id, - token_ids=token_ids, - block_ids=tracker.allocated_block_ids, - save_spec=save_spec, - load_spec=load_spec, - is_last_chunk=is_last_chunk, - ) - - -class MooncakeConnectorMetadata(KVConnectorMetadata): - - def __init__(self, unfinished_request_ids): - self.requests = [] - self.unfinished_request_ids = unfinished_request_ids - - def add_request(self, req_meta: ReqMeta) -> None: - """Add a request to the metadata. - - Args: - req_meta (ReqMeta): the request metadata. - """ - self.requests.append(req_meta) - - -@dataclass -class LasyerMultiBlockReqMeta: - req_id: str - keys: List[LayerMooncakeEngineKey] - starts: List[int] - ends: list[int] - block_ids: list[int] - layer_id: int - - -@dataclass -class MooncakeStoreConfig: - local_hostname: str - metadata_server: str - global_segment_size: Union[int, str] - local_buffer_size: int - protocol: str - device_name: str - master_server_address: str - use_ascend_direct: bool - - @staticmethod - def from_file(file_path: str) -> "MooncakeStoreConfig": - with open(file_path) as file: - config = json.load(file) - return MooncakeStoreConfig( - local_hostname=config.get("local_hostname"), - metadata_server=config.get("metadata_server"), - global_segment_size=_parse_global_segment_size( - config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE)), - local_buffer_size=(config.get("local_buffer_size", - DEFAULT_LOCAL_BUFFER_SIZE)), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - use_ascend_direct=config.get("use_ascend_direct", False)) - - @staticmethod - def load_from_env() -> "MooncakeStoreConfig": - config_path = os.getenv("MOONCAKE_CONFIG_PATH") - if not config_path: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeStoreConfig.from_file(config_path) - - -def _parse_global_segment_size(value) -> int: - """ - Parse storage size strings with support for units: GB, MB, KB, B - - Args: - value: Input value (int, str, or other convertible types) - - Returns: - int: Size in bytes - - Raises: - ValueError: For invalid format, missing number, or negative values - TypeError: For unsupported input types - """ - - if isinstance(value, int): - return value - elif not isinstance(value, str): - try: - return int(value) - except (TypeError, ValueError) as e: - raise TypeError( - f"Unsupported type for global_segment_size: {type(value)}" - ) from e - - cleaned_input = value.strip().lower() - if not cleaned_input: - raise ValueError("global segment size cannot be empty.") - - UNIT_MULTIPLIERS = { - 'gb': 1024**3, # 1 GB = 1024^3 bytes - 'mb': 1024**2, # 1 MB = 1024^2 bytes - 'kb': 1024, # 1 KB = 1024 bytes - 'b': 1 # 1 B = 1 byte - } - pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' - match = re.match(pattern, cleaned_input) - - if not match: - raise ValueError(f"Invalid format: '{value}'") - - number_str = match.group(1) - unit = match.group(2) or 'b' - - multiplier = UNIT_MULTIPLIERS[unit] - return _convert_to_bytes(number_str, multiplier, value) - - -def _convert_to_bytes(number_str: str, multiplier: int, - original_input: str) -> int: - """ - Convert numeric string to byte count - - Args: - number_str: Numeric portion of input - multiplier: Unit conversion factor - original_input: Original input string (for error messages) - - Returns: - int: Byte count - - Raises: - ValueError: For invalid numbers or negative results - """ - try: - numeric_value = float(number_str) - except ValueError: - raise ValueError( - f"Invalid numeric value '{number_str}' in: '{original_input}'") - # Calculate byte count - try: - byte_count = int(numeric_value * multiplier) - except OverflowError: - raise ValueError(f"Storage size too large: '{original_input}'") - return byte_count diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py deleted file mode 100644 index 4472f678ddd..00000000000 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ /dev/null @@ -1,282 +0,0 @@ -import queue -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional - -import torch -from vllm.utils import logger - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore - - -class KVTransferThread(threading.Thread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, name: str): - super().__init__(daemon=True, name=name) - self.tp_rank = tp_rank - self.tp_size = tp_size - self.m_store = m_store - self.ready_event = ready_event - self.kv_caches_base_addr = local_kv_caches_base_addr - self.block_len = block_len - self.token_database = token_database - self.block_size = block_size - self.done_task_lock = threading.Lock() - # TODO(jianzs): find a better way to detect MLA. - self.use_mla = len(block_len) == 2 - - self.request_queue: queue.Queue[Any] = queue.Queue() - # TODO(jianzs): make this configurable - self.executor = ThreadPoolExecutor(max_workers=32) - self.finished_requests: set[str] = set() - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id - - def prepare_value_layer(self, start: int, end: int, block_ids: list[int], - layer_id: int): - block_id = block_ids[start // self.block_size] - if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[1] - length_k = int(self.block_len[0] / self.block_size * (end - start)) - length_v = int(self.block_len[1] / self.block_size * (end - start)) - size_list = [length_k, length_v] - else: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[0] - length = int(self.block_len[0] / self.block_size * (end - start)) - size_list = [length, length] - addr_list = [addr_k, addr_v] - return addr_list, size_list - - def add_request( - self, - req_id: str, - tokens: torch.Tensor, - block_ids: list[int], - mask: Optional[torch.Tensor] = None, - is_last_chunk: Optional[bool] = None, - ) -> torch.Tensor: - req = ({ - "req_id": req_id, - "tokens": tokens, - "block_ids": block_ids, - "mask": mask, - "is_last_chunk": is_last_chunk, - }) - self.request_queue.put(req) - - def get_and_clear_finished_requests(self) -> set[str]: - """ - Get and clear the requests that have been completed. - Returns: - A set of request IDs that have been completed. - """ - with self.done_task_lock: - finished_requests = self.finished_requests.copy() - self.finished_requests.clear() - return finished_requests - - def set_finished_request(self, req_id): - with self.done_task_lock: - self.finished_requests.add(req_id) - - def run(self): - """Run the thread to handle KV cache transfer requests.""" - self.ready_event.set() - while True: - try: - request_data = self.request_queue.get() - if request_data is None: - logger.warning("Received a None request!") - self.request_queue.task_done() - continue - self._handle_request(request_data) - except Exception as e: - logger.error(f"Error in KVCacheTransferThread: {e}") - - def _handle_request(self, req_meta: dict[str, Any]): - pass - - -class KVCacheStoreSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheSendingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - is_last_chunk = req_meta["is_last_chunk"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - torch.npu.current_stream().synchronize() - self.m_store.put_batch(key_list, addr_list, size_list, blockIds) - else: - torch.npu.current_stream().synchronize() - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.put(key, addr, size) - if is_last_chunk: - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreRecvingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.get(key, addr, size) - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - num_layers: int): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerSendingThread") - self.final_layer_id = num_layers - 1 - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - torch.npu.current_stream().synchronize() - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.put(key, addr, size) - if req_meta.layer_id == self.final_layer_id: - self.set_finished_request(req_meta.req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - get_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerRecvingThread") - self.get_event = get_event - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.get(key, addr, size) - self.request_queue.task_done() - self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py deleted file mode 100644 index 01020d72d87..00000000000 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ /dev/null @@ -1,127 +0,0 @@ -# Standard -import os - -# Third Party -from mooncake.store import ReplicateConfig # type: ignore -from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.utils import logger -from vllm.utils.network_utils import get_ip - -from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te - -from .config_data import MooncakeStoreConfig - -METADATA_BYTES_LEN = 24 -BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) - - -class Mooncakestore(): - - def __init__(self, parallel_config: ParallelConfig): - try: - from mooncake.store import MooncakeDistributedStore # type: ignore - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e - tp_rank = get_tensor_model_parallel_rank() - tp_size = parallel_config.tensor_parallel_size - dp_rank = parallel_config.data_parallel_rank_local - all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) - if not all_device_ids: - device_ids_list = list( - range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) - else: - device_ids_list = list(map(int, all_device_ids.split(','))) - assert len(device_ids_list) > tp_rank - device_id = device_ids_list[tp_rank] - self.config = MooncakeStoreConfig.load_from_env() - self.store = MooncakeDistributedStore() - if self.config.protocol == "ascend" and not self.config.use_ascend_direct: - local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \ - ":npu_" + str(device_id) - ret = self.store.setup(local_hostname, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address) - else: - local_hostname = get_ip() - transfer_engine = get_global_te(local_hostname, device_name=None) - self.local_seg = local_hostname + ":" + str( - transfer_engine.get_rpc_port()) - ret = self.store.setup(self.local_seg, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - transfer_engine.get_engine()) - if ret != 0: - msg = "Initialize mooncake failed." - logger.error(msg) - raise RuntimeError(msg) - - def exists(self, key: MooncakeEngineKey) -> bool: - return self.store.is_exist(key.to_string()) == 1 - - def batch_exists(self, keys: list[str]) -> list[int]: - return self.store.batch_is_exist(keys) - - def register_buffer(self, ptr, length): - return self.store.register_buffer(ptr, length) - - def get_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - res = self.store.batch_get_into_multi_buffers( - keys, addrs, sizes, True) - for value in res: - if value < 0: - logger.error(f"Failed to get key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to get key {keys}. {e}") - - def put_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - config = ReplicateConfig() - config.preferred_segment = self.local_seg - config.prefer_alloc_in_same_node = True - res = self.store.batch_put_from_multi_buffers( - keys, addrs, sizes, config) - for value in res: - if value < 0: - logger.error(f"Failed to put key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to put key {keys},error:{e}") - - def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - expect_res = sum(size) - key_str = key.to_string() - try: - res = self.store.batch_get_into_ascend(key_str, addr, size) - if res[0] != expect_res: - logger.error(f"Failed to get key: [{key_str}] .") - except Exception: - logger.error(f"Failed to get key: [{key_str}] .") - return res - - def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - key_str = key.to_string() - try: - ret = self.store.batch_put_from_ascend(key_str, addr, size) - if ret[0] != 0: - logger.error(f"Failed to put key {key_str}.") - except Exception: - logger.error(f"Failed to put key {key_str}.") - - return ret - - def close(self): - self.store.close() - logger.info("Closed the mooncake store connection") diff --git a/vllm_ascend/distributed/mooncake/transfer_engine.py b/vllm_ascend/distributed/mooncake/transfer_engine.py deleted file mode 100644 index d4e172b7857..00000000000 --- a/vllm_ascend/distributed/mooncake/transfer_engine.py +++ /dev/null @@ -1,38 +0,0 @@ -import ipaddress -import threading -from typing import Optional - -from mooncake.engine import TransferEngine # type: ignore - -_global_te = None -_global_te_lock = threading.Lock() - - -def get_global_te(hostname: str, device_name: Optional[str]): - try: - ip = ipaddress.ip_address(hostname) - if isinstance(ip, ipaddress.IPv6Address): - raise RuntimeError( - "The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6." - ) - except ValueError: - pass - - global _global_te - if _global_te is None: - with _global_te_lock: - # Double-Checked Locking - if _global_te is None: - if TransferEngine is None: - raise RuntimeError("mooncake is not available") - transfer_engine = TransferEngine() - device_name = device_name if device_name is not None else "" - ret_value = transfer_engine.initialize(hostname, - "P2PHANDSHAKE", - "ascend", device_name) - if ret_value != 0: - raise RuntimeError( - f"TransferEngine initialization failed with ret_value: {ret_value}" - ) - _global_te = transfer_engine - return _global_te diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 403b17e4f9a..754bba7b68b 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -31,11 +31,12 @@ get_tensor_model_parallel_rank, get_tp_group) from vllm.utils import logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te +from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.utils import prefill_context_parallel_enable @@ -634,7 +635,10 @@ def add_new_req( class MooncakeConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id @@ -944,7 +948,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): else: hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" logger.info("Initializing Mooncake work %s", engine_id) - self.engine = get_global_te(hostname, device_name=None) + self.engine = global_te.get_transfer_engine(hostname, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -1054,6 +1058,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: @@ -1061,13 +1067,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) elif self.use_sparse: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 3] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [ cache_or_caches @@ -1076,8 +1084,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) - + ptrs.append(base_addr) + lengths.append(region_len) + global_te.register_buffer(ptrs, lengths) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( engine_id=self.engine_id, @@ -1101,14 +1110,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_recv_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - ret_value = self.engine.register_memory(ptr, length) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed.") - def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index ccb6d344970..215becc5477 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -30,6 +30,7 @@ from vllm.utils import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -359,7 +360,10 @@ def add_new_req(self, class MooncakeLayerwiseConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() diff --git a/vllm_ascend/distributed/mooncake_transfer_engine.py b/vllm_ascend/distributed/mooncake_transfer_engine.py new file mode 100644 index 00000000000..fceecd4c4aa --- /dev/null +++ b/vllm_ascend/distributed/mooncake_transfer_engine.py @@ -0,0 +1,53 @@ +import ipaddress +import threading +from typing import Optional + +from mooncake.engine import TransferEngine # type: ignore + + +class GlobalTE(): + + def __init__(self): + self.transfer_engine = None + self.is_register_buffer: bool = False + self.transfer_engine_lock = threading.Lock() + self.register_buffer_lock = threading.Lock() + + def get_transfer_engine(self, hostname: str, device_name: Optional[str]): + try: + ip = ipaddress.ip_address(hostname) + if isinstance(ip, ipaddress.IPv6Address): + raise RuntimeError( + "The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6." + ) + except ValueError: + pass + if self.transfer_engine is None: + with self.transfer_engine_lock: + # Double-Checked Locking + if self.transfer_engine is None: + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + self.transfer_engine = TransferEngine() + device_name = device_name if device_name is not None else "" + ret_value = self.transfer_engine.initialize( + hostname, "P2PHANDSHAKE", "ascend", device_name) + if ret_value != 0: + raise RuntimeError( + f"TransferEngine initialization failed with ret_value: {ret_value}" + ) + return self.transfer_engine + + def register_buffer(self, ptrs: list[int], sizes: list[int]): + with self.register_buffer_lock: + assert self.transfer_engine is not None, "Transfer engine must be initialized" + if self.is_register_buffer: + return + for ptr, size in zip(ptrs, sizes): + ret_value = self.transfer_engine.register_memory(ptr, size) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + self.is_register_buffer = True + + +global_te = GlobalTE() diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 956df2eb315..31eae8d7cbe 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -1,7 +1,5 @@ from vllm import ModelRegistry -import vllm_ascend.envs as envs_ascend - def register_model(): ModelRegistry.register_model( @@ -10,24 +8,11 @@ def register_model(): ModelRegistry.register_model( "Qwen3VLMoeForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration" - ) + "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration") ModelRegistry.register_model( "Qwen3VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration" - ) - - if envs_ascend.USE_OPTIMIZED_MODEL: - ModelRegistry.register_model( - "Qwen2_5_VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration" - ) - else: - ModelRegistry.register_model( - "Qwen2_5_VLForConditionalGeneration", - "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" - ) + "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration") # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py deleted file mode 100644 index 0ff31712668..00000000000 --- a/vllm_ascend/models/qwen2_5_vl.py +++ /dev/null @@ -1,556 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/qwen2_5_vl.py -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Iterable, Optional, Set, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_npu -from einops import rearrange -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) -from vllm.config import VllmConfig -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import get_act_and_mul_fn -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, - Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer, - Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) -from vllm.model_executor.models.utils import maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz - -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight - - -class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.embed_dim = embed_dim - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # [s, b, 3 * head * head_dim] - seq_len, bs, _ = qkv.shape - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] - q, k, v = qkv.chunk(3, dim=2) - - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - q, k, v = (x.view(*new_shape) for x in (q, k, v)) - return q, k, v - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=cu_seqlens, - scale_value=self.origin_hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(dim=dim, - num_heads=num_heads, - mlp_hidden_dim=mlp_hidden_dim, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=prefix) - - self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding): - - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__(dim, theta) - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.inv_freq = inv_freq - - -class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): - - def __init__( - self, - vision_config: Qwen2_5_VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - norm_layer = partial(RMSNorm, eps=norm_eps) - self.interleaved = interleaved - self.enable_pad = False - head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // - 2) - self.patch_embed = Qwen2_5_VisionPatchEmbed( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - - act_fn = get_act_and_mul_fn(vision_config.hidden_act) - self.blocks = nn.ModuleList([ - AscendQwen2_5_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.enable_pad = True - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 - self.half_pad_hidden_size_per_attention_head = ( - MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - if self.enable_pad: - cos = torch.nn.functional.pad( - cos, (0, self.half_pad_hidden_size_per_attention_head)) - sin = torch.nn.functional.pad( - sin, (0, self.half_pad_hidden_size_per_attention_head)) - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def pad_qkv_bias(self, bias): - first_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, :self.half_origin_hidden_size_per_attention_head] - second_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, self.half_origin_hidden_size_per_attention_head:] - first_half_padded = torch.nn.functional.pad( - first_half, (0, self.half_pad_hidden_size_per_attention_head)) - second_half_padded = torch.nn.functional.pad( - second_half, (0, self.half_pad_hidden_size_per_attention_head)) - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) - bias_final = bias_padded.reshape(-1) - return bias_final - - def pad_qkv_weight(self, data): - qkv_weight_first_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] - qkv_weight_second_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] - - qkv_weight_first_half_padded = torch.nn.functional.pad( - qkv_weight_first_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_second_half_padded = torch.nn.functional.pad( - qkv_weight_second_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_padded = torch.cat( - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], - dim=2) - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - - if is_enable_nz(): - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( - qkv_weight_final) - qkv_weight_final_copy = torch_npu.npu_format_cast( - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) - return qkv_weight_final_copy - - return qkv_weight_final - - def pad_proj_weight(self, data): - out_weight = torch.nn.functional.pad( - data.reshape(self.hidden_size, -1, - self.half_origin_hidden_size_per_attention_head), - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( - self.hidden_size, -1) - - if is_enable_nz(): - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) - out_weight_copy = torch_npu.npu_format_cast( - out_weight_copy, ACL_FORMAT_FRACTAL_ND) - return out_weight_copy - - return out_weight - - def pad_qkv_weight_scale_offset(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, 1) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head, :] - data2 = reshaped_data[:, :, self. - half_origin_hidden_size_per_attention_head:, :] - data1_paded = torch.nn.functional.pad( - data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - data2_paded = torch.nn.functional.pad( - data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1, 1) - return res - - def pad_qkv_deq_scale_quant_bias(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head] - data2 = reshaped_data[:, :, - self.half_origin_hidden_size_per_attention_head:] - - data1_paded = torch.nn.functional.pad( - data1, (0, self.half_pad_hidden_size_per_attention_head)) - data2_paded = torch.nn.functional.pad( - data2, (0, self.half_pad_hidden_size_per_attention_head)) - - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1) - return res - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("mlp.gate_up_proj.", "mlp.gate_proj.", 0), - ("mlp.gate_up_proj.", "mlp.up_proj.", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if ("attn.proj.weight_scale" in name or - "attn.proj.weight_offset" in name) and self.enable_pad: - continue - elif ("attn.proj.deq_scale" in name - or "attn.proj.quant_bias" in name) and self.enable_pad: - continue - elif ("attn.qkv.weight_scale" in name - or "attn.qkv.weight_offset" in name) and self.enable_pad: - param.data = self.pad_qkv_weight_scale_offset(param.data) - elif ("attn.qkv.deq_scale" in name - or "attn.qkv.quant_bias" in name) and self.enable_pad: - param.data = self.pad_qkv_deq_scale_quant_bias(param.data) - elif ("attn.proj.weight" in name) and self.enable_pad: - param.data = self.pad_proj_weight(param.data) - elif ("attn.qkv.weight" in name) and self.enable_pad: - param.data = self.pad_qkv_weight(param.data) - elif ("attn.qkv.bias" in name) and self.enable_pad: - param.data = self.pad_qkv_bias(param.data) - loaded_params.add(name) - return loaded_params - - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=x.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) - seq_len, _ = x.size() - x = x.reshape(seq_len // self.spatial_merge_unit, - self.spatial_merge_unit, -1) - x = x[window_index, :, :] - x = x.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - # transformers - x = x.unsqueeze(1) - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - reverse_indices = torch.argsort(window_index) - x = x[reverse_indices, :] - return x - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen2_5_VLMultiModalProcessor, - info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class AscendQwen2_5_VLForConditionalGeneration( - Qwen2_5_VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: - - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - with set_ascend_forward_context(None, self.vllm_config): - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - - # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) - - def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: - - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - - if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - with set_ascend_forward_context(None, self.vllm_config): - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw) - - # Split concatenated embeddings for each video item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py deleted file mode 100644 index d51a5aca9a6..00000000000 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ /dev/null @@ -1,617 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_npu -from einops import rearrange -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) - -try: - from transformers.models.qwen3_vl.configuration_qwen3_vl import \ - Qwen3VLConfig - from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ - Qwen3VLMoeConfig -except ImportError: - pass -from vllm.config import VllmConfig -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - get_act_and_mul_fn) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, - Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, - Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, - Qwen2_5_VLProcessingInfo) - -try: - from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) - from vllm.model_executor.models.qwen3_vl_moe import ( - Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) -except ImportError: - Qwen3_VisionBlock = object - Qwen3_VisionPatchEmbed = object - Qwen3_VisionTransformer = object - Qwen3VLDummyInputsBuilder = object - Qwen3VLForConditionalGeneration = object - Qwen3VLMultiModalProcessor = object - Qwen3VLProcessingInfo = object - Qwen3VLMoeForConditionalGeneration = object - Qwen3VLMoeProcessingInfo = object -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding - - -class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.embed_dim = embed_dim - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1.dev20250226 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=cu_seqlens, - scale_value=self.hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock): - - def __init__(self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(dim=dim, - num_heads=num_heads, - mlp_hidden_dim=mlp_hidden_dim, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=prefix) - self.attn = AscendQwen2_5_VisionAttention_Without_Padding( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - return x - - -class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer - ): - - def __init__( - self, - vision_config: Qwen2_5_VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - norm_layer = partial(RMSNorm, eps=norm_eps) - self.interleaved = interleaved - head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // - 2) - self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - - act_fn = get_act_and_mul_fn(vision_config.hidden_act) - self.blocks = nn.ModuleList([ - AscendQwen2_5_VisionBlock_Without_Padding( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - return window_index, cu_window_seqlens - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=x.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) - seq_len, _ = x.size() - x = x.reshape(seq_len // self.spatial_merge_unit, - self.spatial_merge_unit, -1) - x = x[window_index, :, :] - x = x.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - # transformers - x = x.unsqueeze(1) - for layer_num, blk in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: - cu_seqlens_now = cu_seqlens - else: - cu_seqlens_now = cu_window_seqlens - x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - reverse_indices = torch.argsort(window_index) - x = x[reverse_indices, :] - return x - - -class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - x = x + self.proj.bias - return x - - -class AscendQwen3_VisionBlock(Qwen3_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(dim=dim, - num_heads=num_heads, - mlp_hidden_dim=mlp_hidden_dim, - act_fn=act_fn, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=prefix, - use_data_parallel=use_data_parallel) - - self.attn = AscendQwen2_5_VisionAttention_Without_Padding( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): - - def __init__( - self, - vision_config, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix, - use_data_parallel) - norm_layer = partial(nn.LayerNorm, eps=norm_eps) - self.patch_embed = AscendQwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - self.blocks = nn.ModuleList([ - AscendQwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def forward( - self, - x: torch.Tensor, - grid_thw: list[list[int]], - ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) - cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cpu().to(torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - deepstack_feature_lists = [] - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin) - if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) - deepstack_feature_lists.append(deepstack_feature) - hidden_states = self.merger(hidden_states) - hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] - return hidden_states - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen2_5_VLMultiModalProcessor, - info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( - Qwen2_5_VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( - vision_config=config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: - - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - - # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) - - def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: - - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - - if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - - # Split concatenated embeddings for each video item. - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLMoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLMoeForConditionalGeneration( - Qwen3VLMoeForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 8578dec4e32..b1d7b5444a9 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -423,50 +423,20 @@ def _forward( non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 - batch_size = initial_state.shape[0] - core_attn_out = [] - last_recurrent_state = [] - - for b_idx in range(batch_size): - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - cur_q = query_non_spec[:, start:end, ...] - cur_k = key_non_spec[:, start:end, ...] - cur_v = value_non_spec[:, start:end, ...] - cur_g = g_non_spec[:, start:end, ...] - cur_b = beta_non_spec[:, start:end, ...] - cur_state = initial_state[b_idx].unsqueeze(0) - - ( - cur_core_attn_out_non_spec, - cur_last_recurrent_state, - ) = chunk.chunk_gated_delta_rule( - query=cur_q, - key=cur_k, - value=cur_v, - g=cur_g, - beta=cur_b, - initial_state=cur_state, - output_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - - core_attn_out.append(cur_core_attn_out_non_spec) - last_recurrent_state.append(cur_last_recurrent_state) - - tar_dtype = core_attn_out[0].dtype - tar_device = core_attn_out[0].device - tar_shape = list(core_attn_out[0].shape) - tar_shape[1] = non_spec_query_start_loc[-1] - core_attn_out_non_spec = torch.empty(tar_shape, - dtype=tar_dtype, - device=tar_device) - for b_idx in range(batch_size): - cur_core_attn_out = core_attn_out[b_idx] - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out - last_recurrent_state = torch.cat(last_recurrent_state, dim=0) + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk.chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( @@ -675,7 +645,7 @@ def _forward_core( initial_state[~has_initial_state, ...] = 0 batch_size = initial_state.shape[0] - core_attn_out = [] + temp_core_attn_out = [] last_recurrent_state = [] for b_idx in range(batch_size): @@ -702,18 +672,18 @@ def _forward_core( use_qk_l2norm_in_kernel=True, ) - core_attn_out.append(cur_core_attn_out_non_spec) + temp_core_attn_out.append(cur_core_attn_out_non_spec) last_recurrent_state.append(cur_last_recurrent_state) - tar_dtype = core_attn_out[0].dtype - tar_device = core_attn_out[0].device - tar_shape = list(core_attn_out[0].shape) + tar_dtype = temp_core_attn_out[0].dtype + tar_device = temp_core_attn_out[0].device + tar_shape = list(temp_core_attn_out[0].shape) tar_shape[1] = non_spec_query_start_loc[-1] core_attn_out_non_spec = torch.empty(tar_shape, dtype=tar_dtype, device=tar_device) for b_idx in range(batch_size): - cur_core_attn_out = core_attn_out[b_idx] + cur_core_attn_out = temp_core_attn_out[b_idx] start, end = non_spec_query_start_loc[ b_idx], non_spec_query_start_loc[b_idx + 1] core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out diff --git a/vllm_ascend/models/qwen3_vl.py b/vllm_ascend/models/qwen3_vl.py new file mode 100644 index 00000000000..c79e71e7197 --- /dev/null +++ b/vllm_ascend/models/qwen3_vl.py @@ -0,0 +1,264 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from transformers.models.qwen3_vl.configuration_qwen3_vl import \ + Qwen3VLConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ + Qwen3VLMoeConfig +except ImportError: + pass +from vllm.config import VllmConfig +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention + +try: + from vllm.model_executor.models.qwen3_vl import ( + Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) + from vllm.model_executor.models.qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) +except ImportError: + Qwen3_VisionBlock = object + Qwen3_VisionPatchEmbed = object + Qwen3_VisionTransformer = object + Qwen3VLDummyInputsBuilder = object + Qwen3VLForConditionalGeneration = object + Qwen3VLMultiModalProcessor = object + Qwen3VLProcessingInfo = object + Qwen3VLMoeForConditionalGeneration = object + Qwen3VLMoeProcessingInfo = object +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + + +class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + x = x + self.proj.bias + return x + + +class AscendQwen3_VisionBlock(Qwen3_VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix, use_data_parallel) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix, + use_data_parallel) + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + self.patch_embed = AscendQwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + self.blocks = nn.ModuleList([ + AscendQwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0]).cpu().to(torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + cos=cos, + sin=sin) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class AscendQwen3VLMoeForConditionalGeneration( + Qwen3VLMoeForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.visual = AscendQwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py new file mode 100644 index 00000000000..2d3dade7741 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +import warnings +from typing import Optional + +import torch +from einops import rearrange +from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd +from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .solve_tril import solve_tril +from .utils import input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len( + beta.shape + ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2) + q, k, v, beta, g = map( + lambda x: rearrange(x, 'b h t ... -> b t h ...'), + (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, + use_qk_l2norm_in_kernel) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state \ No newline at end of file diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py new file mode 100644 index 00000000000..846623ad53f --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp + +_CONDITIONS = ("seq7168", ) + + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_nh = tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = 1 * T + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + stride_v = H * V + stride_k = Hg * K + stride_w = H * K + + b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32) + b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32) + + v_start1 = 0 + v_start2 = 64 + + offs_k = tl.arange(0, 128)[:, None] + offs_v1 = v_start1 + tl.arange(0, 64)[None, :] + offs_v2 = v_start2 + tl.arange(0, 64)[None, :] + mask_kv1 = (offs_k < K) & (offs_v1 < V) + mask_kv2 = (offs_k < K) & (offs_v2 < V) + + # load initial state + if USE_INITIAL_STATE: + h0_ptr = h0 + i_nh * K * V + ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1 + b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, + other=0.0).to(tl.float32) + + ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1 + b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, + other=0.0).to(tl.float32) + + # main recurrence + for i_t in range(NT): + h_base = h + (boh + i_t) * H * K * V + i_h * K * V + + p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), + (128, 64), (1, 0)) + tl.store(p_h1_bv1, + b_h1_bv1.to(p_h1_bv1.dtype.element_ty), + boundary_check=(0, 1)) + + p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), + (128, 64), (1, 0)) + tl.store(p_h1_bv2, + b_h1_bv2.to(p_h1_bv2.dtype.element_ty), + boundary_check=(0, 1)) + + offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None] + offs_k_wv = tl.arange(0, 128)[None, :] + mask_w = (offs_t_wv < T) & (offs_k_wv < K) + + w_base = w + bos * H * K + i_h * K + ptr_w = w_base + offs_t_wv * stride_w + offs_k_wv * 1 + b_w = tl.load(ptr_w, mask=mask_w, other=0.0) + + k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K + p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), + (128, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + v_new_base = v_new + bos * H * V + i_h * V + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos + i_h * T_max + last_idx) + + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_g = safe_exp(b_g_last - b_g) + b_g_last = tl.exp(b_g_last) + + offs_t_v = (i_t * BT + tl.arange(0, BT))[:, None] + mask_v1 = (offs_t_v < T) & (offs_v1 < V) + + v_base = v + bos * H * V + i_h * V + ptr_v1 = v_base + offs_t_v * stride_v + offs_v1 * 1 + b_v1 = tl.load(ptr_v1, mask=mask_v1, other=0.0) + b_v_new1 = b_v1.to(tl.float32) + b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype)) + + if SAVE_NEW_VALUE: + p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), + (i_t * BT, v_start1), (BT, 64), + (1, 0)) + tl.store(p_v_new1, + b_v_new1.to(p_v_new1.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + b_v_new1 = b_v_new1 * b_g[:, None] + b_h1_bv1 = b_h1_bv1 * b_g_last + + b_v_new1 = b_v_new1.to(k.dtype.element_ty) + b_h1_bv1 += tl.dot(b_k, b_v_new1) + + mask_v2 = (offs_t_v < T) & (offs_v2 < V) + ptr_v2 = v_base + offs_t_v * stride_v + offs_v2 * 1 + b_v2 = tl.load(ptr_v2, mask=mask_v2, other=0.0) + b_v_new2 = b_v2.to(tl.float32) + b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype)) + + if SAVE_NEW_VALUE: + p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), + (i_t * BT, v_start2), (BT, 64), + (1, 0)) + tl.store(p_v_new2, + b_v_new2.to(p_v_new2.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + b_v_new2 = b_v_new2 * b_g[:, None] + b_h1_bv2 = b_h1_bv2 * b_g_last + + b_v_new2 = b_v_new2.to(k.dtype.element_ty) + b_h1_bv2 += tl.dot(b_k, b_v_new2) + + # epilogue + if STORE_FINAL_STATE: + ht_ptr = ht + i_nh * K * V + + p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), + (128, 64), (1, 0)) + tl.store(p_ht1_bv1, + b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), + boundary_check=(0, 1)) + + p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), + (128, 64), (1, 0)) + tl.store(p_ht1_bv2, + b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), + boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None else None) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = (k.new_empty(N, H, K, V, dtype=torch.float32) + if output_final_state else None) + + v_new = torch.empty_like(u) if save_new_value else None + g = g.transpose(1, 2).contiguous() + + def grid(meta): + return (1, N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + num_warps=4, + num_stages=2, + ) + return h, v_new, final_state diff --git a/vllm_ascend/ops/triton/fla/chunk_o.py b/vllm_ascend/ops/triton/fla/chunk_o.py new file mode 100644 index 00000000000..5a3578a8261 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_o.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_offsets, safe_exp + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = T + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int64) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + + for i_t in range(NT): + i_tg = boh + i_t + h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), + (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), + (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), + (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_o = b_o * tl.exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT).to(tl.float32) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + # to fix mma -> mma layout conversion + # already solved by fla v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = chunk_size + + if scale is None: + scale = k.shape[-1]**-0.5 + + o = torch.empty_like(v) + if cu_seqlens is None: + N, chunk_offsets = B, None + else: + N, chunk_offsets = ( + len(cu_seqlens) - 1, + prepare_chunk_offsets(cu_seqlens, BT), + ) + + def grid(meta): + return (triton.cdiv(V, meta['BV']), N * H) + + g = g.transpose(1, 2).contiguous() + chunk_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + o=o, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=128, + num_warps=4, + num_stages=2, + ) + return o diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py new file mode 100644 index 00000000000..aa183149a67 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices, safe_exp + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g_cumsum'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, # [H, B, T] + g_cumsum, # [H, B, T] + A, + cu_seqlens, + chunk_indices, + T, + B, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + bt_stride = B * T + i_t_i, _ = tl.program_id(0), tl.program_id(1) + + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to( + tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + i_t = i_t_i + o_t = tl.arange(0, BT) + o_t_fp32 = o_t.to(tl.float32) + + p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), + (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ), + (1, ), (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A *= safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + gk (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.cpu() + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + chunk_indices = chunk_indices.npu() + cu_seqlens = cu_seqlens.npu() + else: + chunk_indices = None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + + chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)]( + k=k, + beta=torch.permute(beta, (2, 0, 1)).contiguous(), + g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(), + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=128, + num_warps=8, + num_stages=3, + multibuffer=True, + ) + return A diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py new file mode 100644 index 00000000000..e93a2438ffa --- /dev/null +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BLOCK_T: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, + CHUNK_SIZE: tl.constexpr = 64, +): + i_block, i_b = tl.program_id(0), tl.program_id(1) + N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE + + if IS_VARLEN: + i_s, i_block = tl.load(chunk_indices + i_block * 2).to( + tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_s).to( + tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), + (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), + (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE)) + b_s = tl.trans(b_s, (2, 0, 1)) + b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) + if HAS_SCALE: + b_o *= scale + b_o = tl.trans(b_o, (2, 0, 1)) + b_o = tl.reshape(b_o, (H, BLOCK_T)) + else: + ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), + (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), + (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) + b_s = tl.trans(b_s, (1, 0, 2)) + b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) + if HAS_SCALE: + b_o *= scale + b_o = tl.trans(b_o, (1, 0, 2)) + b_o = tl.reshape(b_o, (BLOCK_T, H)) + + tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, )) + return + + +def chunk_local_cumsum_scalar( + g, + chunk_size, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.Tensor] = torch.float, +): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size)) + block_indices = prepare_chunk_indices( + cu_seqlens, + chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None + num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv( + T, OPTIM_BLOCK_SIZE) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (num_blocks, B) + chunk_local_cumsum_scalar_kernel[grid](s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=block_indices, + T=T, + B=B, + H=H, + BLOCK_T=OPTIM_BLOCK_SIZE, + CHUNK_SIZE=chunk_size, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=8, + num_stages=3) + return g + + +def chunk_local_cumsum(g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[ + 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise") diff --git a/vllm_ascend/ops/triton/fla/fla.py b/vllm_ascend/ops/triton/fla/layernorm_guard.py similarity index 62% rename from vllm_ascend/ops/triton/fla/fla.py rename to vllm_ascend/ops/triton/fla/layernorm_guard.py index 79039002d1f..c99f9e08d4b 100644 --- a/vllm_ascend/ops/triton/fla/fla.py +++ b/vllm_ascend/ops/triton/fla/layernorm_guard.py @@ -7,7 +7,6 @@ # mypy: ignore-errors import torch -import torch.nn.functional as F from vllm.triton_utils import tl, triton MAX_CORES = 65535 @@ -200,100 +199,3 @@ def forward( is_rms_norm=is_rms_norm, ) return y.reshape(x_shape_og) - - -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -( - (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - - last_recurrent_state = (torch.zeros(batch_size, sequence_length, - k_head_dim, v_head_dim).to(value) if - initial_state is None else initial_state.to(value)) - - core_attn_out = torch.zeros_like(value) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=1) - - # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * - decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() + - (k_i * - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2) @ v_new) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], - core_attn_out.shape[1], -1, - core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :num_heads] - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py new file mode 100644 index 00000000000..a80003207ca --- /dev/null +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, + LARGE_BLOCK_T: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + base_t = i_t * LARGE_BLOCK_T + + NTASKS: tl.constexpr = 2 + N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS + + for taskid in range(0, NTASKS): + base_t += taskid * (LARGE_BLOCK_T // NTASKS) + + # use make_block_ptr to reduce vector computation + b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32) + for blkid in range(0, N_BLOCKS): + row_start_o = base_t + blkid * 16 + col_start_o = row_start_o % BT + + # 1 Create in-block offset + offs_rows_in_block = tl.arange(0, 16) + offs_cols_in_block = tl.arange(0, 16) + + # 2 Calculate the pointer of each element + ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o + + offs_rows_in_block[:, None] * H * BT + + offs_cols_in_block[None, :]) + + # 3 Create a mask to prevent out-of-bounds access + global_rows = row_start_o + offs_rows_in_block[:, None] + global_cols = col_start_o + offs_cols_in_block[None, :] + load_mask = (global_rows < T) & (global_cols < BT) + + # 4 Use mask to safely load data + b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, + other=0.0).to(tl.float32) + b_A = tl.insert_slice( + ful=b_A, + sub=b_A_subrec16[None, :, :], # (1, 16, 16) + offsets=[blkid, 0, 0], + sizes=[1, 16, 16], + strides=[1, 1, 1]) + + local_ori_A = tl.trans(b_A, (1, 0, 2)) + local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS)) + + # Convert mask into matrix multiplication to avoid for loops ub oom + tmp = tl.arange(0, 16).to(tl.float32) + rows = tmp[:, None] + cols = tmp[None, :] + is_lower = (rows > cols).to(b_A.dtype) + b_A = -b_A * is_lower + + # for loop to update N_BLOCKS row vector + for i in range(1, 16): + nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), + (1, 16 * N_BLOCKS), + (16 * N_BLOCKS, 1)) + b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) + + dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) + dot_product = tl.sum(dot_tmp, 0) + b_a = b_a + dot_product + + b_a_new_expanded = b_a[:, None, :] + b_A = tl.insert_slice(ful=b_A, + sub=b_a_new_expanded, + offsets=[0, i, 0], + sizes=[N_BLOCKS, 1, 16], + strides=[1, 1, 1]) + + on_diagonal = (rows == cols) + b_A = tl.where(on_diagonal, b_A + 1.0, b_A) + + b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), + (N_BLOCKS * 16, 16), (1, 0)) + + # 1 Create in-block offset + offs_rows_to_store = tl.arange(0, N_BLOCKS * 16) + offs_cols_to_store = tl.arange(0, 16) + + # 2 Calculate the pointer of each element + p_Ai = (Ad + base_t * H * 16 + 0 + + offs_rows_to_store[:, None] * H * 16 + + offs_cols_to_store[None, :]) + # 3 Create a mask to prevent out-of-bounds access, only check rows + global_store_rows = base_t + offs_rows_to_store[:, None] + store_mask = global_store_rows < T + # 4 use mask to save data safely + tl.store(p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=store_mask) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), + Ai_11, + input_precision="ieee", + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t_val = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + i_t = i_t_val + else: + bos, eos = i_b * T, i_b * T + T + + # Base pointers (already offset by batch and head) + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + # load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16) + offs_m = i_t * 64 + 16 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + # load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_22, A_21, input_precision="ieee") + + # load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16) + offs_m = i_t * 64 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee") + + # load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16) + offs_m = i_t * 64 + 48 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + # load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16) + offs_n = 32 + tl.arange(0, 16) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_44, A_43, input_precision="ieee") + + # load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16) + offs_m = i_t * 64 + 32 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee") + + # build Ai_22_32 (32 * 32) + Ai_22_32 = tl.zeros((32, 32), tl.float32) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) + + # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) + offs_m = i_t * 64 + 32 + tl.arange(0, 32) + offs_n = tl.arange(0, 32) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee") + + # build Ai_11_32 (32 * 32) + Ai_11_32 = tl.zeros((32, 32), tl.float32) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) + + Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") + + # store Ai_11_32 to (i_t * 64, 0) + offs_m = i_t * 64 + tl.arange(0, 32) + offs_n = tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # store Ai_22_32 to (i_t * 64 + 32, 32) + offs_m = i_t * 64 + 32 + tl.arange(0, 32) + offs_n = 32 + tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # store Ai_21_32 to (i_t * 64 + 32, 32) + offs_n = tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63) + offs_m = i_t * 64 + tl.arange(0, 32) + offs_n = 32 + tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT) + ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :] + zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty) + tl.store(ptr_Ai, zero_block, mask=mask_store) + + +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, + T, + H, + 16, + device=A.device, + dtype=torch.float if BT != 16 else output_dtype) + + LARGE_BLOCK_T = 608 * 2 + + chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv( + T, LARGE_BLOCK_T) + + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + LARGE_BLOCK_T=LARGE_BLOCK_T, + num_warps=1, + num_stages=4, + ) + + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = (merge_16x16_to_32x32_inverse_kernel + if BT == 32 else merge_16x16_to_64x64_inverse_kernel) + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + num_warps=4, + num_stages=3, + ) + return Ai diff --git a/vllm_ascend/ops/triton/fla/utils.py b/vllm_ascend/ops/triton/fla/utils.py new file mode 100644 index 00000000000..4d2cd1350ff --- /dev/null +++ b/vllm_ascend/ops/triton/fla/utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +from typing import Callable + +import torch +from vllm.triton_utils import tl, triton + + +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + indices = torch.cat([ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], + 1).to(cu_seqlens) + + +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + ]).cumsum(-1) + + +def input_guard( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else + i.contiguous() for i in args) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.npu.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float("-inf"))) diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py new file mode 100644 index 00000000000..1d4c295553f --- /dev/null +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional, Tuple + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, + T, H: tl.constexpr, Hg: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, IS_VARLEN: tl.constexpr): + T_max = T + i_t_o = tl.program_id(0) + + for i_bh in range(H): + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to( + tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + offs_t = tl.arange(0, BT) + global_offs_t = i_t * BT + offs_t + mask_t = global_offs_t < T + + offs_t_2d = global_offs_t[:, None] + offs_bt = tl.arange(0, BT)[None, :] + ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1) + mask_A = mask_t[:, None] + b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + + ptr_g = g + bos + i_h * T_max + global_offs_t + b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32) + + ptr_beta = beta + bos + i_h * T_max + global_offs_t + b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + offs_v = i_v * BV + tl.arange(0, BV)[None, :] + mask_v = (mask_t[:, None]) & (offs_v < V) + + ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) + + offs_v * 1) + b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32) + + b_vb = (b_v * b_beta[:, None]) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + + ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) + + offs_v * 1) + tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v) + + for i_k in range(tl.cdiv(K, BK)): + offs_k = i_k * BK + tl.arange(0, BK)[None, :] + mask_k = (mask_t[:, None]) & (offs_k < K) + ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * + (Hg * K) + offs_k * 1) + b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32) + + b_kb = (b_k * b_beta[:, None] * b_g[:, None]) + b_w = tl.dot(b_A, b_kb) + + ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) + + offs_k * 1) + tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \ + if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BK = 64 + BV = 64 + + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + beta = beta.transpose(1, 2).contiguous() + g_cumsum = g_cumsum.transpose(1, 2).contiguous() + recompute_w_u_fwd_kernel[(NT, B)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=4, + num_stages=3, + ) + return w, u diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index a361789f3dd..faa57b6140f 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -27,3 +27,5 @@ import vllm_ascend.patch.worker.patch_weight_loader # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_minicpm # noqa +import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa +import vllm_ascend.patch.worker.patch_rope # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py new file mode 100644 index 00000000000..27f08751bff --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -0,0 +1,501 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from functools import lru_cache, partial + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \ + Qwen2_5_VLVisionConfig +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer, + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImageInputs, + Qwen2_5_VLVideoInputs) +from vllm.model_executor.models.utils import cast_overflow_tensors +from vllm.model_executor.models.vision import ( + get_vit_attn_backend, run_dp_sharded_mrope_vision_model) + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_forward_context import set_ascend_forward_context + +MIN_PAD_SIZE = 64 # min_size to pad weight +MAX_PAD_SIZE = 128 # max_size to pad weight + + +class AscendQwen2_5_VisionAttention(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: torch.Tensor, + seqlens: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + seq_len, batch_size, _ = x.shape + + # Split q k v. + qkv = einops.rearrange( + x, + "s b (three head head_dim) -> b s three head head_dim", + three=3, + head=self.num_attention_heads_per_partition, + ) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + origin_shape = q.shape[-1] + + # Convert cumulative tensor to intervals and move it to cpu. + cu_seqlens = torch.diff(cu_seqlens).to("cpu") + + cos = rotary_pos_emb_cos + sin = rotary_pos_emb_sin + cos = einops.rearrange( + torch.stack((cos, cos), dim=-1), + "... d two -> ...(d two)", + two=2, + ) + sin = einops.rearrange( + torch.stack((sin, sin), dim=-1), + "... d two -> ...(d two)", + two=2, + ) + cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head) + sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head) + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + + q, k, v = [ + einops.rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + enable_pad = (envs_ascend.USE_OPTIMIZED_MODEL + and self.hidden_size_per_attention_head > MIN_PAD_SIZE + and self.hidden_size_per_attention_head < MAX_PAD_SIZE) + + if enable_pad: + pad_len = MAX_PAD_SIZE - origin_shape + # q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] + q = F.pad(q, (0, pad_len), mode="constant", value=0) + k = F.pad(k, (0, pad_len), mode="constant", value=0) + v = F.pad(v, (0, pad_len), mode="constant", value=0) + + context_layer = torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer, + ) + + if enable_pad: + context_layer = context_layer[..., :origin_shape] + + context_layer = einops.rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class AscendQwen2_5_VisionBlock(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers + ) -> torch.Tensor: + x_attn = self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) + return x + + +class AscendQwen2_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + nn.Module.__init__(self) + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.out_hidden_size + + # args for get_window_index_thw + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + # TODO[@lucaskabela]: Investigate fixing this usage + # see https://github.com/vllm-project/vllm/issues/27044 + # DO NOT MOVE THIS IMPORT + from vllm.compilation.backends import set_model_tag + + with set_model_tag("Qwen2_5_VisionPatchEmbed"): + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) + + use_upstream_fa = False + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa, + attn_backend_override=attn_backend_override, + )) + + with set_model_tag("Qwen2_5_VisionBlock"): + self.blocks = nn.ModuleList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + attn_backend_override=attn_backend_override, + ) for layer_idx in range(depth) + ]) + + with set_model_tag("Qwen2_5_VisionPatchMerger"): + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + + def rotary_pos_emb_thw(self, t, h, w): + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = (hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + wpos_ids = (wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + max_size = max(h, w) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + cos_combined = cos_combined.reshape( + cos_combined.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, + -1, + ) + sin_combined = sin_combined.reshape( + sin_combined.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, + -1, + ) + + return cos_combined, sin_combined + + @lru_cache(maxsize=1024) # noqa: B019 + def get_rope_by_thw(self, t, h, w): + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( + t, h, w) + cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w) + + cos_thw = cos_thw[window_index_thw, :, :] + cos_thw = cos_thw.flatten(start_dim=0, end_dim=1) + sin_thw = sin_thw[window_index_thw, :, :] + sin_thw = sin_thw.flatten(start_dim=0, end_dim=1) + + cu_seqlens_thw = torch.repeat_interleave( + torch.tensor([h * w], dtype=torch.int32), t) + return ( + cos_thw, + sin_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + # patchify + seq_len, _ = x.size() + rotary_pos_emb_cos: list = [] + rotary_pos_emb_sin: list = [] + window_index: list = [] + cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] + cu_seqlens: list = [] + + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + window_index_id = 0 + cu_window_seqlens_last = 0 + for t, h, w in grid_thw: + t, h, w = int(t), int(h), int(w) + llm_h = h // self.spatial_merge_size + llm_w = w // self.spatial_merge_size + + ( + cos_thw, + sin_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) = self.get_rope_by_thw(t, h, w) + + window_index.append(window_index_thw + window_index_id) + window_index_id += t * llm_h * llm_w + + cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last + cu_window_seqlens_last = cu_seqlens_window_thw[-1] + cu_window_seqlens.append(cu_seqlens_window_thw) + + rotary_pos_emb_cos.append(cos_thw) + rotary_pos_emb_sin.append(sin_thw) + + cu_seqlens.append(cu_seqlens_thw) + + rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos) + rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin) + window_index = torch.cat(window_index) + # compute reverse indices + reverse_indices = self.invert_permutation(window_index) + cu_window_seqlens = torch.cat(cu_window_seqlens) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_seqlens = torch.cat(cu_seqlens) + cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + # pre-compute seqlens for window/full attn to reduce cuMemcpy operations + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( + cu_seqlens) + max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( + cu_window_seqlens) + + cu_seqlens = cu_seqlens.to( # type: ignore[attr-defined] + device=self.device, + non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to( # type: ignore[attr-defined] + device=self.device, + non_blocking=True) + rotary_pos_emb_cos = rotary_pos_emb_cos.to( # type: ignore[attr-defined] + device=self.device, + non_blocking=True) + rotary_pos_emb_sin = rotary_pos_emb_sin.to( # type: ignore[attr-defined] + device=self.device, + non_blocking=True) + window_index = window_index.to( # type: ignore[attr-defined] + device=hidden_states.device, + non_blocking=True) + reverse_indices = reverse_indices.to(device=hidden_states.device, + non_blocking=True) + + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + hidden_states = hidden_states.unsqueeze(1) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + seqlens_now = seqlens_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + seqlens_now = seqlens_window + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen_now, + seqlens=seqlens_now, + ) + + # For Qwen2.5-VL-3B, float16 will overflow at last block + # for long visual tokens sequences. + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + + # adapter + hidden_states = self.merger(hidden_states) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + +class AscendQwen2_5_VLForConditionalGeneration(nn.Module): + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"] + with set_ascend_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return image_embeds.split(sizes) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"] + with set_ascend_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d", + ) + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw_list) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return video_embeds.split(sizes) + + +# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm. +Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward + +# NOTE: These will be removed after https://github.com/vllm-project/vllm/pull/29388 is merged. +Qwen2_5_VLForConditionalGeneration._process_image_input = AscendQwen2_5_VLForConditionalGeneration._process_image_input +Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input + +# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +Qwen2_5_VisionBlock.forward = AscendQwen2_5_VisionBlock.forward +Qwen2_5_VisionTransformer.__init__ = AscendQwen2_5_VisionTransformer.__init__ +Qwen2_5_VisionTransformer.rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer.rotary_pos_emb_thw +Qwen2_5_VisionTransformer.get_rope_by_thw = AscendQwen2_5_VisionTransformer.get_rope_by_thw +Qwen2_5_VisionTransformer.forward = AscendQwen2_5_VisionTransformer.forward diff --git a/vllm_ascend/patch/worker/patch_rope.py b/vllm_ascend/patch/worker/patch_rope.py new file mode 100644 index 00000000000..cb40af86728 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_rope.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch.nn as nn +from vllm.model_executor.layers.rotary_embedding.base import \ + RotaryEmbeddingBase + + +class AscendRotaryEmbeddingBase(nn.Module): + + def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_sin = self.cos_sin_cache[:seqlen] + cos, sin = cos_sin.chunk(2, dim=-1) + return cos, sin + + +# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +RotaryEmbeddingBase.get_cos_sin = AscendRotaryEmbeddingBase.get_cos_sin diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index eb3f300bfac..2f5af43be48 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -1,10 +1,7 @@ -import vllm.model_executor.layers.fla.ops.chunk -import vllm.model_executor.layers.fla.ops.fused_recurrent -import vllm.model_executor.layers.fla.ops.layernorm_guard import vllm.model_executor.layers.mamba.ops.causal_conv1d -from vllm_ascend.ops.triton.fla.fla import (LayerNormFn, - torch_chunk_gated_delta_rule) +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule +from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_recurrent_gated_delta_rule_fwd_kernel from vllm_ascend.ops.triton.mamba.casual_conv1d import ( @@ -14,4 +11,4 @@ vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn -vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = chunk_gated_delta_rule diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0797da3270e..7cc84fc6ae3 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,12 +30,34 @@ init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, AscendDeviceType, - enable_sp, get_ascend_device_type, is_vl_model, - prefill_context_parallel_enable, - update_aclgraph_sizes, - update_cudagraph_capture_sizes, - update_default_aclgraph_sizes) + +# isort: off +from vllm_ascend.utils import ( + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, + enable_sp, get_ascend_device_type, is_vl_model, + prefill_context_parallel_enable, update_aclgraph_sizes, + update_cudagraph_capture_sizes, update_default_aclgraph_sizes) + +# set custom ops path +CUR_DIR = os.path.dirname(os.path.realpath(__file__)) +CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "vllm_ascend", "_cann_ops_custom", + "vendors", "customize") +CUSTOM_LIB_PATH = os.path.join(CUSTOM_OPP_PATH, "op_api", "lib") + +if os.path.exists(CUSTOM_OPP_PATH): + current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "") + if current_cust_opp_path: + os.environ[ + "ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" + else: + os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH + +if os.path.exists(CUSTOM_LIB_PATH): + current_lib_path = os.environ.get("LD_LIBRARY_PATH", "") + if current_lib_path: + os.environ["LD_LIBRARY_PATH"] = f"{CUSTOM_LIB_PATH}:{current_lib_path}" + else: + os.environ["LD_LIBRARY_PATH"] = CUSTOM_LIB_PATH if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -56,7 +78,9 @@ class NPUPlatform(Platform): device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" - supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD] + supported_quantization: list[str] = [ + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + ] def is_sleep_mode_available(self) -> bool: return True @@ -79,6 +103,8 @@ def pre_register_and_update(cls, if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) + from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ + AscendCompressedTensorsConfig # noqa: F401 from vllm_ascend.quantization.quant_config import \ AscendQuantConfig # noqa: F401 diff --git a/vllm_ascend/distributed/mooncake/__init__.py b/vllm_ascend/quantization/compressed_tensors/__init__.py similarity index 100% rename from vllm_ascend/distributed/mooncake/__init__.py rename to vllm_ascend/quantization/compressed_tensors/__init__.py diff --git a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 00000000000..f95ff7f0215 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,252 @@ +from typing import TYPE_CHECKING, Any, Optional, cast + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import ( + QUANTIZATION_METHODS, register_quantization_config) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ + CompressedTensorsScheme +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, is_activation_quantization_format, + should_ignore_layer) + +from vllm_ascend.quantization.quant_config import (AscendLinearMethod, + AscendQuantConfig) +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] + + +def remove_quantization_method(): + if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS: + QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD) + + +remove_quantization_method() + + +@register_quantization_config(COMPRESSED_TENSORS_METHOD) +class AscendCompressedTensorsConfig(QuantizationConfig): + + def __init__( + self, + target_scheme_map: dict[str, Any], + ignore: list[str], + quant_format: str, + config: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.quant_description = config + + def get_name(self) -> str: + return "compressed-tensors" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, + Any]) -> "AscendCompressedTensorsConfig": + ignore: list[str] = cast(list[str], config.get("ignore", [])) + quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config( + config=config) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + config=config, + ) + + @classmethod + def _quantization_scheme_map_from_config( + cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A dictionary mapping target layer names to their corresponding + quantization_args for weights and input activations + """ + target_scheme_map: dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target][ + "weights"] = QuantizationArgs.model_validate( + quant_config.get("weights")) + + target_scheme_map[target]["input_activations"] = None + target_scheme_map[target]["format"] = quant_config.get( + "format") + format = target_scheme_map[target].get("format") + # If no per-config format defined, use global format in config + act_quant_format = ( + is_activation_quantization_format(format) + if format is not None else + is_activation_quantization_format(quant_format)) + input_activations = quant_config.get("input_activations") + if act_quant_format and input_activations is not None: + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.model_validate( + quant_config.get("input_activations"))) + return target_scheme_map + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + ascend_quant_config = AscendQuantConfig(self.quant_description + or {}) + quant_method = AscendLinearMethod(ascend_quant_config, prefix, + None, layer) + return quant_method + return None + + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for inference. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping, + ) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + if weight_quant is None: + logger.warning_once("Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod") + return None + + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( + weight_quant=weight_quant, + input_quant=input_quant, + ) + return scheme + + def _get_scheme_from_parts( + self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> "CompressedTensorsScheme": + act_quant_format = is_activation_quantization_format(self.quant_format) + if act_quant_format and input_quant is not None: + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return AscendW8A8LinearMethod() + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return AscendW8A8DynamicLinearMethod() + + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) + is_static = not weight_quant.dynamic and not input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + # Only symmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and is_symmetric and is_static + + def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + # Only symmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and is_symmetric and is_dynamic + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index d66963041a9..72c04e50b70 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -94,8 +94,10 @@ def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - if torch.npu.is_available(): - return ASCEND_QUANTIZATION_METHOD + if hf_quant_cfg is not None: + quant_method = hf_quant_cfg.get("quant_method", None) + if quant_method is None and torch.npu.is_available(): + return ASCEND_QUANTIZATION_METHOD return None def get_quant_method(self, layer: torch.nn.Module, @@ -113,7 +115,7 @@ def get_quant_method(self, layer: torch.nn.Module, self.packed_modules_mapping): return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) elif isinstance(layer, Attention) and \ 'fa_quant_type' in self.quant_description.keys() and \ self.quant_description['fa_quant_type'] is not None: @@ -126,13 +128,13 @@ def get_quant_method(self, layer: torch.nn.Module, self.packed_modules_mapping): return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendFusedMoEMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) elif isinstance(layer, VocabParallelEmbedding): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedEmbeddingMethod() return AscendEmbeddingMethod(self, prefix, - self.packed_modules_mapping) + self.packed_modules_mapping, layer) return None def is_layer_skipped_ascend( @@ -259,11 +261,16 @@ class AscendLinearMethod(LinearMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]) -> None: + def __init__(self, + quant_config: AscendQuantConfig, + prefix: str, + packed_modules_mapping: Dict[str, Any] | None, + layer: torch.nn.Module = None) -> None: self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "linear", - packed_modules_mapping) + prefix, + "linear", + packed_modules_mapping, + layer=layer) def create_weights( self, @@ -401,11 +408,16 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]): + def __init__(self, + quant_config: AscendQuantConfig, + prefix: str, + packed_modules_mapping: Dict[str, Any], + layer: torch.nn.Module = None): self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "moe", - packed_modules_mapping) + prefix, + "moe", + packed_modules_mapping, + layer=layer) def create_weights( self, @@ -485,7 +497,10 @@ class AscendEmbeddingMethod(AscendLinearMethod): """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]) -> None: + packed_modules_mapping: Dict[str, Any], + layer: torch.nn.Module) -> None: self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "linear", - packed_modules_mapping) + prefix, + "linear", + packed_modules_mapping, + layer=layer) diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 6d914c0dade..eaaaee86702 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -1,7 +1,10 @@ from typing import Any, Dict, Optional, Type +import torch from vllm.logger import logger +from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD + from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) @@ -60,8 +63,28 @@ def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, def get_quant_method(quant_description: Dict[str, Any], prefix: str, layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - logger.info_once("Using the vLLM Ascend Quantization now!") + packed_modules_mapping: Optional[Dict[str, Any]] = None, + layer: torch.nn.Module = None): + if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: + return get_quant_method_llmcompressor(layer) + + return get_quant_method_modelslim(quant_description, prefix, layer_type, + packed_modules_mapping) + + +def get_quant_method_llmcompressor(layer: torch.nn.Module): + logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") + if layer.scheme is None: + raise ValueError("A scheme must be defined for each layer") + return layer.scheme + + +def get_quant_method_modelslim( + quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + logger.info_once("Using the vLLM Ascend modelslim Quantization now!") if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index c4f8fb048f5..8a7bbfe7263 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -25,7 +25,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, + COMPRESSED_TENSORS_METHOD, AscendDeviceType, get_ascend_device_type, is_enable_nz) @@ -149,6 +150,10 @@ def apply( ) quant_bias = layer.quant_bias if tp_rank == 0 else None + if getattr(layer, "ascend_quant_method", + "") == COMPRESSED_TENSORS_METHOD: + quant_bias = bias + if get_ascend_device_type() == AscendDeviceType._310P: # On 300I Duo platform, we need transpose again if # using nz. This transpose can be skipped in torchair. @@ -187,6 +192,11 @@ def process_weights_after_loading(self, layer): layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + if getattr(layer, "ascend_quant_method", + "") == COMPRESSED_TENSORS_METHOD: + deq_scale = layer.input_scale.data * layer.weight_scale.data + layer.deq_scale = torch.nn.Parameter(deq_scale, + requires_grad=False) class AscendW8A8FusedMoEMethod: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index efb1d5f5c4c..0a74bcbfdcf 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -41,6 +41,7 @@ VllmConfig = None ASCEND_QUANTIZATION_METHOD = "ascend" +COMPRESSED_TENSORS_METHOD = "compressed-tensors" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] REGISTERED_ASCEND_OPS = {} diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff55d1d1897..2e7c4ea299b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -991,8 +991,8 @@ def _make_attention_mask(self, seq_lens, position, max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 2048, self.dtype, self.device) + return self.attn_mask_builder.get_splitfuse_attn_mask().to( + torch.bool) # Decode-only situation. else: return None diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e9000eae38e..df7fec602d0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -208,12 +208,18 @@ def _init_device(self): NPUPlatform.set_device(device) NPUPlatform.empty_cache() - visible_device_count = (torch.npu.device_count() - if torch.npu.is_available() else 0) - assert self.parallel_config.local_world_size <= visible_device_count, ( - f"local_world_size ({self.parallel_config.local_world_size}) must be " - f"less than or equal to the number of visible devices " - f"({visible_device_count}).") + if (self.parallel_config.data_parallel_size > 1 + and self.parallel_config.data_parallel_size_local > 0 + and self.parallel_config.distributed_executor_backend + not in ["ray", "external_launcher"] and + self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1): + visible_device_count = (torch.npu.device_count() + if torch.npu.is_available() else 0) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count}).") self.init_npu_memory = NPUPlatform.mem_get_info()[0] # Initialize the distributed environment.