From f7138416a305a92a1dd243cc6147b45a025318ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Dav=C3=B3?= Date: Sun, 6 Oct 2024 17:47:43 +0000 Subject: [PATCH] WideDeep Avoid eval every iter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: David Davó --- .../models/wide_deep/wide_deep_utils.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/recommenders/models/wide_deep/wide_deep_utils.py b/recommenders/models/wide_deep/wide_deep_utils.py index ba7cbefe6..7544472b7 100644 --- a/recommenders/models/wide_deep/wide_deep_utils.py +++ b/recommenders/models/wide_deep/wide_deep_utils.py @@ -212,6 +212,7 @@ def __init__( optimizer_params: dict[str, Any] = dict(), disable_batch_progress: bool = False, disable_iter_progress: bool = False, + eval_epoch: int = 1, model_dir: Optional[Union[str, Path]] = None, save_model_iter: int = -1, prediction_col: str = DEFAULT_PREDICTION_COL, @@ -255,6 +256,7 @@ def __init__( self.current_epoch = 0 self.epochs = epochs + self.eval_epoch = eval_epoch self.model_dir = Path(model_dir) if model_dir else None self.save_model_iter = save_model_iter @@ -297,7 +299,7 @@ def fit(self): pbar.update() pbar.set_postfix( train_loss=self.train_loss_history[-1], - test_loss=self.test_loss_history[-1], + test_loss=self.test_loss_history[-1][1], ) if self.save_model_iter != -1 and self.current_epoch % self.save_model_iter == 0: @@ -341,19 +343,20 @@ def fit_step(self): self.train_loss_history.append(train_loss / len(self.train_dataloader)) self.model.eval() - num_batches = len(self.test_dataloader) - test_loss = 0 - - with torch.no_grad(): - for X, y in self.test_dataloader: - pred = self.model( - X['interactions'], - continuous_features=X.get('continuous_features', None), - ) - test_loss += self.loss_fn(pred, y).item() - - test_loss /= num_batches - self.test_loss_history.append(test_loss) + if self.eval_epoch != -1 and self.current_epoch%self.eval_epoch == 0: + num_batches = len(self.test_dataloader) + test_loss = 0 + + with torch.no_grad(): + for X, y in self.test_dataloader: + pred = self.model( + X['interactions'], + continuous_features=X.get('continuous_features', None), + ) + test_loss += self.loss_fn(pred, y).item() + + test_loss /= num_batches + self.test_loss_history.append((self.current_epoch, test_loss)) self.current_epoch += 1