-
Notifications
You must be signed in to change notification settings - Fork 176
feat: VLM support via megatron backend #1115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
📝 WalkthroughWalkthroughUpdates Megatron-Bridge submodule pointer; pins/raises transformers dependency versions; adds a new VLM GRPO Megatron config; changes PackedTensor to allow None entries and adds empty_like; switches PackedTensor aggregation to Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Trainer
participant DataLoader
participant PolicyWorker
participant MegatronModel
participant Bridge as Megatron-Bridge
Trainer->>DataLoader: next() batches (text + multimodal)
DataLoader-->>PolicyWorker: data_dict
PolicyWorker->>PolicyWorker: multimodal_data = get_multimodal_dict(as_tensors=True, device=...)
PolicyWorker->>PolicyWorker: get_ltor_masks_and_position_ids(data=..., eod_token=0, pad_token=0, ...)
PolicyWorker->>MegatronModel: forward(input_ids=..., position_ids=..., attention_mask=..., packed_seq_params=..., **multimodal_data)
Note right of MegatronModel #D6EAF8: If VLM, unwrap `language_model` to reach decoder
MegatronModel->>Bridge: execute layers / parallel runtime
MegatronModel-->>PolicyWorker: logits / losses
PolicyWorker-->>Trainer: aggregated metrics
sequenceDiagram
autonumber
participant RunScript as run_vlm_grpo.py
participant TruncationLogic
participant vLLM
RunScript->>TruncationLogic: compute token length
alt length > max_seq_length
TruncationLogic->>RunScript: truncate token_ids, mask PackedTensor multimodal fields, set loss_multiplier=0
RunScript->>vLLM: call with vllm_kwargs (text-only, no images)
else length <= max_seq_length
TruncationLogic->>RunScript: keep full dialog & images
RunScript->>vLLM: call with vllm_kwargs (content + images)
end
vLLM-->>RunScript: generation outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
✅ Submodule Fast-Forward Check ResultsCheck based on commit: 8b47c39 (PR #1115 from ✅ Submodules that are properly updated:Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward) All submodule changes look good! ✨ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
nemo_rl/models/megatron/common.py (1)
398-485: Device construction bug in broadcast_tensor (runtime error on non-src ranks).torch.device(device) with device as an int will raise; also the created buffers should target the CUDA device explicitly.
Apply:
- device = torch.cuda.current_device() + device_index = torch.cuda.current_device() + device = torch.device("cuda", device_index) @@ - if tensor is None: - # Create tensor if it wasn't provided by the caller - tensor = torch.empty(received_shape, dtype=received_dtype, device=device) + if tensor is None: + # Create tensor if it wasn't provided by the caller + tensor = torch.empty(received_shape, dtype=received_dtype, device=device) @@ - if tensor.device != torch.device(device): + if tensor.device != device: raise ValueError( f"Rank {rank}: Provided tensor is on device {tensor.device}, " f"but expected broadcast device is {device}." )nemo_rl/models/policy/megatron_policy_worker.py (4)
769-778: Incorrect source for padded vocab size (bug).TokenizerConfig doesn’t carry padded_vocab_size; use the built tokenizer.
Apply:
- self.final_padded_vocab_size = tokenizer_config.padded_vocab_size + self.final_padded_vocab_size = self.megatron_tokenizer.padded_vocab_size
1168-1191: Align VLM path with DTensor: null position_ids when passing multimodal_data.Prevents conflicts with model-specific positional handling for VL inputs.
Apply:
- multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids.device) + multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids.device) + if len(multimodal_data) > 0: + position_ids = None
1427-1433: tokens_to_generate is miscomputed (can go negative).max_new_tokens already means “how many to generate.” Don’t subtract prompt length.
Apply:
- tokens_to_generate=self.cfg["generation"]["max_new_tokens"] # type: ignore - - data["input_ids"].size(1), + tokens_to_generate=int(self.cfg["generation"]["max_new_tokens"]), # type: ignoreOptionally clamp to engine limits if needed.
1825-1835: Severe performance bug: redundant full reload inside loop.The ordinary offload path rebuilds and loads state_dict once per parameter. Build once.
Apply:
- else: - # Ordinary offload case - if move_params: - for name, param in model.state_dict().items(): - new_state_dict = {} - for name, item in model.state_dict().items(): - if isinstance(item, torch.Tensor): - item = item.detach().to( - device=device, non_blocking=True, copy=True - ) - new_state_dict[name] = item - model.load_state_dict(new_state_dict) + else: + # Ordinary offload case + if move_params: + new_state_dict = {} + for name, item in model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device=device, non_blocking=True, copy=True) + new_state_dict[name] = item + model.load_state_dict(new_state_dict)
🧹 Nitpick comments (4)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)
1-1: Add a VLM smoke test in CI to catch submodule drift early.Add a minimal test that instantiates the new GRPO VLM config, feeds a tiny dummy batch (image + prompt), and runs a single forward step. This will immediately flag ABI or API mismatches introduced by future Bridge/LM pointer updates.
examples/configs/vlm_grpo_3B_megatron.yaml (2)
21-25: Comment is misleading with current settings.You note dynamic batching is incompatible with pipeline parallelism, but pipeline_model_parallel_size is 1 in this config. Consider clarifying the comment or enabling dynamic batching in this specific single-PP setup if desired.
161-161: Trim trailing whitespace.YAML linter flags trailing spaces on this line.
Apply:
- kwargs: + kwargs:nemo_rl/models/policy/megatron_policy_worker.py (1)
179-184: Use get_global_rank for PP broadcast source (API compatibility).torch.distributed.get_process_group_ranks may be unavailable across versions. Prefer get_global_rank; optionally keep a fallback.
Apply:
- pp_ranks = torch.distributed.get_process_group_ranks(pp_group) - global_src = pp_ranks[src_rank] + try: + global_src = torch.distributed.get_global_rank(pp_group, src_rank) + except Exception: + # Fallback for older torch; remove if not needed in your env + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + global_src = pp_ranks[src_rank]Please confirm the targeted torch (2.7.1) exposes get_global_rank in your runtime.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (8)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge(1 hunks)3rdparty/Megatron-Bridge-workspace/setup.py(1 hunks)3rdparty/Megatron-LM-workspace/Megatron-LM(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/distributed/batched_data_dict.py(1 hunks)nemo_rl/models/megatron/common.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(4 hunks)pyproject.toml(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
nemo_rl/distributed/batched_data_dict.py (1)
nemo_rl/data/multimodal_utils.py (2)
PackedTensor(21-132)concat(69-99)
examples/configs/vlm_grpo_3B_megatron.yaml (3)
tests/unit/models/generation/test_vllm_generation.py (3)
test_vllm_megatron_weight_update_with_packing(1777-1831)test_vllm_megatron_weight_update_memory(1601-1683)get_basic_megatron_test_config(121-206)nemo_rl/algorithms/grpo.py (3)
setup(133-408)grpo_train(509-968)MasterConfig(117-125)examples/run_grpo_math.py (1)
main(191-268)
nemo_rl/models/megatron/common.py (1)
nemo_rl/distributed/batched_data_dict.py (1)
get_multimodal_dict(88-99)
nemo_rl/models/policy/megatron_policy_worker.py (1)
nemo_rl/distributed/batched_data_dict.py (1)
get_multimodal_dict(88-99)
🪛 YAMLlint (1.37.1)
examples/configs/vlm_grpo_3B_megatron.yaml
[error] 161-161: trailing spaces
(trailing-spaces)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (8)
3rdparty/Megatron-Bridge-workspace/setup.py (1)
36-36: Specs match: no action needed
Thetransformers>=4.55.0requirement is identical in bothsetup.pyandpyproject.toml, so the consistency check will pass.nemo_rl/distributed/batched_data_dict.py (1)
129-129: Switching to PackedTensor.concat preserves per-sample slicing semantics — looks good; confirm no callers depended on flattened behavior.
concat()keeps underlying segments, makingslice/__len__align with other per-sample keys. Verify any downstream that assumedflattened_concat()’s “one tensor per input-batch” invariant.pyproject.toml (1)
28-28: transformers==4.55.4 pin validated
- Release 4.55.4 exists on PyPI and satisfies vLLM’s ≥4.53.2 and huggingface-hub<1.0,≥0.34.0 bounds.
- Torch 2.7.1 plus torchvision≥0.22.0 resolves to 0.22.1 (vLLM’s requirement), matching torch==2.7.1.
- Megatron-Bridge’s megatron-core[dev] extra pulls in transformers without version bounds, so 4.55.4 is acceptable.
examples/configs/vlm_grpo_3B_megatron.yaml (2)
29-35: Sequence packing is disabled; CP requires it.Given context_parallel_size=1 here, that’s fine. If CP>1 later, remember to enable sequence_packing (your code enforces this). No change required now—just a heads-up.
104-126: vLLM generation for a VL model: confirm image pathway is exercised.The training path adds multimodal kwargs; this vLLM generation config won’t pass images unless your data layer and vLLM adapter handle them. If you plan to eval generation with images, ensure the vLLM side accepts/forwards the same multimodal keys.
nemo_rl/models/megatron/common.py (1)
335-343: Switch to keyword args for get_ltor_masks_and_position_ids looks good.Clearer and future-proof against signature changes.
nemo_rl/models/policy/megatron_policy_worker.py (2)
64-66: FSDP import path update is correct.Alias preserved; matches current mcore adapter location.
254-256: Unwrapping language_model for VLM before router tweaks is appropriate.Ensures router freeze/bias maintenance reaches the real module.
Also applies to: 270-272
|
@yfw there's a conflict with the mbridge submodule |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 8c034be (PR #1115 from ❌ Submodules that need attention:Megatron-Bridge: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
nemo_rl/models/policy/megatron_policy_worker.py (2)
252-260: Safer VLM unwrap; also freeze router bias and guard helper availabilityUnwrapping to
language_modelblindly can misfire if the module already exposes a decoder. Also, freezing onlyrouter.weightmay leaverouter.biastrainable unless configs are perfectly set. Finally,_maintain_float32_expert_bias()should be called defensively.Apply this diff to harden both hooks:
if isinstance(model_module, Float16Module): model_module = model_module.module - # Handle VLM models - if hasattr(model_module, "language_model"): - model_module = model_module.language_model + # Handle VLM wrappers: unwrap only if needed + if hasattr(model_module, "language_model") and not hasattr(model_module, "decoder"): + model_module = model_module.language_model for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): layer.mlp.router.weight.requires_grad = False + if hasattr(layer.mlp.router, "bias") and layer.mlp.router.bias is not None: + layer.mlp.router.bias.requires_grad = False @@ if isinstance(model_module, Float16Module): model_module = model_module.module - # Handle VLM models - if hasattr(model_module, "language_model"): - model_module = model_module.language_model + # Handle VLM wrappers: unwrap only if needed + if hasattr(model_module, "language_model") and not hasattr(model_module, "decoder"): + model_module = model_module.language_model for layer in model_module.decoder.layers: if hasattr(layer.mlp, "router"): - layer.mlp.router._maintain_float32_expert_bias() + if hasattr(layer.mlp.router, "_maintain_float32_expert_bias"): + layer.mlp.router._maintain_float32_expert_bias()Verification asks:
- Confirm
moe_router_bias_update_rateis actually 0.0 in your configs at runtime; otherwise bias may still update via optimizer even if frozen late.Also applies to: 268-276
1187-1194: Normalize multimodal tensor devices to the model’s input device
get_multimodal_dict(as_tensors=True, device=...)moves onlyPackedTensorvalues. Optional multimodal tensors (if any) may remain on CPU and trigger device mismatches.Apply this diff to align devices and use the sharded input’s device:
- multimodal_data = data_dict.get_multimodal_dict(as_tensors=True, device=input_ids.device) + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids_cp_sharded.device + ) + # Ensure any stray tensor values are moved as well + multimodal_data = { + k: (v.to(input_ids_cp_sharded.device) if torch.is_tensor(v) and v.device != input_ids_cp_sharded.device else v) + for k, v in multimodal_data.items() + } output_tensor = model( input_ids=input_ids_cp_sharded, position_ids=position_ids, attention_mask=attention_mask, packed_seq_params=packed_seq_params, **multimodal_data, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (5)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge(1 hunks)3rdparty/Megatron-Bridge-workspace/setup.py(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(3 hunks)pyproject.toml(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
- 3rdparty/Megatron-Bridge-workspace/setup.py
- examples/configs/vlm_grpo_3B_megatron.yaml
- pyproject.toml
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
nemo_rl/distributed/batched_data_dict.py (1)
get_multimodal_dict(88-99)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Coverage (doc-test)
- GitHub Check: Coverage (e2e)
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (3)
nemo_rl/models/generation/vllm/utils.py (1)
60-66: Optional: Prefer text prompt fallback when content exists but no images.If msg is a formatted string and images are absent, using {"prompt": msg} preserves the chat template better than token_ids fallback.
Apply this minimal tweak alongside the guard:
- if not images or i >= len(images) or images[i] is None: - prompts.append(_get_regular_prompt(i)) + if not images or i >= len(images) or images[i] is None: + prompts.append({"prompt": msg} if isinstance(msg, str) else _get_regular_prompt(i)) continuenemo_rl/data/multimodal_utils.py (1)
68-72: Side-effect in as_tensor duplicates .to() behavior.Device migration inside as_tensor mutates internal state and can surprise callers. Prefer removing the in-place loop and rely on the return’s .to() or the class’s .to().
No extra changes needed if you apply the previous diff.
examples/run_vlm_grpo.py (1)
215-218: Also strip non‑text content in message_log on truncation.Processors or analytics that still read message_log may try to handle image entries. Replace them with text-only to avoid accidental heavy processing.
Apply:
for chat_message in message_log: chat_message["token_ids"] = chat_message["token_ids"][ : min(4, max_seq_length // len(message_log)) ] + # Remove non-text content for truncated samples + if isinstance(chat_message.get("content"), list): + chat_message["content"] = [ + c if isinstance(c, dict) and c.get("type") == "text" else {"type": "text", "text": ""} + for c in chat_message["content"] + ] for key, value in chat_message.items(): if isinstance(value, PackedTensor): chat_message[key] = PackedTensor.empty_like(value)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/run_vlm_grpo.py(2 hunks)nemo_rl/data/multimodal_utils.py(2 hunks)nemo_rl/models/generation/vllm/utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/run_vlm_grpo.py (1)
nemo_rl/data/multimodal_utils.py (2)
PackedTensor(21-148)empty_like(80-82)
nemo_rl/data/multimodal_utils.py (1)
nemo_rl/distributed/batched_data_dict.py (2)
to(804-811)slice(680-701)
🔇 Additional comments (5)
nemo_rl/data/multimodal_utils.py (2)
79-83: empty_like helper is a good addition.This cleanly supports truncation masking and None-enabled flows.
32-46: Confirm callers or make PackedTensor safely handle NoneAccepting Optional[torch.Tensor] is fine, but many downstream sites assume non-None (call .as_tensor(), .to(), torch.cat). Ensure callers never pass None or update PackedTensor.as_tensor()/related ops to skip/handle None entries.
- Key places to check: examples/run_vlm_grpo.py:190-192, nemo_rl/data/llm_message_utils.py:592, nemo_rl/distributed/batched_data_dict.py:95-96 (v.as_tensor(device=...)).
- Also scan other .as_tensor() usages and tests that assert shapes for assumptions about non-None contents.
examples/run_vlm_grpo.py (3)
204-209: Graceful truncation fallback for vLLM kwargs looks good.Clear signal to downstream that this sample should be text-only and ignored for loss.
Please sanity-check that the collator batches vllm_content/vllm_images into per-sample lists (list[str|None], list[list]) as expected.
219-226: Non‑truncated vLLM kwargs assembly looks correct.Consistent with utils consuming vllm_content/vllm_images.
234-234: Passing kwargs via expansion keeps output dict tidy.LGTM.
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
nemo_rl/data/multimodal_utils.py (1)
52-63: Fix:.to(None)crash and side effects inas_tensor; make it side‑effect‑free.
- Calling
.to(device)withdevice=Nonewill raise; this path is hit on the final return.- The method mutates
self.tensorswhendevice is not None, which is surprising for an accessor.Apply:
- def as_tensor(self, device: Optional[torch.device] = None) -> Optional[torch.Tensor]: - if device is not None: - # Move only non-None tensors to device, preserve Nones - for i, item in enumerate(self.tensors): - if item is not None: - self.tensors[i] = item.to(device) - non_none_tensors = [t for t in self.tensors if t is not None] - if len(non_none_tensors) == 0: - return None - else: - return torch.cat(non_none_tensors, dim=self.dim_to_pack).to(device) + def as_tensor(self, device: Optional[str | torch.device] = None) -> Optional[torch.Tensor]: + non_none_tensors = [t for t in self.tensors if t is not None] + if not non_none_tensors: + return None + out = torch.cat(non_none_tensors, dim=self.dim_to_pack) + return out if device is None else out.to(device)Note the widened type to accept
str | torch.device, matching common PyTorch usage.
🧹 Nitpick comments (7)
nemo_rl/data/multimodal_utils.py (7)
33-36: Broaden input typing: avoidlist[None], acceptSequence[Optional[Tensor]].
list[None]is awkward for type checkers and too specific. PreferSequence[Optional[torch.Tensor]], while still storing aslist[Optional[torch.Tensor]].- tensors: Union[torch.Tensor, list[Optional[torch.Tensor]], list[None]], + tensors: Union[torch.Tensor, Sequence[Optional[torch.Tensor]]],Add the missing import:
from typing import Optional, Union, Sequence
39-46: Avoid aliasing caller-owned lists; shallow-copy input list.Currently
self.tensors = tensorsshares the caller’s list reference and allows external mutation. Make a shallow copy.- elif isinstance(tensors, list): + elif isinstance(tensors, list): assert len(tensors) > 0, ( "Input tensors to PackedTensor must be a non-empty list" ) - self.tensors: list[Optional[torch.Tensor]] = tensors + self.tensors: list[Optional[torch.Tensor]] = list(tensors)Optional: if empty slices are possible, consider allowing an empty list instead of asserting.
47-49: UseTypeErrorfor invalid type and keep the message short.Aligns with Ruff TRY004/TRY003 hints.
- else: - raise ValueError( - f"Unsupported type for input tensors to PackedTensor: {type(tensors)}" - ) + else: + raise TypeError(f"Unsupported PackedTensor input type: {type(tensors).__name__}")
147-148: Guard againstNoneoutputs from sub-as_tensor()inflattened_concat.
as_tensor()can returnNone; passing[None, ...]through is allowed by the ctor but often undesirable. Filter them out.- tensors = [p.as_tensor() for p in from_packed_tensors] + tensors = [t for p in from_packed_tensors if (t := p.as_tensor()) is not None]If empty after filtering, decide whether to allow an empty
PackedTensoror raise—align with the decision on empty inputs.
159-168: Avoid attribute error ifprocessorlackstokenizer.Some processors may not expose
.tokenizer. Guard before accessingmodel_input_names.- all_keys.difference_update(set(processor.tokenizer.model_input_names)) + tokenizer_names = set() + if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "model_input_names"): + tokenizer_names = set(processor.tokenizer.model_input_names) + all_keys.difference_update(tokenizer_names)
171-179: Hard‑coded class check forSmolVLMProcessor; confirm Qwen2.5‑VL needs defaults.If Qwen2.5‑VL (via AutoProcessor) requires a non‑zero pack dim for any key, capture it here rather than keying on class name. Consider a per‑key mapping or a processor‑provided hint.
32-83: Findings: no string-device calls; as_tensor mutates internal storage — recommend documenting or making it non‑mutating
- Only call that forwards a device into as_tensor is BatchedDataDict.get_multimodal_dict (nemo_rl/distributed/batched_data_dict.py:95) which passes an Optional[torch.device]; I found no callsite that passes a string to as_tensor.
- as_tensor(device=...) performs in-place moves into self.tensors (nemo_rl/data/multimodal_utils.py:52–57). Tests use PackedTensor.to("cuda") when they expect mutation, not as_tensor.
- Action: either make as_tensor non‑mutating (use temporary moved tensors for concat) or explicitly document the side‑effect in PackedTensor.as_tensor docstring. Change target: nemo_rl/data/multimodal_utils.py (as_tensor implementation).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/data/multimodal_utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/data/multimodal_utils.py (1)
nemo_rl/distributed/batched_data_dict.py (2)
to(804-811)slice(680-701)
🪛 Ruff (0.12.2)
nemo_rl/data/multimodal_utils.py
47-49: Prefer TypeError exception for invalid type
(TRY004)
47-49: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: CI quality check
🔇 Additional comments (2)
nemo_rl/data/multimodal_utils.py (2)
68-72:to()implementation LGTM.Handles
Noneentries correctly and returnsselffor chaining; acceptsstr | torch.device.
79-83: Nice addition:empty_like.Useful for masking/truncation flows; consistent length and
dim_to_packsemantics.
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change lgtm
some things to add before merging:
- nightly test for mbridge path (maybe same config as dtensor so apples to apples)
- could you compare qwen2.5 vl with mcore and dtensor to show convergence remains unchanged in PR description
- confirm from a few nightly runs of different model classes that convergence remains unchanged from the transformers bump?
|
@NVIDIA-NeMo/rl_reviewers_mcore to review |
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]>
Signed-off-by: Yi-Fu Wu <[email protected]> Co-authored-by: Yuki Huang <[email protected]>
What does this PR do ?
This PR adds VLM support via megatron backend for the Qwen2.5-VL model. This PR also upgrades transformers to 4.55.4 since the Qwen2.5-VL change in Megatron-Bridge requires 4.55+ and an issue with reward models is fixed in 4.55.4 (#811 (comment)).
Closes #822
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Chores