Skip to content

fix beamsearch crash and incorrect output in decode-only model and en…#627

Merged
regisss merged 4 commits into
mainfrom
beam_search_fix
Feb 20, 2024
Merged

fix beamsearch crash and incorrect output in decode-only model and en…#627
regisss merged 4 commits into
mainfrom
beam_search_fix

Conversation

@sywangyi
Copy link
Copy Markdown
Collaborator

@sywangyi sywangyi commented Jan 8, 2024

…code-decode model

found with "do_sample=False, num_beams=4" in generation. test model include
decode only:"gpt2" "tiiuae/falcon-7b-instruct" "distilgpt2"
encode decoder: "facebook/bart-large-cnn" "sshleifer/distilbart-cnn-12-6" "philschmid/bart-large-cnn-samsum"

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Jan 8, 2024

@regisss

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Jan 8, 2024

because the static shape in self.generation_config is not set (see https://github.com/huggingface/optimum-habana/pull/627/files#diff-284d6c109e6788a4405e302113203f6b7e98be68afc7dd3d85b89c6329b9275eL540-R541) if generation_config is passed explicitly with static_shapes=True, the issue is hidden in optimum-habana text generation example. I use transformer pipeline to run in gaudi without generation_config pass in generate(), so find the bug

@sywangyi sywangyi force-pushed the beam_search_fix branch 4 times, most recently from 6a05a32 to 54f780a Compare January 9, 2024 02:21
@sywangyi
Copy link
Copy Markdown
Collaborator Author

provide a simple test

from transformers import pipeline
import habana_frameworks.torch
import torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
generation_kwargs = dict(do_sample=False, num_beams=4, use_cache=True, max_new_tokens=32)

generator = pipeline(
        "text-generation",
        model="gpt2",
        torch_dtype=torch.bfloat16,
        device="hpu",
        **generation_kwargs,
)

print(generator("DeepSpeed is a machine learning framework"))

before the fix, the output is
[{'generated_text': 'DeepSpeed is a machine learning framework'}]

after the fix, the output is
[{'generated_text': "DeepSpeed is a machine learning framework that can be used to build machine learning models.\n\nIn this tutorial, we'll learn how to build a machine learning model that can be used to build"}

@sywangyi
Copy link
Copy Markdown
Collaborator Author

@libinta @mandy-li

Comment thread optimum/habana/transformers/generation/utils.py Outdated
Comment thread optimum/habana/transformers/generation/utils.py
Comment thread optimum/habana/transformers/generation/utils.py
@sywangyi
Copy link
Copy Markdown
Collaborator Author

provide a simple test for summerization using encoder-decoder model.

from transformers import pipeline
import habana_frameworks.torch
import torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
generation_kwargs = dict(do_sample=False, num_beams=4, use_cache=True, max_new_tokens=32)

generator = pipeline(
        "summarization",
        model="facebook/bart-large-cnn",
        torch_dtype=torch.bfloat16,
        device="hpu",
        **generation_kwargs,
)

print(generator(["(CNN)The Palestinian Authority officially became", "DeepSpeed is a machine learning"], batch_size=2))

without the fix. it will coredump

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Jan 15, 2024

@sywangyi It seems the PR "breaks" regular generation. Running

python run_generation.py --model_name_or_path gpt2 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --prompt "Hello world"

on main returns

01/15/2024 21:55:18 - INFO - __main__ - Running generate...                                                                                  
                                                                                                                                             
Input/outputs:                                                                                                                               
input 1: ('Hello world',)                                                                                                                    
output 1: ('Hello world, I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to 
say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to 
say.\n\n"I\'m sorry, but I\'m not',)                                                                                                         
                                                                                                                                             
                                                                                                                                             
Stats:                                                                                                                                       
--------------------------------------------------------------------------------------------------------------                               
Throughput (including tokenization) = 859.1097929752864 tokens/second                                                                        
Number of HPU graphs                = 12                                                                                                     
Memory allocated                    = 0.51 GB                                                                                                
Max memory allocated                = 0.62 GB                                                                                                
Total memory available              = 94.62 GB
Graph compilation duration          = 1.3922957889735699 seconds
--------------------------------------------------------------------------------------------------------------

but on your branch I get

01/15/2024 21:55:53 - INFO - __main__ - Running generate...

Input/outputs:
input 1: ('Hello world',)
output 1: ('Hello world,',)


Stats:
-------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 34314.56502749868 tokens/second
Number of HPU graphs                = 6
Memory allocated                    = 0.51 GB
Max memory allocated                = 0.62 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 0.5902187939500436 seconds
-------------------------------------------------------------------------------------------------------------

@sywangyi
Copy link
Copy Markdown
Collaborator Author

@sywangyi It seems the PR "breaks" regular generation. Running

python run_generation.py --model_name_or_path gpt2 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --prompt "Hello world"

on main returns

01/15/2024 21:55:18 - INFO - __main__ - Running generate...                                                                                  
                                                                                                                                             
Input/outputs:                                                                                                                               
input 1: ('Hello world',)                                                                                                                    
output 1: ('Hello world, I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to 
say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to say.\n\n"I\'m sorry, but I\'m not sure what to 
say.\n\n"I\'m sorry, but I\'m not',)                                                                                                         
                                                                                                                                             
                                                                                                                                             
Stats:                                                                                                                                       
--------------------------------------------------------------------------------------------------------------                               
Throughput (including tokenization) = 859.1097929752864 tokens/second                                                                        
Number of HPU graphs                = 12                                                                                                     
Memory allocated                    = 0.51 GB                                                                                                
Max memory allocated                = 0.62 GB                                                                                                
Total memory available              = 94.62 GB
Graph compilation duration          = 1.3922957889735699 seconds
--------------------------------------------------------------------------------------------------------------

but on your branch I get

01/15/2024 21:55:53 - INFO - __main__ - Running generate...

Input/outputs:
input 1: ('Hello world',)
output 1: ('Hello world,',)


Stats:
-------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 34314.56502749868 tokens/second
Number of HPU graphs                = 6
Memory allocated                    = 0.51 GB
Max memory allocated                = 0.62 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 0.5902187939500436 seconds
-------------------------------------------------------------------------------------------------------------

upload a commit to fix it.

Comment thread optimum/habana/transformers/generation/utils.py
@bhargaveede
Copy link
Copy Markdown

I see perf drop 1.082(main) vs 0.972 (on your branch)
@sywangyi Do you see the same?

python3 /root/optimum-habana/examples/summarization/run_summarization.py --model_name_or_path t5-3b --do_predict --predict_with_generate --dataset_name cnn_dailymail --dataset_config 3.0.0 --use_habana --per_device_eval_batch_size 2 --gaudi_config_name Habana/t5 --generation_num_beams 4 --ignore_pad_token_for_loss False --pad_to_max_length --use_hpu_graphs_for_inference --use_lazy_mode --max_predict_samples 200 --bf16 --bf16_full_eval --output_dir /tmp/tmpccx5bpdn

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Jan 23, 2024

I see perf drop 1.082(main) vs 0.972 (on your branch) @sywangyi Do you see the same?

python3 /root/optimum-habana/examples/summarization/run_summarization.py --model_name_or_path t5-3b --do_predict --predict_with_generate --dataset_name cnn_dailymail --dataset_config 3.0.0 --use_habana --per_device_eval_batch_size 2 --gaudi_config_name Habana/t5 --generation_num_beams 4 --ignore_pad_token_for_loss False --pad_to_max_length --use_hpu_graphs_for_inference --use_lazy_mode --max_predict_samples 200 --bf16 --bf16_full_eval --output_dir /tmp/tmpccx5bpdn

Hi, the difference is caused by the change in
image
the question is I think the logic of "cur_len == self.generation_config.max_length" as stop criteria is incorrect

  1. self.generation_config.max_length is not equal to max_new_length + prompt len in generation_config. lead to output text with incorrected length. (you could use the text generation script I show in the comment to try. also,using current logic will generate 2 less tokenizer than using stop criteria, if you use cur_len == self.generation_config.max_length+2,the perf is apple to apple.
  2. there's a lot of stop criteria. even user could define their own stop criteria. and this logic break the rule.

@bhargaveede
Copy link
Copy Markdown

  1. cur_len == self.generation_config.max_length+2

Understood.
cur_len == self.generation_config.max_length+2 is also a static check whereas slicing will cause shape change across runs and that's impacting the perf. We need to use someway of using pad_token_id with input_ids (or) some other workaround to avoid the shape change which is impacting the perf

@bhargaveede
Copy link
Copy Markdown

  1. cur_len == self.generation_config.max_length+2

Understood. cur_len == self.generation_config.max_length+2 is also a static check whereas slicing will cause shape change across runs and that's impacting the perf. We need to use someway of using pad_token_id with input_ids (or) some other workaround to avoid the shape change which is impacting the perf

Also, when ever user defines their own stop criteria, perf degradation is expected based on whether they support static_shapes (or) not. However, in default case, we should make it static_shapes for better perf. I will also brainstorm on ways to avoid slicing.

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Jan 29, 2024

@bhargaveede and @regisss the is max length criteria that we use to determine if the generation has ended, you could see it use input_ids.shape[-1] and compare with max_length. if you want to use static shape. we need to pass token idx in. override the max length criteria in transformer. user defined criteria also need the token idx, because input_ids.shape contain pad length which is not reliable. what's your point?

class MaxLengthCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
    in mind for decoder-only type of transformers, this will include the initial prompted tokens.

    Args:
        max_length (`int`):
            The maximum length that the output sequence can have in number of tokens.
        max_position_embeddings (`int`, `optional`):
            The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
    """

    def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
        self.max_length = max_length
        self.max_position_embeddings = max_position_embeddings

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        cur_len = input_ids.shape[-1]
        is_done = cur_len >= self.max_length
        if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
            logger.warning_once(
                "This is a friendly reminder - the current text generation call will exceed the model's predefined "
                f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
                "exceptions, performance degradation, or nothing at all."
            )
        return is_done

@bhargaveede
Copy link
Copy Markdown

@bhargaveede and @regisss the is max length criteria that we use to determine if the generation has ended, you could see it use input_ids.shape[-1] and compare with max_length. if you want to use static shape. we need to pass token idx in. override the max length criteria in transformer. user defined criteria also need the token idx, because input_ids.shape contain pad length which is not reliable. what's your point?

class MaxLengthCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
    in mind for decoder-only type of transformers, this will include the initial prompted tokens.

    Args:
        max_length (`int`):
            The maximum length that the output sequence can have in number of tokens.
        max_position_embeddings (`int`, `optional`):
            The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
    """

    def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
        self.max_length = max_length
        self.max_position_embeddings = max_position_embeddings

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        cur_len = input_ids.shape[-1]
        is_done = cur_len >= self.max_length
        if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
            logger.warning_once(
                "This is a friendly reminder - the current text generation call will exceed the model's predefined "
                f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
                "exceptions, performance degradation, or nothing at all."
            )
        return is_done

Cant we using StaticMaxLengthCriteria?

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Jan 29, 2024

now the the StoppingCriteria contained more than one criterias. for example. for text generation. it contains [<transformers.generation.stopping_criteria.MaxLengthCriteria object at 0x7fc31fc862f0>, <optimum.habana.transformers.generation.utils.StaticMaxLengthCriteria object at 0x7fc31fc86320>]. once one of the criteria meet exit requirment. the generaion will end. and if user defined criteria like end in some specific "ids", correct logic is to get the latest generation ids and comparing. see if meet the ending requirement.

@bhargaveede
Copy link
Copy Markdown

bhargaveede commented Jan 29, 2024

now the the StoppingCriteria contained more than one criterias. for example. for text generation. it contains [<transformers.generation.stopping_criteria.MaxLengthCriteria object at 0x7fc31fc862f0>, <optimum.habana.transformers.generation.utils.StaticMaxLengthCriteria object at 0x7fc31fc86320>]. once one of the criteria meet exit requirment. the generaion will end. and if user defined criteria like end in some specific "ids", correct logic is to get the latest generation ids and comparing. see if meet the ending requirement.

Got it. I think both StaticMaxLengthCriteria and MaxLengthCritieria shouldn't be together. If it's there then we should avoid it.
AS per your earlier comment can we modify coden "cur_len == self.generation_config.max_length+2" to something like "cur_len == adjusted_max_length"?
self.generation_config.max_length is not equal to max_new_length + prompt len in generation_config. lead to output text with incorrected length. (you could use the text generation script I show in the comment to try. also,using current logic will generate 2 less tokenizer than using stop criteria, if you use cur_len == self.generation_config.max_length+2,the perf is apple to apple.
@regisss
In case of user defined criteria, It's fine if we miss static shape there as it depends on the user.
But for all other default criterias, We should maintain static shape validity in the checks. What do you suggest? @regisss

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Feb 5, 2024

@sywangyi @bhargaveede MaxLengthCriteria and StaticMaxLengthCriteria should be mutually exclusive if I understand correctly.

Ideally, we could remove the StaticMaxLengthCriteria class from the codebase and override the MaxLengthCriteria from Transformers:

  • it should get token_idx=None as an optional argument
  • if token_idx is None, fall back to Transformers implementation
  • if token_idx is not None, compare it with max length

Would that solve this issue and work for you?

@sywangyi
Copy link
Copy Markdown
Collaborator Author

sywangyi commented Feb 5, 2024

agree, we should overide MaxLengthCriteria, pass token_idx in and remove StaticMaxLengthCriteria, which makes things simple.

@bhargaveede
Copy link
Copy Markdown

Changes look fine to me. I will check locally if t5 and bart are fine with the patch.

@bhargaveede
Copy link
Copy Markdown

@regisss @sywangyi Bart and T5 are fine

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a couple of comments. Also, there is now a merge conflict as I merged #651 yesterday. Let me know if you need any help to solve them.

Comment thread optimum/habana/transformers/generation/utils.py
Comment thread optimum/habana/transformers/generation/utils.py
@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Feb 19, 2024

Hmm still seeing some throughput regressions, for example with BLOOMZ-7b:

python3 /root/workspace/optimum-habana/examples/text-generation/run_generation.py --model_name_or_path bigscience/bloomz-7b1 --batch_size 1 --use_kv_cache --max_new_tokens 100 --use_hpu_graphs --bf16

I get 37.98 tokens/s instead of 41.52 on Gaudi1, and 106.42 instead of 130.1 on Gaudi2.
Do you get the same @sywangyi ?

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@sywangyi
Copy link
Copy Markdown
Collaborator Author

Hmm still seeing some throughput regressions, for example with BLOOMZ-7b:

python3 /root/workspace/optimum-habana/examples/text-generation/run_generation.py --model_name_or_path bigscience/bloomz-7b1 --batch_size 1 --use_kv_cache --max_new_tokens 100 --use_hpu_graphs --bf16

I get 37.98 tokens/s instead of 41.52 on Gaudi1, and 106.42 instead of 130.1 on Gaudi2. Do you get the same @sywangyi ?

yes, I reproduce it and find it's related with token_idx, because it's a hpu tensor? upload a PR to fix it.

@sywangyi
Copy link
Copy Markdown
Collaborator Author

could you try by your side, and see if it's working now @regisss ?

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Feb 20, 2024

could you try by your side, and see if it's working now @regisss ?

All regression tests passed, except the torch.compile one that seems to be broken:

Traceback (most recent call last):
  File "/root/workspace/optimum-habana/examples/text-generation/run_generation.py", line 562, in <module>
    main()
  File "/root/workspace/optimum-habana/examples/text-generation/run_generation.py", line 335, in main
    generate(None, args.reduce_recompile)
  File "/root/workspace/optimum-habana/examples/text-generation/run_generation.py", line 312, in generate
    outputs = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/transformers/generation/utils.py", line 655, in generate
    self._validate_generated_length(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1139, in _validate_generated_length
    if input_ids_length >= generation_config.max_length:
RuntimeError: [Rank:0] FATAL ERROR :: MODULE:PT_BRIDGE Exception in Lowering thread...
Graph compile failed. synStatus=synStatus 26 [Generice failure]. 
[Rank:0] Habana exception raised from compile at graph.cpp:503

Here is the command to reproduce it (torch.compile is available on Gaudi2 only):

PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python3 run_generation.py --model_name_or_path meta-llama/Llama-2-7b-hf --batch_size 1 --use_kv_cache --max_new_tokens 100 --attn_softmax_bf16 --reuse_cache --trim_logits --torch_compile --bf16

It comes from this line in Transformers: https://github.com/huggingface/transformers/blob/345b9b1a6a308a1fa6559251eb33ead2211240ac/src/transformers/generation/utils.py#L1139
Maybe we should override the _validate_generated_length method so that it takes the current length as an input argument for comparison with max length?

@sywangyi
Copy link
Copy Markdown
Collaborator Author

Hi, @regisss , I use you command, and see different error in gaudi 2
error is
Traceback (most recent call last):
File "run_generation.py", line 562, in
main()
File "run_generation.py", line 335, in main
generate(None, args.reduce_recompile)
File "run_generation.py", line 312, in generate
outputs = model.generate(
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/generation/utils.py", line 789, in generate
return self.greedy_search(
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/generation/utils.py", line 1370, in greedy_search
outputs = self(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1521, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1530, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py", line 763, in forward
outputs = self.model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1521, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1530, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1521, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1530, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py", line 672, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1521, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1530, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py", line 465, in forward
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py", line 507, in pre_attn
hidden_states = self.input_layernorm(hidden_states)
File "/home/user/wangyi/optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py", line 507, in
hidden_states = self.input_layernorm(hidden_states)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 3917, in forward
return compiled_fn(full_args)
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1482, in g
return f(*args)
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 2533, in runtime_wrapper
all_outs = call_func_with_args(
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1506, in call_func_with_args
out = normalize_as_list(f(args))
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1594, in rng_functionalization_wrapper
return compiled_fw(args)
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1482, in g
return f(*args)
File "<eval_with_key>.1220", line 17, in forward
File "/usr/local/lib/python3.8/dist-packages/torch/_ops.py", line 448, in call
return self._op(*args, **kwargs or {})
RuntimeError: shape mismatch: value tensor of shape [1, 32, 1, 128] cannot be broadcast to indexing result of shape [1, 32, 128]

Also there's code like in this PR
image
so I wander why this error occur

Comment thread optimum/habana/transformers/generation/utils.py Outdated
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
@regisss regisss merged commit 6a483a2 into main Feb 20, 2024
@regisss regisss deleted the beam_search_fix branch February 20, 2024 09:30
jychen21 pushed a commit to jychen21/optimum-habana that referenced this pull request Feb 27, 2024
HolyFalafel pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Mar 11, 2024
@libinta libinta mentioned this pull request Mar 13, 2024
3 tasks
gplutop7 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Oct 15, 2025
Co-authored-by: Piotr Bielak <pbielak@users.noreply.github.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants