Skip to content

Commit 1023d93

Browse files
committed
Update on "introduce triton sdpa kernel to cuda backend"
**Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new file `__init__.py` to `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/) [ghstack-poisoned]
2 parents 1ce8bc6 + abbf37c commit 1023d93

File tree

10 files changed

+209
-114
lines changed

10 files changed

+209
-114
lines changed

CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ if(EXECUTORCH_BUILD_PYBIND)
801801
torch
802802
)
803803

804+
# RPATH for _portable_lib.so
805+
set(_portable_lib_rpath "$ORIGIN/../../../torch/lib")
806+
804807
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
805808
# Always use static linking for pybindings to avoid runtime symbol
806809
# resolution issues
@@ -835,6 +838,7 @@ if(EXECUTORCH_BUILD_PYBIND)
835838

836839
if(EXECUTORCH_BUILD_QNN)
837840
list(APPEND _dep_libs qnn_executorch_backend)
841+
string(APPEND _portable_lib_rpath ":$ORIGIN/../../backends/qualcomm")
838842
endif()
839843

840844
if(EXECUTORCH_BUILD_ENN)
@@ -886,19 +890,20 @@ if(EXECUTORCH_BUILD_PYBIND)
886890
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
887891
target_link_libraries(portable_lib PRIVATE ${_dep_libs})
888892

889-
# Set RPATH to find PyTorch libraries relative to the installation location
890-
# This goes from executorch/extension/pybindings up to site-packages, then to
891-
# torch/lib. Don't do this to APPLE, as it will error out on the following
892-
# error:
893+
# Set RPATH to find PyTorch and backend libraries relative to the installation
894+
# location. This goes from executorch/extension/pybindings up to
895+
# site-packages, then to torch/lib. If QNN is enabled, also add
896+
# backends/qualcomm/. Don't do this to APPLE, as it will error out on the
897+
# following error:
893898
#
894899
if(APPLE)
895900
# Skip setting @loader_path for APPLE, since it causes error like ld:
896901
# duplicate LC_RPATH '@loader_path' in '<site-packages>/torch/lib/
897902
# libtorch_cpu.dylib'
898903
else()
899904
set_target_properties(
900-
portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib"
901-
INSTALL_RPATH "$ORIGIN/../../../torch/lib"
905+
portable_lib PROPERTIES BUILD_RPATH "${_portable_lib_rpath}"
906+
INSTALL_RPATH "${_portable_lib_rpath}"
902907
)
903908
endif()
904909

backends/cuda/triton/kernels/sdpa.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Optimized Triton SDPA Kernel for ExecuTorch CUDA Backend.
8+
Triton SDPA Kernel for ExecuTorch CUDA Backend.
99
1010
This module provides a Triton-optimized implementation of scaled dot-product attention
11-
that can replace the default ATen SDPA operator during graph transformation.
11+
that can replace the default ATen/Edge SDPA operator during graph transformation to allow
12+
us export the model without decomposing the SDPA operator under libtorch free environment
13+
and have better performance.
1214
"""
1315

1416
import math
@@ -221,22 +223,17 @@ def sdpa(
221223
"""
222224
Triton fused Scaled Dot-Product Attention with support for different sequence lengths.
223225
224-
Supports different sequence lengths for query and key/value:
225-
- Query: [B, H, L_q, D]
226-
- Key: [B, H, L_kv, D]
227-
- Value: [B, H, L_kv, D]
228-
- Output: [B, H, L_q, D] (matches query shape)
229226
Args:
230-
query: Query tensor [B, H, L_q, D]
231-
key: Key tensor [B, H, L_kv, D]
232-
value: Value tensor [B, H, L_kv, D]
233-
attn_mask: Optional attention mask [B, H, L_q, L_kv] or broadcastable shape
234-
dropout_p: must be 0.0 (not supported)
227+
query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16
228+
key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16
229+
value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16
230+
attn_mask: Optional attention mask [B, H, L_q, L_kv] or broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv])
231+
dropout_p: must be 0.0 (others are not supported)
235232
is_causal: whether to apply causal masking
236-
scale: attention scale (default: 1/sqrt(d))
237-
enable_gqa: must be False (not supported)
233+
scale: attention scale (default: 1/sqrt(D))
234+
enable_gqa: must be False (True is not supported)
238235
Returns:
239-
Output tensor [B, H, L_q, D]
236+
Output tensor [B, H, L_q, D] with dtype torch.bfloat16
240237
"""
241238
# Validate inputs
242239
if not (query.is_cuda and key.is_cuda and value.is_cuda):

backends/qualcomm/CMakeLists.txt

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,47 @@ get_filename_component(
2323
_common_include_directories "${EXECUTORCH_SOURCE_DIR}/.." ABSOLUTE
2424
)
2525

26+
# We only download QNN SDK when we build pip wheel for ExecuTorch. Please don't
27+
# change this code unless you know what you are doing.
28+
if(EXECUTORCH_BUILD_WHEEL_DO_NOT_USE)
29+
set(_qnn_default_sdk_dir "${CMAKE_CURRENT_BINARY_DIR}/sdk/qnn")
30+
31+
if(EXISTS "${_qnn_default_sdk_dir}" AND EXISTS "${_qnn_default_sdk_dir}/lib")
32+
message(STATUS "Found cached Qualcomm SDK at ${_qnn_default_sdk_dir}")
33+
set(QNN_SDK_ROOT
34+
${_qnn_default_sdk_dir}
35+
CACHE PATH "Qualcomm SDK root directory" FORCE
36+
)
37+
else()
38+
message(STATUS "Downloading Qualcomm SDK")
39+
execute_process(
40+
COMMAND
41+
${PYTHON_EXECUTABLE}
42+
${EXECUTORCH_SOURCE_DIR}/backends/qualcomm/scripts/download_qnn_sdk.py
43+
--dst-folder ${_qnn_default_sdk_dir} --print-sdk-path
44+
WORKING_DIRECTORY ${EXECUTORCH_SOURCE_DIR}
45+
RESULT_VARIABLE _qnn_sdk_download_result
46+
OUTPUT_VARIABLE _qnn_sdk_download_output
47+
ERROR_VARIABLE _qnn_sdk_download_error
48+
OUTPUT_STRIP_TRAILING_WHITESPACE
49+
)
50+
if(NOT _qnn_sdk_download_result EQUAL 0 OR _qnn_sdk_download_output
51+
STREQUAL ""
52+
)
53+
message(
54+
FATAL_ERROR
55+
"Failed to download Qualcomm SDK. stdout: ${_qnn_sdk_download_output}\n"
56+
"stderr: ${_qnn_sdk_download_error}"
57+
)
58+
endif()
59+
set(QNN_SDK_ROOT
60+
${_qnn_sdk_download_output}
61+
CACHE PATH "Qualcomm SDK root directory" FORCE
62+
)
63+
endif()
64+
set(ENV{QNN_SDK_ROOT} ${QNN_SDK_ROOT})
65+
endif()
66+
2667
if(NOT DEFINED QNN_SDK_ROOT)
2768
message(
2869
FATAL_ERROR
@@ -214,7 +255,9 @@ add_subdirectory(
214255
install(
215256
TARGETS qnn_executorch_backend
216257
EXPORT ExecuTorchTargets
217-
DESTINATION ${CMAKE_INSTALL_LIBDIR}
258+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
259+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
260+
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
218261
)
219262

220263
# QNN pybind
@@ -275,4 +318,12 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
275318
${QNN_EXECUTORCH_ROOT_DIR}/aot/python
276319
${CMAKE_CURRENT_BINARY_DIR}/qnn_executorch/python
277320
)
321+
322+
install(
323+
TARGETS PyQnnManagerAdaptor PyQnnWrapperAdaptor
324+
LIBRARY
325+
DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python
326+
RUNTIME
327+
DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python
328+
)
278329
endif()

backends/qualcomm/scripts/download_qnn_sdk.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Add these imports for additional logging
1+
import argparse
22
import ctypes
33
import logging
44
import os
@@ -592,3 +592,46 @@ def install_qnn_sdk() -> bool:
592592

593593
# libc++ and QNN SDK setup
594594
return _ensure_libcxx_stack() and _ensure_qnn_sdk_lib()
595+
596+
597+
def main(argv: Optional[List[str]] = None) -> int:
598+
parser = argparse.ArgumentParser(
599+
description="Helper utility for Qualcomm SDK staging."
600+
)
601+
parser.add_argument(
602+
"--dst-folder",
603+
type=pathlib.Path,
604+
default=SDK_DIR,
605+
help="Destination directory for the Qualcomm SDK.",
606+
)
607+
parser.add_argument(
608+
"--print-sdk-path",
609+
action="store_true",
610+
help="Print the resolved Qualcomm SDK path to stdout.",
611+
)
612+
parser.add_argument(
613+
"--install-sdk",
614+
action="store_true",
615+
help="Ensure the SDK and runtime libraries are staged and loaded.",
616+
)
617+
args = parser.parse_args(argv)
618+
619+
logging.basicConfig(level=logging.INFO)
620+
621+
sdk_path: Optional[pathlib.Path]
622+
if args.install_sdk:
623+
if not install_qnn_sdk():
624+
return 1
625+
sdk_path = pathlib.Path(os.environ.get("QNN_SDK_ROOT", args.dst_folder))
626+
else:
627+
sdk_path = _download_qnn_sdk(dst_folder=args.dst_folder)
628+
if sdk_path is None:
629+
return 1
630+
631+
if args.print_sdk_path and sdk_path is not None:
632+
print(sdk_path)
633+
return 0
634+
635+
636+
if __name__ == "__main__":
637+
raise SystemExit(main())

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,22 @@ Error Runner<T>::generate_from_prompt_or_file(
378378
stats_.inference_start_ms = time_in_ms();
379379

380380
int32_t seq_len = config.seq_len;
381-
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
381+
if (seq_len > context_len_) {
382+
ET_LOG(
383+
Info,
384+
"Warning: Requested seq_len (%d) exceeds compiled max_seq_len (%d). Clamping to %d.",
385+
seq_len,
386+
context_len_,
387+
context_len_);
388+
seq_len = context_len_;
389+
} else if (seq_len <= 0) {
390+
ET_LOG(
391+
Info,
392+
"Warning: Invalid seq_len (%d). Using compiled max_seq_len (%d).",
393+
seq_len,
394+
context_len_);
395+
seq_len = context_len_;
396+
}
382397
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
383398

384399
// encode the (string) prompt into tokens sequence

examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,30 @@ Result<int64_t> TokenGenerator<T>::generate(
323323
break;
324324
}
325325
}
326+
327+
// Check if generation was truncated due to seq_len limit (no EOS token)
328+
if (eos_ids_->count(cur_token) == 0 && pos >= seq_len - 1) {
329+
printf("\n");
330+
ET_LOG(
331+
Info,
332+
"Warning: Generation stopped at seq_len limit (%d) without reaching EOS token. Response may be incomplete.",
333+
seq_len);
334+
if (seq_len >= metadata_.context_len) {
335+
ET_LOG(
336+
Info,
337+
"- seq_len (%d) already equals compiled max_seq_len (%d). Consider recompiling with larger --max_seq_len.",
338+
seq_len,
339+
metadata_.context_len);
340+
} else {
341+
ET_LOG(
342+
Info,
343+
"- seq_len (%d) is less than compiled max_seq_len (%d). Consider increasing --seq_len (up to %d).",
344+
seq_len,
345+
metadata_.context_len,
346+
metadata_.context_len);
347+
}
348+
}
349+
326350
return pos - start_pos;
327351
}
328352
// Explicit instantiations

exir/serde/export_serialize.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,17 +2143,23 @@ def deserialize_meta_func(serialized_target: str):
21432143
def import_nn_module_stack(key, path, ty):
21442144
return key, (path, ty)
21452145

2146-
# Helper function that splits strings by commas except for those
2147-
# encapsulated by parens, which are valid traces.
2148-
# TODO: Currently this is needed due to indexing Sequential
2149-
# layers introducing names in the form "layer.slice(1, None, None)".
2150-
# If that naming is improved, this fancier splitting can probably be
2151-
# reverted to a simple split by comma.
2146+
# Helper function to split string by commas, accounting for nested parentheses/brackets
21522147
def metadata_split(metadata):
2153-
# Remove the parentheses and commas inside them
2154-
metadata = re.sub(r"\(.*?\)", "", metadata)
2155-
# Split the string by comma, except for those inside parentheses
2156-
return re.split(r"(?<!\()\s*,\s*(?!\()", metadata)
2148+
out = []
2149+
start, depth = 0, 0
2150+
for position, char in enumerate(metadata):
2151+
if char in "[(":
2152+
depth += 1
2153+
elif char in ")]":
2154+
depth -= 1
2155+
if depth < 0:
2156+
raise ValueError(f"Mismatched brackets in metadata: {metadata}")
2157+
elif char == "," and depth == 0:
2158+
out.append(metadata[start:position].strip())
2159+
start = position + 1
2160+
out.append(metadata[start:].strip())
2161+
assert len(out) == 3
2162+
return out
21572163

21582164
nn_module_stack = dict(
21592165
import_nn_module_stack(*metadata_split(item))

0 commit comments

Comments
 (0)