diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index 9cd8bb6a7c8..91ecd884d85 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -79,6 +79,8 @@ def gptmodel_forward_qwen2_5_vl( assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" pre_process = unwrap_model(model).pre_process post_process = unwrap_model(model).post_process + pixel_values = multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + image_grid_thw = multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None if pack_seqs: batch_size, seq_len = attention_mask.shape[:2] input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) @@ -88,8 +90,8 @@ def gptmodel_forward_qwen2_5_vl( attention_mask=None, position_ids=position_ids, packed_seq_params=packed_seq_params, - pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device), - image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device), + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, ) if post_process and logits_processor is not None: @@ -105,8 +107,8 @@ def gptmodel_forward_qwen2_5_vl( input_ids=new_input_ids, position_ids=new_position_ids, attention_mask=new_attention_mask, - pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device), - image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device), + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, ) output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process) if value_model and post_process: diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index af40d78e1d8..7028f6e2ac2 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -409,7 +409,9 @@ def forward_step(batch_iter, model): multi_modal_inputs = {} if "multi_modal_inputs" in batch: for key in batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat([batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0) + idxs = batch["multi_modal_inputs_idx"] + mmi = batch["multi_modal_inputs"] + multi_modal_inputs[key] = torch.cat([mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0) responses = batch["responses"] response_length = responses.size(1) label = copy.deepcopy(position_ids)