Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests
Browse files Browse the repository at this point in the history
ananyahjha93 committed Sep 21, 2021
2 parents be97ef2 + 166006c commit 81cb161
Showing 5 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions flash/image/embedding/losses/vissl_losses.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from vissl.config.attr_dict import AttrDict
else:
AttrDict = object
ClassyLoss = object


def get_loss_fn(loss_name: str, cfg: AttrDict):
6 changes: 4 additions & 2 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
@@ -20,14 +20,16 @@
from flash.core.data.data_source import DefaultDataKeys
from flash.core.model import Task
from flash.core.utilities.imports import _VISSL_AVAILABLE
from flash.image.embedding.vissl.hooks import AdaptVISSLHooks

if _VISSL_AVAILABLE:
from classy_vision.hooks.classy_hook import ClassyHook
from classy_vision.losses import ClassyLoss
from vissl.config.attr_dict import AttrDict
from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel

from flash.image.embedding.vissl.hooks import AdaptVISSLHooks
else:
ClassyLoss = object
ClassyHook = object


class MockVISSLTask:
4 changes: 3 additions & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,9 @@
if _VISSL_AVAILABLE:
from classy_vision.hooks.classy_hook import ClassyHook
else:
ClassyHook = object

class ClassyHook:
_noop = object


class TrainingSetupHook(ClassyHook):
5 changes: 2 additions & 3 deletions flash/image/embedding/vissl/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from flash.core.utilities.imports import _VISSL_AVAILABLE # noqa: F401
from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401
from flash.image.embedding.vissl.transforms.utilities import multicrop_collate_fn # noqa: F401

if _VISSL_AVAILABLE:
from classy_vision.dataset.transforms import register_transform # noqa: F401

from flash.image.embedding.vissl.transforms.multicrop import StandardMultiCropSSLTransform # noqa: F401
from flash.image.embedding.vissl.transforms.utilities import multicrop_collate_fn # noqa: F401

register_transform("multicrop_ssl_transform")(StandardMultiCropSSLTransform)
6 changes: 2 additions & 4 deletions tests/image/embedding/utils.py
Original file line number Diff line number Diff line change
@@ -3,16 +3,14 @@
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE
from flash.image import ImageClassificationData
from flash.image.embedding.vissl.transforms.utilities import multicrop_collate_fn
from flash.image.embedding.vissl.transforms import multicrop_collate_fn

if _TORCHVISION_AVAILABLE:
from torchvision.datasets import FakeData

if _VISSL_AVAILABLE:
from classy_vision.dataset.transforms import TRANSFORM_REGISTRY

from flash.image.embedding.vissl.transforms import multicrop_collate_fn # noqa: F401


def ssl_datamodule(
batch_size=2,
@@ -33,7 +31,7 @@ def ssl_datamodule(
preprocess = DefaultPreprocess(
train_transform={
"to_tensor_transform": to_tensor_transform,
"collate": multi_crop_transform,
"collate": collate_fn,
}
)

0 comments on commit 81cb161

Please sign in to comment.