Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,24 @@ def set_current_microbatch(model, microbatch_id):
for layer in model_with_decoder.mtp.layers:
layer.transformer_layer.current_microbatch = microbatch_id

# Also set current_microbatch on vision encoder layers so that
# _te_cuda_graph_replay selects the correct graph index. Without this,
# vision layers always use graph 0 (since current_microbatch defaults to 0),
# causing all microbatch forwards to overwrite the same static buffers.
# When backward runs for earlier microbatches, the buffers contain stale
# data from later forwards, producing NaN gradients.
try:
model_with_vision = get_attr_wrapped_model(
model, "vision_model", allow_none=True, return_model_obj=True
)
except RuntimeError:
model_with_vision = None
if model_with_vision is not None and hasattr(model_with_vision, 'vision_model'):
vision_model = model_with_vision.vision_model
if hasattr(vision_model, 'decoder') and hasattr(vision_model.decoder, 'layers'):
for layer in vision_model.decoder.layers:
layer.current_microbatch = microbatch_id


def forward_step_calc_loss(
model,
Expand Down
Loading