Skip to content

Fix kernel map collision on MGPU context#2401

Open
Micky774 wants to merge 1 commit intomainfrom
zain/um-mgpu-collision
Open

Fix kernel map collision on MGPU context#2401
Micky774 wants to merge 1 commit intomainfrom
zain/um-mgpu-collision

Conversation

@Micky774
Copy link
Copy Markdown
Contributor

@Micky774 Micky774 commented Mar 20, 2026

Motivation

Note: This is not necessary with #2221 so I do think ideally that gets merged instead. This is just required as long as we use the standard hip module load.

In TE's JAX integration, the MGPU tests fail due to multiple devices registering separate kernels in the impl_ptr_map, leading to devices attempting to launch kernels that may have come from other devices resulting in a mismatched device ordinal launch error.

Technical Details

Adds device ID to map key

Test Plan

Verify TE JAX integration tests pass

Test Result

TE JAX integration tests do indeed pass

Submission Checklist

@Micky774 Micky774 requested review from a team and Copilot March 20, 2026 17:29
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2401 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes multi-GPU (MGPU) kernel-cache collisions in the FMHA v3 ASM path by ensuring cached kernel modules are keyed per device, preventing cross-device module reuse that can trigger “mismatched device ordinal launch” failures (notably in TE’s JAX integration).

Changes:

  • Add current HIP device ordinal to the FMHA forward kernel cache key.
  • Add current HIP device ordinal to the FMHA backward kernel cache key (pre/dqdkdv/post kernels).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
csrc/cpp_itfs/mha_fwd.cu Prefixes the forward ASM kernel cache key with the current device id to avoid cross-device cache collisions.
csrc/cpp_itfs/mha_bwd.cu Prefixes the backward ASM kernel cache key with the current device id for the same MGPU safety.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/cpp_itfs/mha_fwd.cu
Comment on lines +245 to +253
// Include device ID in cache key so each GPU gets its own loaded module
int current_device;
HIP_CALL(hipGetDevice(&current_device));
std::string dev_prefix = std::to_string(current_device) + ":";

const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);
std::string key = dev_prefix + name;
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_kernel_co_name() relies on is_mi308_device(), which (via get_pci_chip_id() in aiter_hip_common.h) caches the PCI chip id in a function-static value based on whichever device was current on first call. In a multi-GPU process with heterogeneous devices, this can pick the wrong hsaco subfolder (MI300 vs MI308) for subsequent devices. Consider making the MI308/MI300 selection device-specific (e.g., keyed by hipGetDevice() or by querying properties for the current device each time / caching per device id).

Suggested change
// Include device ID in cache key so each GPU gets its own loaded module
int current_device;
HIP_CALL(hipGetDevice(&current_device));
std::string dev_prefix = std::to_string(current_device) + ":";
const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);
std::string key = dev_prefix + name;
// Include device ID and co_name in cache key so each GPU/variant gets its own loaded module
int current_device;
HIP_CALL(hipGetDevice(&current_device));
const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);
std::string key = std::to_string(current_device) + ":" + co_name + ":" + name;

Copilot uses AI. Check for mistakes.
Comment thread csrc/cpp_itfs/mha_bwd.cu
Comment on lines 372 to +377
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;

// Include device ID in cache key so each GPU gets its own loaded module
int current_device;
HIP_CALL(hipGetDevice(&current_device));
std::string dev_prefix = std::to_string(current_device) + ":";
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

impl_ptr_map is a process-wide static std::unordered_map that is mutated via emplace() without any synchronization. In MGPU setups it’s common for different host threads to run on different devices concurrently, and concurrent inserts/reads on unordered_map are undefined behavior (can crash/corrupt the cache). Consider making this cache thread_local (as in fmha_fwd_v3) or protecting all accesses with a mutex (or another thread-safe cache).

Copilot uses AI. Check for mistakes.
@Micky774
Copy link
Copy Markdown
Contributor Author

Note: This is not necessary with #2221 so I do think ideally that gets merged instead. This is just required as long as we use the standard hip module load.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants