File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
torchvision/prototype/models Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -202,7 +202,7 @@ def _init_weights(self):
202202 nn .init .zeros_ (self .heads .head .weight )
203203 nn .init .zeros_ (self .heads .head .bias )
204204
205- def forward (self , x : torch .Tensor ):
205+ def _process_input (self , x : torch .Tensor ) -> torch . Tensor :
206206 n , c , h , w = x .shape
207207 p = self .patch_size
208208 torch ._assert (h == self .image_size , "Wrong image height!" )
@@ -221,7 +221,14 @@ def forward(self, x: torch.Tensor):
221221 # embedding dimension
222222 x = x .permute (0 , 2 , 1 )
223223
224- # Expand the class token to the full batch.
224+ return x
225+
226+ def forward (self , x : torch .Tensor ):
227+ # Reshaping and permuting the input tensor
228+ x = self ._process_input (x )
229+ n = x .shape [0 ]
230+
231+ # Expand the class token to the full batch
225232 batch_class_token = self .class_token .expand (n , - 1 , - 1 )
226233 x = torch .cat ([batch_class_token , x ], dim = 1 )
227234
You can’t perform that action at this time.
0 commit comments