diff --git a/optimum/habana/transformers/models/bloom/modeling_bloom.py b/optimum/habana/transformers/models/bloom/modeling_bloom.py index b76aa05a64..922675183c 100644 --- a/optimum/habana/transformers/models/bloom/modeling_bloom.py +++ b/optimum/habana/transformers/models/bloom/modeling_bloom.py @@ -85,8 +85,13 @@ def gaudi_bloom_build_alibi_tensor( ).unsqueeze(0).expand(num_heads, -1, -1) # Select the part of the tensor that corresponds to our tensor parallel index. - tp_world_size = int(os.environ.get("WORLD_SIZE", 1)) - tp_index = int(os.environ.get("RANK", 0)) + # if inference_tp_size is set use it instead of world size + world = int(os.environ.get("WORLD_SIZE", 1)) + tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world + tp_index = 0 # if world size == 1 ignore rank and use 0 (for cases where WORLD_SIZE is not equal to tp size) + if tp_world_size > 1: + tp_index = int(os.environ.get("RANK", 0)) + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] alibi = alibi.repeat(batch_size, 1, 1) @@ -282,7 +287,8 @@ def gaudi_bloom_convert_to_standard_cache( if training: num_heads = batch_size_times_num_heads // batch_size else: - tp_world_size = int(os.environ.get("WORLD_SIZE", 1)) + world = int(os.environ.get("WORLD_SIZE", 1)) + tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world num_heads = self.config.n_head // tp_world_size batch_size = batch_size_times_num_heads // num_heads # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] @@ -462,6 +468,13 @@ def custom_forward(*inputs): class GaudiBloomForCausalLM(BloomForCausalLM): + inference_tp_size = None + + def set_tp_for_inference(tp_for_inference: int): + world = int(os.environ.get("WORLD_SIZE", 1)) + assert tp_for_inference == 1 or tp_for_inference == world, "only setting 1 (no tp) or world size is supported" + GaudiBloomForCausalLM.inference_tp_size = tp_for_inference + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor,