-
Notifications
You must be signed in to change notification settings - Fork 20
Enable Flash Attention in recompute and causal modes #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ff27bfa
363fc7c
d0d7403
e20451f
347a0bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -232,6 +232,21 @@ def setup_parser(parser): | |
| action="store_true", | ||
| help="Whether to enable Habana Flash Attention, provided that the model supports it.", | ||
| ) | ||
| parser.add_argument( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems as a counter intuitive argument for inferencing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is no performance penalty and memory is also saved then we can internally pass it as True for 1st token when flash attention is used.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline. Here is summary: |
||
| "--flash_attention_recompute", | ||
| action="store_true", | ||
| help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", | ||
| ) | ||
| parser.add_argument( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we forcefully set this to True when batch size is 1 and when use_flash_attention is passed. We can add it to help text that this will be taken care.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline. Here is summary: |
||
| "--flash_attention_causal_mask", | ||
| action="store_true", | ||
| help="Whether to enable Habana Flash Attention in causal mode on first token generation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--book_source", | ||
| action="store_true", | ||
| help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", | ||
| ) | ||
| parser.add_argument( | ||
| "--torch_compile", | ||
| action="store_true", | ||
|
|
@@ -271,6 +286,45 @@ def main(): | |
| # Benchmark over the prompts below | ||
| if args.prompt: | ||
| input_sentences = args.prompt | ||
| elif args.book_source: | ||
|
|
||
| def download_book(book_id): | ||
| import os | ||
|
|
||
| import requests | ||
|
|
||
| url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" | ||
| response = requests.get(url) | ||
| if response.status_code == 200: | ||
| pid = os.getpid() | ||
| save_path = f"/tmp/{book_id}_{pid}.txt" | ||
| with open(save_path, "wb") as file: | ||
| file.write(response.content) | ||
| print(f"Book downloaded and saved to: {save_path}") | ||
| return save_path | ||
| else: | ||
| print("Failed to download book! Exiting...") | ||
| import sys | ||
|
|
||
| sys.exit() | ||
|
|
||
| def assemble_prompt(prompt_size, book_path): | ||
| prompt = "" | ||
| counter = 0 | ||
| book_lines = open(book_path).readlines() | ||
| for line in book_lines: | ||
| for word in line.split(): | ||
| counter += 1 | ||
| prompt += word + " " | ||
| if counter == prompt_size: | ||
| return [prompt] * args.batch_size | ||
|
|
||
| book_ids = [ | ||
| 2701, # Moby Dick; Or, The Whale | ||
| 1513, # Romeo and Juliet | ||
| 1342, # Pride and Prejudice | ||
| ] | ||
| input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) | ||
| else: | ||
| input_sentences = [ | ||
| "DeepSpeed is a machine learning framework", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.