diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 91cc97c95c1e..4c2f20f040b2 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -156,7 +156,7 @@ def _expand_inputs_for_generation( if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( - 0, expanded_return_idx + 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @@ -226,7 +226,7 @@ def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[t For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in subclasses of :class:`~transformers.PreTrainedModel`. """ - return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) + return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past) def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 26fc6be67228..57b421c61085 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1166,6 +1166,34 @@ def cast_to_device(dictionary, device): for value_, parallel_value_ in zip(value, parallel_value): self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7)) + @require_torch_multi_gpu + def test_model_parallel_beam_search(self): + if not self.test_model_parallel: + return + + all_generative_and_parallelizable_model_classes = tuple( + set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes) + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in all_generative_and_parallelizable_model_classes: + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + def cast_to_device(dictionary, device): + output = {} + for k, v in dictionary.items(): + if isinstance(v, torch.Tensor): + output[k] = v.to(device) + else: + output[k] = v + + return output + + model.parallelize() + model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) + global_rng = random.Random()