Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into bugfix/tabular_from_csv_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 23, 2021
2 parents 3b59a76 + ca3870a commit b510f52
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Lightning Flash
integrations/fiftyone
integrations/learn2learn
integrations/icevision
integrations/vissl

.. toctree::
:maxdepth: 1
Expand Down
33 changes: 33 additions & 0 deletions docs/source/integrations/vissl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. _vissl:

#####
VISSL
#####

`VISSL <https://github.com/facebookresearch/vissl>`__ is a library from Facebook AI Research for state-of-the-art self-supervised learning.
We integrate VISSL models and algorithms into Flash with the :ref:`image embedder <image_embedder>` task.

Using VISSL with Flash
----------------------

The ImageEmbedder task in Flash can be configured with different backbones, projection heads, image transforms and loss functions so that you can train your feature extractor using a SOTA SSL method.

.. code-block:: python
from flash.image import ImageEmbedder
embedder = ImageEmbedder(
backbone="resnet",
training_strategy="barlow_twins",
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 256, "dims": [2048, 2048, 256]},
pretraining_transform_kwargs={"size_crops": [196]},
)
The user can pass arguments to the training strategy, image transforms and backbones using the optional dictionary arguments the ImageEmbedder task accepts.
The training strategies club together the projection head, the loss function as well as VISSL hooks for a particular algorithm and the arguments to customize these can passed via ``training_strategy_kwargs``.
As an example, in the above code block, the ``latent_embedding_dim`` is an argument to the BarlowTwins loss function from VISSL, while the ``dims`` argument configures the projection head to output 256 dim vectors for the loss function.

If you find VISSL integration in Flash useful for your research, please don't forget to cite us and the VISSL library.
You can find our bibtex on `Flash <https://github.com/PyTorchLightning/lightning-flash>`__ and VISSL's bibxtex on their `github <https://github.com/facebookresearch/vissl>`__ page.
15 changes: 12 additions & 3 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@ The Task
Image embedding encodes an image into a vector of features which can be used for a downstream task.
This could include: clustering, similarity search, or classification.

The :class:`~flash.image.embedding.model.ImageEmbedder` internally relies on `VISSL <https://vissl.ai/>`_.

------

*******
Example
*******

Let's see how to use the :class:`~flash.image.embedding.model.ImageEmbedder` with a pretrained backbone to obtain feature vectors from the hymenoptera data.
Once we've downloaded the data, we create the :class:`~flash.image.embedding.model.ImageEmbedder` and perform inference (obtaining feature vectors / embeddings) using :meth:`~flash.image.embedding.model.ImageEmbedder.predict`.
Here's the full example:
Let's see how to configure a training strategy for the :class:`~flash.image.embedding.model.ImageEmbedder` task.
A vanilla :class:`~flash.core.data.data_module.DataModule` object be created using standard Datasets as shown below.
Then the user can configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``.
There are options provided to send additional arguments to config selections.
This task can now be sent to the ``fit()`` method of :class:`~flash.core.trainer.Trainer`.

.. note::

A lot of VISSL loss functions use hard-coded ``torch.distributed`` methods. The user is suggested to use ``accelerator=ddp`` even with a single GPU.
Only ``barlow_twins`` training strategy works on the CPU. All other loss functions are configured to work on GPUs.

.. literalinclude:: ../../../flash_examples/image_embedder.py
:language: python
Expand Down
14 changes: 14 additions & 0 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@


class SimCLRHead(nn.Module):
"""VISSL adpots a complicated config input to create an MLP.
This class simplifies the standard SimCLR projection head.
Can be configured to be used with barlow twins and moco as well.
Returns MLP according to dimensions provided as a list.
linear-layer -> batch-norm (if flag) -> Relu -> ...
Args:
model_config: Model config AttrDict from VISSL
dims: list of dimensions for creating a projection head
use_bn: use batch-norm after each linear layer or not
"""

def __init__(
self,
model_config: AttrDict,
Expand Down
23 changes: 13 additions & 10 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,23 @@ class ImageEmbedder(AdapterTask):
more details, see :ref:`image_embedder`.
Args:
embedding_dim: Dimension of the embedded vector. ``None`` uses the default from the backbone.
backbone: A model to use to extract image features, defaults to ``"swav-imagenet"``.
pretrained: Use a pretrained backbone, defaults to ``True``.
loss_fn: Loss function for training and finetuning, defaults to :func:`torch.nn.functional.cross_entropy`
training_strategy: Training strategy from VISSL,
select between 'simclr', 'swav', 'dino', 'moco', or 'barlow_twins'.
head: projection head used for task, select between
'simclr_head', 'swav_head', 'dino_head', 'moco_head', or 'barlow_twins_head'.
pretraining_transform: transform applied to input image for pre-training SSL model.
Select between 'simclr_transform', 'swav_transform', 'dino_transform',
'moco_transform', or 'barlow_twins_transform'.
backbone: VISSL backbone, defaults to ``resnet``.
pretrained: Use a pretrained backbone, defaults to ``False``.
optimizer: Optimizer to use for training and finetuning, defaults to :class:`torch.optim.SGD`.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
pooling_fn: Function used to pool image to generate embeddings, defaults to :func:`torch.max`.
backbone_kwargs: arguments to be passed to VISSL backbones, i.e. ``vision_transformer`` and ``resnet``.
training_strategy_kwargs: arguments passed to VISSL loss function, projection head and training hooks.
pretraining_transform_kwargs: arguments passed to VISSL transforms.
"""

training_strategies: FlashRegistry = IMAGE_EMBEDDER_STRATEGIES
Expand All @@ -73,7 +76,7 @@ def __init__(
head: str,
pretraining_transform: str,
backbone: str = "resnet",
pretrained: bool = True,
pretrained: bool = False,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
Expand Down
7 changes: 6 additions & 1 deletion flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@


class MockVISSLTask:
"""Mock task class from VISSL to support loss, configs, base_model, last batch etc."""

def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None:
self.vissl_adapter = vissl_adapter
self.loss = vissl_loss
Expand All @@ -51,7 +53,10 @@ def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None:


class VISSLAdapter(Adapter, AdaptVISSLHooks):
"""The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL."""
"""The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL.
Also inherits from ``AdaptVISSLHooks`` to support VISSL hooks.
"""

required_extras: str = "image"

Expand Down
1 change: 1 addition & 0 deletions flash/image/embedding/vissl/transforms/multicrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StandardMultiCropSSLTransform(nn.Module):
crops.
This transform was proposed in SwAV - https://arxiv.org/abs/2006.09882
This transform can act as a base transform class for SimCLR, SwAV, MoCo, Barlow Twins and DINO from VISSL.
This transform has been modified from the ImgPilToMultiCrop code present at
https://github.com/facebookresearch/vissl/blob/master/vissl/data/ssl_transforms/img_pil_to_multicrop.py
Expand Down

0 comments on commit b510f52

Please sign in to comment.