diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index 7008b04ec82f..9d2e2fcf9882 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -2201,7 +2201,7 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su pos_emb = self.output_position_encodings(batch_size) elif self.position_encoding_type == "fourier": pos_emb = self.output_position_encodings( - self.output_index_dims, batch_size=batch_size, device=inputs.device, pos=pos + self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos ) # Optionally project them to a target dimension. @@ -2215,7 +2215,9 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su if self.position_encoding_type == "trainable": pos_emb = self.output_position_encodings(batch_size) elif self.position_encoding_type == "fourier": - pos_emb = self.output_position_encodings(index_dims, batch_size, device=inputs.device) + pos_emb = self.output_position_encodings( + index_dims, batch_size, device=inputs.device, dtype=inputs.dtype + ) # Optionally project them to a target dimension. pos_emb = self.positions_projection(pos_emb) @@ -2816,7 +2818,12 @@ def output_size(self): return encoding_size def forward( - self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None + self, + index_dims: List[int], + batch_size: int, + device: torch.device, + dtype: torch.dtype, + pos: torch.FloatTensor = None, ) -> torch.FloatTensor: pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) fourier_pos_enc = generate_fourier_features( @@ -2825,7 +2832,7 @@ def forward( max_resolution=self.max_resolution, concat_pos=self.concat_pos, sine_only=self.sine_only, - ).to(device) + ).to(device=device, dtype=dtype) return fourier_pos_enc @@ -3156,7 +3163,7 @@ def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool if self.position_encoding_type == "trainable": pos_enc = self.position_embeddings(batch_size) elif self.position_encoding_type == "fourier": - pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device) + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) # Optionally project them to a target dimension. pos_enc = self.positions_projection(pos_enc) @@ -3324,7 +3331,7 @@ def _build_network_inputs(self, inputs): if self.position_encoding_type == "trainable": pos_enc = self.position_embeddings(batch_size) elif self.position_encoding_type == "fourier": - pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device) + pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) # Optionally project them to a target dimension. pos_enc = self.positions_projection(pos_enc)