Skip to content

Commit abbf37c

Browse files
committed
Update base for Update on "introduce triton sdpa kernel to cuda backend"
**Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new file `__init__.py` to `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/) [ghstack-poisoned]
2 parents b8cdeb3 + bee30ac commit abbf37c

File tree

9 files changed

+197
-99
lines changed

9 files changed

+197
-99
lines changed

CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ if(EXECUTORCH_BUILD_PYBIND)
801801
torch
802802
)
803803

804+
# RPATH for _portable_lib.so
805+
set(_portable_lib_rpath "$ORIGIN/../../../torch/lib")
806+
804807
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
805808
# Always use static linking for pybindings to avoid runtime symbol
806809
# resolution issues
@@ -835,6 +838,7 @@ if(EXECUTORCH_BUILD_PYBIND)
835838

836839
if(EXECUTORCH_BUILD_QNN)
837840
list(APPEND _dep_libs qnn_executorch_backend)
841+
string(APPEND _portable_lib_rpath ":$ORIGIN/../../backends/qualcomm")
838842
endif()
839843

840844
if(EXECUTORCH_BUILD_ENN)
@@ -886,19 +890,20 @@ if(EXECUTORCH_BUILD_PYBIND)
886890
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
887891
target_link_libraries(portable_lib PRIVATE ${_dep_libs})
888892

889-
# Set RPATH to find PyTorch libraries relative to the installation location
890-
# This goes from executorch/extension/pybindings up to site-packages, then to
891-
# torch/lib. Don't do this to APPLE, as it will error out on the following
892-
# error:
893+
# Set RPATH to find PyTorch and backend libraries relative to the installation
894+
# location. This goes from executorch/extension/pybindings up to
895+
# site-packages, then to torch/lib. If QNN is enabled, also add
896+
# backends/qualcomm/. Don't do this to APPLE, as it will error out on the
897+
# following error:
893898
#
894899
if(APPLE)
895900
# Skip setting @loader_path for APPLE, since it causes error like ld:
896901
# duplicate LC_RPATH '@loader_path' in '<site-packages>/torch/lib/
897902
# libtorch_cpu.dylib'
898903
else()
899904
set_target_properties(
900-
portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib"
901-
INSTALL_RPATH "$ORIGIN/../../../torch/lib"
905+
portable_lib PROPERTIES BUILD_RPATH "${_portable_lib_rpath}"
906+
INSTALL_RPATH "${_portable_lib_rpath}"
902907
)
903908
endif()
904909

backends/qualcomm/CMakeLists.txt

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,47 @@ get_filename_component(
2323
_common_include_directories "${EXECUTORCH_SOURCE_DIR}/.." ABSOLUTE
2424
)
2525

26+
# We only download QNN SDK when we build pip wheel for ExecuTorch. Please don't
27+
# change this code unless you know what you are doing.
28+
if(EXECUTORCH_BUILD_WHEEL_DO_NOT_USE)
29+
set(_qnn_default_sdk_dir "${CMAKE_CURRENT_BINARY_DIR}/sdk/qnn")
30+
31+
if(EXISTS "${_qnn_default_sdk_dir}" AND EXISTS "${_qnn_default_sdk_dir}/lib")
32+
message(STATUS "Found cached Qualcomm SDK at ${_qnn_default_sdk_dir}")
33+
set(QNN_SDK_ROOT
34+
${_qnn_default_sdk_dir}
35+
CACHE PATH "Qualcomm SDK root directory" FORCE
36+
)
37+
else()
38+
message(STATUS "Downloading Qualcomm SDK")
39+
execute_process(
40+
COMMAND
41+
${PYTHON_EXECUTABLE}
42+
${EXECUTORCH_SOURCE_DIR}/backends/qualcomm/scripts/download_qnn_sdk.py
43+
--dst-folder ${_qnn_default_sdk_dir} --print-sdk-path
44+
WORKING_DIRECTORY ${EXECUTORCH_SOURCE_DIR}
45+
RESULT_VARIABLE _qnn_sdk_download_result
46+
OUTPUT_VARIABLE _qnn_sdk_download_output
47+
ERROR_VARIABLE _qnn_sdk_download_error
48+
OUTPUT_STRIP_TRAILING_WHITESPACE
49+
)
50+
if(NOT _qnn_sdk_download_result EQUAL 0 OR _qnn_sdk_download_output
51+
STREQUAL ""
52+
)
53+
message(
54+
FATAL_ERROR
55+
"Failed to download Qualcomm SDK. stdout: ${_qnn_sdk_download_output}\n"
56+
"stderr: ${_qnn_sdk_download_error}"
57+
)
58+
endif()
59+
set(QNN_SDK_ROOT
60+
${_qnn_sdk_download_output}
61+
CACHE PATH "Qualcomm SDK root directory" FORCE
62+
)
63+
endif()
64+
set(ENV{QNN_SDK_ROOT} ${QNN_SDK_ROOT})
65+
endif()
66+
2667
if(NOT DEFINED QNN_SDK_ROOT)
2768
message(
2869
FATAL_ERROR
@@ -214,7 +255,9 @@ add_subdirectory(
214255
install(
215256
TARGETS qnn_executorch_backend
216257
EXPORT ExecuTorchTargets
217-
DESTINATION ${CMAKE_INSTALL_LIBDIR}
258+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
259+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
260+
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm
218261
)
219262

220263
# QNN pybind
@@ -275,4 +318,12 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
275318
${QNN_EXECUTORCH_ROOT_DIR}/aot/python
276319
${CMAKE_CURRENT_BINARY_DIR}/qnn_executorch/python
277320
)
321+
322+
install(
323+
TARGETS PyQnnManagerAdaptor PyQnnWrapperAdaptor
324+
LIBRARY
325+
DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python
326+
RUNTIME
327+
DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python
328+
)
278329
endif()

backends/qualcomm/scripts/download_qnn_sdk.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Add these imports for additional logging
1+
import argparse
22
import ctypes
33
import logging
44
import os
@@ -592,3 +592,46 @@ def install_qnn_sdk() -> bool:
592592

593593
# libc++ and QNN SDK setup
594594
return _ensure_libcxx_stack() and _ensure_qnn_sdk_lib()
595+
596+
597+
def main(argv: Optional[List[str]] = None) -> int:
598+
parser = argparse.ArgumentParser(
599+
description="Helper utility for Qualcomm SDK staging."
600+
)
601+
parser.add_argument(
602+
"--dst-folder",
603+
type=pathlib.Path,
604+
default=SDK_DIR,
605+
help="Destination directory for the Qualcomm SDK.",
606+
)
607+
parser.add_argument(
608+
"--print-sdk-path",
609+
action="store_true",
610+
help="Print the resolved Qualcomm SDK path to stdout.",
611+
)
612+
parser.add_argument(
613+
"--install-sdk",
614+
action="store_true",
615+
help="Ensure the SDK and runtime libraries are staged and loaded.",
616+
)
617+
args = parser.parse_args(argv)
618+
619+
logging.basicConfig(level=logging.INFO)
620+
621+
sdk_path: Optional[pathlib.Path]
622+
if args.install_sdk:
623+
if not install_qnn_sdk():
624+
return 1
625+
sdk_path = pathlib.Path(os.environ.get("QNN_SDK_ROOT", args.dst_folder))
626+
else:
627+
sdk_path = _download_qnn_sdk(dst_folder=args.dst_folder)
628+
if sdk_path is None:
629+
return 1
630+
631+
if args.print_sdk_path and sdk_path is not None:
632+
print(sdk_path)
633+
return 0
634+
635+
636+
if __name__ == "__main__":
637+
raise SystemExit(main())

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,22 @@ Error Runner<T>::generate_from_prompt_or_file(
378378
stats_.inference_start_ms = time_in_ms();
379379

380380
int32_t seq_len = config.seq_len;
381-
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
381+
if (seq_len > context_len_) {
382+
ET_LOG(
383+
Info,
384+
"Warning: Requested seq_len (%d) exceeds compiled max_seq_len (%d). Clamping to %d.",
385+
seq_len,
386+
context_len_,
387+
context_len_);
388+
seq_len = context_len_;
389+
} else if (seq_len <= 0) {
390+
ET_LOG(
391+
Info,
392+
"Warning: Invalid seq_len (%d). Using compiled max_seq_len (%d).",
393+
seq_len,
394+
context_len_);
395+
seq_len = context_len_;
396+
}
382397
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
383398

384399
// encode the (string) prompt into tokens sequence

examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,30 @@ Result<int64_t> TokenGenerator<T>::generate(
323323
break;
324324
}
325325
}
326+
327+
// Check if generation was truncated due to seq_len limit (no EOS token)
328+
if (eos_ids_->count(cur_token) == 0 && pos >= seq_len - 1) {
329+
printf("\n");
330+
ET_LOG(
331+
Info,
332+
"Warning: Generation stopped at seq_len limit (%d) without reaching EOS token. Response may be incomplete.",
333+
seq_len);
334+
if (seq_len >= metadata_.context_len) {
335+
ET_LOG(
336+
Info,
337+
"- seq_len (%d) already equals compiled max_seq_len (%d). Consider recompiling with larger --max_seq_len.",
338+
seq_len,
339+
metadata_.context_len);
340+
} else {
341+
ET_LOG(
342+
Info,
343+
"- seq_len (%d) is less than compiled max_seq_len (%d). Consider increasing --seq_len (up to %d).",
344+
seq_len,
345+
metadata_.context_len,
346+
metadata_.context_len);
347+
}
348+
}
349+
326350
return pos - start_pos;
327351
}
328352
// Explicit instantiations

exir/serde/export_serialize.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,17 +2143,23 @@ def deserialize_meta_func(serialized_target: str):
21432143
def import_nn_module_stack(key, path, ty):
21442144
return key, (path, ty)
21452145

2146-
# Helper function that splits strings by commas except for those
2147-
# encapsulated by parens, which are valid traces.
2148-
# TODO: Currently this is needed due to indexing Sequential
2149-
# layers introducing names in the form "layer.slice(1, None, None)".
2150-
# If that naming is improved, this fancier splitting can probably be
2151-
# reverted to a simple split by comma.
2146+
# Helper function to split string by commas, accounting for nested parentheses/brackets
21522147
def metadata_split(metadata):
2153-
# Remove the parentheses and commas inside them
2154-
metadata = re.sub(r"\(.*?\)", "", metadata)
2155-
# Split the string by comma, except for those inside parentheses
2156-
return re.split(r"(?<!\()\s*,\s*(?!\()", metadata)
2148+
out = []
2149+
start, depth = 0, 0
2150+
for position, char in enumerate(metadata):
2151+
if char in "[(":
2152+
depth += 1
2153+
elif char in ")]":
2154+
depth -= 1
2155+
if depth < 0:
2156+
raise ValueError(f"Mismatched brackets in metadata: {metadata}")
2157+
elif char == "," and depth == 0:
2158+
out.append(metadata[start:position].strip())
2159+
start = position + 1
2160+
out.append(metadata[start:].strip())
2161+
assert len(out) == 3
2162+
return out
21572163

21582164
nn_module_stack = dict(
21592165
import_nn_module_stack(*metadata_split(item))

setup.py

Lines changed: 24 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757
import site
5858
import subprocess
5959
import sys
60-
import sysconfig
61-
import tempfile
6260

6361
from distutils import log # type: ignore[import-not-found]
6462
from distutils.sysconfig import get_python_lib # type: ignore[import-not-found]
@@ -463,84 +461,6 @@ def run(self):
463461
if self._ran_build:
464462
return
465463

466-
try:
467-
# Following code is for building the Qualcomm backend.
468-
from backends.qualcomm.scripts.download_qnn_sdk import (
469-
_download_qnn_sdk,
470-
is_linux_x86,
471-
)
472-
473-
if is_linux_x86():
474-
os.environ["EXECUTORCH_BUILDING_WHEEL"] = "1"
475-
476-
with tempfile.TemporaryDirectory() as tmpdir:
477-
tmp_path = Path(tmpdir)
478-
sdk_path = _download_qnn_sdk(dst_folder=tmp_path)
479-
480-
if not sdk_path:
481-
raise RuntimeError(
482-
"Qualcomm SDK not found, cannot build backend"
483-
)
484-
485-
# Determine paths
486-
prj_root = Path(__file__).parent.resolve()
487-
build_sh = prj_root / "backends/qualcomm/scripts/build.sh"
488-
build_root = prj_root / "build-x86"
489-
490-
if not build_sh.exists():
491-
raise FileNotFoundError(f"{build_sh} not found")
492-
493-
# Run build.sh with SDK path exported
494-
env = dict(**os.environ)
495-
env["QNN_SDK_ROOT"] = str(sdk_path)
496-
subprocess.check_call(
497-
[
498-
str(build_sh),
499-
"--skip_linux_android",
500-
"--skip_linux_embedded",
501-
],
502-
env=env,
503-
)
504-
505-
# Copy the main .so into the wheel package
506-
so_src = (
507-
build_root / "backends/qualcomm/libqnn_executorch_backend.so"
508-
)
509-
so_dst = Path(
510-
self.get_ext_fullpath(
511-
"executorch.backends.qualcomm.qnn_backend"
512-
)
513-
)
514-
self.mkpath(str(so_dst.parent)) # ensure destination exists
515-
self.copy_file(str(so_src), str(so_dst))
516-
logging.info(f"Copied Qualcomm backend: {so_src} -> {so_dst}")
517-
518-
# Copy Python adaptor .so files
519-
ext_suffix = sysconfig.get_config_var("EXT_SUFFIX")
520-
521-
so_files = [
522-
(
523-
"executorch.backends.qualcomm.python.PyQnnManagerAdaptor",
524-
prj_root
525-
/ f"backends/qualcomm/python/PyQnnManagerAdaptor{ext_suffix}",
526-
),
527-
(
528-
"executorch.backends.qualcomm.python.PyQnnWrapperAdaptor",
529-
prj_root
530-
/ f"backends/qualcomm/python/PyQnnWrapperAdaptor{ext_suffix}",
531-
),
532-
]
533-
534-
for module_name, so_src in so_files:
535-
so_dst = Path(self.get_ext_fullpath(module_name))
536-
self.mkpath(str(so_dst.parent))
537-
self.copy_file(str(so_src), str(so_dst))
538-
logging.info(f"Copied Qualcomm backend: {so_src} -> {so_dst}")
539-
540-
except ImportError:
541-
logging.error("Fail to build Qualcomm backend")
542-
logging.exception("Import error")
543-
544464
if self.editable_mode:
545465
self._ran_build = True
546466
self.run_command("build")
@@ -837,6 +757,11 @@ def run(self): # noqa C901
837757
cmake_build_args += ["--target", "custom_ops_aot_lib"]
838758
cmake_build_args += ["--target", "quantized_ops_aot_lib"]
839759

760+
if cmake_cache.is_enabled("EXECUTORCH_BUILD_QNN"):
761+
cmake_build_args += ["--target", "qnn_executorch_backend"]
762+
cmake_build_args += ["--target", "PyQnnManagerAdaptor"]
763+
cmake_build_args += ["--target", "PyQnnWrapperAdaptor"]
764+
840765
# Set PYTHONPATH to the location of the pip package.
841766
os.environ["PYTHONPATH"] = (
842767
site.getsitepackages()[0] + ";" + os.environ.get("PYTHONPATH", "")
@@ -924,5 +849,24 @@ def run(self): # noqa C901
924849
dst="executorch/data/lib/",
925850
dependent_cmake_flags=[],
926851
),
852+
BuiltFile(
853+
src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/",
854+
src_name="qnn_executorch_backend",
855+
dst="executorch/backends/qualcomm/",
856+
is_dynamic_lib=True,
857+
dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"],
858+
),
859+
BuiltExtension(
860+
src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/",
861+
src="PyQnnManagerAdaptor.*",
862+
modpath="executorch.backends.qualcomm.python.PyQnnManagerAdaptor",
863+
dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"],
864+
),
865+
BuiltExtension(
866+
src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/",
867+
src="PyQnnWrapperAdaptor.*",
868+
modpath="executorch.backends.qualcomm.python.PyQnnWrapperAdaptor",
869+
dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"],
870+
),
927871
],
928872
)

0 commit comments

Comments
 (0)