From f5690e04e2212113ec7b84a29c84653b4a51bb28 Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Wed, 27 Sep 2023 09:56:46 +0300 Subject: [PATCH 1/2] Add support of tp_size != WORLD_SIZE to bloom Add suport for cases where tp_size != world size we encountered a case during develoment of DS chat step 3 where the tp size for inference needs to be different than the world size. The code in modeling_bloom.py assumes that tp size is always equal to world size in inference and this change allows to configure it differently. --- .../models/bloom/modeling_bloom.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index b76aa05a64..385d91f3dc 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -85,8 +85,13 @@ def gaudi_bloom_build_alibi_tensor( ).unsqueeze(0).expand(num_heads, -1, -1) # Select the part of the tensor that corresponds to our tensor parallel index. - tp_world_size = int(os.environ.get("WORLD_SIZE", 1)) - tp_index = int(os.environ.get("RANK", 0)) + # if inference_tp_size is set use it instead of world size + world = int(os.environ.get("WORLD_SIZE", 1)) + tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world + tp_index = 0 # if world size == 1 ignore rank and use 0 (for cases where WORLD_SIZE is not equal to tp size) + if tp_world_size > 1: + tp_index = int(os.environ.get("RANK", 0)) + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] alibi = alibi.repeat(batch_size, 1, 1) @@ -282,7 +287,8 @@ def gaudi_bloom_convert_to_standard_cache( if training: num_heads = batch_size_times_num_heads // batch_size else: - tp_world_size = int(os.environ.get("WORLD_SIZE", 1)) + world = int(os.environ.get("WORLD_SIZE", 1)) + tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world num_heads = self.config.n_head // tp_world_size batch_size = batch_size_times_num_heads // num_heads # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] @@ -462,6 +468,13 @@ def custom_forward(*inputs): class GaudiBloomForCausalLM(BloomForCausalLM): + inference_tp_size = None + + def set_tp_for_inteference(tp_for_inference): + world = int(os.environ.get("WORLD_SIZE", 1)) + assert tp_for_inference == 1 or tp_for_inference == world,"only setting 1 (no tp) or world size is supported" + GaudiBloomForCausalLM.inference_tp_size = tp_for_inference + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, From 636f1a8916b104c17acaa9511d7f3e6a0be30afe Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 23 Oct 2023 10:13:34 +0300 Subject: [PATCH 2/2] Update optimum/habana/transformers/models/bloom/modeling_bloom.py Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- optimum/habana/transformers/models/bloom/modeling_bloom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index 385d91f3dc..922675183c 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -470,9 +470,9 @@ def custom_forward(*inputs): class GaudiBloomForCausalLM(BloomForCausalLM): inference_tp_size = None - def set_tp_for_inteference(tp_for_inference): + def set_tp_for_inference(tp_for_inference: int): world = int(os.environ.get("WORLD_SIZE", 1)) - assert tp_for_inference == 1 or tp_for_inference == world,"only setting 1 (no tp) or world size is supported" + assert tp_for_inference == 1 or tp_for_inference == world, "only setting 1 (no tp) or world size is supported" GaudiBloomForCausalLM.inference_tp_size = tp_for_inference def prepare_inputs_for_generation(