Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 76 additions & 65 deletions oslo/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,59 +47,6 @@
logger = logging.get_logger(__name__)


class GPT2PreTrainedModel(PreTrainedModel, OsloModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""

config_class = GPT2Config
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, onn.Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.n_layer)
),
)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPT2Model):
module.gradient_checkpointing = value


class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
Expand Down Expand Up @@ -456,6 +403,59 @@ def forward(
return outputs # hidden_states, present, (attentions, cross_attentions)


class GPT2PreTrainedModel(PreTrainedModel, OsloModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""

config_class = GPT2Config
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, onn.Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.n_layer)
),
)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPT2Model):
module.gradient_checkpointing = value


class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]

Expand Down Expand Up @@ -747,10 +747,12 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
# only last token for inputs_ids if past_key_values is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
Expand All @@ -762,18 +764,27 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs

def forward(
self,
Expand Down