diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 211e58861c..20b34d4923 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -120,6 +120,19 @@ def setup_device(args): return torch.device(args.device) +# patching LinearAllreduce to use ScopedLinearAllReduce +def patch_scoped_linear_all_reduce(model): + from deepspeed.module_inject.layers import LinearAllreduce + + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + patch_scoped_linear_all_reduce(module) + + def setup_model(args, model_dtype, model_kwargs, logger): logger.info("Single-device run.") @@ -194,6 +207,8 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module + if model.config.model_type == "llama": + patch_scoped_linear_all_reduce(model) return model diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 480e329cfb..82639d6f91 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -32,6 +32,7 @@ GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaForCausalLM, + GaudiLlamaMLP, GaudiLlamaModel, GaudiMistralForCausalLM, GaudiMptForCausalLM, @@ -240,6 +241,7 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM transformers.models.llama.modeling_llama.LlamaModel = GaudiLlamaModel transformers.models.llama.modeling_llama.LlamaAttention = GaudiLlamaAttention + transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 08525022d7..0c5b639608 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -62,6 +62,7 @@ GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaForCausalLM, + GaudiLlamaMLP, GaudiLlamaModel, gaudi_llama_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/llama/__init__.py b/optimum/habana/transformers/models/llama/__init__.py index 3827fd63c9..7d93b38078 100644 --- a/optimum/habana/transformers/models/llama/__init__.py +++ b/optimum/habana/transformers/models/llama/__init__.py @@ -2,6 +2,7 @@ GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaForCausalLM, + GaudiLlamaMLP, GaudiLlamaModel, gaudi_llama_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 8f1bc2e7d4..e8109a9433 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -9,11 +9,14 @@ LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, + LlamaMLP, LlamaModel, apply_rotary_pos_emb, logger, ) +from ..modeling_all_models import ScopedLinearAllReduce + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -77,10 +80,20 @@ def gaudi_llama_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tens return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) + self.matmul_qk = Matmul() + self.matmul_av = Matmul() self.past_key = None self.past_value = None self.inp_seq_len = -1 @@ -126,7 +139,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.past_value, beam_idx, seq_length, head_dim) return (self.past_key.shape, self.past_value.shape) - def forward( + def pre_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -137,7 +150,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ): """ Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py The only differences are: @@ -209,7 +222,7 @@ def forward( key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups) value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -232,7 +245,7 @@ def forward( query_states.dtype ) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -244,18 +257,57 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value + def attention_all_reduce(self, attn_output): + if self.o_proj.__class__ is ScopedLinearAllReduce: + self.o_proj.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if self.o_proj.__class__ is ScopedLinearAllReduce: + self.o_proj.post_all_reduce(attn_output) + return attn_output + + +class GaudiLlamaMLP(LlamaMLP): + def pre_mlp_forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + output = sum(down_proj) + else: + input = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + output = self.down_proj(input) + return output + + def mlp_all_reduce(self, x): + if self.down_proj.__class__ is ScopedLinearAllReduce: + self.down_proj.all_reduce(x) + + def post_mlp_forward(self, x): + if self.config.pretraining_tp > 1: + return x + if self.down_proj.__class__ is ScopedLinearAllReduce: + return self.down_proj.post_all_reduce(x) + return x + class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): @@ -287,30 +339,23 @@ def forward( - add new args reuse_cache """ residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - attn_softmax_bf16=attn_softmax_bf16, - reuse_cache=reuse_cache, + output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + token_idx, + attn_softmax_bf16, + reuse_cache, ) - hidden_states = residual + hidden_states + self.self_attn.attention_all_reduce(output_pre_attn) + output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) + self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) + output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) + outputs = (output_post_mlp,) if output_attentions: outputs += (self_attn_weights,) @@ -319,6 +364,48 @@ def forward( return outputs + def pre_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = self.input_layernorm(hidden_states) + output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + token_idx, + attn_softmax_bf16, + reuse_cache, + ) + return output_attn, attn_weights, present_key_value + + def post_attn_pre_mlp(self, input, residual): + output_post_attn = self.self_attn.post_attn_forward(input) + + hidden_states = residual + output_post_attn + residual = hidden_states + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp.pre_mlp_forward(hidden_states) + return hidden_states, residual + + def post_mlp(self, input, residual): + output_post_mlp = self.mlp.post_mlp_forward(input) + output = output_post_mlp + residual + return output + class GaudiLlamaModel(LlamaModel): def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 3647ac8df8..ef039b4240 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -109,3 +109,25 @@ def gaudi_conv1d_forward(self, x): bias = self.bias.view(bias_shape) x = x + bias return x + + +# Splitting DeepSpeed LinearAllReduce to three parts to avoid redundant memory consumption +class ScopedLinearAllReduce(torch.nn.Module): + def __init__(self, mod, *args, **kwargs): + self.__dict__.update(mod.__dict__) + + def forward(self, input): + # pre_all_reduce + + output = torch.matmul(input, self.weight.transpose(-1, -2)) + return output + + def all_reduce(self, input): + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(input, group=self.mp_group) + + def post_all_reduce(self, input): + output = input + self.bias if (self.bias is not None) else input + return output