diff --git a/recommenders/models/wide_deep/wide_deep_utils.py b/recommenders/models/wide_deep/wide_deep_utils.py index adc0fdc9d..ba7cbefe6 100644 --- a/recommenders/models/wide_deep/wide_deep_utils.py +++ b/recommenders/models/wide_deep/wide_deep_utils.py @@ -116,6 +116,9 @@ def __init__( ): self._check_cols_df('ratings', ratings, [user_col, item_col, rating_col]) self._check_cols_df('item_feat', item_feat, [item_col]) + self._check_cols_df('user_feat', user_feat, [user_col]) + self._check_cont_features('item_feat', item_feat, item_col) + self._check_cont_features('user_feat', user_feat, user_col) self.user_col = user_col self.item_col = item_col @@ -123,7 +126,7 @@ def __init__( self.ratings = ratings.copy() self.item_feat = item_feat.set_index(item_col).copy() if item_feat is not None else pd.DataFrame() self.user_feat = user_feat.set_index(user_col).copy() if user_feat is not None else pd.DataFrame() - self.n_cont_features = n_cont_features or len(self._get_continuous_features(self.item_feat.index.min(), self.user_feat.index.min())) + self.n_cont_features = n_cont_features or self._get_continuous_features([self.item_feat.index.min()], [self.user_feat.index.min()]).shape[1] self.n_users = n_users or ratings[user_col].max()+1 self.n_items = n_items or ratings[item_col].max()+1 @@ -140,23 +143,42 @@ def _check_cols_df(df_name: str, df: Optional[pd.DataFrame], cols: list[str]) -> raise ValueError(f"Column '{c}' is not present on {df_name}") return True + + @staticmethod + def _check_cont_features(df_name: str, df: Optional[pd.DataFrame], col: str) -> bool: + if df is None or df.empty: + return True + + for c in df.columns: + if c == col: + continue + + # Check that dtype is float + if not pd.api.types.is_numeric_dtype(df[c].dtype): + raise ValueError(f"Column '{c}' from {df_name} has dtype {df[c].dtype}. Only numeric dtypes are allowed.") def __len__(self): return len(self.ratings) - def _get_continuous_features(self, item_id, user_id) -> np.array: + def _get_continuous_features(self, item_ids, user_ids) -> np.array: + if len(item_ids) != len(user_ids): + raise ValueError('item_ids and user_ids should have same length') + + if self.item_feat.empty and self.user_feat.empty: + return np.array([]) + # Put empty array so concat doesn't fail - continuous_features = [np.array([])] + continuous_features = [] if not self.item_feat.empty: - feats = self.item_feat.loc[item_id] - continuous_features.extend(np.array(f).reshape(-1) for f in feats) + feats = self.item_feat.loc[item_ids] + continuous_features.extend(feats.values) if not self.user_feat.empty: - feats = self.user_feat.loc[user_id] - continuous_features.extend(np.array(f).reshape(-1) for f in feats) + feats = self.user_feat.loc[user_ids] + continuous_features.extend(feats.values) - return np.concatenate(continuous_features).astype('float32') + return np.stack(continuous_features) def __getitem__(self, idx): # TODO: Get additional embeddings too (e.g: user demographics) @@ -167,7 +189,8 @@ def __getitem__(self, idx): } if self.n_cont_features: - ret['continuous_features'] = self._get_continuous_features(item[self.item_col], item[self.user_col]) + # Reshape because it is only one item + ret['continuous_features'] = self._get_continuous_features([item[self.item_col]], [item[self.user_col]]).reshape(-1) return ret, self.ratings[self.rating_col].iloc[idx] @@ -350,14 +373,15 @@ def _get_uip_cont(self, user_ids, item_ids, remove_seen: bool): self.train.ratings.set_index([self.user_col, self.item_col]).index ) + uip = uip.to_frame(index=False) + cont_features = None - # TODO: [!] CACHE THE "RANKING POOL" (uip and cont_features) IT TAKES SEVERAL SECONDS TO GEN if self.train.n_cont_features > 0: cont_features = torch.from_numpy( - np.stack(uip.map(lambda x: self.train._get_continuous_features(*x)).values) + self.train._get_continuous_features(uip.values[:,0], uip.values[:,1]) ) - return uip.to_frame(index=False), cont_features + return uip, cont_features def recommend_k_items( self, user_ids=None, item_ids=None, top_k=10, remove_seen=True,