Skip to content

Commit

Permalink
Speedup 10x WideAndDeep._get_uip_cont
Browse files Browse the repository at this point in the history
Signed-off-by: David Davó <[email protected]>
  • Loading branch information
daviddavo committed Oct 6, 2024
1 parent 3df2dfe commit cc1e2f9
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions recommenders/models/wide_deep/wide_deep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,17 @@ def __init__(
):
self._check_cols_df('ratings', ratings, [user_col, item_col, rating_col])
self._check_cols_df('item_feat', item_feat, [item_col])
self._check_cols_df('user_feat', user_feat, [user_col])
self._check_cont_features('item_feat', item_feat, item_col)
self._check_cont_features('user_feat', user_feat, user_col)

self.user_col = user_col
self.item_col = item_col
self.rating_col = rating_col
self.ratings = ratings.copy()
self.item_feat = item_feat.set_index(item_col).copy() if item_feat is not None else pd.DataFrame()
self.user_feat = user_feat.set_index(user_col).copy() if user_feat is not None else pd.DataFrame()
self.n_cont_features = n_cont_features or len(self._get_continuous_features(self.item_feat.index.min(), self.user_feat.index.min()))
self.n_cont_features = n_cont_features or self._get_continuous_features([self.item_feat.index.min()], [self.user_feat.index.min()]).shape[1]

self.n_users = n_users or ratings[user_col].max()+1
self.n_items = n_items or ratings[item_col].max()+1
Expand All @@ -140,23 +143,42 @@ def _check_cols_df(df_name: str, df: Optional[pd.DataFrame], cols: list[str]) ->
raise ValueError(f"Column '{c}' is not present on {df_name}")

return True

@staticmethod
def _check_cont_features(df_name: str, df: Optional[pd.DataFrame], col: str) -> bool:
if df is None or df.empty:
return True

for c in df.columns:
if c == col:
continue

# Check that dtype is float
if not pd.api.types.is_numeric_dtype(df[c].dtype):
raise ValueError(f"Column '{c}' from {df_name} has dtype {df[c].dtype}. Only numeric dtypes are allowed.")

def __len__(self):
return len(self.ratings)

def _get_continuous_features(self, item_id, user_id) -> np.array:
def _get_continuous_features(self, item_ids, user_ids) -> np.array:
if len(item_ids) != len(user_ids):
raise ValueError('item_ids and user_ids should have same length')

if self.item_feat.empty and self.user_feat.empty:
return np.array([])

# Put empty array so concat doesn't fail
continuous_features = [np.array([])]
continuous_features = []

if not self.item_feat.empty:
feats = self.item_feat.loc[item_id]
continuous_features.extend(np.array(f).reshape(-1) for f in feats)
feats = self.item_feat.loc[item_ids]
continuous_features.extend(feats.values)

if not self.user_feat.empty:
feats = self.user_feat.loc[user_id]
continuous_features.extend(np.array(f).reshape(-1) for f in feats)
feats = self.user_feat.loc[user_ids]
continuous_features.extend(feats.values)

return np.concatenate(continuous_features).astype('float32')
return np.stack(continuous_features)

def __getitem__(self, idx):
# TODO: Get additional embeddings too (e.g: user demographics)
Expand All @@ -167,7 +189,8 @@ def __getitem__(self, idx):
}

if self.n_cont_features:
ret['continuous_features'] = self._get_continuous_features(item[self.item_col], item[self.user_col])
# Reshape because it is only one item
ret['continuous_features'] = self._get_continuous_features([item[self.item_col]], [item[self.user_col]]).reshape(-1)

return ret, self.ratings[self.rating_col].iloc[idx]

Expand Down Expand Up @@ -350,14 +373,15 @@ def _get_uip_cont(self, user_ids, item_ids, remove_seen: bool):
self.train.ratings.set_index([self.user_col, self.item_col]).index
)

uip = uip.to_frame(index=False)

cont_features = None
# TODO: [!] CACHE THE "RANKING POOL" (uip and cont_features) IT TAKES SEVERAL SECONDS TO GEN
if self.train.n_cont_features > 0:
cont_features = torch.from_numpy(
np.stack(uip.map(lambda x: self.train._get_continuous_features(*x)).values)
self.train._get_continuous_features(uip.values[:,0], uip.values[:,1])
)

return uip.to_frame(index=False), cont_features
return uip, cont_features

def recommend_k_items(
self, user_ids=None, item_ids=None, top_k=10, remove_seen=True,
Expand Down

0 comments on commit cc1e2f9

Please sign in to comment.