Skip to content

Conversation

@dongxuy04
Copy link
Collaborator

@dongxuy04 dongxuy04 commented Aug 26, 2025

Changes in this MR:

  • Fix possible hang issue in Alltoall
  • Fix possible race condition in moePrepare and EPLB
  • Move some post-merge tests to pre-merge

Summary by CodeRabbit

  • New Features
    • Expanded multi-GPU test coverage with two new online EPLB variants for DeepSeekV3Lite.
  • Performance
    • Optimized MoE data transfer path for more predictable, warp-synchronous behavior.
  • Bug Fixes
    • Strengthened synchronization in MoE load balancing to prevent race conditions.
    • Added backpressure when releasing counters to avoid overwriting unconsumed slots.
    • Ensured CUDA operations complete before cross-thread barriers to improve determinism.
  • Tests
    • Added new pre-merge and post-merge scenarios to improve regression coverage for multi-GPU configurations.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 26, 2025

📝 Walkthrough

Walkthrough

Switches S2G workspace copy from asynchronous bulk cp_async to a synchronous register-based write loop; adds atomic load/store for a shared signal; introduces a spin-wait before releasing FIFO values; inserts a CUDA synchronize after workspace initialization; and adds two online EPLB tests for DeepSeekV3Lite on multi-GPU.

Changes

Cohort / File(s) Summary of modifications
Fused MOE S2G transfer
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
Added startWorkspaceS2GReg and replaced startWorkspaceS2G call with it; removed cp_async bulk commit/wait usage; implemented synchronous per-thread int4 writes from shared memory to FIFO with warp sync.
MOE load-balance signal atomics
cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
Replaced direct reads/writes of signal->stepAndOwner with __atomic_load_n/__atomic_store_n using acquire/release semantics; logic otherwise unchanged.
MOE prepare backpressure
cpp/tensorrt_llm/kernels/moePrepareKernels.cu
Added spin-wait in CounterCommunicator::releaseValue to wait for slot reset (0) before writing value+1.
Workspace init sync
tensorrt_llm/_mnnvl_utils.py
Inserted torch.cuda.synchronize() immediately after moe_initialize_workspace() and before barrier.
Integration tests (GB200 multi-GPU)
tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Added two online EPLB tests for DeepSeekV3Lite to pre_merge and post_merge: nvfp4_4gpus_online_eplb[fp8kv=True] and bfloat16_4gpus_online_eplb[mtp_nextn=2].

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Host
  participant GPU as GPU Kernel
  participant Shmem as Shared Memory
  participant FIFO as FIFO/Global Mem
  participant Signal as Step/Owner Signal

  rect rgba(200,230,255,0.3)
    Note over Host,GPU: Workspace initialization
    Host->>GPU: moe_initialize_workspace(...)
    Host->>GPU: torch.cuda.synchronize()
    Host->>Host: barrier.wait()
  end

  rect rgba(230,255,230,0.4)
    Note over GPU,FIFO: S2G workspace transfer (synchronous register copy)
    GPU->>Shmem: Read 128B chunks
    loop per lane (int4 writes)
      GPU->>FIFO: Write 128B via registers
      GPU->>GPU: __syncwarp()
    end
  end

  rect rgba(255,245,200,0.5)
    Note over GPU,Signal: Load-balance signaling with atomics
    GPU->>Signal: __atomic_load_n(stepAndOwner, ACQ)
    alt ready
      GPU-->>GPU: proceed
    else not ready
      GPU-->>GPU: wait/retry
    end
    Host->>Signal: __atomic_store_n(stepAndOwner, REL)
  end

  rect rgba(255,230,230,0.5)
    Note over GPU,FIFO: CounterCommunicator release with backpressure
    loop until slot zero
      GPU->>FIFO: check values[index] == 0
    end
    GPU->>FIFO: write value+1
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • WeiHaocheng
  • hlu1
  • kaiyux
  • litaotju

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@dongxuy04
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16561 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml (1)

36-37: Confirm intent: tests duplicated in both pre_merge and post_merge lists

You added the two online_eplb variants to pre_merge, but they still exist in post_merge. If the goal is to “move” them, consider removing them from post_merge to avoid running twice and burning CI time.

Apply this diff (outside the changed hunk) if duplication is not intended:

-  - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
-  - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
cpp/tensorrt_llm/kernels/moePrepareKernels.cu (1)

2-2: Update copyright year to include 2025

Header still says 2022-2024 while this file is modified in 2025.

- * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION.  All rights reserved.
cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu (1)

2-2: Update copyright year to include 2025

This file’s header still lists 2019-2023.

- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2019-2025, NVIDIA CORPORATION.  All rights reserved.
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu (1)

779-783: Optional: keep cp.async path feature-gated for perf on safe architectures

If the hang was specific to cp.async bulk S2G on certain platforms, consider retaining the original startWorkspaceS2G path behind a compile-time or runtime switch (e.g., env var) to preserve peak performance where it’s stable.

I can wire a small flag plumbed through FusedMoeWorkspace/MoeSingleCommMeta to toggle.

tensorrt_llm/_mnnvl_utils.py (2)

372-372: Device synchronization before barrier is reasonable; consider scoping to the active device

torch.cuda.synchronize() syncs all devices. Narrowing it to the active device avoids unintended stalls if multiple devices are visible in the process.

-        torch.cuda.synchronize()
+        torch.cuda.synchronize(MnnvlMemory.dev_id)

If moe_initialize_workspace may use non-default streams, this still correctly serializes all work on the specified device.


1-2: Update SPDX/copyright to 2025

Header lists 2022-2024.

-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 80043af and cc175d8.

📒 Files selected for processing (5)
  • cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu (2 hunks)
  • cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu (2 hunks)
  • cpp/tensorrt_llm/kernels/moePrepareKernels.cu (1 hunks)
  • tensorrt_llm/_mnnvl_utils.py (1 hunks)
  • tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_mnnvl_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_mnnvl_utils.py
  • cpp/tensorrt_llm/kernels/moePrepareKernels.cu
  • cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
  • cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
**/*.{c,cc,cpp,cxx,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{c,cc,cpp,cxx,cu}: Closing braces of C++ namespaces must include a trailing comment naming the namespace (e.g., } // namespace foo)
Use Allman brace style; empty for/while loop semicolon on its own line; always use braces for control statements
C++ filenames must be lowerCamelCase (e.g., thisIsAFilename.cpp) and be case-insensitively unique within a compilation target
Use smart pointers; prefer unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases; do not use deprecated smart pointers
In implementation, prefer C++ comments (//); use inline C comments only for annotating parameters in calls (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) or chained x = y = z)

Files:

  • cpp/tensorrt_llm/kernels/moePrepareKernels.cu
  • cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
  • cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
**/*.{c,cc,cpp,cxx,cu,h,hh,hpp,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{c,cc,cpp,cxx,cu,h,hh,hpp,hxx,cuh}: Prefer const or constexpr variables over #define for constants; variables not modified after initialization must be declared const
Avoid using literals (except 0, nullptr, true, false) outside of initialization; prefer named constexpr constants
Type names (classes, structs, enums, typedefs) must be UpperCamelCase
Local variables, methods, and namespaces must be lowerCamelCase
Non-magic-number global variables that are non-static/not in anonymous namespace must be prefixed with g (e.g., gDontUseGlobalFoos)
Non-magic-number globals that are static or in an anonymous namespace must be prefixed with s (e.g., sMutableStaticGlobal)
Locally visible static variables should be lowerCamelCase prefixed with s (e.g., static std::once_flag sFlag)
Member variables should be lowerCamelCase prefixed with m (e.g., mNbFooValues); public members may omit but prefix is encouraged for clarity
Constants (enums, globals, static constants, and function-scope magic numbers) should be UPPER_SNAKE_CASE with k prefix (e.g., kDIGIT_NUM)
Avoid Hungarian notation except limited 'apps Hungarian' like nb for counts; literal suffixes should be uppercase (e.g., 1234L)
Use spaces only; indent with 4 spaces (no tabs)
Format C++ code with clang-format (LLVM style) and limit lines to 120 characters; exceptions must be bracketed with // clang-format off/on
Disable code with #if/#endif (prefer mnemonic conditions) or macros that noop in release; do not comment out code; avoid dead code
Use the least forceful cast necessary; avoid removing const/volatile; avoid C-style and functional casts (except explicit constructors); cast void* to T* with static_cast; use reinterpret_cast only as last resort; avoid dynamic_cast
Switch on enum should cover all values and omit default when possible; switch statements must be well-structured with no fall-through except between adjacent empty cases; each case must end with break or throw; returns at end of case are not allowed; if ...

Files:

  • cpp/tensorrt_llm/kernels/moePrepareKernels.cu
  • cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
  • cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
🧠 Learnings (1)
📚 Learning: 2025-08-19T03:35:20.866Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.

Applied to files:

  • tensorrt_llm/_mnnvl_utils.py
  • cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu
🧬 Code graph analysis (3)
tensorrt_llm/_mnnvl_utils.py (2)
cpp/tensorrt_llm/runtime/torchUtils.h (1)
  • cuda (129-132)
cpp/include/tensorrt_llm/runtime/cudaStream.h (1)
  • synchronize (81-96)
cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu (1)
cpp/tensorrt_llm/kernels/moePrepareKernels.cu (2)
  • value (60-68)
  • value (60-60)
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu (1)
cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h (1)
  • BYTES_PER_128B_BLOCK (69-107)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (4)
tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml (2)

36-37: Pre-merge budget: ensure stability and runtime are acceptable

These online_eplb cases can be heavy. If they are occasionally flaky or close to the time budget, mark them with a higher timeout or gating condition, otherwise they may block merges. Let’s confirm their current average runtime on GB200 4-GPU.

I can help profile recent CI runs if you share logs.


36-37: Test selectors exist and are correct

Both parametrized tests are present in tests/integration/defs/accuracy/test_llm_api_pytorch.py and match the YAML entries exactly:

  • test_bfloat16_4gpus_online_eplb(self, mtp_nextn) at line 1384
  • test_nvfp4_4gpus_online_eplb(self, fp8kv) at line 1408

No issues—selectors are spelled correctly and need no further changes.

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu (2)

606-608: Good fix: acquire load for shared signal

Switching to __atomic_load_n with __ATOMIC_ACQUIRE is appropriate here to prevent tearing and ensure reads after the load aren’t reordered before it.


622-623: Good fix: release store for shared signal

Using __atomic_store_n with __ATOMIC_RELEASE correctly orders prior writes before publishing the new step/owner state.

@kaiyux kaiyux enabled auto-merge (squash) August 26, 2025 15:09
@tensorrt-cicd
Copy link
Collaborator

PR_Github #16561 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12437 completed with status: 'FAILURE'

@dongxuy04
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16627 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16627 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12482 completed with status: 'SUCCESS'

@kaiyux kaiyux merged commit abdb273 into NVIDIA:main Aug 27, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants