Split the graphs to run with flash_attention on 1x#75
Conversation
|
|
||
| for layer_idx, decoder_layer in enumerate(self.layers): | ||
| if torch.distributed.is_initialized() == False: | ||
| htcore.mark_step() |
There was a problem hiding this comment.
@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?
There was a problem hiding this comment.
@kalyanjk ,why only mark_step() for 1x?
There was a problem hiding this comment.
For 8x mark_step will be introduced through a collective call.
There was a problem hiding this comment.
@kalyanjk what's the impact for input/output not introduced oom? should we add an argument in text-generation from cmd line?
The issue is not with oom. The real issue is recipe size being too large and also compilation time is too high.
There was a problem hiding this comment.
Please update as below
if lazy_mode and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1):
|
Wait we should not put mark step after the start of loop. Will create more graphs and perf is lower. |
|
msinnha1
left a comment
There was a problem hiding this comment.
Verified the change and it is required for faster recipe compilation
| _gaudi_prepare_4d_causal_attention_mask, | ||
| ) | ||
|
|
||
| import habana_frameworks.torch.core as htcore |
There was a problem hiding this comment.
If you rebase to latest then this htcore import is not required, as it is part of PR#65
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
* Split the graphs to run with flash_attention on 1x * Added lazy_mode check and removed additional htcore import --------- Co-authored-by: Kalyan <kkumar@habana.ai>
|
This PR solves the actual issue #126 |
With flash attention enabled for larger batch sizes, recipe arc hbm memory size exceeds QueueComputeScal arc hbm memory. Hence split the graph on 1x.