Skip to content

Commit 8bdb6b5

Browse files
committed
Update on "introduce triton sdpa kernel to cuda backend"
**Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new file `__init__.py` to `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/) [ghstack-poisoned]
2 parents 4a25f0e + 9eb007e commit 8bdb6b5

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

backends/aoti/common_shims.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,14 @@ AOTI_SHIM_EXPORT AOTITorchError
6060
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
6161

6262
// Utility functions for device and layout information
63-
6463
AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
6564
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
66-
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float16();
6765
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
6866
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
6967
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
7068
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
7169
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
7270
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
73-
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
7471

7572
// Dtype utility function needed by Metal backend
7673
AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);

0 commit comments

Comments
 (0)