diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 6f8b6d46..9ecd43ac 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -176,7 +176,7 @@ def _get_input_embeds( video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if model.training and pixel_values is None and pixel_values_videos is None: + if pixel_values is None and pixel_values_videos is None: pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device) image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device) image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)