Skip to content
2 changes: 1 addition & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
if model.config.model_type == "llama" or "falcon":
if model.config.model_type in ["llama", "falcon"]:
patch_scoped_linear_all_reduce(model)

if args.quant_config:
Expand Down
39 changes: 22 additions & 17 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
out = F.dropout(x, p=prob, training=training)
if training:
out = residual + out
return out
else:
out.add_(residual)
return out
residual.add_(out)
return residual


def apply_customized_rope(q, k, cos, sin, position_ids):
Expand Down Expand Up @@ -536,21 +537,25 @@ def forward(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out = (
self.pre_attn( # layernorm + attention before AllReduce
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
**kwargs,
)
(
hidden_states,
present,
attn_scores,
attention_layernorm_out,
mlp_layernorm_out,
) = self.pre_attn( # layernorm + attention before AllReduce
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
**kwargs,
)

self.self_attention.attention_all_reduce(hidden_states)
Expand Down