Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
Expand Down Expand Up @@ -334,18 +335,37 @@ def forward(

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

if layer_past is not None:
_, q_len, _ = hidden_states.size()
bucket_internal_decode_stage = cache_idx is not None and q_len == 1

if not bucket_internal_decode_stage:
if layer_past is not None:
past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
if token_idx is not None:
# Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled.
key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1))
value = past_value.index_add(
1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)
)
else:
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None
else:
assert token_idx is not None, "Invalid parameters: token_idx is None at decode stage with bucket_internal"
assert (
layer_past is not None
), "Invalid parameters: layer_past is None at decode stage with bucket_internal"

past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
if token_idx is not None:
# Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled.
key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1))
value = past_value.index_add(
1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)
)
else:
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None
key = past_key.index_copy_(1, token_idx - 1, key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you verify this works with tgi-gaudi.. out of place op was used to fix a specific issue when tensor cache is disabled otherwise we saw error

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sent you ticket link of empty tensor optional error with tgi-gaudi

Copy link
Copy Markdown
Contributor Author

@mgonchar mgonchar Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidyasiv I tried to rollback changes from your commit #1181 and it works for me on latest 1.18 with command line

PT_HPU_DISABLE_TENSOR_CACHE=1 python run_generation.py --model_name_or_path bigcode/starcoder --batch_size 2 --use_hpu_graphs --use_kv_cache --max_new_tokens 100 --bf16

and output is the same as without PT_HPU_DISABLE_TENSOR_CACHE variable. It seems that original issues was fixed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for bucket it also works fine and gives the same output:

PT_HPU_DISABLE_TENSOR_CACHE=1 python run_generation.py --model_name_or_path bigcode/starcoder --batch_size 2 --use_hpu_graphs --use_kv_cache  --max_new_tokens 100 --bf16 --bucket_size=128 --bucket_internal

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidyasiv What's the TGI config that was leading to an error?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@regisss issue in tgi from original ticket:

# server:
text-generation-launcher --model-id bigcode/starcoderbase-3b --sharded false --hostname 127.0.0.1 --max-input-length 2048  --max-batch-size 8 --dtype bfloat16

# In container: 
docker run -it --runtime=habana --name gaudi-tgi-scb-3b-e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e ENABLE_HPU_GRAPH=True -e BATCH_BUCKET_SIZE=8  -e PREFILL_BATCH_BUCKET_SIZE=4  -e PAD_SEQUENCE_TO_MULTIPLE_OF=128 --cap-add=sys_nice --net=host --entrypoint bash tgi_gaudi

HF equivalent back then was to set PT_HPU_DISABLE_TENSOR_CACHE=1 and --use_hpu_graphs

value = past_value.index_copy_(1, token_idx - 1, value)
present = layer_past

if bucket_internal_decode_stage:
key = key[:, :cache_idx, :]
value = value[:, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]

if not output_attentions and head_mask is None and use_flash_attention:
# Difference with the original implementation: there is no need to transpose the key here,
Expand All @@ -367,6 +387,11 @@ def forward(
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

if bucket_internal_decode_stage:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
present = present.shape

outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
Expand All @@ -392,6 +417,7 @@ def gaudi_gpt_bigcode_block_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Expand All @@ -413,6 +439,7 @@ def gaudi_gpt_bigcode_block_forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand Down Expand Up @@ -475,6 +502,7 @@ def gaudi_gpt_bigcode_model_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Expand Down Expand Up @@ -638,6 +666,7 @@ def gaudi_gpt_bigcode_model_forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -750,6 +779,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute", False),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False),
"cache_idx": kwargs.get("cache_idx", None),
}
)
return model_inputs
Expand All @@ -775,6 +805,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -803,6 +834,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
hidden_states = transformer_outputs[0]

Expand Down
10 changes: 2 additions & 8 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,15 @@
("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354, False),
("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076, True),
("tiiuae/falcon-40b", 1, True, 25.202450111088346, False),
(
"bigcode/starcoder",
256,
True,
6846.575763562658,
False,
), # TODO: Enable check_output after model bigcode/starcoder is fixed
("bigcode/starcoder", 256, True, 6846.575763562658, True),
("Salesforce/codegen2-1B", 1, False, 446.4029486883532, False),
("mosaicml/mpt-30b", 1, False, 36.06464336116623, False),
("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782, True),
("mistralai/Mixtral-8x7B-v0.1", 1, False, 23.7931001677926, True),
("microsoft/phi-2", 1, False, 224.72307766211117, False),
("meta-llama/Meta-Llama-3-8B", 1, True, 129, False),
("meta-llama/Llama-2-7b-hf", 512, True, 12808, False),
("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isnt used
("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isn't used
("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218, False),
("codellama/CodeLlama-34b-hf", 1, True, 32.644, False),
("bigcode/starcoder2-3b", 1, False, 261.07213776344133, True),
Expand Down