Skip to content
Merged
3 changes: 2 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_rank 4 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True
--use_flash_attention True \
--flash_attention_causal_mask True
```

- Multi-card finetuning of Llama2-70B with FSDP and LoRA:
Expand Down
12 changes: 11 additions & 1 deletion examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ class ModelArguments:
)
},
)
flash_attention_causal_mask: bool = field(
default=False,
metadata={
"help": (
"Whether to enable causal mask in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True.",
)
},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -547,7 +556,8 @@ def main():
if model_args.use_flash_attention:
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
if model_args.use_fused_rope is False:
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None:
Expand Down
24 changes: 24 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
```
`--fp8` is required to enable quantization in fp8.

### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.

Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same length it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation.

```bash
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--use_hpu_graphs \
--use_kv_cache \
--reuse_cache \
--trim_logits \
--attn_softmax_bf16 \
--max_input_tokens 31744 \
--max_new_tokens 1024 \
--batch_size=12 \
--use_flash_attention \
--flash_attention_recompute \
--flash_attention_causal_mask \
--book_source
```

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

## Language Model Evaluation Harness

Expand Down
54 changes: 54 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,21 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)
parser.add_argument(
"--flash_attention_recompute",
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
"--flash_attention_causal_mask",
action="store_true",
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
)
parser.add_argument(
"--book_source",
action="store_true",
help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.",
)
parser.add_argument(
"--torch_compile",
action="store_true",
Expand Down Expand Up @@ -272,6 +287,45 @@ def main():
# Benchmark over the prompts below
if args.prompt:
input_sentences = args.prompt
elif args.book_source:

def download_book(book_id):
import os

import requests

url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
response = requests.get(url)
if response.status_code == 200:
pid = os.getpid()
save_path = f"/tmp/{book_id}_{pid}.txt"
with open(save_path, "wb") as file:
file.write(response.content)
print(f"Book downloaded and saved to: {save_path}")
return save_path
else:
print("Failed to download book! Exiting...")
import sys

sys.exit()

def assemble_prompt(prompt_size, book_path):
prompt = ""
counter = 0
book_lines = open(book_path).readlines()
for line in book_lines:
for word in line.split():
counter += 1
prompt += word + " "
if counter == prompt_size:
return [prompt] * args.batch_size

book_ids = [
2701, # Moby Dick; Or, The Whale
1513, # Romeo and Juliet
1342, # Pride and Prejudice
]
input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0]))
Comment thread
regisss marked this conversation as resolved.
else:
input_sentences = [
"DeepSpeed is a machine learning framework",
Expand Down
2 changes: 2 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def setup_generation_config(args, model, tokenizer):
if generation_config.reduce_recompile:
assert generation_config.bucket_size > 0
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
return generation_config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Whether to enable recompute if use Habana flash attention.
flash_attention_causal_mask (`bool`, *optional*):
Whether to enable causal_mask if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -48,4 +50,5 @@ def __init__(self, **kwargs):
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
2 changes: 2 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,8 @@ def generate(
# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1]
if not generation_config.static_shapes and generation_config.max_new_tokens is not None:
Expand Down
27 changes: 23 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -325,6 +326,7 @@ def pre_attn_forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -408,10 +410,15 @@ def pre_attn_forward(
)
else:
# first token
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -498,6 +505,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -509,6 +517,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand All @@ -529,6 +538,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
**kwargs,
)
Expand Down Expand Up @@ -560,6 +570,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -576,6 +587,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
cache_idx=cache_idx,
)
return hidden_states, attn_weights, present_key_value
Expand Down Expand Up @@ -668,6 +680,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
Expand All @@ -679,6 +692,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg lazy_mode
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -778,6 +792,7 @@ def forward(
False,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -793,6 +808,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -864,6 +880,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
Expand Down Expand Up @@ -893,6 +910,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
Expand Down Expand Up @@ -1027,6 +1045,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True

# TODO: keep syncs for fast DDP?
with self.accelerator.accumulate(model):
Expand Down Expand Up @@ -1806,6 +1808,8 @@ def evaluation_loop(
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
Expand Down