Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 59 additions & 82 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,53 +569,41 @@ class ModernBertPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = False

def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3

def init_weight(module: nn.Module, std: float):
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)

if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)

stds = {
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_stds = {
"in": self.config.initializer_range,
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
"embedding": self.config.initializer_range,
"final_out": self.config.hidden_size**-0.5,
}

def _init_weights_trunc_normal(self, module: nn.Module, std: float):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3

nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)

if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)

def _init_weights(self, module: nn.Module):
if isinstance(module, ModernBertEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
self._init_weights_trunc_normal(module.tok_embeddings, self.init_stds["embedding"])
elif isinstance(module, ModernBertMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
self._init_weights_trunc_normal(module.Wi, self.init_stds["in"])
self._init_weights_trunc_normal(module.Wo, self.init_stds["out"])
elif isinstance(module, ModernBertAttention):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, ModernBertPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, ModernBertForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(
module,
(
ModernBertForSequenceClassification,
ModernBertForMultipleChoice,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
),
):
init_weight(module.classifier, stds["final_out"])
self._init_weights_trunc_normal(module.Wqkv, self.init_stds["in"])
self._init_weights_trunc_normal(module.Wo, self.init_stds["out"])
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
if module.bias is not None:
Expand Down Expand Up @@ -967,6 +955,13 @@ def __init__(self, config: ModernBertConfig):
# Initialize weights and apply final processing
self.post_init()

def _init_weights(self, module: nn.Module):
super()._init_weights(module)
if isinstance(module, ModernBertForMaskedLM):
self._init_weights_trunc_normal(module.decoder, self.init_stds["out"])
elif isinstance(module, ModernBertPredictionHead):
self._init_weights_trunc_normal(module.dense, self.init_stds["out"])

def get_output_embeddings(self):
return self.decoder

Expand Down Expand Up @@ -1090,12 +1085,7 @@ def forward(
)


@auto_docstring(
custom_intro="""
The ModernBert Model with a sequence classification head on top that performs pooling.
"""
)
class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
class ModernBertClassificationModel(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
Expand All @@ -1104,11 +1094,29 @@ def __init__(self, config: ModernBertConfig):
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.hidden_size, self.output_dim)

# Initialize weights and apply final processing
self.post_init()

@property
def output_dim(self):
return self.config.num_labels

def _init_weights(self, module: nn.Module):
super()._init_weights(module)
if isinstance(module, ModernBertClassificationModel):
self._init_weights_trunc_normal(module.classifier, self.init_stds["final_out"])
elif isinstance(module, ModernBertPredictionHead):
self._init_weights_trunc_normal(module.dense, self.init_stds["out"])


@auto_docstring(
custom_intro="""
The ModernBert Model with a sequence classification head on top that performs pooling.
"""
)
class ModernBertForSequenceClassification(ModernBertClassificationModel):
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1232,19 +1240,7 @@ def forward(
The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
"""
)
class ModernBertForTokenClassification(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

class ModernBertForTokenClassification(ModernBertClassificationModel):
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1323,18 +1319,7 @@ def forward(


@auto_docstring
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels

self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.post_init()

class ModernBertForQuestionAnswering(ModernBertClassificationModel):
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1419,18 +1404,10 @@ def forward(
The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
"""
)
class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.config = config

self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)

# Initialize weights and apply final processing
self.post_init()
class ModernBertForMultipleChoice(ModernBertClassificationModel):
@property
def output_dim(self):
return 1

@auto_docstring
def forward(
Expand Down
Loading