From d65f11af9b2229cefcd050ec8d7fe158b3fb6a4b Mon Sep 17 00:00:00 2001 From: Kalyan Kumar Date: Thu, 14 Mar 2024 11:56:38 +0530 Subject: [PATCH] Added additional check to run with distributed enabled and world_size=1 (#96) * Added additionla check to run with distributed enabled and world_size = 1 * Reduce the number of graph splits to avoid memory allocation error for 1x LLAMA1_7b_ft --------- Co-authored-by: Kalyan --- optimum/habana/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index dbb8b18f4e..2bd9001cc9 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -696,7 +696,8 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): - if lazy_mode and torch.distributed.is_initialized() == False: + if lazy_mode and use_flash_attention and \ + (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1): htcore.mark_step() if output_hidden_states: