Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
last batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Sep 7, 2021
1 parent bb95027 commit c0803c5
Showing 1 changed file with 8 additions and 56 deletions.
64 changes: 8 additions & 56 deletions flash/core/integrations/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,51 +39,6 @@
from flash.core.integrations.vissl.hooks import AdaptVISSLHooks


# TODO: replace this
def construct_sample_for_model(batch_data, task):
"""
Given the input batch from the dataloader, verify the input is
as expected: the input data and target data is present in the
batch.
In case of multi-input trainings like PIRL, make sure the data
is in right format i.e. the multiple input should be nested
under a common key "input".
"""
# sample_key_names = task.data_and_label_keys
# inp_key, target_key = sample_key_names["input"], sample_key_names["target"]
# all_keys = inp_key + target_key

# assert len(inp_key) + len(target_key) <= len(
# batch_data
# ), "Number of input and target keys in batch and train config don't match."

# # every input should be a list. The list corresponds to various data sources
# # and hence could be used to handle several data modalities.
# for key in all_keys:
# assert isinstance(batch_data[key], list), f"key: {key} input is not a list"
# assert (
# len(batch_data[key]) == 1
# ), "Please modify your train step to handle multi-modal input"

# # single input case
# if len(sample_key_names["input"]) == 1 and len(sample_key_names["target"]) == 1:
# sample = {
# "input": batch_data[inp_key[0]][0],
# "target": batch_data[target_key[0]][0],
# "data_valid": batch_data["data_valid"][0],
# }

# # copy the other keys as-is, method dependent
# for k in batch_data.keys():
# if k not in all_keys:
# sample[k] = batch_data[k]

sample = {}
sample["input"] = batch_data[0]

return sample


class MockVISSLTask:
def __init__(self, vissl_loss, task_config, vissl_model) -> None:
self.loss = vissl_loss
Expand All @@ -95,7 +50,11 @@ def __init__(self, vissl_loss, task_config, vissl_model) -> None:
self.iteration = 0
self.max_iteration = 100000 # set this correctly

self.last_batch = SimpleNamespace
self.last_batch = AttrDict({
'sample': AttrDict({
'input': None
})
})

# task.loss.checkpoint to None
# task.loss.center
Expand Down Expand Up @@ -282,18 +241,11 @@ def forward(self, batch) -> Any:
# out = self.single_input_forward(batch[DefaultDataKeys.INPUT], [], self.vissl_heads)

# return out
return self.vissl_base_model(batch[DefaultDataKeys.INPUT])
return self.vissl_base_model(batch)

def training_step(self, batch: Any, batch_idx: int) -> Any:
out = self(batch)

print('$$$$$$$$$$$$$$$$$$$')
print(len(out))
print(out[0][0].shape)
print('This executes!!!!')
print('$$$$$$$$$$$$$$$$$$$')

self.task.last_batch.sample = construct_sample_for_model(batch, self.task)
out = self(batch[DefaultDataKeys.INPUT])
self.task.last_batch['sample']['input'] = batch[DefaultDataKeys.INPUT]

# call forward hook from VISSL (momentum updates)
for hook in self.hooks:
Expand Down

0 comments on commit c0803c5

Please sign in to comment.