Skip to content

Split the graphs to run with flash_attention on 1x#75

Merged
3 commits merged into
HabanaAI:habana-mainfrom
kalyanjk:decoder_mark_step
Mar 4, 2024
Merged

Split the graphs to run with flash_attention on 1x#75
3 commits merged into
HabanaAI:habana-mainfrom
kalyanjk:decoder_mark_step

Conversation

@kalyanjk
Copy link
Copy Markdown

With flash attention enabled for larger batch sizes, recipe arc hbm memory size exceeds QueueComputeScal arc hbm memory. Hence split the graph on 1x.

@kalyanjk kalyanjk requested a review from a user February 26, 2024 17:07

for layer_idx, decoder_layer in enumerate(self.layers):
if torch.distributed.is_initialized() == False:
htcore.mark_step()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk ,why only mark_step() for 1x?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 8x mark_step will be introduced through a collective call.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?

The issue is not with oom. The real issue is recipe size being too large and also compilation time is too high.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update as below
if lazy_mode and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1):

@puneeshkhanna
Copy link
Copy Markdown

@kalyanjk - we can abandon this PR. I have handled the change in #65.
This also helps 8x inference.
I m checking 1x perf results too.
Further need to check finetuning script once.

@puneeshkhanna
Copy link
Copy Markdown

Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower.

@kalyanjk
Copy link
Copy Markdown
Author

Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower.
@puneeshkhanna
On G3 we were seeing good perf with mark_step inside the for loop. With mark_step outside the for loop we are not able to run on single card. This issue is also present in G2

Copy link
Copy Markdown

@msinnha1 msinnha1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified the change and it is required for faster recipe compilation

_gaudi_prepare_4d_causal_attention_mask,
)

import habana_frameworks.torch.core as htcore
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you rebase to latest then this htcore import is not required, as it is part of PR#65

Copy link
Copy Markdown

@msinnha1 msinnha1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@ghost ghost merged commit eec5b3f into HabanaAI:habana-main Mar 4, 2024
astachowiczhabana pushed a commit that referenced this pull request Apr 5, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 5, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 19, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 22, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Split the graphs to run with flash_attention on 1x

* Added lazy_mode check and removed additional htcore import

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
@kalyanjk
Copy link
Copy Markdown
Author

This PR solves the actual issue #126

@astachowiczhabana
Copy link
Copy Markdown

astachowiczhabana commented Jun 12, 2024

huggingface#875

@kalyanjk kalyanjk deleted the decoder_mark_step branch July 5, 2024 11:47
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants