From 3df2dfe66f10f2a2fd150b0e789618419ece04ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Dav=C3=B3?= Date: Sun, 6 Oct 2024 15:56:47 +0000 Subject: [PATCH] Save wide and deep model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: David Davó --- .../models/wide_deep/wide_deep_utils.py | 61 ++++++++++++++++--- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/recommenders/models/wide_deep/wide_deep_utils.py b/recommenders/models/wide_deep/wide_deep_utils.py index 17be4b141..adc0fdc9d 100644 --- a/recommenders/models/wide_deep/wide_deep_utils.py +++ b/recommenders/models/wide_deep/wide_deep_utils.py @@ -1,7 +1,8 @@ # Copyright (c) Recommenders contributors. # Licensed under the MIT License. -from typing import Tuple, Dict, Optional, Any +from typing import Tuple, Dict, Optional, Any, Union from dataclasses import dataclass, field +from pathlib import Path import numpy as np import pandas as pd @@ -15,7 +16,7 @@ import recommenders.utils.python_utils as pu import recommenders.utils.torch_utils as tu -@dataclass(kw_only=True, frozen=True) +@dataclass(frozen=True) class WideAndDeepHyperParams: user_dim: int = 32 item_dim: int = 32 @@ -182,12 +183,14 @@ def __init__( n_items: Optional[int] = None, epochs: int = 100, batch_size: int = 128, - loss_fn: str | nn.Module = 'mse', + loss_fn: Union[str, nn.Module] = 'mse', optimizer: str = 'sgd', l1: float = 0.0001, optimizer_params: dict[str, Any] = dict(), disable_batch_progress: bool = False, disable_iter_progress: bool = False, + model_dir: Optional[Union[str, Path]] = None, + save_model_iter: int = -1, prediction_col: str = DEFAULT_PREDICTION_COL, ): self.n_users = n_users or max(train.n_users, test.n_users) @@ -230,6 +233,10 @@ def __init__( self.current_epoch = 0 self.epochs = epochs + self.model_dir = Path(model_dir) if model_dir else None + self.save_model_iter = save_model_iter + self._check_save_model() + self.train_loss_history = list() self.test_loss_history = list() @@ -237,9 +244,24 @@ def __init__( def user_col(self) -> str: return self.train.user_col + @property + def model_path(self) -> Path: + return self.model_dir / f'wide_deep_state_{self.current_epoch:05d}.pth' + @property def item_col(self) -> str: return self.train.item_col + + def _check_save_model(self) -> bool: + # The two conditions should be True/False at the same time + if (self.save_model_iter == -1) != (self.model_dir is None): + raise ValueError('You should set both save_model_iter and model_dir at the same time') + + if self.model_dir is not None: + # Check that save works + self.save() + + return True def fit(self): if self.current_epoch >= self.epochs: @@ -255,6 +277,26 @@ def fit(self): test_loss=self.test_loss_history[-1], ) + if self.save_model_iter != -1 and self.current_epoch % self.save_model_iter == 0: + self.save() + + def save(self, model_path=None): + model_path = Path(model_path) if model_path else self.model_path + model_path.parent.mkdir(exist_ok=True) + + torch.save(self.model.state_dict(), model_path) + + def load(self, model_path=None): + if model_path is None: + print('Model path not specified, automatically loading from model dir') + model_path = max(self.model_dir.glob('*.pth'), key=lambda f: int(f.stem.split('_')[-1])) + print(' Loading', model_path) + else: + model_path = Path(model_path) + + self.model.load_state_dict(torch.load(model_path)) + self.current_epoch = int(model_path.stem.split('_')[-1]) + def fit_step(self): self.model.train() @@ -292,9 +334,7 @@ def fit_step(self): self.current_epoch += 1 - def recommend_k_items( - self, user_ids=None, item_ids=None, top_k=10, remove_seen=True, - ): + def _get_uip_cont(self, user_ids, item_ids, remove_seen: bool): if user_ids is None: user_ids = np.arange(1, self.n_users) if item_ids is None: @@ -316,8 +356,13 @@ def recommend_k_items( cont_features = torch.from_numpy( np.stack(uip.map(lambda x: self.train._get_continuous_features(*x)).values) ) - - uip = uip.to_frame(index=False) + + return uip.to_frame(index=False), cont_features + + def recommend_k_items( + self, user_ids=None, item_ids=None, top_k=10, remove_seen=True, + ): + uip, cont_features = self._get_uip_cont(user_ids, item_ids, remove_seen) with torch.no_grad(): uip[self.prediction_col] = self.model(