diff --git a/verl/utils/model.py b/verl/utils/model.py index e0c275e8c43..15fdecd62da 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -696,6 +696,47 @@ def get_hf_auto_model_class(hf_config): return actor_module_class +def extract_multi_modal_inputs( + batch_data: list[dict[str, torch.Tensor]], + indices: Optional[list[int]] = None, +) -> dict[str, torch.Tensor | list[torch.Tensor]]: + """ + Extract and process multi-modal inputs from a batch. + + Args: + batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs + indices (Optional[list[int]]): If provided, only extract inputs at these indices + + Returns: + dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption + + """ + multi_modal_inputs = {} + multi_modal_inputs_collected = {} + has_image_bound = False + + selected_batch_data = batch_data + if indices is not None: + selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)] + + for inputs in selected_batch_data: + if "image_bound" in inputs: + has_image_bound = True + for key, value in inputs.items(): + if value is not None: + if key not in multi_modal_inputs_collected: + multi_modal_inputs_collected[key] = [] + multi_modal_inputs_collected[key].append(value) + + for key, values in multi_modal_inputs_collected.items(): + if has_image_bound: # minicpm-o logic + multi_modal_inputs[key] = values + else: + multi_modal_inputs[key] = torch.cat(values, dim=0) + + return multi_modal_inputs + + @dataclass class CausalLMOutputForPPO(CausalLMOutputWithPast): log_probs: Optional[torch.FloatTensor] = None diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 32caab3ef10..99184725112 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -98,14 +98,9 @@ def _forward_micro_batch( response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): - if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] - else: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index fb16922e4d2..b3c2a52a9b7 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -466,12 +466,10 @@ def forward_step(batch_iter, model): multi_modal_inputs = {} if "multi_modal_inputs" in batch: - for key in batch["multi_modal_inputs"][0].keys(): - idxs = batch["multi_modal_inputs_idx"] - mmi = batch["multi_modal_inputs"] - multi_modal_inputs[key] = torch.cat( - [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices) responses = batch["responses"] response_length = responses.size(1) label = position_ids.clone() diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index e45f6f2048d..3f7314c93be 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -58,10 +58,9 @@ def _forward_micro_batch(self, micro_batch): response_length = micro_batch["responses"].size(-1) multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 46430d26741..d3548211dda 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -644,14 +644,9 @@ def prepare_model_inputs(self, micro_batch: TensorDict): multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch.keys(): - if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] - else: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) input_ids = micro_batch["input_ids"] attention_mask = micro_batch["attention_mask"] diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 42df0da4ffc..9d6ad2f2134 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -165,19 +165,14 @@ def _build_megatron_module(self): return module def _build_optimizer(self): - from verl.utils.megatron.optimizer import ( - get_megatron_optimizer, - init_megatron_optim_config, - ) + from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config optim_config_megatron = init_megatron_optim_config(self.optimizer_config) optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron) return optimizer def _build_lr_scheduler(self): - from verl.utils.megatron.optimizer import ( - get_megatron_optimizer_param_scheduler, - ) + from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler optimizer_scheduler = get_megatron_optimizer_param_scheduler( optimizer=self.optimizer, config=self.optimizer_config @@ -495,13 +490,11 @@ def prepare_model_inputs(self, batch: TensorDict): ] # mcore patch recompute qwen2vl's pos ids during forward multi_modal_inputs = {} - if "multi_modal_inputs" in batch.keys(): - for key in batch["multi_modal_inputs"][0].keys(): - idxs = batch["multi_modal_inputs_idx"] - mmi = batch["multi_modal_inputs"] - multi_modal_inputs[key] = torch.cat( - [mmi[idx].get(key).to(input_ids.device) for idx in idxs if mmi[idx].get(key) is not None], dim=0 - ) + if "multi_modal_inputs" in batch: + from verl.utils.model import extract_multi_modal_inputs + + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices) return { "input_ids": input_ids, diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 3457a530764..5a27bf45e47 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -281,11 +281,10 @@ def forward_step(batch_iter, model): multi_modal_inputs = {} if "multi_modal_inputs" in batch: - for key in batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0 - ) + from verl.utils.model import extract_multi_modal_inputs + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices) output = forward_fn( model, input_ids,