Skip to content

Commit 1ce8bc6

Browse files
committed
Update on "introduce triton sdpa kernel to cuda backend"
**Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new file `__init__.py` to `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/) [ghstack-poisoned]
2 parents 7a573e2 + b8cdeb3 commit 1ce8bc6

40 files changed

+765
-517
lines changed

.ci/scripts/test_backend.sh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,25 @@ if [[ "$FLOW" == *vulkan* ]]; then
5757
fi
5858

5959
if [[ "$FLOW" == *arm* ]]; then
60+
6061
# Setup ARM deps.
61-
.ci/scripts/setup-arm-baremetal-tools.sh
62+
if [[ "$FLOW" == *vgf* ]]; then
63+
.ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip
64+
else
65+
.ci/scripts/setup-arm-baremetal-tools.sh
66+
fi
6267
source examples/arm/ethos-u-scratch/setup_path.sh
6368

6469
if [[ "$FLOW" == *ethos_u* ]]; then
6570
# Prepare a test runner binary that can run on the Corstone-3x0 FVPs
6671
backends/arm/scripts/build_executorch.sh
6772
backends/arm/test/setup_testing.sh
6873
fi
74+
75+
if [[ "$FLOW" == *vgf* ]]; then
76+
# Prepare a test runner binary for VKML runtime
77+
backends/arm/test/setup_testing_vkml.sh
78+
fi
6979
fi
7080

7181
if [[ $IS_MACOS -eq 1 ]]; then

.ci/scripts/test_model_e2e.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ case "$MODEL_NAME" in
181181
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR"
182182
;;
183183
whisper-*)
184-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR --model_name ${MODEL_NAME}"
184+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR"
185185
;;
186186
gemma3)
187187
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --image_path $IMAGE_PATH"

.github/workflows/test-backend-arm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
uses: ./.github/workflows/_test_backend.yml
2727
with:
2828
backend: arm
29-
flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85"]'
29+
flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85", "arm_vgf_fp", "arm_vgf_int"]'
3030
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
3131
timeout: 120
3232
run-linux: true

.github/workflows/trunk.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,40 @@ jobs:
317317
# Test test_arm_baremetal.sh with test
318318
backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}"
319319
320+
test-arm-backend-vkml:
321+
name: test-arm-backend-vkml
322+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
323+
permissions:
324+
id-token: write
325+
contents: read
326+
strategy:
327+
matrix:
328+
include:
329+
- test_arm_baremetal: test_pytest_ops_vkml
330+
fail-fast: false
331+
with:
332+
runner: linux.2xlarge.memory
333+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
334+
submodules: 'recursive'
335+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
336+
timeout: 120
337+
script: |
338+
# The generic Linux job chooses to use base env, not the one setup by the image
339+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
340+
conda activate "${CONDA_ENV}"
341+
source .ci/scripts/utils.sh
342+
install_executorch "--use-pt-pinned-commit"
343+
344+
.ci/scripts/setup-arm-baremetal-tools.sh --disable-ethos-u-deps --enable-mlsdk-deps --install-mlsdk-deps-with-pip
345+
346+
# Increase number of files user can monitor to bypass buck failures.
347+
# Hopefully this is high enough for this setup.
348+
sudo sysctl fs.inotify.max_user_watches=1048576 # 1024 * 1024
349+
350+
ARM_TEST=${{ matrix.test_arm_baremetal }}
351+
352+
backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}"
353+
320354
test-arm-cortex-m-size-test:
321355
name: test-arm-cortex-m-size-test
322356
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

0 commit comments

Comments
 (0)