enable llava static generation.#767
Conversation
|
based on the image-to-text generation pr #738 I test it on single card Gaudi2 with the --use_hpu_graphs: result = [[{'generated_text': "[\nUSER: What's the content of the image?\nASSISTANT: The image features a pier extending out into a large body of water, likely a lake.\n\n"}]], time = 264.1947269439697ms Input/outputs: |
Input/outputs 1: Input/outputs 2: Input/outputs 3: Input/outputs 4: Input/outputs: Number of HPU graphs = 26 |
|
Just want to let you know this works like a charm! |
There was a problem hiding this comment.
@lkk12014402 , could you please provide a brief description of the changes needed in optimum/habana/transformers/models/llava/modeling_llava.py wrt the base model in transformers
I see a couple of single input where, which are usually dynamic on HPU. If these are on CPU, then its fine, but if these are on HPU, they might need rewriting.
hi, @ssarkar2 I will give a description and check the operation torch.where() as soon as possible |
hi, @ssarkar2 , DescriptionLet's assume the input text is generation with huggingface transformers directlythe huggingface transformers will get the text embedding [1, 4, 4096] with llava-1.5-7b-hf, and get image embedding [1,576, 4096]. Then the 2 embeddings will be merged to final input embedding [1, 579, 4096] using here. The merge function also has many dynamic op, like torch.where and the input shape is dynamic during the generation. So when we use gaudi2 to do generation, there are 2 problems:
note: to reproduce, you can use this pr image-to-text example
my optimizationIn order to maintain the transformers usage (same input, same generation script) and enable static shape by padding and inserting token_idx for generation, I add a new function And for keeping same input shape during generation, I also use token_idx. So I create 2 auxiliary variables, the explanation of maintaining
|
|
@ssarkar2 please help review~ |
libinta
left a comment
There was a problem hiding this comment.
@lkk12014402 can you add a ci test case and rebase?
@libinta I will update the pr with your comments soon. |
d44c540 to
1a1ee0b
Compare
hi, @libinta I have resolved the conflicting files. And I haven't seen image-to-text example test case like |
|
@lkk12014402 can you add a file like test_image2text_generation_example.py to include image2text generation Line 76 in 081130d |
hi @libinta please help review/check the image to text ut. Thanks~ |
|
|
||
|
|
||
| @pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["bf16"]) | ||
| def test_text_generation_bf16(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str): |
There was a problem hiding this comment.
better to have image_to_test rather than text_generation
| f"--model_name_or_path {model_name}", | ||
| f"--batch_size {batch_size}", | ||
| "--use_kv_cache", | ||
| "--max_new_tokens 20", |
There was a problem hiding this comment.
have you ran the test with
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_image_to_text_example.py -v -s
if so, you will see run_pipeline.py: error: unrecognized arguments: --use_kv_cache --output_dir /tmp/tmpsp9f6li_ --token None
you should include whatever arguments as python3 run_pipeline.py
--model_name_or_path "llava-hf/llava-1.5-7b-hf"
--image_path "https://llava-vl.github.io/static/images/view.jpg"
--prompt "\nUSER: What's the content of the image?\nASSISTANT:"
--max_new_tokens 20
--use_hpu_graphs
--bf16
| pattern = re.compile(r"([\"\'].+?[\"\'])|\s") | ||
| command = [x for y in command for x in re.split(pattern, y) if x] | ||
|
|
||
| if fp8: |
There was a problem hiding this comment.
remove fp8 section for now
regisss
left a comment
There was a problem hiding this comment.
It seems there are some merge conflicts to solve, can you update your main branch and merge it into this one?
Also, please run
pip install -U ruff
make style
to have the code style check pass.
update code style with the command |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
hi, @regisss updated code with your comments. please review~ Thanks~ |
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
…ace#767) Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Adam Stachowicz <astachow@habana.ai>
…ace#767) Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Adam Stachowicz <astachow@habana.ai>
What does this PR do?
support llava image to text generation