Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: inner train loop (intermediate step) 14/n #3373

Merged
merged 2 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def train(self):
self.train_loop.on_train_epoch_start(epoch)

# run train epoch
self.run_training_epoch()
self.train_loop.run_training_epoch()

if self.max_steps and self.max_steps <= self.global_step:

Expand Down
98 changes: 0 additions & 98 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ class TrainerTrainLoopMixin(ABC):
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def is_function_implemented(self, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_evaluation(self, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def detect_nan_tensors(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand All @@ -255,96 +247,6 @@ def call_hook(self, hook_name, *args, **kwargs):
def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def run_training_epoch(self):

# get model
model = self.get_model()

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)

# track epoch output
epoch_output = [[] for _ in range(self.train_loop.num_optimizers)]

# enable profiling for the dataloader
train_dataloader = self.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
break

self.batch_idx = batch_idx
model.global_step = self.global_step

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
self.train_loop.early_stopping_accumulator,
self.train_loop.checkpoint_accumulator
)

# hook
self.train_loop.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
self.should_stop = batch_output.signal == -1

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.should_check_val(batch_idx, is_last_batch)
if should_check_val:
self.run_evaluation(test_mode=False)

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.save_loggers_in_training_loop(batch_idx)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)

# update LR schedulers
monitor_metrics = deepcopy(self.callback_metrics)
monitor_metrics.update(batch_output.batch_log_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.max_steps is not None and self.max_steps == self.global_step:
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if self.should_stop:
break

# process epoch outputs
self.run_training_epoch_end(
epoch_output,
self.train_loop.checkpoint_accumulator,
self.train_loop.early_stopping_accumulator,
self.train_loop.num_optimizers
)

# checkpoint callback
self.check_checkpoint_callback(self.train_loop.should_check_val)

# epoch end hook
self.run_on_epoch_end_hook()

def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator):
"""
Figure out what needs to be tracked/logged at the end of the epoch
Expand Down
92 changes: 91 additions & 1 deletion pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.parsing import AttributeDict
from copy import copy
from copy import copy, deepcopy


class TrainLoop:
Expand Down Expand Up @@ -299,3 +299,93 @@ def tbptt_split_batch(self, batch):
with self.trainer.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits

def run_training_epoch(self):

# get model
model = self.trainer.get_model()

# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.accelerator_backend.process_dataloader(self.trainer.train_dataloader)

# track epoch output
epoch_output = [[] for _ in range(self.num_optimizers)]

# enable profiling for the dataloader
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.trainer.num_training_batches:
break

self.trainer.batch_idx = batch_idx
model.global_step = self.trainer.global_step

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.trainer.run_training_batch(batch, batch_idx, dataloader_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.trainer.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
self.early_stopping_accumulator,
self.checkpoint_accumulator
)

# hook
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
self.trainer.should_stop = batch_output.signal == -1

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.trainer.should_check_val(batch_idx, is_last_batch)
if should_check_val:
self.trainer.run_evaluation(test_mode=False)

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.trainer.save_loggers_in_training_loop(batch_idx)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.trainer.save_train_loop_metrics_to_loggers(batch_idx, batch_output)

# update LR schedulers
monitor_metrics = deepcopy(self.trainer.callback_metrics)
monitor_metrics.update(batch_output.batch_log_metrics)
self.trainer.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

# progress global step according to grads progress
self.trainer.increment_accumulated_grad_global_step()

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step:
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if self.trainer.should_stop:
break

# process epoch outputs
self.trainer.run_training_epoch_end(
epoch_output,
self.checkpoint_accumulator,
self.early_stopping_accumulator,
self.num_optimizers
)

# checkpoint callback
self.check_checkpoint_callback(self.should_check_val)

# epoch end hook
self.trainer.run_on_epoch_end_hook()
2 changes: 1 addition & 1 deletion tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_result_obj_predictions_ddp_spawn(tmpdir):

prediction_file = Path('predictions.pt')

model = EvalModelTemplate()
model = EvalModelTemplate(learning_rate=0.002)
model.test_option = option
model.prediction_file = prediction_file.as_posix()
model.test_step = model.test_step_result_preds
Expand Down