From 603b42151a7361c13c3418e0c4bca4911ad476db Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Fri, 10 Sep 2021 14:14:05 -0400 Subject: [PATCH] . --- flash/image/embedding/heads/vissl_heads.py | 2 -- .../embedding/strategies/vissl_strategies.py | 1 - flash/image/embedding/vissl/adapter.py | 27 ++++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/flash/image/embedding/heads/vissl_heads.py b/flash/image/embedding/heads/vissl_heads.py index 73d1b70bd0..34a69caefc 100644 --- a/flash/image/embedding/heads/vissl_heads.py +++ b/flash/image/embedding/heads/vissl_heads.py @@ -34,8 +34,6 @@ def swav_head( normalize_feats: bool = True, activation_name: str = "ReLU", use_weight_norm_prototypes: bool = True, - batchnorm_eps: float = 1e-5, - batchnorm_momentum: float = 0.1, **kwargs, ) -> nn.Module: cfg = VISSLAdapter.get_model_config_template() diff --git a/flash/image/embedding/strategies/vissl_strategies.py b/flash/image/embedding/strategies/vissl_strategies.py index 5b973e399c..75ea04763b 100644 --- a/flash/image/embedding/strategies/vissl_strategies.py +++ b/flash/image/embedding/strategies/vissl_strategies.py @@ -23,7 +23,6 @@ from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS -# TODO: update head creation using config? def dino(head: str = 'swav_head', **kwargs): loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get('dino_loss')(**kwargs) head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs) diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py index 122fbc1661..95794872f1 100644 --- a/flash/image/embedding/vissl/adapter.py +++ b/flash/image/embedding/vissl/adapter.py @@ -183,7 +183,13 @@ def get_model_config_template(): return cfg def forward(self, batch) -> Any: - return self.vissl_base_model(batch) + model_output = self.vissl_base_model(batch) + + # vissl-specific + if len(model_output) == 1: + model_output = model_output[0] + + return model_output def training_step(self, batch: Any, batch_idx: int) -> Any: out = self(batch[DefaultDataKeys.INPUT]) @@ -193,22 +199,19 @@ def training_step(self, batch: Any, batch_idx: int) -> Any: for hook in self.hooks: hook.on_forward(self.vissl_task) - # out can be torch.Tensor/List target is torch.Tensor - # loss = self.vissl_loss(out, target=None) + loss = self.loss_fn(out, target=None) + self.log_dict({'train_loss': loss}) - # TODO: log - # TODO: Include call to ClassyHooks during training - # return loss + return loss def validation_step(self, batch: Any, batch_idx: int) -> None: - out = self(batch) + out = self(batch[DefaultDataKeys.INPUT]) + self.task.last_batch['sample']['input'] = batch[DefaultDataKeys.INPUT] - # out can be torch.Tensor/List target is torch.Tensor - # loss = self.vissl_loss(out, target) + loss = self.loss_fn(out, target=None) + self.log_dict({'val_loss': loss}) - # TODO: log - # TODO: Include call to ClassyHooks during training - # return loss + return loss def test_step(self, batch: Any, batch_idx: int) -> None: # vissl_input, target = batch