Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Sep 10, 2021
1 parent 2bce93e commit 603b421
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 0 additions & 2 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def swav_head(
normalize_feats: bool = True,
activation_name: str = "ReLU",
use_weight_norm_prototypes: bool = True,
batchnorm_eps: float = 1e-5,
batchnorm_momentum: float = 0.1,
**kwargs,
) -> nn.Module:
cfg = VISSLAdapter.get_model_config_template()
Expand Down
1 change: 0 additions & 1 deletion flash/image/embedding/strategies/vissl_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from flash.image.embedding.heads import IMAGE_EMBEDDER_HEADS


# TODO: update head creation using config?
def dino(head: str = 'swav_head', **kwargs):
loss_fn = IMAGE_EMBEDDER_LOSS_FUNCTIONS.get('dino_loss')(**kwargs)
head = IMAGE_EMBEDDER_HEADS.get(head)(**kwargs)
Expand Down
27 changes: 15 additions & 12 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,13 @@ def get_model_config_template():
return cfg

def forward(self, batch) -> Any:
return self.vissl_base_model(batch)
model_output = self.vissl_base_model(batch)

# vissl-specific
if len(model_output) == 1:
model_output = model_output[0]

return model_output

def training_step(self, batch: Any, batch_idx: int) -> Any:
out = self(batch[DefaultDataKeys.INPUT])
Expand All @@ -193,22 +199,19 @@ def training_step(self, batch: Any, batch_idx: int) -> Any:
for hook in self.hooks:
hook.on_forward(self.vissl_task)

# out can be torch.Tensor/List target is torch.Tensor
# loss = self.vissl_loss(out, target=None)
loss = self.loss_fn(out, target=None)
self.log_dict({'train_loss': loss})

# TODO: log
# TODO: Include call to ClassyHooks during training
# return loss
return loss

def validation_step(self, batch: Any, batch_idx: int) -> None:
out = self(batch)
out = self(batch[DefaultDataKeys.INPUT])
self.task.last_batch['sample']['input'] = batch[DefaultDataKeys.INPUT]

# out can be torch.Tensor/List target is torch.Tensor
# loss = self.vissl_loss(out, target)
loss = self.loss_fn(out, target=None)
self.log_dict({'val_loss': loss})

# TODO: log
# TODO: Include call to ClassyHooks during training
# return loss
return loss

def test_step(self, batch: Any, batch_idx: int) -> None:
# vissl_input, target = batch
Expand Down

0 comments on commit 603b421

Please sign in to comment.