Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema):
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]

feat_is_patch: Annotated[
image_input_idx: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# A boolean mask indicating which image features correspond to patch tokens.
# An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]


Expand Down Expand Up @@ -1177,7 +1177,7 @@ def __call__(
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)

outputs["feat_is_patch"] = feat_is_patch
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id

Expand Down Expand Up @@ -1211,8 +1211,9 @@ def get_num_image_tokens(
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h

extra = image_token_length_w * image_token_length_h
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra = 2 + (image_token_length_w + 1) * image_token_length_h
joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)

return extra + joint

Expand Down Expand Up @@ -1299,7 +1300,7 @@ def _get_mm_fields_config(
return dict(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
Expand Down Expand Up @@ -1444,7 +1445,7 @@ def _parse_and_validate_image_input(
) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
feat_is_patch = kwargs.pop("feat_is_patch", None)
image_input_idx = kwargs.pop("image_input_idx", None)
num_crops = kwargs.pop("num_crops", None)

if images is None:
Expand All @@ -1466,7 +1467,7 @@ def _parse_and_validate_image_input(
return MolmoImageInputs(
images=images,
image_masks=image_masks,
feat_is_patch=feat_is_patch,
image_input_idx=image_input_idx,
num_crops=num_crops,
)

Expand All @@ -1476,15 +1477,15 @@ def _process_image_input(
) -> list[torch.Tensor]:
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
image_input_idx = image_input["image_input_idx"]
num_crops = image_input["num_crops"]

# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True)
)
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)

image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
Expand All @@ -1494,13 +1495,18 @@ def _process_image_input(
).squeeze(0)

# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]
# Re-order the features using the image_input_idx tensor
results = []
num_crops_list = num_crops.tolist()
for feats, img_idx in zip(
image_features_flat.split(num_crops_list),
image_input_idx_flat.split(num_crops_list),
):
is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid]
order = torch.argsort(valid_img_idx)
results.append(feats[is_valid][order])
return results

def get_language_model(self) -> torch.nn.Module:
return self.model
Expand Down