@@ -65,12 +65,6 @@ def __init__(self, wte=None) -> None:
65
65
self .type_feature : str
66
66
self .img_processor : CLIPVisionModel
67
67
68
- def set_img_features (self , img_features : torch .FloatTensor ) -> None :
69
- self .img_features = img_features
70
-
71
- def set_img_sizes (self , img_sizes : torch .LongTensor ) -> None :
72
- self .img_sizes = img_sizes
73
-
74
68
def get_img_features (self ,
75
69
img_embeds : torch .FloatTensor ) -> torch .FloatTensor :
76
70
LAYER_IDX = self .layer_idx
@@ -144,21 +138,16 @@ def __init__(self,
144
138
self .layer_idx = config .img_processor .get ('layer_idx' , - 2 )
145
139
self .type_feature = config .img_processor .get ('type_feature' , 'patch' )
146
140
147
- def forward (self ,
148
- input_ids : torch .LongTensor ,
141
+ def forward (self , input_ids : torch .LongTensor ,
149
142
pixel_values : torch .FloatTensor ,
150
- image_sizes = None ) -> torch .FloatTensor :
143
+ image_sizes : torch . Tensor ) -> torch .FloatTensor :
151
144
"""process and merge text embeddings with image embeddings."""
152
145
146
+ # (batch_size, max_num_crops, 3, height, width)
153
147
img_embeds = pixel_values
154
- img_sizes = image_sizes
155
148
156
- if self .img_features is not None :
157
- img_embeds = self .img_features .clone ()
158
- self .img_features = None
159
-
160
- if self .img_sizes is not None :
161
- img_sizes = self .img_sizes
149
+ # (batch_size, 2)
150
+ img_sizes = image_sizes
162
151
163
152
input_shape = input_ids .size ()
164
153
input_ids = input_ids .view (- 1 , input_shape [- 1 ])
@@ -190,11 +179,8 @@ def forward(self,
190
179
output_imgs = []
191
180
output_len = []
192
181
193
- if isinstance (img_sizes , torch .Tensor ):
194
- img_sizes .squeeze_ (0 )
195
-
196
182
for _bs in range (bs ):
197
- h , w = img_sizes
183
+ h , w = img_sizes [ _bs ]
198
184
h = h // 336
199
185
w = w // 336
200
186
B_ = h * w
0 commit comments