-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[NVSHMEM] Extend CUDA backend to compile and link TIR modules with NVSHMEM #18093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CMakeLists.txt
Outdated
| message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM}) | ||
| endif() | ||
| set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -I${NVSHMEM_INCLUDE_DIR} -L${NVSHMEM_LIB_DIR}") | ||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${NVSHMEM_INCLUDE_DIR} -L${NVSHMEM_LIB_DIR}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cmake primitives target_include_directories instead of setting CXX and CUDA_FLAGS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
| The loaded VM module. | ||
| """ | ||
| if device is None: | ||
| device = Device(device_type=0, device_id=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in the latest dlpack, DLDeviceType has enum values 1 to 17. The valueDevice(device_type=0, device_id=0) would raise an error of unrecognized device type. Since it is meant to indicate a Null value, I replace the subsequent usage with Optional<Device> type, see the changes of UseDefaultDeviceIfNone.
|
please include one testcase that tests the basic functionality such as just calling the get worker id |
| CUmodule mod = static_cast<CUmodule>(cuModule); | ||
| auto status = nvshmemx_init_status(); | ||
| // The NVSHMEM library must have completed device initialization prior to | ||
| // nvshmemx_cumodule_init. If not, we skip the cumodule initialization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not device initialized, we should return with error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design here is to enable NVSHMEM compilation and linking broadly for every kernel, including those whose NVSHMEM context is not initialized and do not use NVSHMEM in their kernels.
In such case, nvshmemx_init_status() is used to check whether we need to call nvshmemx_cumodule_init or not. If not device initialized, we just skip nvshmemx_cumodule_init.
| decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; | ||
| decl_stream << "#endif\n"; | ||
|
|
||
| decl_stream << "\n#ifdef _WIN32\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because NVSHMEM contains #include <cstdint>, which is in conflict with the original #define int64_t long long and could lead to CUDA compilation error. The #define semantics is quite error prone, so I remove it and just do using.
Added a test case under |
|
@tvm-bot rerun |
|
Failed to re-run CI in https://github.com/apache/tvm/actions/runs/15886908209 |
|
@tvm-bot rerun |
|
Failed to re-run CI in https://github.com/apache/tvm/actions/runs/15887848436 |
|
@tvm-bot rerun |
This PR enables the native NVSHMEM compilation support.