From c0803c5de1aaa8fdb9e7c1911f21cdddfd458252 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 7 Sep 2021 12:43:06 -0400 Subject: [PATCH] last batch --- flash/core/integrations/vissl/adapter.py | 64 +++--------------------- 1 file changed, 8 insertions(+), 56 deletions(-) diff --git a/flash/core/integrations/vissl/adapter.py b/flash/core/integrations/vissl/adapter.py index 4b88c97f8dd..dd11063ab05 100644 --- a/flash/core/integrations/vissl/adapter.py +++ b/flash/core/integrations/vissl/adapter.py @@ -39,51 +39,6 @@ from flash.core.integrations.vissl.hooks import AdaptVISSLHooks -# TODO: replace this -def construct_sample_for_model(batch_data, task): - """ - Given the input batch from the dataloader, verify the input is - as expected: the input data and target data is present in the - batch. - In case of multi-input trainings like PIRL, make sure the data - is in right format i.e. the multiple input should be nested - under a common key "input". - """ - # sample_key_names = task.data_and_label_keys - # inp_key, target_key = sample_key_names["input"], sample_key_names["target"] - # all_keys = inp_key + target_key - - # assert len(inp_key) + len(target_key) <= len( - # batch_data - # ), "Number of input and target keys in batch and train config don't match." - - # # every input should be a list. The list corresponds to various data sources - # # and hence could be used to handle several data modalities. - # for key in all_keys: - # assert isinstance(batch_data[key], list), f"key: {key} input is not a list" - # assert ( - # len(batch_data[key]) == 1 - # ), "Please modify your train step to handle multi-modal input" - - # # single input case - # if len(sample_key_names["input"]) == 1 and len(sample_key_names["target"]) == 1: - # sample = { - # "input": batch_data[inp_key[0]][0], - # "target": batch_data[target_key[0]][0], - # "data_valid": batch_data["data_valid"][0], - # } - - # # copy the other keys as-is, method dependent - # for k in batch_data.keys(): - # if k not in all_keys: - # sample[k] = batch_data[k] - - sample = {} - sample["input"] = batch_data[0] - - return sample - - class MockVISSLTask: def __init__(self, vissl_loss, task_config, vissl_model) -> None: self.loss = vissl_loss @@ -95,7 +50,11 @@ def __init__(self, vissl_loss, task_config, vissl_model) -> None: self.iteration = 0 self.max_iteration = 100000 # set this correctly - self.last_batch = SimpleNamespace + self.last_batch = AttrDict({ + 'sample': AttrDict({ + 'input': None + }) + }) # task.loss.checkpoint to None # task.loss.center @@ -282,18 +241,11 @@ def forward(self, batch) -> Any: # out = self.single_input_forward(batch[DefaultDataKeys.INPUT], [], self.vissl_heads) # return out - return self.vissl_base_model(batch[DefaultDataKeys.INPUT]) + return self.vissl_base_model(batch) def training_step(self, batch: Any, batch_idx: int) -> Any: - out = self(batch) - - print('$$$$$$$$$$$$$$$$$$$') - print(len(out)) - print(out[0][0].shape) - print('This executes!!!!') - print('$$$$$$$$$$$$$$$$$$$') - - self.task.last_batch.sample = construct_sample_for_model(batch, self.task) + out = self(batch[DefaultDataKeys.INPUT]) + self.task.last_batch['sample']['input'] = batch[DefaultDataKeys.INPUT] # call forward hook from VISSL (momentum updates) for hook in self.hooks: