Skip to content

Commit

Permalink
Merge pull request #129 from coqui-ai/fix_eval
Browse files Browse the repository at this point in the history
Multiples bug fixes and add on_train_epoch_start callback
  • Loading branch information
erogol authored Nov 16, 2023
2 parents 47781f5 + 5b3cb63 commit 385cced
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pypi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
build-sdist:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Verify tag matches version
run: |
set -ex
Expand All @@ -19,7 +19,7 @@ jobs:
if [[ "$version" != "$tag" ]]; then
exit 1
fi
- uses: actions/setup-python@v2
- uses: actions/checkout@v3
with:
python-version: 3.9
- run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
python-version: [3.8, 3.9, "3.10", "3.11"]
experimental: [false]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64
Expand Down
42 changes: 42 additions & 0 deletions trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ def __init__(self) -> None:
self.callbacks_on_init_end = []
self.callbacks_on_epoch_start = []
self.callbacks_on_epoch_end = []
self.callbacks_on_train_epoch_start = []
self.callbacks_on_train_epoch_end = []
self.callbacks_on_train_step_start = []
self.callbacks_on_train_step_end = []
self.callbacks_on_keyboard_interrupt = []
Expand All @@ -21,6 +23,10 @@ def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None:
self.callbacks_on_epoch_start.append(value)
elif key == "on_epoch_end":
self.callbacks_on_epoch_end.append(value)
elif key == "on_train_epoch_start":
self.callbacks_on_train_epoch_start.append(value)
elif key == "on_train_epoch_end":
self.callbacks_on_train_epoch_end.append(value)
elif key == "on_train_step_start":
self.callbacks_on_train_step_start.append(value)
elif key == "on_train_step_end":
Expand Down Expand Up @@ -102,6 +108,42 @@ def on_epoch_end(self, trainer) -> None:
for callback in self.callbacks_on_epoch_end:
callback(trainer)

def on_train_epoch_start(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_epoch_start"):
trainer.model.module.on_train_epoch_start(trainer)
else:
if hasattr(trainer.model, "on_train_epoch_start"):
trainer.model.on_train_epoch_start(trainer)

if hasattr(trainer.criterion, "on_train_epoch_start"):
trainer.criterion.on_train_epoch_start(trainer)

if hasattr(trainer.optimizer, "on_train_epoch_start"):
trainer.optimizer.on_train_epoch_start(trainer)

if self.callbacks_on_train_epoch_start:
for callback in self.callbacks_on_train_epoch_start:
callback(trainer)

def on_train_epoch_end(self, trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_epoch_end"):
trainer.model.module.on_train_epoch_end(trainer)
else:
if hasattr(trainer.model, "on_train_epoch_end"):
trainer.model.on_train_epoch_end(trainer)

if hasattr(trainer.criterion, "on_train_epoch_end"):
trainer.criterion.on_train_epoch_end(trainer)

if hasattr(trainer.optimizer, "on_train_epoch_end"):
trainer.optimizer.on_train_epoch_end(trainer)

if self.callbacks_on_train_epoch_end:
for callback in self.callbacks_on_train_epoch_end:
callback(trainer)

@staticmethod
def before_backward_pass(trainer, loss_dict) -> None:
if hasattr(trainer.model, "module"):
Expand Down
43 changes: 30 additions & 13 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def __init__( # pylint: disable=dangerous-default-value
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step

# make sure that start_with_eval is disabled if eval is disabled
if not self.config.run_eval and self.start_with_eval:
self.start_with_eval = False

self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
Expand Down Expand Up @@ -525,6 +529,16 @@ def __init__( # pylint: disable=dangerous-default-value
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)

# If multiple-optimizer setup with grad accumulation and without custom optimize method raise an error
if (
self.grad_accum_steps != 1
and isinstance(self.optimizer, list)
and not isimplemented(self.model, "optimize")
):
raise ValueError(
" [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called ´optimize` that need to deal with dangling gradients in multiple-optimizer setup!"
)

# CALLBACK
self.callbacks = TrainerCallback()
self.callbacks.parse_callbacks_dict(callbacks)
Expand Down Expand Up @@ -1480,6 +1494,8 @@ def train_epoch(self) -> None:
self.model.train()
epoch_start_time = time.time()

self.callbacks.on_train_epoch_start(self)

self.c_logger.print_train_start()
loader_start_time = time.time()
# TRAINING EPOCH -> iterate over the training samples
Expand All @@ -1502,6 +1518,8 @@ def train_epoch(self) -> None:
torch.set_grad_enabled(True)

epoch_time = time.time() - epoch_start_time
self.callbacks.on_train_epoch_end(self)

# scheduler step
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
Expand Down Expand Up @@ -1884,14 +1902,12 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""

eval_loss = None
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
train_loss = self._pick_target_avg_loss(self.keep_avg_train)

# save the model and update the best_loss
self.best_loss = save_best_model(
train_loss if eval_loss is None else eval_loss,
eval_loss if eval_loss else train_loss,
self.best_loss,
self.config,
self.model,
Expand All @@ -1908,9 +1924,7 @@ def save_best_model(self) -> None:
@rank_zero_only
def save_checkpoint(self) -> None:
"""Save the current model checkpoint."""
eval_loss = None
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
train_loss = self._pick_target_avg_loss(self.keep_avg_train)

save_checkpoint(
Expand Down Expand Up @@ -2101,18 +2115,21 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict:

def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""

# if the keep_avg_target is None or empty return None
if keep_avg_target is None or len(list(keep_avg_target.avg_values.keys())) == 0:
return None

target_avg_loss = None
# return if target loss defined in the model config
# if not available in Dict use loss_1 as by default loss
if "target_loss" in self.config and self.config.target_loss:
if f"avg_{self.config.target_loss}" in keep_avg_target.avg_values.keys():
return keep_avg_target[f"avg_{self.config.target_loss}"]
target_loss = keep_avg_target["avg_loss_1"]
if target_loss is None:
raise ValueError(
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
)
return target_loss

raise ValueError(
" [!] Target loss not found in the keep_avg_target. You might be exiting the training loop before it is computed or set the target_loss in the model config incorrectly."
)

# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
Expand Down

0 comments on commit 385cced

Please sign in to comment.