-
Notifications
You must be signed in to change notification settings - Fork 26.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem? #20794
Comments
cc @gante and @sanchit-gandhi |
Hey @joytianya! Sorry about the late reply here! Cool to see that you're using the Flax MT5 model! The big speed-up from using JAX on TPU comes from JIT compiling a function: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html. It's worth reading this guide to get a feel for how JAX + XLA + TPU work in combination to give you fast kernel execution. I've written an ipynb notebook that demonstrates how you can JIT compile the generate method: https://github.com/sanchit-gandhi/codesnippets/blob/main/benchmark_flaxmt5_jit_generate.ipynb Running this using a 'tiny' version of the Flax MT5 model on CPU, I get a 75x speed-up JIT compiling the generate function vs the vanilla generate function! That's fast right! You can adapt the script for the Let me know if you have any other questions, happy to help! |
Thank you very much for your reply, I tried it, it is indeed effective jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 320.03M at the bottom of memory. That was not possible. There are 1.20G free, 0B reserved, and 196.31M reservable. If fragmentation is eliminated, the maximum reservable bytes would be 1.20G, so compaction will enable this reservation. The nearest obstacle is at 196.31M from the bottom with size 160.00M.
|
Hey @joytianya! Glad to hear that JIT'ing the generate function worked well! The MT5-XXL checkpoint is 13 billion params (2.33GB) - this is pretty significant! We have to get pretty advanced to fit such a big model on a single TPU v3-8. There are two things that you can try:
1 is quite straightforward! 2 is very involved 😅. Let's start with 1! Here's a code snippet on how you can achieve 1: https://github.com/sanchit-gandhi/codesnippets/blob/main/flaxmt5_inference_half_precision.ipynb For pjit, you'll need to modify the code for Flax MT5 to add the sharing annotations. You can see an example for Flax BLOOM here: https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L200-L202 This is pretty advanced stuff! I can explain how it works a bit more if you really need to use pjit. Best of luck! Hope these answers provide some pointers as to how you can fit the XXL model on a v3-8! |
One other thing I forgot! If you're running inference on batches of data, using You can do this easily using the example script run_clm_flax.py with the
Currently, the evaluation step will only return the eval loss. You can modify it to also return the logits to get the actual predictions as well:
If nothing else, you can use the run_clm_flax.py script as an example of how we can pmap to effectively parallelise across TPU devices. |
great! Thank you very much for your suggestion. I will try it next |
Put together a quick codesnippet that isolates This doesn't require any optimiser initialisation so should be much more memory efficient than using the previous suggestion of run_clm_flax.py. |
ok, Does this method also support the XXL model on the TPU V3-8? |
The methodology remains the same for any checkpoint. As to whether the XXL model fits in memory you'll have to experiment for yourself! Definitely worth trying converting the model params to half-precision and running the computations in bf16 for this size model (as done in this code snippet: https://github.com/sanchit-gandhi/codesnippets/blob/main/flaxmt5_inference_half_precision.ipynb) |
ok, I am very grateful for your suggestion, I plan to try and experiment further |
When I load it with this model"ClueAI/ChatYuan-large-v1", the following error will occur. How to solve this problem? Some weights of the model checkpoint at ClueAI/ChatYuan-large-v1 were not used when initializing FlaxT5ForConditionalGeneration: {('decoder', 'embed_tokens', 'kernel'), ('encoder', 'embed_tokens', 'kernel')}
- This IS expected if you are initializing FlaxT5ForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxT5ForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-16-6177c268ed70>](https://localhost:8080/#) in <module>
1 model_name = "ClueAI/ChatYuan-large-v1"
2 #model, params = FlaxMT5ForConditionalGeneration.from_pretrained(model_name, _do_init=False)
----> 3 model, params = FlaxT5ForConditionalGeneration.from_pretrained(model_name, from_pt=True)
4
5 tokenizer = T5Tokenizer.from_pretrained(model_name)
TypeError: cannot unpack non-iterable FlaxT5ForConditionalGeneration object
model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True) |
Hey @joytianya! It's not possible to use model = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True)
params = model.params Or directly load Flax weights if they are saved in the repo. If you want to load the model instance and weights separately, you can set model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", _do_init=False) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
model, params = FlaxT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", _do_init=False) OSError: ClueAI/ChatYuan-large-v1 does not appear to have a file named flax_model.msgpack but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those weights. |
i try it, and How to configure "max_length", "top_k", "do_sample" and other parameters with this ? |
outputs = jit_generate(input_ids=input_ids, max_new_tokens=512, top_k=30, do_sample=True, temperature=0.7).sequences |
I found that the results of each run are the same, but do_ Sample=True, how to configure it to generate randomly |
hi, @sanchit-gandhi I look forward to your reply |
Hey @joytianya! Answering your questions sequentially:
import jax
from transformers import FlaxMT5ForConditionalGeneration
SAVE_DIR = "/path/to/save/dir" # change this to where you want the model to be saved
with jax.default_device(jax.devices("cpu")[0]):
model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True)
model.save_pretrained(SAVE_DIR) Now the next time you load the model, you can do so with model, params = FlaxT5ForConditionalGeneration.from_pretrained(SAVE_DIR, _do_init=False)
pmap_generate = jax.pmap(model.generate, "batch", static_broadcasted_argnums =[ <PUT A LIST OF THE ARGNUMS YOU WANT TO PASS>]) See https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html for details.
The model will stop generating when the EOS token is reached. Make sure you have configured your tokenizer correctly: https://huggingface.co/docs/transformers/model_doc/mt5#transformers.T5Tokenizer
Do you have a codesnippet you could share that demonstrates this? Thanks! |
In order to explain the problem 3 and 4 in detail, I wrote this code and after execution. from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small", from_pt=True)
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
# vanilla generate -> JIT generate
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "top_k", "do_sample"])
def answer(max_length):
input_context = ["The dog is", "The cat is"]
input_ids = tokenizer(input_context, return_tensors="np").input_ids
outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(outputs)
print(res)
return res
answer(20)
import time
start_time = time.time()
for i in range(10):
answer(20)
print(time.time() - start_time)
answer(1024)
import time
start_time = time.time()
for i in range(10):
answer(1024)
print(time.time() - start_time) from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax
import jax.numpy as jnp
model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True, dtype=jnp.bfloat16)
model.params = model.to_bf16(model.params)
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
# copy (replicate) the params across your TPU devices
#params = jax_utils.replicate(params)
# pmap generate (like jit, but replicated across our JAX devices)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "max_new_tokens", "top_k", "do_sample", "temperature", "eos_token_id"])
def answer(max_length):
input_context = ["The dog is", "The cat is"]
input_ids = tokenizer(input_context, return_tensors="np").input_ids
outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(outputs)
print(res)
return res
answer(256)
import time
start_time = time.time()
for i in range(10):
answer(256)
print(time.time() - start_time)
answer(1024)
import time
start_time = time.time()
for i in range(10):
answer(1024)
print(time.time() - start_time) |
for 2, Is this correct? pmap_generate = jax.pmap(model.generate, "batch", static_broadcasted_argnums = [ 2, 3, 4, 5, 6])
outputs = pmap_generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=30, do_sample=True, temperature=0.7, params=params).sequences error occur: outputs = pmap_generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, top_k=30, do_sample=True, temperature=0.7, params=params).sequences
ValueError: pmapped function has static_broadcasted_argnums=(2, 3, 4, 5, 6) but was called with only 1 positional argument. All static broadcasted arguments must be passed positionally. |
hi, @sanchit-gandhi I look forward to your reply |
Hey @joytianya, If you don't want to change the generation params in from flax.training.common_utils shard
def generate(params, batch):
outputs = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_new_tokens=128, top_k=30, do_sample=True, temperature=0.7, params=params).sequences # anything that does not depend on `batch` is fixed
return outputs
p_generate = jax.pmap(generate, "batch")
input_context = ["The dog is" for _ in range(8)] # batch size needs to be a multiple of the number of TPU devices
batch = tokenizer(input_context, return_tensors="np")
batch = shard(batch)
# slow - we're compiling
outputs = p_generate(batch)
# fast!
outputs = p_generate(batch) |
Is this phenomenon correct? |
Hey @joytianya
We can't really rely on the outputs of the model since it's only been pre-trained, not fine-tuned, so it's bound to output gibberish regardless of what we give it (see https://huggingface.co/google/mt5-small for details). You can try using a fine-tuned checkpoint if you want to look at the actual token predictions.
This is because the model has only been pre-trained (not fine-tuned): the model never hits the end-of-sequence token, it generates random outputs until it hits max length. Therefore, it always generates to max length and never terminates early. So if you increase max length, the model generates more tokens, and so decoding takes longer. |
hey @sanchit-gandhi ,
from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax
import jax.numpy as jnp
model = FlaxMT5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v1", from_pt=True, dtype=jnp.bfloat16)
model.params = model.to_bf16(model.params)
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v1")
# copy (replicate) the params across your TPU devices
#params = jax_utils.replicate(params)
# pmap generate (like jit, but replicated across our JAX devices)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "max_new_tokens", "top_k", "do_sample", "temperature", "eos_token_id"])
def answer(max_length):
input_context = ["The dog is", "The cat is"]
input_ids = tokenizer(input_context, return_tensors="np").input_ids
outputs = jit_generate(input_ids=input_ids, max_length=max_length, top_k=30, do_sample=True).sequences
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(outputs)
print(res)
return res
answer(256)
import time
start_time = time.time()
for i in range(10):
answer(256)
print(time.time() - start_time)
answer(1024)
import time
start_time = time.time()
for i in range(10):
answer(1024)
print(time.time() - start_time) |
Hey @joytianya - if running this on a GPU gives one answer and running it on a TPU another, I'm not really sure this is a transformers based issue but probably a JAX or Flax one. Could you try re-running the code-snippet under the highest JAX matmul precision? We should then get equivalence on CPU/GPU/TPU. See #15754 (comment) for details. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem?
jax device is exist
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Expect it to be fast
The text was updated successfully, but these errors were encountered: