Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem? #20794

Closed
4 tasks
joytianya opened this issue Dec 16, 2022 · 28 comments

Comments

@joytianya
Copy link

System Info

When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem?
jax device is exist

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)

Expected behavior

Expect it to be fast

@sgugger
Copy link
Collaborator

sgugger commented Dec 16, 2022

cc @gante and @sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Dec 21, 2022

Hey @joytianya! Sorry about the late reply here! Cool to see that you're using the Flax MT5 model!

The big speed-up from using JAX on TPU comes from JIT compiling a function: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html. It's worth reading this guide to get a feel for how JAX + XLA + TPU work in combination to give you fast kernel execution.

I've written an ipynb notebook that demonstrates how you can JIT compile the generate method: https://github.com/sanchit-gandhi/codesnippets/blob/main/benchmark_flaxmt5_jit_generate.ipynb

Running this using a 'tiny' version of the Flax MT5 model on CPU, I get a 75x speed-up JIT compiling the generate function vs the vanilla generate function! That's fast right!

You can adapt the script for the mt5-small checkpoint as you require 🤗 You'll need to pass any additional args that use boolean control flow in the generate method under static_argnames (as done with max_length, top_k, do_sample).

Let me know if you have any other questions, happy to help!

@joytianya
Copy link
Author

joytianya commented Dec 22, 2022

Thank you very much for your reply, I tried it, it is indeed effective
In addition, It reports OOM on the V3-8TPU to use MT5-XXL. do you have any suggestions? Make me can inference MT5-XXL with v3-8 TPU

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 320.03M at the bottom of memory. That was not possible. There are 1.20G free, 0B reserved, and 196.31M reservable. If fragmentation is eliminated, the maximum reservable bytes would be 1.20G, so compaction will enable this reservation.  The nearest obstacle is at 196.31M from the bottom with size 160.00M.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Dec 22, 2022

Hey @joytianya! Glad to hear that JIT'ing the generate function worked well!

The MT5-XXL checkpoint is 13 billion params (2.33GB) - this is pretty significant! We have to get pretty advanced to fit such a big model on a single TPU v3-8.

There are two things that you can try:

  1. Half-precision inference: set the computation dtype and model parameters to bfloat16 (half) precision. This will save a significant amount of memory vs float32 (full) precision and should get you numerically equivalent results
  2. Model partitioning: use pjit for model parallelism

1 is quite straightforward! 2 is very involved 😅. Let's start with 1!

Here's a code snippet on how you can achieve 1: https://github.com/sanchit-gandhi/codesnippets/blob/main/flaxmt5_inference_half_precision.ipynb

For pjit, you'll need to modify the code for Flax MT5 to add the sharing annotations. You can see an example for Flax BLOOM here: https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L200-L202 This is pretty advanced stuff! I can explain how it works a bit more if you really need to use pjit.

Best of luck! Hope these answers provide some pointers as to how you can fit the XXL model on a v3-8!

@huggingface huggingface deleted a comment from github-actions bot Jan 16, 2023
@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jan 16, 2023

One other thing I forgot! If you're running inference on batches of data, using pmap for data parallelism across TPU devices is by far your best shout.

You can do this easily using the example script run_clm_flax.py with the --do_eval flag. This example wraps up the model loading, data loading and data parallelisation using pmap into one script, so you can run it using a single command:

python run_clm_flax.py \
    --output_dir="./eval-out" \
    --model_name_or_path="google/mt5-small" \
    --dataset_name="oscar" \
    --dataset_config_name="unshuffled_deduplicated_no" \
    --do_eval \
    --per_device_eval_batch_size="64" \
    --overwrite_output_dir \

Currently, the evaluation step will only return the eval loss. You can modify it to also return the logits to get the actual predictions as well:

logits = model(**batch, params=params, train=False)[0]

If nothing else, you can use the run_clm_flax.py script as an example of how we can pmap to effectively parallelise across TPU devices.

@joytianya
Copy link
Author

great! Thank you very much for your suggestion. I will try it next

@sanchit-gandhi
Copy link
Contributor

Put together a quick codesnippet that isolates pmap: https://github.com/sanchit-gandhi/codesnippets/blob/main/pmap_flaxmt5_generate.ipynb

This doesn't require any optimiser initialisation so should be much more memory efficient than using the previous suggestion of run_clm_flax.py.

@joytianya
Copy link
Author

ok, Does this method also support the XXL model on the TPU V3-8?

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jan 18, 2023

The methodology remains the same for any checkpoint. As to whether the XXL model fits in memory you'll have to experiment for yourself! Definitely worth trying converting the model params to half-precision and running the computations in bf16 for this size model (as done in this code snippet: https://github.com/sanchit-gandhi/codesnippets/blob/main/flaxmt5_inference_half_precision.ipynb)

@joytianya
Copy link
Author

ok, I am very grateful for your suggestion, I plan to try and experiment further

@joytianya
Copy link
Author

When I load it with this model"ClueAI/ChatYuan-large-v1", the following error will occur. How to solve this problem?

Some weights of the model checkpoint at ClueAI/ChatYuan-large-v1 were not used when initializing FlaxT5ForConditionalGeneration: {('decoder', 'embed_tokens', 'kernel'), ('encoder', 'embed_tokens', 'kernel')}
- This IS expected if you are initializing FlaxT5ForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxT5ForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-16-6177c268ed70>](https://localhost:8080/#) in <module>
      1 model_name = "ClueAI/ChatYuan-large-v1"
      2 #model, params = FlaxMT5ForConditionalGeneration.from_pretrained(model_name, _do_init=False)
----> 3 model, params = FlaxT5ForConditionalGeneration.from_pretrained(model_name, from_pt=True)
      4 
      5 tokenizer = T5Tokenizer.from_pretrained(model_name)

TypeError: cannot unpack non-iterable FlaxT5ForConditionalGeneration object
model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True)

@sanchit-gandhi
Copy link
Contributor

Hey @joytianya! It's not possible to use from_pt=True with _do_init=False. Currently, you need to load PyTorch weights with _do_init=True:

model = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True)
params = model.params

Or directly load Flax weights if they are saved in the repo. If you want to load the model instance and weights separately, you can set _do_init=False (see #16148 (comment)):

model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1",  _do_init=False)

@github-actions
Copy link

github-actions bot commented Mar 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@joytianya
Copy link
Author

#16148 (comment)
while i try, error occur, How to solve this problem?

model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1",  _do_init=False)
OSError: ClueAI/ChatYuan-large-v1 does not appear to have a file named flax_model.msgpack but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those weights.

@joytianya
Copy link
Author

i try it, and How to configure "max_length", "top_k", "do_sample" and other parameters with this ?

https://github.com/sanchit-gandhi/codesnippets/blob/main/pmap_flaxmt5_generate.ipynb

@joytianya
Copy link
Author

outputs = jit_generate(input_ids=input_ids, max_new_tokens=512, top_k=30, do_sample=True, temperature=0.7).sequences
I found that the generated shape is max_new_tokens ,
Whether the end character can be reached and terminated , so as to save time
What shall I do?

@joytianya
Copy link
Author

I found that the results of each run are the same, but do_ Sample=True, how to configure it to generate randomly

@joytianya
Copy link
Author

hi, @sanchit-gandhi I look forward to your reply

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Mar 8, 2023

Hey @joytianya! Answering your questions sequentially:

  1. _do_init=False is only supported when we directly load Flax weights. The error message we're getting is telling us that the model only has PyTorch weights available. Let's first load the model in PyTorch on CPU, save it as a Flax model, then re-load in on TPU:
import jax
from transformers import FlaxMT5ForConditionalGeneration

SAVE_DIR = "/path/to/save/dir"  # change this to where you want the model to be saved

with jax.default_device(jax.devices("cpu")[0]):
    model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True)
    model.save_pretrained(SAVE_DIR)

Now the next time you load the model, you can do so with _do_init=False and the default TPU device:

model, params = FlaxT5ForConditionalGeneration.from_pretrained(SAVE_DIR,  _do_init=False)
  1. Can you try using static_broadcasted_argnums and passing the argument indices of the variables you want to control:
pmap_generate = jax.pmap(model.generate, "batch", static_broadcasted_argnums =[ <PUT A LIST OF THE ARGNUMS YOU WANT TO PASS>])

See https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html for details.

  1. Whether the end character can be reached and terminated , so as to save time

The model will stop generating when the EOS token is reached. Make sure you have configured your tokenizer correctly: https://huggingface.co/docs/transformers/model_doc/mt5#transformers.T5Tokenizer

  1. I found that the results of each run are the same, but do_ Sample=True, how to configure it to generate randomly

Do you have a codesnippet you could share that demonstrates this? Thanks!

@joytianya
Copy link
Author

In order to explain the problem 3 and 4 in detail, I wrote this code and after execution.
For 4. The result of each generation is exactly the same
For 3. Different from max_length, time is very different. Time and max_length are proportional. It doesn’t seem to end early

from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small", from_pt=True)
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
# vanilla generate -> JIT generate 
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "top_k", "do_sample"])


def answer(max_length):
    input_context = ["The dog is", "The cat is"]
    input_ids = tokenizer(input_context, return_tensors="np").input_ids
    outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences
    res = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    print(outputs)
    print(res)
    return res

answer(20)

import time
start_time = time.time()
for i in range(10):
    answer(20)
print(time.time() - start_time)


answer(1024)

import time
start_time = time.time()
for i in range(10):
    answer(1024)
print(time.time() - start_time)
from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax
import jax.numpy as jnp
model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True, dtype=jnp.bfloat16)
model.params = model.to_bf16(model.params)
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
# copy (replicate) the params across your TPU devices
#params = jax_utils.replicate(params)
# pmap generate (like jit, but replicated across our JAX devices)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "max_new_tokens", "top_k", "do_sample", "temperature", "eos_token_id"])

def answer(max_length):
    input_context = ["The dog is", "The cat is"]
    input_ids = tokenizer(input_context, return_tensors="np").input_ids
    outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences
    res = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    print(outputs)
    print(res)
    return res

answer(256)

import time
start_time = time.time()
for i in range(10):
    answer(256)
print(time.time() - start_time)


answer(1024)

import time
start_time = time.time()
for i in range(10):
    answer(1024)
print(time.time() - start_time)

@joytianya
Copy link
Author

for 2, Is this correct?

pmap_generate = jax.pmap(model.generate, "batch", static_broadcasted_argnums = [ 2, 3, 4, 5, 6])
outputs = pmap_generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=30, do_sample=True, temperature=0.7, params=params).sequences

error occur:

 outputs = pmap_generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=30, do_sample=True, temperature=0.7, params=params).sequences
ValueError: pmapped function has static_broadcasted_argnums=(2, 3, 4, 5, 6) but was called with only 1 positional argument. All static broadcasted arguments must be passed positionally.

@joytianya
Copy link
Author

hi, @sanchit-gandhi I look forward to your reply

@sanchit-gandhi
Copy link
Contributor

Hey @joytianya,

If you don't want to change the generation params in .generate, you can just fix them like this:

from flax.training.common_utils shard

def generate(params, batch):
    outputs = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_new_tokens=128, top_k=30, do_sample=True, temperature=0.7, params=params).sequences  # anything that does not depend on `batch` is fixed
    return outputs

p_generate = jax.pmap(generate, "batch")

input_context = ["The dog is" for _ in range(8)]  # batch size needs to be a multiple of the number of TPU devices

batch = tokenizer(input_context, return_tensors="np")
batch = shard(batch)

# slow - we're compiling
outputs = p_generate(batch)

# fast!
outputs = p_generate(batch)

@joytianya
Copy link
Author

In order to explain the problem 3 and 4 in detail, I wrote this code and after execution.

For 4. The result of each generation is exactly the same

For 3. Different from max_length, time is very different. Time and max_length are proportional. It doesn’t seem to end early

from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration

import jax

model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small", from_pt=True)

tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")

# vanilla generate -> JIT generate 

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "top_k", "do_sample"])





def answer(max_length):

    input_context = ["The dog is", "The cat is"]

    input_ids = tokenizer(input_context, return_tensors="np").input_ids

    outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences

    res = tokenizer.batch_decode(outputs, skip_special_tokens=True)



    print(outputs)

    print(res)

    return res



answer(20)



import time

start_time = time.time()

for i in range(10):

    answer(20)

print(time.time() - start_time)





answer(1024)



import time

start_time = time.time()

for i in range(10):

    answer(1024)

print(time.time() - start_time)
from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration

import jax

import jax.numpy as jnp

model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True, dtype=jnp.bfloat16)

model.params = model.to_bf16(model.params)

tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")

# copy (replicate) the params across your TPU devices

#params = jax_utils.replicate(params)

# pmap generate (like jit, but replicated across our JAX devices)

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "max_new_tokens", "top_k", "do_sample", "temperature", "eos_token_id"])



def answer(max_length):

    input_context = ["The dog is", "The cat is"]

    input_ids = tokenizer(input_context, return_tensors="np").input_ids

    outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences

    res = tokenizer.batch_decode(outputs, skip_special_tokens=True)



    print(outputs)

    print(res)

    return res



answer(256)



import time

start_time = time.time()

for i in range(10):

    answer(256)

print(time.time() - start_time)





answer(1024)



import time

start_time = time.time()

for i in range(10):

    answer(1024)

print(time.time() - start_time)

Is this phenomenon correct?

@sanchit-gandhi
Copy link
Contributor

Hey @joytianya

The result of each generation is exactly the same

We can't really rely on the outputs of the model since it's only been pre-trained, not fine-tuned, so it's bound to output gibberish regardless of what we give it (see https://huggingface.co/google/mt5-small for details). You can try using a fine-tuned checkpoint if you want to look at the actual token predictions.

Different from max_length, time is very different. Time and max_length are proportional. It doesn’t seem to end early

This is because the model has only been pre-trained (not fine-tuned): the model never hits the end-of-sequence token, it generates random outputs until it hits max length. Therefore, it always generates to max length and never terminates early. So if you increase max length, the model generates more tokens, and so decoding takes longer.

@joytianya
Copy link
Author

hey @sanchit-gandhi ,

  1. I can try using a fine-tuned checkpoint ClueAI/ChatYuan-large-v1, The phenomenon is the same. I used sample sampling. With the same code, when I use GPU, the results of each run are different. But the results on TPU are still the same.
  2. Additionally, you can see that the length of the generated sentence is much smaller than the max length of tokens, so it should have already hit the end-of-sequence token.
    Hope you can give it a try.
from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration

import jax

import jax.numpy as jnp

model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True, dtype=jnp.bfloat16)

model.params = model.to_bf16(model.params)

tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")

# copy (replicate) the params across your TPU devices

#params = jax_utils.replicate(params)

# pmap generate (like jit, but replicated across our JAX devices)

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "max_new_tokens", "top_k", "do_sample", "temperature", "eos_token_id"])



def answer(max_length):

    input_context = ["The dog is", "The cat is"]

    input_ids = tokenizer(input_context, return_tensors="np").input_ids

    outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences

    res = tokenizer.batch_decode(outputs, skip_special_tokens=True)



    print(outputs)

    print(res)

    return res



answer(256)



import time

start_time = time.time()

for i in range(10):

    answer(256)

print(time.time() - start_time)





answer(1024)



import time

start_time = time.time()

for i in range(10):

    answer(1024)

print(time.time() - start_time)

@sanchit-gandhi
Copy link
Contributor

Hey @joytianya - if running this on a GPU gives one answer and running it on a TPU another, I'm not really sure this is a transformers based issue but probably a JAX or Flax one.

Could you try re-running the code-snippet under the highest JAX matmul precision? We should then get equivalence on CPU/GPU/TPU. See #15754 (comment) for details.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants