Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,47 @@ def get_hf_auto_model_class(hf_config):
return actor_module_class


def extract_multi_modal_inputs(
batch_data: list[dict[str, torch.Tensor]],
indices: Optional[list[int]] = None,
) -> dict[str, torch.Tensor | list[torch.Tensor]]:
"""
Extract and process multi-modal inputs from a batch.

Args:
batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs
indices (Optional[list[int]]): If provided, only extract inputs at these indices

Returns:
dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption

"""
multi_modal_inputs = {}
multi_modal_inputs_collected = {}
has_image_bound = False

selected_batch_data = batch_data
if indices is not None:
selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)]

for inputs in selected_batch_data:
if "image_bound" in inputs:
has_image_bound = True
for key, value in inputs.items():
if value is not None:
if key not in multi_modal_inputs_collected:
multi_modal_inputs_collected[key] = []
multi_modal_inputs_collected[key].append(value)

for key, values in multi_modal_inputs_collected.items():
if has_image_bound: # minicpm-o logic
multi_modal_inputs[key] = values
else:
multi_modal_inputs[key] = torch.cat(values, dim=0)

return multi_modal_inputs


@dataclass
class CausalLMOutputForPPO(CausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
Expand Down
11 changes: 3 additions & 8 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,9 @@ def _forward_micro_batch(
response_length = micro_batch["responses"].size(-1)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]]
else:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
from verl.utils.model import extract_multi_modal_inputs

multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])

with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
Expand Down
10 changes: 4 additions & 6 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,10 @@ def forward_step(batch_iter, model):

multi_modal_inputs = {}
if "multi_modal_inputs" in batch:
for key in batch["multi_modal_inputs"][0].keys():
idxs = batch["multi_modal_inputs_idx"]
mmi = batch["multi_modal_inputs"]
multi_modal_inputs[key] = torch.cat(
[mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0
)
from verl.utils.model import extract_multi_modal_inputs

indices = batch.get("multi_modal_inputs_idx", None)
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
responses = batch["responses"]
response_length = responses.size(1)
label = position_ids.clone()
Expand Down
7 changes: 3 additions & 4 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ def _forward_micro_batch(self, micro_batch):
response_length = micro_batch["responses"].size(-1)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
from verl.utils.model import extract_multi_modal_inputs

multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])

with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
Expand Down
11 changes: 3 additions & 8 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,9 @@ def prepare_model_inputs(self, micro_batch: TensorDict):

multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch.keys():
if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]]
else:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
from verl.utils.model import extract_multi_modal_inputs

multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])

input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
Expand Down
21 changes: 7 additions & 14 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,14 @@ def _build_megatron_module(self):
return module

def _build_optimizer(self):
from verl.utils.megatron.optimizer import (
get_megatron_optimizer,
init_megatron_optim_config,
)
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config

optim_config_megatron = init_megatron_optim_config(self.optimizer_config)
optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron)
return optimizer

def _build_lr_scheduler(self):
from verl.utils.megatron.optimizer import (
get_megatron_optimizer_param_scheduler,
)
from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler

optimizer_scheduler = get_megatron_optimizer_param_scheduler(
optimizer=self.optimizer, config=self.optimizer_config
Expand Down Expand Up @@ -495,13 +490,11 @@ def prepare_model_inputs(self, batch: TensorDict):
] # mcore patch recompute qwen2vl's pos ids during forward

multi_modal_inputs = {}
if "multi_modal_inputs" in batch.keys():
for key in batch["multi_modal_inputs"][0].keys():
idxs = batch["multi_modal_inputs_idx"]
mmi = batch["multi_modal_inputs"]
multi_modal_inputs[key] = torch.cat(
[mmi[idx].get(key).to(input_ids.device) for idx in idxs if mmi[idx].get(key) is not None], dim=0
)
if "multi_modal_inputs" in batch:
from verl.utils.model import extract_multi_modal_inputs

indices = batch.get("multi_modal_inputs_idx", None)
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)

return {
"input_ids": input_ids,
Expand Down
7 changes: 3 additions & 4 deletions verl/workers/reward_model/megatron/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ def forward_step(batch_iter, model):

multi_modal_inputs = {}
if "multi_modal_inputs" in batch:
for key in batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0
)
from verl.utils.model import extract_multi_modal_inputs

indices = batch.get("multi_modal_inputs_idx", None)
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
output = forward_fn(
model,
input_ids,
Expand Down
Loading