Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_rank 4 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True
--use_flash_attention True \
--flash_attention_causal_mask True
```

- Multi-card finetuning of Falcon-180B:
Expand Down
10 changes: 10 additions & 0 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ class ModelArguments:
)
},
)
flash_attention_causal_mask: bool = field(
default=False,
metadata={
"help": (
"Whether to enable causal mask in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True.",
)
},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -545,6 +554,7 @@ def main():
if model_args.use_flash_attention:
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

Expand Down
24 changes: 24 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
```
`--fp8` is required to enable quantization in fp8.

### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.

Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation.

```bash
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--use_hpu_graphs \
--use_kv_cache \
--reuse_cache \
--trim_logits \
--attn_softmax_bf16 \
--max_input_tokens 31744 \
--max_new_tokens 1024 \
--batch_size=12 \
--use_flash_attention \
--flash_attention_recompute \
Comment thread
wszczurekhabana marked this conversation as resolved.
--flash_attention_causal_mask \
--book_source
```

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

## Language Model Evaluation Harness

Expand Down
54 changes: 54 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)
parser.add_argument(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems as a counter intuitive argument for inferencing.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no performance penalty and memory is also saved then we can internally pass it as True for 1st token when flash attention is used.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Here is summary:
We should left it controllable, cause it may be reasonable to turn it off for example on finetuning.
Additionnaly this parameter may cause some slight overhead even on inference, so maybe it's reasonable to turn it off if we don't need it.

"--flash_attention_recompute",
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we forcefully set this to True when batch size is 1 and when use_flash_attention is passed. We can add it to help text that this will be taken care.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Here is summary:
The improvements in memory are to achieve higher batch size and thus increase throughput. Improvements from setting causal=True (and using triangular softmax underneath) will scale with batch size and seq-lenght. So causal on bigger batches is still usefull, but we need to be mindfull that without speciall handling of padding tokens it requires same lenght inputs in a batch.

"--flash_attention_causal_mask",
action="store_true",
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
)
parser.add_argument(
"--book_source",
action="store_true",
help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.",
)
parser.add_argument(
"--torch_compile",
action="store_true",
Expand Down Expand Up @@ -271,6 +286,45 @@ def main():
# Benchmark over the prompts below
if args.prompt:
input_sentences = args.prompt
elif args.book_source:

def download_book(book_id):
import os

import requests

url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
response = requests.get(url)
if response.status_code == 200:
pid = os.getpid()
save_path = f"/tmp/{book_id}_{pid}.txt"
with open(save_path, "wb") as file:
file.write(response.content)
print(f"Book downloaded and saved to: {save_path}")
return save_path
else:
print("Failed to download book! Exiting...")
import sys

sys.exit()

def assemble_prompt(prompt_size, book_path):
prompt = ""
counter = 0
book_lines = open(book_path).readlines()
for line in book_lines:
for word in line.split():
counter += 1
prompt += word + " "
if counter == prompt_size:
return [prompt] * args.batch_size

book_ids = [
2701, # Moby Dick; Or, The Whale
1513, # Romeo and Juliet
1342, # Pride and Prejudice
]
input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0]))
else:
input_sentences = [
"DeepSpeed is a machine learning framework",
Expand Down
2 changes: 2 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def setup_generation_config(args, model, tokenizer):
assert generation_config.bucket_size > 0
generation_config.kv_cache_fp8 = args.kv_cache_fp8
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
return generation_config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Whether to enable recompute if use Habana flash attention.
flash_attention_causal_mask (`bool`, *optional*):
Whether to enable causal_mask if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -49,4 +51,5 @@ def __init__(self, **kwargs):
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ def generate(
# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True

if not self.config.is_encoder_decoder:
Expand Down
27 changes: 23 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -211,6 +212,7 @@ def pre_attn_forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -289,10 +291,15 @@ def pre_attn_forward(
)
else:
# first token
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same lenght
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -424,6 +431,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -435,6 +443,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
residual = hidden_states
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
Expand All @@ -449,6 +458,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -479,6 +489,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -495,6 +506,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -545,6 +557,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
Expand All @@ -556,6 +569,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -637,6 +651,7 @@ def custom_forward(*inputs):
attn_softmax_bf16=attn_softmax_bf16,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
use_fused_rope=use_fused_rope,
)

Expand All @@ -658,6 +673,7 @@ def custom_forward(*inputs):
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -727,6 +743,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
Expand All @@ -751,6 +768,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -838,6 +856,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"cache_idx": kwargs.get("cache_idx"),
}
)
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True
if not self.model.generation_config.use_fused_rope:
inputs["use_fused_rope"] = False

Expand Down Expand Up @@ -1628,6 +1630,8 @@ def evaluation_loop(
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True
if not self.model.generation_config.use_fused_rope:
inputs["use_fused_rope"] = False

Expand Down