Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model registry #246

Open
lorenzoh opened this issue Jul 6, 2022 · 0 comments
Open

Model registry #246

lorenzoh opened this issue Jul 6, 2022 · 0 comments
Labels
api-proposal Implementation or suggestion for new APIs and improvements to existing APIs enhancement New feature or request plans Long-term plans

Comments

@lorenzoh
Copy link
Member

lorenzoh commented Jul 6, 2022

In addition to existing feature registries, FastAI.jl will be getting a model registry.

The model registry should allow

  • domain and third-party packages to add additional models when loaded;
  • using models from different "backend" deep learning libraries, i.e.
    Flux.jl, Lux.jl, PyTorch, Jax; and
  • searching for models suitable for a specific task, fitting a computational budget,
    or built in a specific deep learning library

Examples

Some code examples to show how the model registry can be used:

Loading models

Load a pretrained ResNet implemented in Metalhead.jl for transfer learning:

load(models()["metalhead/resnet18/head"], pretrained=true)

Load the ResNet as an untrained backbone for a different task

load(models()["metalhead/resnet18/backbone"], pretrained=false)

Searching for models

Find models that take in preprocessed images:

filter(models(), input=ImageTensor{2})

Or find a suitable model for a supervised learning task directly:

task = SupervisedTask(_)/ImageSegmentation(_)/TabularClassificationSingle(_)
filter(models(), input=task.blocks.x, output=task.blocks.y)

List models implemented in PyTorch:

filter(models(), backend=:pytorch)

Find models of a certain size:

filter(models(), input=ImageTensor{2}, nparams=<(1000000))

Training workflow

Since models in the registry are associated with block information, we can use them
to automatically construct task-specific models using the taskmodel API (possibly
extended by an additional backend argument).

config = models()["torchvision/resnet18/backbone"]
backbone = load(config)

task = ImageSegmentation(_)
# build the task-specific model
model = taskmodel(task,           # includes info about required input and target block for task
                  config.backend  # dispatch on the DL library used, here :pytorch
                  backbone,
                  config.input,   # backbone input block: `ImageTensor{2}(3)`
                  config.output)  # backbone output block: `ConvFeatures{2}(512)`

learner = tasklearner(task, data; model)
fit!(learner, 10)
@lorenzoh lorenzoh added enhancement New feature or request api-proposal Implementation or suggestion for new APIs and improvements to existing APIs plans Long-term plans labels Jul 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api-proposal Implementation or suggestion for new APIs and improvements to existing APIs enhancement New feature or request plans Long-term plans
Projects
None yet
Development

No branches or pull requests

1 participant