Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add way of skipping pretrained weights download #5172

Merged
merged 3 commits into from
May 2, 2021

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Apr 30, 2021

Fixes #4599.

Changes proposed in this pull request:

  • Adds a load_weights: bool (default = True) parameter to cached_transformers.get() and all higher-level modules that call this function, such as PretrainedTransformerEmbedder and PretrainedTransformerMismatchedEmbedder. Setting this parameter to False will avoid downloading and loading pretrained transformer weights, so only the architecture is instantiated. So you can set the parameter to False via the overrides parameter when loading an AllenNLP model/predictor from an archive to avoid an unnecessary download.

For example, suppose your training config looks something like this:

{
  "model": {
    "type": "basic_classifier",
    "text_field_embedder": {
      "tokens": {
        "type": "pretrained_transformer",
        "model_name": "bert-base-cased",
        // ... other stuff ...
      }
    },
  },
  // ... other stuff ...
}

And now you have an archive from training this model: model.tar.gz. Then you can load the trained model into a predictor like so:

from allennlp.predictors import Predictor

overrides = {"model.text_field_embedder.tokens.load_weights": False}
predictor = Predictor.from_path("model.tar.gz", overrides=overrides)

@epwalsh
Copy link
Member Author

epwalsh commented Apr 30, 2021

Unfortunately this actually doesn't address #5170, because the SrlBert model uses the transformers library directly. But that's not hard to fix. I'll follow up with a separate PR for that in allennlp-models.

Copy link
Contributor

@ArjunSubramonian ArjunSubramonian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! looks great to me. thanks for meticulously going down the stack :)

@epwalsh epwalsh merged commit a463e0e into main May 2, 2021
@epwalsh epwalsh deleted the transformer-no-load-weights branch May 2, 2021 21:51
dirkgr pushed a commit that referenced this pull request May 10, 2021
* add way of skipping pretrained weights download

* clarify docstring

* add link to PR in CHANGELOG
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
2 participants