-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Refactor: move DeepEP from Docker images to wheel building #5534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 36 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
92fce7d
Add deep_ep/CMakeLists.txt
yuantailing e67c4ea
Compile nvshmem
yuantailing 732c3b3
Download DeepEP as tar.gz rather than submodule
yuantailing a3e2d33
Copy DeepEP python files
yuantailing 0fbbb91
Update build_wheel.py
yuantailing 821756d
Fix: no symlinks in .whl
yuantailing 57c3215
Support more CUDA archs
yuantailing 1a86b3a
Speed up NVSHMEM build
yuantailing 99e733d
Fix NVSHMEM clang aarch64 build
yuantailing 19c07b5
Build NVSHMEM with g++; Minor fixes
yuantailing 84e6e35
Remove DeepEP from the Docker image
yuantailing 0c35157
Move CMAKE_ARGS to CMAKE_CACHE_ARGS
yuantailing 39c2fac
Update DeepEP version
yuantailing 2696352
Refine patch command and file glob
yuantailing e6ad386
Store nvshmem_src_3.2.5-1.txz in LFS
yuantailing 314522d
Update staging Docker image
yuantailing 3d62063
Fix the use of `split`
yuantailing 6dd5661
Add `patch` to Docker image
yuantailing 77c49e0
Fix if(DEFINED ...)
yuantailing f59d2be
Update staging Docker image
yuantailing 8ec209a
Add a readme about generating nvshmem_fast_build.patch
yuantailing 940c901
Add `libmlx5.so` and headers to Docker image
yuantailing 5b77338
Merge remote-tracking branch 'origin/main' into contains_deep_ep
yuantailing f358ab8
Fix for license check
yuantailing 5e23a56
Fix cmake configure dependencies
yuantailing a2cf24e
No specify `libmlx5.so.1`
yuantailing dce177e
Update staging Docker images
yuantailing 635f0d9
Add torch/lib to rpath
yuantailing e472e6b
Merge branch 'main' into contains_deep_ep
yuantailing c0b0759
Update Docker images
yuantailing d154d20
Remove GITHUB_MIRROR from Docker image environment
yuantailing 0dd8442
Fit to #5476
yuantailing 9bab622
Fix for Ninja
yuantailing c625e99
Check empty GITHUB_MIRROR
yuantailing 869627f
Fix Dockerfile ARG
yuantailing bfd4818
Update staging Docker images
yuantailing 1427067
Fix indentation
yuantailing 671a9fa
Add LINK_DEPENDS
yuantailing 670a2bf
Polish comments
yuantailing 9d41946
Update Docker images
yuantailing acccee8
Merge branch 'main' into contains_deep_ep
yuantailing c921142
Set DISABLE_AGGRESSIVE_PTX_INSTRS
yuantailing ca3cdcf
Update comments
yuantailing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| set(DEEP_EP_COMMIT c381dadf43a85062f6a8947592017ee513abc70b) | ||
| set(NVSHMEM_URL_HASH | ||
| SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) | ||
|
|
||
| add_custom_target(deep_ep) | ||
|
|
||
| # CUDA architectures | ||
| # ================== | ||
|
|
||
| # Filter CUDA arch >= 9.0 | ||
| set(DEEP_EP_CUDA_ARCHITECTURES "") | ||
| foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) | ||
| string(REGEX MATCHALL "^([1-9][0-9]*)([0-9])[af]?(-real|-virtual)?$" MATCHES | ||
| ${CUDA_ARCH}) | ||
| if(NOT CMAKE_MATCH_0) | ||
| message(FATAL_ERROR "Invalid CUDA arch format: \"${CUDA_ARCH}\"") | ||
| endif() | ||
| set(CUDA_ARCH_MAJOR ${CMAKE_MATCH_1}) | ||
| set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2}) | ||
| set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3}) | ||
| if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9) | ||
| list(APPEND DEEP_EP_CUDA_ARCHITECTURES | ||
| "${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}") | ||
| endif() | ||
| endforeach() | ||
|
|
||
| # Skip build if there is no suitable CUDA arch | ||
| if(WIN32) | ||
| set(DEEP_EP_CUDA_ARCHITECTURES "") | ||
| endif() | ||
| message( | ||
| STATUS "deep_ep DEEP_EP_CUDA_ARCHITECTURES: ${DEEP_EP_CUDA_ARCHITECTURES}") | ||
| file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/cuda_architectures.txt | ||
| "${DEEP_EP_CUDA_ARCHITECTURES}") | ||
| if(NOT DEEP_EP_CUDA_ARCHITECTURES) | ||
| return() | ||
| endif() | ||
|
|
||
| # Prepare files | ||
| # ============= | ||
|
|
||
| # Download DeepEP | ||
| include(FetchContent) | ||
| if(DEFINED ENV{GITHUB_MIRROR} AND NOT "$ENV{GITHUB_MIRROR}" STREQUAL "") | ||
| set(GITHUB_URL "$ENV{GITHUB_MIRROR}") | ||
| else() | ||
| set(GITHUB_URL "https://github.com") | ||
| endif() | ||
| set(DEEP_EP_URL | ||
| "${GITHUB_URL}/deepseek-ai/DeepEP/archive/${DEEP_EP_COMMIT}.tar.gz") | ||
| message(STATUS "deep_ep DEEP_EP_URL: ${DEEP_EP_URL}") | ||
| FetchContent_Declare(deep_ep_download URL ${DEEP_EP_URL}) | ||
| FetchContent_MakeAvailable(deep_ep_download) | ||
| set(DEEP_EP_SOURCE_DIR ${deep_ep_download_SOURCE_DIR}) | ||
|
|
||
| # Copy and update python files | ||
| set(DEEP_EP_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/deep_ep) | ||
| file(REMOVE_RECURSE ${DEEP_EP_PYTHON_DEST}) | ||
| file(MAKE_DIRECTORY ${DEEP_EP_PYTHON_DEST}) | ||
| configure_file(${DEEP_EP_SOURCE_DIR}/LICENSE ${DEEP_EP_PYTHON_DEST}/LICENSE | ||
| COPYONLY) | ||
| set(_files __init__.py buffer.py utils.py) | ||
| foreach(_f IN LISTS _files) | ||
| set(_src "${DEEP_EP_SOURCE_DIR}/deep_ep/${_f}") | ||
| set(_dst "${DEEP_EP_PYTHON_DEST}/${_f}") | ||
| file(READ "${_src}" _content) | ||
| string(REPLACE "deep_ep_cpp" "tensorrt_llm.deep_ep_cpp_tllm" _content | ||
| "${_content}") | ||
| string( | ||
| PREPEND | ||
| _content | ||
| "# Adapted from https://github.com/deepseek-ai/DeepEP/blob/${DEEP_EP_COMMIT}/deep_ep/${_f}\n" | ||
| ) | ||
| file(WRITE "${_dst}" "${_content}") | ||
| set_property( | ||
| DIRECTORY | ||
| APPEND | ||
| PROPERTY CMAKE_CONFIGURE_DEPENDS ${_src}) | ||
| endforeach() | ||
|
|
||
| # Delete stale nvshmem on patch update | ||
| set(NVSHMEM_STAMP_FILE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_stamp.txt) | ||
| file(SHA256 ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch NVSHMEM_PATCH_HASH) | ||
| file(SHA256 ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch | ||
| NVSHMEM_PATCH_2_HASH) | ||
| set(NVSHMEM_STAMP_CONTENT "${NVSHMEM_URL_HASH}") | ||
| string(APPEND NVSHMEM_STAMP_CONTENT " PATCH_COMMAND v1") | ||
| string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_HASH}") | ||
| string(APPEND NVSHMEM_STAMP_CONTENT " 103") | ||
| string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_2_HASH}") | ||
| set(OLD_NVSHMEM_STAMP_CONTENT "") | ||
| if(EXISTS ${NVSHMEM_STAMP_FILE}) | ||
| file(READ ${NVSHMEM_STAMP_FILE} OLD_NVSHMEM_STAMP_CONTENT) | ||
| endif() | ||
| if(NOT OLD_NVSHMEM_STAMP_CONTENT STREQUAL NVSHMEM_STAMP_CONTENT) | ||
| file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_project-prefix) | ||
| file(WRITE ${NVSHMEM_STAMP_FILE} "${NVSHMEM_STAMP_CONTENT}") | ||
| endif() | ||
| set_property( | ||
| DIRECTORY APPEND | ||
| PROPERTY CMAKE_CONFIGURE_DEPENDS | ||
| ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch) | ||
|
|
||
| # Add NVSHMEM | ||
| # =========== | ||
|
|
||
| # NVSHMEM only works with GCC. Building NVSHMEM with Clang results in | ||
| # compilation errors. Using NVSHMEM with Clang results in slow builds and device | ||
| # link issues. | ||
| if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") | ||
| set(CMAKE_C_COMPILER gcc) | ||
| set(CMAKE_CXX_COMPILER g++) | ||
| set(CMAKE_CUDA_HOST_COMPILER g++) | ||
| endif() | ||
|
|
||
| # Add nvshmem external project | ||
| include(ExternalProject) | ||
| ExternalProject_Add( | ||
| nvshmem_project | ||
| URL file://${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_src_3.2.5-1.txz | ||
| URL_HASH ${NVSHMEM_URL_HASH} | ||
| PATCH_COMMAND patch -p1 --forward --batch -i | ||
| ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch | ||
| COMMAND sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i | ||
| src/CMakeLists.txt | ||
| COMMAND patch -p1 --forward --batch -i | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch | ||
| CMAKE_CACHE_ARGS | ||
| -DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} | ||
| -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} | ||
| -DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} | ||
| -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} | ||
| -DCMAKE_CUDA_ARCHITECTURES:STRING=${DEEP_EP_CUDA_ARCHITECTURES} | ||
| -DCMAKE_CUDA_HOST_COMPILER:STRING=${CMAKE_CUDA_HOST_COMPILER} | ||
| -DCMAKE_CUDA_COMPILER_LAUNCHER:STRING=${CMAKE_CUDA_COMPILER_LAUNCHER} | ||
| -DNVSHMEM_BUILD_EXAMPLES:BOOL=0 | ||
| -DNVSHMEM_BUILD_PACKAGES:BOOL=0 | ||
| -DNVSHMEM_BUILD_TESTS:BOOL=0 | ||
| -DNVSHMEM_IBGDA_SUPPORT:BOOL=1 | ||
| -DNVSHMEM_IBRC_SUPPORT:BOOL=0 | ||
| -DNVSHMEM_MPI_SUPPORT:BOOL=0 | ||
| -DNVSHMEM_PMIX_SUPPORT:BOOL=0 | ||
| -DNVSHMEM_SHMEM_SUPPORT:BOOL=0 | ||
| -DNVSHMEM_TIMEOUT_DEVICE_POLLING:BOOL=0 | ||
| -DNVSHMEM_UCX_SUPPORT:BOOL=0 | ||
| -DNVSHMEM_USE_GDRCOPY:BOOL=0 | ||
| -DNVSHMEM_USE_NCCL:BOOL=0 | ||
| INSTALL_COMMAND "" | ||
| BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build | ||
| BUILD_BYPRODUCTS | ||
| ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a) | ||
| add_library(nvshmem_project::nvshmem STATIC IMPORTED) | ||
| add_dependencies(nvshmem_project::nvshmem nvshmem_project) | ||
| file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) | ||
| set_target_properties( | ||
| nvshmem_project::nvshmem | ||
| PROPERTIES IMPORTED_LOCATION | ||
| ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a | ||
| INTERFACE_INCLUDE_DIRECTORIES | ||
| ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) | ||
|
|
||
| # Add DeepEP cpp | ||
| # ============== | ||
|
|
||
| # Let CMake generate `fatbinData` for CUDA separable compilation. Set to FALSE | ||
| # or TRUE are both OK, but it generates `code=lto_90a` rather than `code=sm_90a` | ||
| # for arch `90a-real` if set to TRUE. | ||
| set(CMAKE_INTERPROCEDURAL_OPTIMIZATION FALSE) | ||
|
|
||
| # Find torch_python | ||
| find_library(TORCH_PYTHON_LIB torch_python REQUIRED | ||
| HINTS ${TORCH_INSTALL_PREFIX}/lib) | ||
|
|
||
| # Add deep_ep_cpp_tllm | ||
| file(GLOB_RECURSE SRC_CPP ${DEEP_EP_SOURCE_DIR}/csrc/*.cpp) | ||
| file(GLOB_RECURSE SRC_CU ${DEEP_EP_SOURCE_DIR}/csrc/*.cu) | ||
| pybind11_add_module(deep_ep_cpp_tllm ${SRC_CPP} ${SRC_CU}) | ||
| set_target_properties( | ||
| deep_ep_cpp_tllm | ||
| PROPERTIES CXX_STANDARD_REQUIRED ON | ||
| CUDA_STANDARD_REQUIRED ON | ||
| CXX_STANDARD 17 | ||
| CUDA_STANDARD 17 | ||
| CUDA_SEPARABLE_COMPILATION ON | ||
| CUDA_ARCHITECTURES "${DEEP_EP_CUDA_ARCHITECTURES}" | ||
| INSTALL_RPATH "$ORIGIN/libs/nvshmem;${TORCH_INSTALL_PREFIX}/lib" | ||
| BUILD_WITH_INSTALL_RPATH TRUE) | ||
| target_compile_options( | ||
| deep_ep_cpp_tllm | ||
| PRIVATE ${TORCH_CXX_FLAGS} -O3 $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-O3> | ||
| $<$<COMPILE_LANGUAGE:CUDA>:--ptxas-options=--register-usage-level=10>) | ||
| target_compile_definitions(deep_ep_cpp_tllm | ||
| PRIVATE TORCH_EXTENSION_NAME=deep_ep_cpp_tllm) | ||
| target_link_libraries( | ||
| deep_ep_cpp_tllm PRIVATE nvshmem_project::nvshmem ${TORCH_LIBRARIES} | ||
| ${TORCH_PYTHON_LIB}) | ||
| target_link_options( | ||
| deep_ep_cpp_tllm PRIVATE | ||
| -Wl,--version-script,${CMAKE_CURRENT_SOURCE_DIR}/deep_ep_cpp_tllm.version | ||
| -Wl,--no-undefined-version) | ||
|
|
||
| # Set targets | ||
| # =========== | ||
| add_dependencies(deep_ep deep_ep_cpp_tllm nvshmem_project) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| How to generate `nvshmem_fast_build.patch`? | ||
|
|
||
| 1. Build without `nvshmem_fast_build.patch`. | ||
| 2. Try linking DeepEP to NVSHMEM while omitting one object file. | ||
| 3. Repeat step 2 until no more object files can be omitted. | ||
| 4. Remove the unused files from NVSHMEM's `CMakelists.txt`, and save the differences as `nvshmem_fast_build.patch`. | ||
|
|
||
| The script `strip_nvshmem_helper.py` automatically performs steps 2 and 3. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| { | ||
| global: PyInit_deep_ep_cpp_tllm; | ||
| local: *; | ||
| }; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt | ||
| index cba899bba..c27337601 100644 | ||
| --- a/src/CMakeLists.txt | ||
| +++ b/src/CMakeLists.txt | ||
| @@ -264,48 +264,20 @@ set(NVSHMEM_HOST_SOURCES_NOMAXREGCOUNT | ||
| host/comm/rma.cu | ||
| host/stream/comm/quiet_on_stream.cu | ||
| host/stream/comm/cuda_interface_sync.cu | ||
| - host/stream/coll/alltoall/alltoall.cu | ||
| host/stream/coll/barrier/barrier.cu | ||
| - host/stream/coll/broadcast/broadcast.cu | ||
| - host/stream/coll/fcollect/fcollect.cu | ||
| - host/stream/coll/rdxn/reduce_and.cu | ||
| - host/stream/coll/rdxn/reduce_or.cu | ||
| - host/stream/coll/rdxn/reduce_xor.cu | ||
| - host/stream/coll/rdxn/reduce_min.cu | ||
| host/stream/coll/rdxn/reduce_max.cu | ||
| - host/stream/coll/rdxn/reduce_prod.cu | ||
| - host/stream/coll/rdxn/reduce_sum.cu | ||
| host/stream/coll/rdxn/reduce_team.cu | ||
| - host/stream/coll/reducescatter/reducescatter_and.cu | ||
| - host/stream/coll/reducescatter/reducescatter_or.cu | ||
| - host/stream/coll/reducescatter/reducescatter_xor.cu | ||
| - host/stream/coll/reducescatter/reducescatter_min.cu | ||
| - host/stream/coll/reducescatter/reducescatter_max.cu | ||
| - host/stream/coll/reducescatter/reducescatter_prod.cu | ||
| - host/stream/coll/reducescatter/reducescatter_sum.cu | ||
| ) | ||
|
|
||
| set(NVSHMEM_HOST_SOURCES | ||
| host/bootstrap/bootstrap.cpp | ||
| host/bootstrap/bootstrap_loader.cpp | ||
| host/coll/cpu_coll.cpp | ||
| - host/coll/alltoall/alltoall.cpp | ||
| - host/coll/alltoall/alltoall_on_stream.cpp | ||
| host/coll/barrier/barrier.cpp | ||
| host/coll/barrier/barrier_on_stream.cpp | ||
| - host/coll/broadcast/broadcast.cpp | ||
| - host/coll/broadcast/broadcast_on_stream.cpp | ||
| - host/coll/fcollect/fcollect.cpp | ||
| - host/coll/fcollect/fcollect_on_stream.cpp | ||
| - host/coll/rdxn/rdxn.cpp | ||
| - host/coll/rdxn/rdxn_on_stream.cpp | ||
| - host/coll/reducescatter/reducescatter.cpp | ||
| - host/coll/reducescatter/reducescatter_on_stream.cpp | ||
| host/comm/putget.cpp | ||
| - host/comm/fence.cpp | ||
| host/comm/quiet.cpp | ||
| host/comm/sync.cpp | ||
| - host/comm/amo.cpp | ||
| host/proxy/proxy.cpp | ||
| host/transport/transport.cpp | ||
| host/transport/p2p/p2p.cpp | ||
| @@ -1006,3 +978,12 @@ set(CPACK_RPM_PACKAGE_REQUIRES_PREUN "/sbin/ldconfig") | ||
|
|
||
| include(CPack) | ||
| # End Installation definitions | ||
| + | ||
| +set_target_properties( | ||
| + git_commit | ||
| + nvshmem_device_project | ||
| + nvshmem_bootstrap_pmi | ||
| + nvshmem_bootstrap_pmi2 | ||
| + nvshmem_host | ||
| + nvshmem-info | ||
| + PROPERTIES EXCLUDE_FROM_ALL TRUE) |
Git LFS file not shown
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import pathlib | ||
| import re | ||
| import subprocess | ||
|
|
||
| project_dir = pathlib.Path(__file__).parent.parent.parent.parent | ||
|
|
||
| # Run `find cpp/build | grep kernels/internode_ll.cu.o$` to get the directory | ||
| deep_ep_obj_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/CMakeFiles/deep_ep_cpp_tllm.dir/__/__/_deps/deep_ep_download-src/csrc" | ||
| assert deep_ep_obj_dir.is_dir() | ||
|
|
||
| # Run `find cpp/build | grep host/bootstrap/bootstrap.cpp.o$` to get the directory | ||
| # Please set to `nvshmem.dir` rather than `nvshmem_host.dir` | ||
| nvshmem_obj_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/nvshmem-build/src/CMakeFiles/nvshmem.dir" | ||
| assert nvshmem_obj_dir.is_dir() | ||
|
|
||
| # Parse -gencode arguments | ||
| with (project_dir / | ||
| "cpp/build/tensorrt_llm/deep_ep/cuda_architectures.txt").open() as f: | ||
| cuda_architectures = f.read() | ||
| pattern = re.compile(r'^([1-9][0-9]*[0-9][af]?)(-real|-virtual)?$') | ||
| gencode_args = [] | ||
| for cuda_arch in cuda_architectures.split(";"): | ||
| matches = re.match(pattern, cuda_arch) | ||
| assert matches is not None, f"Invalid cuda arch \"{cuda_arch}\"" | ||
| sm_version = matches.group(1) | ||
| postfix = matches.group(2) or "" | ||
| code = { | ||
| "": f"[compute_{sm_version},sm_{sm_version}]", | ||
| "-real": f"[sm_{sm_version}]", | ||
| "-virtual": f"[compute_{sm_version}]", | ||
| }[postfix] | ||
| gencode_args.append(f"-gencode=arch=compute_{sm_version},{code=:s}") | ||
|
|
||
| temp_dir = project_dir / "cpp/build/tensorrt_llm/deep_ep/strip_nvshmem_helper" | ||
| temp_dir.mkdir(exist_ok=True) | ||
| ranlib = temp_dir / "liba.a" | ||
| if ranlib.exists(): | ||
| ranlib.unlink() | ||
|
|
||
| deep_ep_obj_list = sorted(deep_ep_obj_dir.glob("kernels/**/*.o")) | ||
| nvshmem_obj_set = set(nvshmem_obj_dir.glob("**/*.o")) | ||
| for exclude_obj in sorted(nvshmem_obj_set): | ||
| # Create liba.a with one object file less | ||
| subprocess.check_call( | ||
| ["ar", "rcs", ranlib, *(nvshmem_obj_set - {exclude_obj})]) | ||
| # Test whether there are undefined symbols | ||
| res = subprocess.call([ | ||
| "/usr/local/cuda/bin/nvcc", *gencode_args, "-Xlinker", "--no-undefined", | ||
| "-shared", *deep_ep_obj_list, ranlib, "-o", temp_dir / "a.out" | ||
| ]) | ||
| # If there is no undefined symbols, then print "-" indicating the file could be omitted. | ||
| print("-" if res == 0 else "+", | ||
| str(exclude_obj.relative_to(nvshmem_obj_dir))[:-2]) | ||
| # Unlink the ranlib because `ar` does append | ||
| ranlib.unlink() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.