diff --git a/oslo/transformers/models/gpt2/modeling_gpt2.py b/oslo/transformers/models/gpt2/modeling_gpt2.py index 8ca6b243..4a903ae9 100644 --- a/oslo/transformers/models/gpt2/modeling_gpt2.py +++ b/oslo/transformers/models/gpt2/modeling_gpt2.py @@ -47,59 +47,6 @@ logger = logging.get_logger(__name__) -class GPT2PreTrainedModel(PreTrainedModel, OsloModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPT2Config - load_tf_weights = load_tf_weights_in_gpt2 - base_model_prefix = "transformer" - is_parallelizable = True - supports_gradient_checkpointing = True - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, onn.Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if "c_proj" in name and "weight" in name: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.data.normal_( - mean=0.0, - std=( - self.config.initializer_range - / math.sqrt(2 * self.config.n_layer) - ), - ) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GPT2Model): - module.gradient_checkpointing = value - - class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -456,6 +403,59 @@ def forward( return outputs # hidden_states, present, (attentions, cross_attentions) +class GPT2PreTrainedModel(PreTrainedModel, OsloModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, onn.Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_( + mean=0.0, + std=( + self.config.initializer_range + / math.sqrt(2 * self.config.n_layer) + ), + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + class GPT2Model(GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = ["attn.masked_bias"] @@ -747,10 +747,12 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: + # only last token for inputs_ids if past_key_values is defined in kwargs + if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) @@ -762,18 +764,27 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past: + if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs def forward( self,