diff --git a/.github/workflows/_runs-on-nv-step1.yml b/.github/workflows/_runs-on-nv-step1.yml index 845288aa4..79b7ebc84 100644 --- a/.github/workflows/_runs-on-nv-step1.yml +++ b/.github/workflows/_runs-on-nv-step1.yml @@ -77,8 +77,6 @@ jobs: && source ${ENV_PATH}/pt2.0_diopi \ && python main.py --mode gen_data" \ || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) - source ~/Aoss_env.sh - ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ elif [[ "${GETRUNNER}" == *diopi* ]];then ssh SH1424 """ set -e @@ -86,8 +84,6 @@ jobs: cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} && cd diopi_test/python && srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_V100} --time=20 --gres=gpu:1 bash -c 'python main.py --mode gen_data' \ || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) - source ~/Aoss_env.sh - ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ """ else ln -s ${GEN_DATA_PATH}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ diff --git a/.github/workflows/data-cron.yml b/.github/workflows/data-cron.yml index 61b173e40..f031b2e32 100644 --- a/.github/workflows/data-cron.yml +++ b/.github/workflows/data-cron.yml @@ -93,25 +93,10 @@ jobs: mkdir -p ${DATA_DIR}/source/${GEN_DATA} rsync -a --delete ${CLUSTER_V100}:${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ """ -# ssh ${CLUSTER_ASCEND_910B} """ -# mkdir -p ${DATA_DIR}/source/${GEN_DATA} -# rsync -a --delete ${CLUSTER_V100}:${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ -# """ - source ~/Aoss_env.sh - ads-cli --dryrun --deleteSrc cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DATA_DIR}/source/${GEN_DATA}/diopi/ /dev/null 2>&1 >/dev/null - ads-cli cp ${DATA_DIR}/source/${GEN_DATA}/diopi/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DATA_DIR}/source/${GEN_DATA}/diopi/ -# ssh ${CLUSTER_SUPA} """ -# rsync -a ${CLUSTER_V100}:${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ -# """ - - Copy-Gen-Data-Ascend-910b: - name: Copy-Gen-Data-Ascend-910b - runs-on: tps-ascend-ci-910b - needs: CheckAndRsync - if: needs.CheckAndRsync.outputs.to_gen_data == 'true' - steps: - - name: Copy Gen-Data - run: | - set -e - source ~/Aoss_env.sh - ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss-internal.cn-sh-01c.sensecoreapi-oss.cn${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ \ No newline at end of file + ssh ${CLUSTER_ASCEND_910B} """ + mkdir -p ${DATA_DIR}/source/${GEN_DATA} + rsync -a --delete ${CLUSTER_V100}:${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ + """ + # ssh ${CLUSTER_SUPA} """ + # rsync -a ${CLUSTER_V100}:${DATA_DIR}/source/${GEN_DATA}/diopi/ ${DATA_DIR}/source/${GEN_DATA}/diopi/ + # """ diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index beb3d2345..038a5a990 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -96,10 +96,8 @@ jobs: && rsync -a --delete ${GITHUB_WORKSPACE}/source/ ${CLUSTER_SUPA}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to supa" ssh ${CLUSTER_1424} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ && rsync -a --delete ${GITHUB_WORKSPACE}/source/ ${CLUSTER_1424}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to sh1424" - # ssh ${CLUSTER_ASCEND_910B} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ - # && rsync -a --delete ${GITHUB_WORKSPACE}/source/ ${CLUSTER_ASCEND_910B}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to ascend 910b" - source ~/Aoss_env.sh - ads-cli cp ${GITHUB_WORKSPACE}/source/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ + ssh ${CLUSTER_ASCEND_910B} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ + && rsync -a --delete ${GITHUB_WORKSPACE}/source/ ${CLUSTER_ASCEND_910B}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to ascend 910b" # ssh ${CLUSTER_KLX} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \ # && rsync -a --delete ${GITHUB_WORKSPACE}/source/ ${CLUSTER_KLX}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to kunlunxin" @@ -295,12 +293,6 @@ jobs: needs: [Rsync] if: ${{ contains( needs.Rsync.outputs.output, 'ASCEND' ) }} steps: - - name: COPY Source - run: | - set -e - source ~/Aoss_env.sh - ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss-internal.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ - ads-cli --dryrun --deleteSrc cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss-internal.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ /dev/null 2>&1 >/dev/null - name: build run: | set -e @@ -327,10 +319,7 @@ jobs: set -e cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} if [[ \"${{ needs.Rsync.outputs.output }}\" == *GENDATA* ]];then - # rsync -a ${CLUSTER_V100}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ - source ~/Aoss_env.sh - ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss-internal.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ - ads-cli --dryrun --deleteSrc cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss-internal.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ /dev/null 2>&1 >/dev/null + rsync -a ${CLUSTER_V100}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ else ln -s ${GEN_DATA_PATH}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ fi diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index fcaae81a2..281e74f1c 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -8629,7 +8629,7 @@ atol=1e-2, rtol=1e-2, para=dict( - actualSeqLengths=[[150,],], + actualSeqLengths=[[5,],], numHeads=[32,], numKeyValueHeads=[32,], dim=[128,], @@ -8652,6 +8652,12 @@ "shape": ((1026, 4096),), "dtype": [np.float16,], }, + { + "ins": ["attenMask"], + "value": [[False, False, False, False, False]], + "dtype": [np.bool_,], + "gen_policy": "gen_tensor_by_value", + }, { "ins": ["blockTable"], "value": ([[0, 1],],), diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index 02913403d..3d9a156ea 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -661,6 +661,7 @@ def paged_attention( query, key, value, + attenMask, actualSeqLengths, numHeads, numKeyValueHeads, diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 9d848b5d3..c47e8c3d2 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -6057,6 +6057,7 @@ def paged_attention( query, key, value, + attenMask, actualSeqLengths, numHeads, numKeyValueHeads, @@ -6075,6 +6076,7 @@ def paged_attention( query, key, value, + attenMask, actualSeqLengths, numHeads, numKeyValueHeads, diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index 707b6373a..128a39885 100755 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -458,3 +458,7 @@ - diopiMaxPool2dBackward: tensor_dtype: indices: (int64)->int32 + +- diopiStd: + dtype: (float64)->float32 + layout: ND \ No newline at end of file diff --git a/impl/ascend/functions/reduce.cpp b/impl/ascend/functions/reduce.cpp index 916b0427b..5aedacd73 100755 --- a/impl/ascend/functions/reduce.cpp +++ b/impl/ascend/functions/reduce.cpp @@ -3,6 +3,7 @@ * @author DeepLink * @copyright (c) 2023, DeepLink. */ +#include #include "../aclnn/acl_scalar.hpp" #include "../aclnn/adaptor.hpp" @@ -52,6 +53,31 @@ diopiError_t diopiMean(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC return diopiSuccess; } +diopiError_t diopiStd(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t dim, bool unbiased) { + AscendTensor inputAt(input); + AscendTensor outAt(out); + + bool keepdim = false; + if (inputAt.dim() == outAt.dim()) { + keepdim = true; + } + + int64_t correction = 0; + if (unbiased) { + correction = 1; + } + + if (dim.data == nullptr || dim.len == 0) { + std::vector allDim(inputAt.dim()); + std::iota(allDim.begin(), allDim.end(), 0); + diopiSize_t rDim = vectorToDiopiSize(allDim); + DIOPI_ASCEND_CALL_ACLNN(aclnnStd, ctx, input, rDim, correction, keepdim, out); + } else { + DIOPI_ASCEND_CALL_ACLNN(aclnnStd, ctx, input, dim, correction, keepdim, out); + } + return diopiSuccess; +} + diopiError_t diopiAll(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const int64_t* dim) { diopiSize_t inputSize, outSize; diopiGetTensorShape(input, &inputSize); diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 80fb671ea..1d281d051 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -223,6 +223,7 @@ ascend: - diopiSqrt - diopiSqrtInp - diopiStack +- diopiStd - diopiTanh - diopiTanhBackward - diopiTanhInp diff --git a/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp b/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp index 903f21a02..680c75a3e 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp @@ -39,9 +39,9 @@ diopiError_t diopiTokenAttentionInference(diopiContextHandle_t ctx, diopiTensorH } diopiError_t diopiPagedAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, - diopiConstTensorHandle_t v, diopiSize_t actualSeqLengths, int64_t numHeads, int64_t numKeyValueHeads, int64_t dim, - diopiConstTensorHandle_t blockTable, int64_t blockSize) { - BEGIN_CALL_ACL_OP(out, q, k, v, blockTable); + diopiConstTensorHandle_t v, diopiConstTensorHandle_t attenMask, diopiSize_t actualSeqLengths, int64_t numHeads, + int64_t numKeyValueHeads, int64_t dim, diopiConstTensorHandle_t blockTable, int64_t blockSize) { + BEGIN_CALL_ACL_OP(out, q, k, v, blockTable, attenMask); at::IntArrayRef actSeqLen(actualSeqLengths.data, actualSeqLengths.len); TORCH_CHECK(actualSeqLengths.len == qAt.size(0), "The size of the first dimension of q must be equal to the length of actualSeqLengths!"); TORCH_CHECK(actualSeqLengths.len == outAt.size(0), "The size of the first dimension of out must be equal to the length of actualSeqLengths!"); @@ -59,13 +59,13 @@ diopiError_t diopiPagedAttention(diopiContextHandle_t ctx, diopiTensorHandle_t o at::TensorList keyTensors = kAt; at::TensorList valueTensors = vAt; int64_t innerPrecise = 1; - at::Tensor paddingMask, attenMask, dequantScale1, quantScale1, dequantScale2, quantScale2, quantOffset2, antiquantScale, antiquantOffset, kvPaddingSize; + at::Tensor paddingMask, dequantScale1, quantScale1, dequantScale2, quantScale2, quantOffset2, antiquantScale, antiquantOffset, kvPaddingSize; EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnIncreFlashAttentionV4, qAt, keyTensors, valueTensors, paddingMask, - attenMask, + attenMaskAt, actSeqLen, dequantScale1, quantScale1, diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index 87bed2f35..80d2e5cda 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -638,6 +638,7 @@ DIOPI_API diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ct * @param[in] q Tensor representing the query matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim]. * @param[in] k Tensor representing the key matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] * @param[in] v Tensor representing the value matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] attenMask Tensor representing the mask matrix in the attention mechanism. shape = [1, single_seq_len] * @param[in] actual_seq_lengths Tensor representing the sequence length in each batch. shape = [batch_size] * @param[in] num_heads head number of q and out. * @param[in] num_kv_heads head number of key and value. @@ -646,8 +647,8 @@ DIOPI_API diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ct * @param[in] block_size Size of eatch block unit. */ DIOPI_API diopiError_t diopiPagedAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, - diopiConstTensorHandle_t v, diopiSize_t actual_seq_lengths, int64_t num_heads, int64_t num_kv_heads, int64_t dim, - diopiConstTensorHandle_t block_table, int64_t block_size); + diopiConstTensorHandle_t v, diopiConstTensorHandle_t attenMask, diopiSize_t actual_seq_lengths, int64_t num_heads, + int64_t num_kv_heads, int64_t dim, diopiConstTensorHandle_t block_table, int64_t block_size); /** * @brief The no pad implementation of * \text{context_attention_out}(\mathrm{q},\mathrm{k},\mathrm{v})=\text{softmax}(\frac{\mathrm{qk}^\mathrm{T}}{\sqrt{\mathrm{d_k}}})\mathrm{v}. For details,