diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b7a87f99a179..84e01c7d4ea9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1667,9 +1667,9 @@ def forward(self, hidden_states): """ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) - tensor_shape = input_tensors[0].shape + tensor_shape = input_tensors[0].shape[chunk_dim] assert all( - input_tensor.shape == tensor_shape for input_tensor in input_tensors + input_tensor.shape[chunk_dim] == tensor_shape for input_tensor in input_tensors ), "All input tenors have to be of the same shape" # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability