-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathbaseline_link.py
96 lines (82 loc) · 3.15 KB
/
baseline_link.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
from collections import Counter
from typing import Dict
import numpy as np
import pandas as pd
from torch_geometric.seed import seed_everything
from relbench.base import Dataset, RecommendationTask, Table
from relbench.datasets import get_dataset
from relbench.tasks import get_task
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-stack")
parser.add_argument("--task", type=str, default="user-post-comment")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
seed_everything(args.seed)
dataset: Dataset = get_dataset(args.dataset, download=True)
task: RecommendationTask = get_task(args.dataset, args.task, download=True)
train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")
trainval_table_df = pd.concat([train_table.df, val_table.df], axis=0)
trainval_table = Table(
df=trainval_table_df,
fkey_col_to_pkey_table=train_table.fkey_col_to_pkey_table,
pkey_col=train_table.pkey_col,
time_col=train_table.time_col,
)
def past_visit_aggr(x):
lst_cat = []
for e in list(x):
lst_cat.extend(e)
counter = Counter(lst_cat)
topk = [elem for elem, _ in counter.most_common(task.eval_k)]
# padding
if len(topk) < task.eval_k:
topk.extend([-1] * (task.eval_k - len(topk)))
return topk
def evaluate(
train_table: Table,
pred_table: Table,
name: str,
) -> Dict[str, float]:
is_test = task.dst_entity_col not in pred_table.df
if name == "past_visit":
"""Predict the most frequently-visited dst nodes per each src node."""
df = (
train_table.df.groupby(task.src_entity_col)[task.dst_entity_col]
.apply(past_visit_aggr)
.reset_index(name="__pred__")
)
pred_ser = pd.merge(pred_table.df, df, how="left", on=task.src_entity_col)[
"__pred__"
]
# Replace NaN with [-1, -1, ..., -1] prediction
pred_ser = pred_ser.apply(
lambda x: x if isinstance(x, list) else [-1] * task.eval_k
)
pred = np.stack(pred_ser.values)
elif name == "global_popularity":
"""Predict the globally most visited dst nodes and predict them across the src
nodes."""
lst_cat = []
for lst in train_table.df[task.dst_entity_col]:
lst_cat.extend(lst)
counter = Counter(lst_cat)
topk = [elem for elem, _ in counter.most_common(task.eval_k)]
# padding
if len(topk) < task.eval_k:
topk.extend([-1] * (task.eval_k - len(topk)))
pred = np.tile(np.array(topk), (len(pred_table), 1))
else:
raise ValueError("Unknown eval name called {name}.")
return task.evaluate(pred, None if is_test else pred_table)
eval_name_list = ["past_visit", "global_popularity"]
for name in eval_name_list:
train_metrics = evaluate(train_table, train_table, name=name)
val_metrics = evaluate(train_table, val_table, name=name)
test_metrics = evaluate(trainval_table, test_table, name=name)
print(f"{name}:")
print(f"Train: {train_metrics}")
print(f"Val: {val_metrics}")
print(f"Test: {test_metrics}")