Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Fixes multi-GPU (MGPU) kernel-cache collisions in the FMHA v3 ASM path by ensuring cached kernel modules are keyed per device, preventing cross-device module reuse that can trigger “mismatched device ordinal launch” failures (notably in TE’s JAX integration).
Changes:
- Add current HIP device ordinal to the FMHA forward kernel cache key.
- Add current HIP device ordinal to the FMHA backward kernel cache key (pre/dqdkdv/post kernels).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
csrc/cpp_itfs/mha_fwd.cu |
Prefixes the forward ASM kernel cache key with the current device id to avoid cross-device cache collisions. |
csrc/cpp_itfs/mha_bwd.cu |
Prefixes the backward ASM kernel cache key with the current device id for the same MGPU safety. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Include device ID in cache key so each GPU gets its own loaded module | ||
| int current_device; | ||
| HIP_CALL(hipGetDevice(¤t_device)); | ||
| std::string dev_prefix = std::to_string(current_device) + ":"; | ||
|
|
||
| const auto& cfg = it->second; | ||
| const char* name = cfg.knl_name.c_str(); | ||
| std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); | ||
| std::string key = dev_prefix + name; |
There was a problem hiding this comment.
get_kernel_co_name() relies on is_mi308_device(), which (via get_pci_chip_id() in aiter_hip_common.h) caches the PCI chip id in a function-static value based on whichever device was current on first call. In a multi-GPU process with heterogeneous devices, this can pick the wrong hsaco subfolder (MI300 vs MI308) for subsequent devices. Consider making the MI308/MI300 selection device-specific (e.g., keyed by hipGetDevice() or by querying properties for the current device each time / caching per device id).
| // Include device ID in cache key so each GPU gets its own loaded module | |
| int current_device; | |
| HIP_CALL(hipGetDevice(¤t_device)); | |
| std::string dev_prefix = std::to_string(current_device) + ":"; | |
| const auto& cfg = it->second; | |
| const char* name = cfg.knl_name.c_str(); | |
| std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); | |
| std::string key = dev_prefix + name; | |
| // Include device ID and co_name in cache key so each GPU/variant gets its own loaded module | |
| int current_device; | |
| HIP_CALL(hipGetDevice(¤t_device)); | |
| const auto& cfg = it->second; | |
| const char* name = cfg.knl_name.c_str(); | |
| std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); | |
| std::string key = std::to_string(current_device) + ":" + co_name + ":" + name; |
| static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map; | ||
|
|
||
| // Include device ID in cache key so each GPU gets its own loaded module | ||
| int current_device; | ||
| HIP_CALL(hipGetDevice(¤t_device)); | ||
| std::string dev_prefix = std::to_string(current_device) + ":"; |
There was a problem hiding this comment.
impl_ptr_map is a process-wide static std::unordered_map that is mutated via emplace() without any synchronization. In MGPU setups it’s common for different host threads to run on different devices concurrently, and concurrent inserts/reads on unordered_map are undefined behavior (can crash/corrupt the cache). Consider making this cache thread_local (as in fmha_fwd_v3) or protecting all accesses with a mutex (or another thread-safe cache).
|
Note: This is not necessary with #2221 so I do think ideally that gets merged instead. This is just required as long as we use the standard hip module load. |
Motivation
Note: This is not necessary with #2221 so I do think ideally that gets merged instead. This is just required as long as we use the standard hip module load.
In TE's JAX integration, the MGPU tests fail due to multiple devices registering separate kernels in the
impl_ptr_map, leading to devices attempting to launch kernels that may have come from other devices resulting in a mismatched device ordinal launch error.Technical Details
Adds device ID to map key
Test Plan
Verify TE JAX integration tests pass
Test Result
TE JAX integration tests do indeed pass
Submission Checklist