diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 861d027a26..a825476d1e 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -240,7 +240,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type == "llama" or "falcon": + if model.config.model_type in ["llama", "falcon"]: patch_scoped_linear_all_reduce(model) if args.quant_config: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index a329ec1ac0..8f7ed7b168 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -65,9 +65,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: out = F.dropout(x, p=prob, training=training) if training: out = residual + out + return out else: - out.add_(residual) - return out + residual.add_(out) + return residual def apply_customized_rope(q, k, cos, sin, position_ids): @@ -536,21 +537,25 @@ def forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states - hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out = ( - self.pre_attn( # layernorm + attention before AllReduce - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - reuse_cache=reuse_cache, - cache_idx=cache_idx, - **kwargs, - ) + ( + hidden_states, + present, + attn_scores, + attention_layernorm_out, + mlp_layernorm_out, + ) = self.pre_attn( # layernorm + attention before AllReduce + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + **kwargs, ) self.self_attention.attention_all_reduce(hidden_states)