Skip to content

Commit

Permalink
ref: inner train loop (intermediate step) 10/n (#3369)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Sep 6, 2020
1 parent b375a26 commit 8542146
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 34 deletions.
29 changes: 17 additions & 12 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,12 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track metrics to log
batch_log_metrics = []

# bookkeeping
using_results_obj = False
self.hiddens = None

# track all outputs across time and num of optimizers
batch_outputs = [[] for i in range(len(self.train_loop.get_optimizers_iterable()))]
batch_outputs = [[] for _ in range(len(self.train_loop.get_optimizers_iterable()))]

if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
Expand All @@ -757,16 +759,13 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
with self.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
# lightning module hook
splits = self.train_loop.tbptt_split_batch(batch)

self.hiddens = None
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx

# loop over optimizers
for opt_idx, optimizer in self.train_loop.get_optimizers_iterable():
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
Expand All @@ -780,7 +779,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# -------------------
# calculate loss (train step + train step end)
# -------------------
opt_closure_result = self.optimizer_closure(
opt_closure_result = self.training_step_and_backward(
split_batch,
batch_idx,
opt_idx,
Expand Down Expand Up @@ -808,13 +807,19 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
accumulation_done = (self.batch_idx + 1) % self.accumulate_grad_batches == 0
is_final_batch = (self.batch_idx + 1) == self.num_training_batches
if accumulation_done or is_final_batch:
# hook
grad_norm_dic = self.train_loop.on_before_backward(batch_idx, optimizer)

# wrap forward + backward pass in closure for 2nd order optimizers
train_step_and_backward_closure = lambda: self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.hiddens,
).loss

# optimizer step
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
self.train_loop.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)

# hook
self.train_loop.on_before_zero_grad(optimizer)
Expand Down Expand Up @@ -843,7 +848,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
)
return result

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
wrap the forward step in a closure so second order methods work
"""
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,11 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
)
return result

def optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# calls .step(), .zero_grad()
# override function to modify this behavior

def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
with self.trainer.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.trainer.optimizer_closure(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.trainer.hiddens,
).loss

# optimizer step lightningModule hook
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx, lambda_closure)
self.trainer.accelerator_backend.optimizer_step(optimizer, batch_idx, opt_idx,
train_step_and_backward_closure)

def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
Expand Down Expand Up @@ -280,3 +270,11 @@ def process_hiddens(self, opt_closure_result):
if isinstance(opt_closure_result.training_step_output, Result):
opt_closure_result.training_step_output_for_epoch_end.drop_hiddens()
return hiddens

def tbptt_split_batch(self, batch):
splits = [batch]
if self.trainer.truncated_bptt_steps is not None:
model_ref = self.trainer.get_model()
with self.trainer.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_steps_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_training_step_dict(tmpdir):
assert pbar_metrics['pbar_acc2'] == 19.0

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_training_step_result_log_step_only(tmpdir):
assert f'step_log_acc2_b{batch_idx}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -158,7 +158,7 @@ def test_training_step_result_log_epoch_only(tmpdir):
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -293,7 +293,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down Expand Up @@ -372,7 +372,7 @@ def test_training_step_epoch_end_result(tmpdir):
assert 'epoch_step_epoch_log_acc2' in train_step_out

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)


Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_training_step_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -80,7 +80,7 @@ def training_step_scalar_with_step_end(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -127,7 +127,7 @@ def test_full_training_loop_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171


Expand Down Expand Up @@ -170,5 +170,5 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert train_step_out.item() == 171

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
opt_closure_result = trainer.training_step_and_backward(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'].item() == 171

0 comments on commit 8542146

Please sign in to comment.