Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
AntreasAntoniou committed Oct 23, 2018
1 parent 7761bc8 commit 116a856
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 55 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pip-delete-this-directory.txt

#codebase-specific
datasets/*
MAML*

# Unit test / coverage reports
htmlcov/
Expand Down
39 changes: 29 additions & 10 deletions few_shot_learning_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(self, data_batch, epoch, use_second_order, optimize_final_target_los
support_loss, support_preds = self.net_forward(x=x_support_set_task,
y=y_support_set_task,
weights=names_weights_copy,
reset_running_statistics=
backup_running_statistics=
True if (num_step == 0) else False,
training=True, num_step=num_step)

Expand All @@ -222,13 +222,13 @@ def forward(self, data_batch, epoch, use_second_order, optimize_final_target_los
if num_step == (self.args.number_of_training_steps_per_iter - 1):
target_loss, target_preds = self.net_forward(x=x_target_set_task,
y=y_target_set_task, weights=names_weights_copy,
reset_running_statistics=False, training=True,
backup_running_statistics=False, training=True,
num_step=num_step)
task_losses.append(target_loss)
else:
target_loss, target_preds = self.net_forward(x=x_target_set_task,
y=y_target_set_task, weights=names_weights_copy,
reset_running_statistics=False, training=True,
backup_running_statistics=False, training=True,
num_step=num_step)

task_losses.append(per_step_loss_importance_vectors[num_step] * target_loss)
Expand All @@ -254,7 +254,7 @@ def forward(self, data_batch, epoch, use_second_order, optimize_final_target_los

return losses

def net_forward(self, x, y, weights, reset_running_statistics, training, num_step):
def net_forward(self, x, y, weights, backup_running_statistics, training, num_step):
"""
A base model forward pass on some data points x. Using the parameters in the weights dictionary. Also requires
boolean flags indicating whether to reset the running statistics at the end of the run (if at evaluation phase).
Expand All @@ -263,7 +263,7 @@ def net_forward(self, x, y, weights, reset_running_statistics, training, num_ste
:param x: A data batch of shape b, c, h, w
:param y: A data targets batch of shape b, n_classes
:param weights: A dictionary containing the weights to pass to the network.
:param reset_running_statistics: A flag indicating whether to reset the batch norm running statistics to their
:param backup_running_statistics: A flag indicating whether to reset the batch norm running statistics to their
previous values after the run (only for evaluation)
:param training: A flag indicating whether the current process phase is a training or evaluation.
:param num_step: An integer indicating the number of the step in the inner loop.
Expand All @@ -274,7 +274,7 @@ def net_forward(self, x, y, weights, reset_running_statistics, training, num_ste

preds = self.classifier.forward(x=input_var, params=weights,
training=training,
reset_running_statistics=reset_running_statistics, num_step=num_step)
backup_running_statistics=backup_running_statistics, num_step=num_step)

loss = F.cross_entropy(input=preds, target=target_var)

Expand Down Expand Up @@ -330,9 +330,9 @@ def meta_update(self, loss):
def run_train_iter(self, data_batch, epoch):
"""
Runs an outer loop update step on the meta-model's parameters.
:param data_batch:
:param epoch:
:return:
:param data_batch: input data batch containing the support set and target set input, output pairs
:param epoch: the index of the current epoch
:return: The losses of the ran iteration.
"""
self.scheduler.step(epoch=epoch)
if self.current_epoch != epoch:
Expand Down Expand Up @@ -360,7 +360,12 @@ def run_train_iter(self, data_batch, epoch):
return losses

def run_validation_iter(self, data_batch):

"""
Runs an outer loop evaluation step on the meta-model's parameters.
:param data_batch: input data batch containing the support set and target set input, output pairs
:param epoch: the index of the current epoch
:return: The losses of the ran iteration.
"""
if self.training:
self.eval()

Expand All @@ -381,10 +386,24 @@ def run_validation_iter(self, data_batch):
return losses

def save_model(self, model_save_dir, state):
"""
Save the network parameter state and experiment state dictionary.
:param model_save_dir: The directory to store the state at.
:param state: The state containing the experiment state and the network. It's in the form of a dictionary
object.
"""
state['network'] = self.state_dict()
torch.save(state, f=model_save_dir)

def load_model(self, model_save_dir, model_name, model_idx):
"""
Load checkpoint and return the state dictionary containing the network state params and experiment state.
:param model_save_dir: The directory from which to load the files.
:param model_name: The model_name to be loaded from the direcotry.
:param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current
experiment)
:return: A dictionary containing the experiment state and the saved model parameters.
"""
filepath = os.path.join(model_save_dir, "{}_{}".format(model_name, model_idx))
checkpoint_state = torch.load(filepath)
state_dict_loaded = checkpoint_state['network']
Expand Down
Loading

0 comments on commit 116a856

Please sign in to comment.