Skip to content

Conversation

@OyvindTafjord
Copy link
Contributor

What does this PR do?

As requested by @patrickvonplaten in conversation on issue #9200, this fixes a crash when trying to use beam search on T5 models split across multiple GPUs using model.parallelize(). It uses the fix from #9219, applied to the T5-specific code (also related is #9596 which refactored the _reorder_cache functions).

I tested the fix on a t5-small model. Before:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  
tokenizer = AutoTokenizer.from_pretrained("allenai/unifiedqa-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/unifiedqa-t5-small")
device_map = {0: range(0,3), 1: range(3, 6)}
input_string = "What was the color of the sky?\\nIt was a dark stormy night."
input_ids = tokenizer.encode(input_string,return_tensors="pt").to("cuda:0")
output = model.generate(input_ids, num_beams=2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/oyvindt/miniconda3/envs/transformers4/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/oyvindt/miniconda3/envs/transformers4/lib/python3.9/site-packages/transformers/generation_utils.py", line 1044, in generate
    return self.beam_search(
  File "/home/oyvindt/miniconda3/envs/transformers4/lib/python3.9/site-packages/transformers/generation_utils.py", line 1788, in beam_search
    model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
  File "/home/oyvindt/miniconda3/envs/transformers4/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1635, in _reorder_cache
    layer_past_state.index_select(0, beam_idx),
RuntimeError: Input, output and indices must be on the current device

After:

...
output = model.generate(input_ids, num_beams=2)
tokenizer.batch_decode(output, skip_special_tokens=True)
 --> ['dark stormy']

As far as I know this small fix shouldn't have any adverse effects. As to why the tests added in #9219 didn't catch this, possibly that's because they're not generally run in multi-GPU setups?

@patrickvonplaten patrickvonplaten merged commit bd3b599 into huggingface:master May 14, 2021
@bing0037
Copy link

@OyvindTafjord Hi, I am trying to figure out how to use model parallelization on T5 but having some problems. I tried to reproduce your result but got the following error:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  
tokenizer = AutoTokenizer.from_pretrained("allenai/unifiedqa-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/unifiedqa-t5-small")
device_map = {0: range(0,3), 1: range(3, 6)}
model.parallelize(device_map)
input_string = "What was the color of the sky?\\nIt was a dark stormy night."
input_ids = tokenizer.encode(input_string,return_tensors="pt").to("cuda:0")

output = model.generate(input_ids, num_beams=2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/transformers/generation_utils.py", line 922, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/transformers/generation_utils.py", line 417, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/transformers/models/t5/modeling_t5.py", line 897, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/nn/functional.py", line 1724, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1587428091666/work/aten/src/THC/generic/THCTensorIndex.cu:403
  • My current environment:
    transformers: 4.7.0.dev0
    torch: 1.5.0

Could you please help me to figure out the problem and give me some direction that I should start with? I don't have much experience with model parallelization, do I need to modify the input_ids?

Thanks in advance.

@OyvindTafjord
Copy link
Contributor Author

@bing0037 Hm, I tested with 4.7.0 now and the above code works for me. I noticed my initial set of commands was missing the critical model.parallelize(device_map) step, but looks like you made sure to include that?

You could double check that model.encoder.first_device returns the expected 'cuda:0', and then the code at https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_t5.py#L870 should make sure the embeddings are also on that same device, so you shouldn't get that line 897 error above.

@bing0037
Copy link

@OyvindTafjord Thank you for your reply. The problem was the inconsistency of my command and the above command works well.
BTW, the above command is for parallelized model inference, could you please give me some suggestions for parallelized model training?
Currently, I am trying to finetune t5-large model using run_summarization.py on multiple GPUs by using model parallelization.

  • My test 1: By adding model.parallieze() directly in run_summarization.py, but got the following error:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    

+    device_map = {0: [0, 1, 2],
+                1: [3, 4, 5, 6, 7, 8, 9],
+                3: [10, 11, 12, 13, 14, 15, 16],
+                4: [17, 18, 19, 20, 21, 22, 23]}
+    model.parallelize(device_map) # Splits the model across several devices

    model.resize_token_embeddings(len(tokenizer))

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
Traceback (most recent call last):
  File "run_summarization.py", line 616, in <module>
    main()
  File "run_summarization.py", line 540, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/transformers/trainer.py", line 1300, in train
    args.max_grad_norm,
  File "/home/guest/anaconda3/envs/huggingface_latest/lib/python3.6/site-packages/torch/nn/utils/clip_grad.py", line 30, in clip_grad_norm_
    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
RuntimeError: All input tensors must be on the same device. Received cuda:0 and cuda:7
pip install git+https://github.com/aws/sagemaker-python-sdk.git
pip install sagemaker
>>> from transformers.file_utils import is_sagemaker_mp_enabled
>>> is_sagemaker_mp_enabled()
False

Could you give me some resources that I could refer to? Thank you!

@OyvindTafjord
Copy link
Contributor Author

@bing0037 I haven't tried the parallelize functionality in the context of training, so I'm not of much help on that.

Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
amyeroberts pushed a commit that referenced this pull request Jun 15, 2023
* Fix LLaMa beam search when using parallelize

same issue as T5 #11717

* fix code format in modeling_llama.py

* fix format of _reorder_cache in modeling_llama.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants