Skip to content

Commit 67461fb

Browse files
committed
feat: Implement get_image_features method in Aria, Mistral3, and VipLlava models with updated parameters
1 parent 9fe8078 commit 67461fb

File tree

6 files changed

+40
-18
lines changed

6 files changed

+40
-18
lines changed

src/transformers/models/aria/modeling_aria.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,15 +1228,13 @@ def get_decoder(self):
12281228
def get_image_features(
12291229
self,
12301230
pixel_values: torch.FloatTensor,
1231-
vision_feature_layer: Optional[Union[int, List[int]]] = None,
1232-
vision_feature_select_strategy: Optional[str] = None,
1233-
**kwargs,
1231+
pixel_mask: Optional[torch.FloatTensor] = None,
1232+
vision_feature_layer: int = -1,
12341233
):
12351234
return self.model.get_image_features(
12361235
pixel_values=pixel_values,
1236+
pixel_mask=pixel_mask,
12371237
vision_feature_layer=vision_feature_layer,
1238-
vision_feature_select_strategy=vision_feature_select_strategy,
1239-
**kwargs,
12401238
)
12411239

12421240
# Make modules available throught conditional class for BC

src/transformers/models/aria/modular_aria.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,18 @@ def forward(
14971497
"""
14981498
)
14991499
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
1500+
def get_image_features(
1501+
self,
1502+
pixel_values: torch.FloatTensor,
1503+
pixel_mask: Optional[torch.FloatTensor] = None,
1504+
vision_feature_layer: int = -1,
1505+
):
1506+
return self.model.get_image_features(
1507+
pixel_values=pixel_values,
1508+
pixel_mask=pixel_mask,
1509+
vision_feature_layer=vision_feature_layer,
1510+
)
1511+
15001512
@can_return_tuple
15011513
@auto_docstring
15021514
def forward(

src/transformers/models/mistral3/modeling_mistral3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,14 +415,14 @@ def get_decoder(self):
415415
def get_image_features(
416416
self,
417417
pixel_values: torch.FloatTensor,
418+
image_sizes: torch.Tensor,
418419
vision_feature_layer: Optional[Union[int, List[int]]] = None,
419-
vision_feature_select_strategy: Optional[str] = None,
420420
**kwargs,
421421
):
422422
return self.model.get_image_features(
423423
pixel_values=pixel_values,
424+
image_sizes=image_sizes,
424425
vision_feature_layer=vision_feature_layer,
425-
vision_feature_select_strategy=vision_feature_select_strategy,
426426
**kwargs,
427427
)
428428

src/transformers/models/mistral3/modular_mistral3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,20 @@ def forward(
254254

255255

256256
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
257+
def get_image_features(
258+
self,
259+
pixel_values: torch.FloatTensor,
260+
image_sizes: torch.Tensor,
261+
vision_feature_layer: Optional[Union[int, List[int]]] = None,
262+
**kwargs,
263+
):
264+
return self.model.get_image_features(
265+
pixel_values=pixel_values,
266+
image_sizes=image_sizes,
267+
vision_feature_layer=vision_feature_layer,
268+
**kwargs,
269+
)
270+
257271
def forward(
258272
self,
259273
input_ids: torch.LongTensor = None,

src/transformers/models/vipllava/modeling_vipllava.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,18 +333,9 @@ def get_decoder(self):
333333
return self.model
334334

335335
def get_image_features(
336-
self,
337-
pixel_values: torch.FloatTensor,
338-
vision_feature_layer: Optional[Union[int, List[int]]] = None,
339-
vision_feature_select_strategy: Optional[str] = None,
340-
**kwargs,
336+
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None
341337
):
342-
return self.model.get_image_features(
343-
pixel_values=pixel_values,
344-
vision_feature_layer=vision_feature_layer,
345-
vision_feature_select_strategy=vision_feature_select_strategy,
346-
**kwargs,
347-
)
338+
return self.model.get_image_features(pixel_values=pixel_values, vision_feature_layers=vision_feature_layers)
348339

349340
# Make modules available throught conditional class for BC
350341
@property

src/transformers/models/vipllava/modular_vipllava.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def forward(
184184

185185

186186
class VipLlavaForConditionalGeneration(LlavaForConditionalGeneration):
187+
def get_image_features(
188+
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None
189+
):
190+
return self.model.get_image_features(
191+
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
192+
)
193+
187194
def forward(
188195
self,
189196
input_ids: torch.LongTensor = None,

0 commit comments

Comments
 (0)