Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Aug 27, 2024

What does this PR do?

Step 2 of #32685 - Removes the GenerationMixin inheritance from PreTrainedModel. Instead, models classes with generative capabilities directly inherit GenerationMixin.

Why?

Currently, we have a circular dependency between PreTrainedModel and GenerationMixin:

  • PreTrainedModel 👈 GenerationMixin: PreTrainedModel has a can_generate() method, which depends on methods that exist in GenerationMixin. Depending on the value of can_generate(), it may hold a GenerationConfig object.
  • GenerationMixin 👈 PreTrainedModel: GenerationMixin needed to inspect the type of the model instance, to throw informative exceptions at the user. This was needed because ALL our models could call generate, but most of them didn't support it.

This PR breaks this circular dependency:

  1. GenerationMixin becomes a stand-alone class with no dependencies on PreTrainedModel. It is now a proper mixin: it may be used with other model base classes, if users desire to do so.
  2. PreTrainedModel doesn't inherit GenerationMixin. This means that non-generative models will become less bloated :)

What else can we improve as a result of this change?

  1. [added in this PR] can_generate() can be simplified: if a model is a subclass of GenerationMixin then it can generate
  2. [added in this PR] no need to validate the model class in generate -- all GenerationMixin subclasses can call generate
  3. Because of 1., all the changes planned in tracker: move prepare_inputs_for_generation into the generation mixin 🧹  #32685 become much simpler to implement (can_generate() no longer depends on prepare_inputs_for_generation -> easier to make structural changes there) 🤗
  4. Perhaps in the future, we can move the GenerationConfig instance to GenerationMixin, so that non-generative models don't hold a generation_config attribute.

🚨🚨 Caveats 🚨🚨

The changes in this PR have no visible consequences in the following cases:
✅ A user loads a transformers model, like LlamaForCausalLM
✅ A user loads custom modeling code from the hub with our auto classes, like this example

However, there are breaking changes in the following situations:
❌ A user has custom code, inheriting PreTrainedModel, and wants to call generate

@gante gante changed the title Generation: deprecate GenerationMixin inherited by PreTrainedModel Generation: PreTrainedModel no longer inherits GenerationMixin Aug 28, 2024
@gante gante changed the title Generation: PreTrainedModel no longer inherits GenerationMixin Generation: PreTrainedModel no longer inherits GenerationMixin 🚨 🚨 Aug 28, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Contributor Author

gante commented Aug 29, 2024

hold up, found a way to make it fully BC 💛

@KeshavSingh29
Copy link
Contributor

@gante
I came across this exact issue : " A user has a custom code, inheriting PreTrainedModel, and wants to call generate"

Basically I was trying to recreate a gpt2 model by inheriting PreTrainedModel.
Here is my code:

from __future__ import annotations

import torch
import torch.nn as nn
from transformers import PretrainedConfig
from transformers import PreTrainedModel

from src.transformer_block.layer_norm import LayerNorm
from src.transformer_block.t_block import GPTTransformerBlock


class GPTConfig(PretrainedConfig):
    """
    Configuration class for GPT-2 model
    """

    model_type = "gpt_fast_llm"

    def __init__(
        self,
        vocab_size: int = 200019,
        context_len: int = 256,
        embedding_dim: int = 768,
        n_heads: int = 12,
        n_layers: int = 12,
        drop_rate: float = 0.0,
        qkv_bias: bool = True,
        batch_size: int = 8,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.context_len = context_len
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.drop_rate = drop_rate
        self.qkv_bias = qkv_bias
        self.batch_size = batch_size


class GPTModel(PreTrainedModel):
    """
    The base model for GPT-2 architecture
    Input:
        x : tensor of shape (batch_size, seq_len)
    Output:
        logits : tensor of shape (batch_size, seq_len, vocab_size)
    """

    config_class = GPTConfig

    def __init__(self, config: GPTConfig):
        super().__init__(config)
        self.token_embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.embedding_dim,
        )
        self.pos_embedding = nn.Embedding(
            num_embeddings=config.context_len,
            embedding_dim=config.embedding_dim,
        )
        self.drop_embedding = nn.Dropout(config.drop_rate)
        self.transformer_blocks = nn.Sequential(
            *[GPTTransformerBlock(config=config) for _ in range(config.n_layers)],
        )
        self.final_norm = LayerNorm(config.embedding_dim)
        self.out_head = nn.Linear(
            config.embedding_dim,
            config.vocab_size,
            bias=False,
        )

    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size, seq_len = input_ids.shape
        token_emb = self.token_embedding(input_ids)
        pos_emb = self.pos_embedding(torch.arange(seq_len, device=input_ids.device))
        x = token_emb + pos_emb
        x = self.drop_embedding(x)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)

        if labels is not None:
            loss = torch.nn.functional.cross_entropy(
                logits.flatten(0, 1),
                labels.flatten(),
            )
            return {"logits": logits, "loss": loss}
        return {"logits": logits}

I already have a out_head linear layer to map the logits to vocab. But when I call model.can_generate() it returns False and I run into this error:

TypeError: The current model class (GPTModel) is not compatible with `.generate()`, as it doesn't have a language model head.

Is there a way to fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants