Skip to content

Commit

Permalink
WideDeep Avoid eval every iter
Browse files Browse the repository at this point in the history
Signed-off-by: David Davó <[email protected]>
  • Loading branch information
daviddavo committed Oct 6, 2024
1 parent cc1e2f9 commit f713841
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions recommenders/models/wide_deep/wide_deep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__(
optimizer_params: dict[str, Any] = dict(),
disable_batch_progress: bool = False,
disable_iter_progress: bool = False,
eval_epoch: int = 1,
model_dir: Optional[Union[str, Path]] = None,
save_model_iter: int = -1,
prediction_col: str = DEFAULT_PREDICTION_COL,
Expand Down Expand Up @@ -255,6 +256,7 @@ def __init__(

self.current_epoch = 0
self.epochs = epochs
self.eval_epoch = eval_epoch

self.model_dir = Path(model_dir) if model_dir else None
self.save_model_iter = save_model_iter
Expand Down Expand Up @@ -297,7 +299,7 @@ def fit(self):
pbar.update()
pbar.set_postfix(
train_loss=self.train_loss_history[-1],
test_loss=self.test_loss_history[-1],
test_loss=self.test_loss_history[-1][1],
)

if self.save_model_iter != -1 and self.current_epoch % self.save_model_iter == 0:
Expand Down Expand Up @@ -341,19 +343,20 @@ def fit_step(self):
self.train_loss_history.append(train_loss / len(self.train_dataloader))
self.model.eval()

num_batches = len(self.test_dataloader)
test_loss = 0

with torch.no_grad():
for X, y in self.test_dataloader:
pred = self.model(
X['interactions'],
continuous_features=X.get('continuous_features', None),
)
test_loss += self.loss_fn(pred, y).item()

test_loss /= num_batches
self.test_loss_history.append(test_loss)
if self.eval_epoch != -1 and self.current_epoch%self.eval_epoch == 0:
num_batches = len(self.test_dataloader)
test_loss = 0

with torch.no_grad():
for X, y in self.test_dataloader:
pred = self.model(
X['interactions'],
continuous_features=X.get('continuous_features', None),
)
test_loss += self.loss_fn(pred, y).item()

test_loss /= num_batches
self.test_loss_history.append((self.current_epoch, test_loss))

self.current_epoch += 1

Expand Down

0 comments on commit f713841

Please sign in to comment.