Skip to content

Fix 2D tensor communication for asymmetric DP in Bridge Communicator#4021

Merged
yashaswikarnati merged 1 commit intoNVIDIA:mainfrom
yashaswikarnati:yash/nmfw-47-dp-fan-in-fix
Mar 25, 2026
Merged

Fix 2D tensor communication for asymmetric DP in Bridge Communicator#4021
yashaswikarnati merged 1 commit intoNVIDIA:mainfrom
yashaswikarnati:yash/nmfw-47-dp-fan-in-fix

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Contributor

@yashaswikarnati yashaswikarnati commented Mar 24, 2026

Summary

Adds native 2D tensor support to BridgeCommunicator for modules that produce [B*S, H] outputs (e.g., vision encoders), fixing a crash when DP sizes differ across modules.

The problem

The bridge communicator assumed all tensors are 3D [S, B, H] and used unsqueeze(-1) / squeeze(-1) to adapt 2D tensors. This breaks with asymmetric DP because the bridge splits/concatenates along dim_mapping['b']=1 — but after unsqueeze(-1), dim 1 is the hidden dimension, not batch:

[577, 4096] → unsqueeze(-1) → [577, 4096, 1]
                                 s=0  b=1   h=2
                                      ^^^
                          bridge thinks this is batch
                          but it's actually hidden

Fan-in (DP=3 → DP=1): cat(dim=1) on 3 tensors concatenates along hidden → [577, 12288] instead of [1731, 4096].
Fan-out (DP=1 → DP=3): split(dim=1, 3) splits hidden into thirds → [1731, 1365] instead of [577, 4096].

Silent when DP ratio is 1:1 (cat/split of 1 tensor are no-ops).

The fix

Add tensor_ndim parameter to BridgeCommunicator. For 2D tensors, batch is folded into dim 0, so fan-in/fan-out uses cat/split at dim 0 directly — no unsqueeze needed:

Fan-in:   3 × [577, H]  → cat(dim=0)    → [1731, H]
Fan-out:  [1731, H]      → split(3, dim=0) → 3 × [577, H]

MultiModulePipelineCommunicator accepts module_output_ndim dict to configure each bridge per source module.

Files changed

  • bridge_communicator.pytensor_ndim param, _batch_dim property
  • multimodule_communicator.pymodule_output_ndim param, passes to bridges
  • test_mimo_1f1b_schedule.py — 4 new 8-GPU asymmetric DP test configs
  • test_bridge_communicator.py — 2 new 2D fan-in/fan-out tests

Test plan

  • 8-GPU: Enc DP=4 → LLM TP=2 PP=2 DP=1 (fan-in)
  • 8-GPU: Enc TP=2 PP=2 DP=1 → LLM DP=4 (fan-out)
  • 8-GPU: Enc DP=2 → LLM TP=2 PP=3 DP=1 (fan-in)
  • 8-GPU: Enc TP=2 PP=2, LLM TP=2 PP=2 (symmetric, no regression)
  • 8-GPU: Bridge-level 2D fan-in/fan-out fwd+bwd

🤖 Generated with Claude Code

@yashaswikarnati yashaswikarnati requested review from a team as code owners March 24, 2026 19:52
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 24, 2026 19:52
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

_prepare_tensor_for_comm() always inserted a singleton dim at position
-1 (the end), regardless of dim_mapping. With SBH format, the bridge
operates on dim_mapping['b']=1, but after unsqueeze(-1), dim 1 is the
hidden dimension, not batch. This caused incorrect cat/split operations
when DP sizes differ between modules (fan-in/fan-out).

Fix: add tensor_ndim parameter to BridgeCommunicator. For 2D tensors
[B*S, H], batch is folded into dim 0, so fan-in/fan-out uses cat/split
at dim 0 directly — no unsqueeze/squeeze needed. Each bridge gets
tensor_ndim from module_output_ndim config in the communicator.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the yash/nmfw-47-dp-fan-in-fix branch from 099b859 to 32422aa Compare March 24, 2026 19:57
@yashaswikarnati
Copy link
Copy Markdown
Contributor Author

/claude review

@yashaswikarnati yashaswikarnati changed the title Fix 2D tensor communication for asymmetric DP in MIMO bridge Fix 2D tensor communication for asymmetric DP in Bridge Communicator Mar 24, 2026
Copy link
Copy Markdown
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yashaswikarnati yashaswikarnati marked this pull request as ready for review March 24, 2026 20:03
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 24, 2026 20:03
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Mar 25, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Mar 25, 2026
@yashaswikarnati
Copy link
Copy Markdown
Contributor Author

/ok to test 32422aa

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 25, 2026
@yashaswikarnati yashaswikarnati added this pull request to the merge queue Mar 25, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23552484202

Merged via the queue into NVIDIA:main with commit 694c3a9 Mar 25, 2026
70 checks passed
@yashaswikarnati yashaswikarnati deleted the yash/nmfw-47-dp-fan-in-fix branch March 25, 2026 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: low

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants