Skip to content

Commit b396778

Browse files
author
Joey
committed
first draft
1 parent dd3569d commit b396778

15 files changed

+2567
-1
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
*.pyc

README.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,28 @@
1-
# TAFA
1+
# Why_I_like_it Improvement
2+
For quick run/debug:
3+
```bash
4+
python main.py --model [model]
5+
```
6+
To tune hyperparameters:
7+
```bash
8+
python tune_parameters_new.py -y config/test.yml
9+
```
10+
where test.yml looks something like this:
11+
```bash
12+
parameters:
13+
model: rnn_autorec
14+
lam: [1, 10, 100, 000]
15+
rank: [50, 100, 200, 500]
16+
rec_batch_size: [100]
17+
rec_epoch: [10]
18+
iteration: [10]
19+
lang_learning_rate: [0.0001]
20+
rec_learning_rate: [0.0001]
21+
lang_feature_batch_size: [32]
22+
predict_loss_positive_only: [0]
23+
fix_encoder: [0, 1]
24+
topK: [[5, 10, 15, 20, 50]]
25+
criteria: [NDCG]
26+
metric: [[R-Precision, NDCG, Precision, Recall]]
27+
```
28+
Check utility/modelnames_new.py for all model names

main.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from utility.argument_check import *
2+
import torch.nn as nn
3+
import git
4+
from models.TAFA import tafa
5+
from utility.data import *
6+
from scipy.sparse import csr_matrix
7+
from utility.model_helper import binarize_dataset, convert_to_rating_matrix_from_lists
8+
from utility.progress import WorkSplitter
9+
from utility.predictor import predict
10+
from utility.metrics import evaluate
11+
12+
13+
def main(args):
14+
progress = WorkSplitter()
15+
args_dictionary = vars(args)
16+
repo = git.Repo(search_parent_directories=True)
17+
sha = repo.head.object.hexsha
18+
print('current git hash is {0}'.format(sha))
19+
print('loading data directory: {0}'.format(args.data_directory))
20+
print("Algorithm: {0}".format(args.model))
21+
progress.section("Loading Data")
22+
if 'yelp' in args.data_directory:
23+
train_users, train_items, train_ratings, val_users, val_items, val_ratings, user_documents, item_documents, word_dict, word_embeddings = load_dataset_yelp(args)
24+
elif 'amazon' in args.data_directory:
25+
train_users, train_items, train_ratings, val_users, val_items, val_ratings, user_documents, item_documents, word_dict, word_embeddings = load_dataset_amazon(args)
26+
else:
27+
raise NotImplementedError
28+
29+
if args.one_class:
30+
train_users, train_items, train_ratings = binarize_dataset(args.one_class_threshold, train_users, train_items,
31+
train_ratings)
32+
val_users, val_items, val_ratings = binarize_dataset(args.one_class_threshold, val_users, val_items,
33+
val_ratings)
34+
35+
train = (train_users, train_items, train_ratings)
36+
validation = (val_users, val_items, val_ratings)
37+
document_data = (user_documents, item_documents, word_dict, word_embeddings)
38+
39+
args_dictionary['train'] = train
40+
args_dictionary['val'] = validation
41+
args_dictionary['document_data'] = document_data
42+
training_result = tafa(**args_dictionary)
43+
44+
progress.section("Predict")
45+
# generate train and validation matrices
46+
num_users = len(user_documents)
47+
num_items = len(item_documents)
48+
matrix_train = convert_to_rating_matrix_from_lists(num_users, num_items, train_users, train_items,
49+
train_ratings, True)
50+
51+
matrix_val = convert_to_rating_matrix_from_lists(num_users, num_items, val_users, val_items, val_ratings, True)
52+
matrix_train_csr = csr_matrix(matrix_train)
53+
matrix_val_csr = csr_matrix(matrix_val)
54+
55+
prediction = predict(training_result['best_prediction'], None, None, args.top_k, matrix_train_csr)
56+
# np.save('{2}/R_{0}_{1}.npy'.format(args.model, args.rank, 'latent'), prediction)
57+
progress.subsection("Evaluation")
58+
metric_names = ['R-Precision', 'NDCG', 'Clicks', 'Recall', 'Precision', 'MAP']
59+
result = evaluate(prediction, matrix_val_csr, metric_names, [5, 10, 15, 20, 30, 40, 50])
60+
for metric in result.keys():
61+
print("{0}:{1}".format(metric, result[metric]))
62+
print('best iteration is {0}'.format(training_result['best_iteration']))
63+
64+
65+
if __name__ == "__main__":
66+
# Commandline arguments
67+
parser = argparse.ArgumentParser('main function')
68+
parser.add_argument('--data_directory', type=str, default="/home/joey/Documents/tafa/amazon_music/")
69+
parser.add_argument('--model', type=str, default="rnn_nceautorec_ee")
70+
parser.add_argument('--iteration', type=check_int_positive, default=200)
71+
parser.add_argument('--lam', type=check_float_positive, default=1)
72+
parser.add_argument('--rank', type=check_int_positive, default=500)
73+
parser.add_argument('--optimizer', type=str, default='Adam')
74+
parser.add_argument('--rec_learning_rate', type=check_float_positive, default=1e-4)
75+
parser.add_argument('--lang_learning_rate', type=check_float_positive, default=1e-4)
76+
parser.add_argument('--glove_embedding_size', type=check_int_positive, default=300)
77+
parser.add_argument('--elmo_embedding_size', type=int, default=None)
78+
parser.add_argument('--encoder_hidden_size', type=check_int_positive, default=64)
79+
parser.add_argument('--attention_size', type=check_int_positive, default=256)
80+
parser.add_argument('--dropout_p', type=float, default=0.5)
81+
parser.add_argument('--encoder_num_layers', type=check_int_positive, default=1)
82+
parser.add_argument('--encoder_dropout_rate', type=float, default=0.0)
83+
parser.add_argument('--encoder_rnn_type', default=nn.LSTM)
84+
parser.add_argument('--encoder_concat_layers', type=int, default=0)
85+
parser.add_argument('--attention_hidden_size', type=check_int_positive, default=256)
86+
parser.add_argument('--attention_dropout_rate', type=float, default=0.0)
87+
parser.add_argument('--mc_times', type=int, default=5)
88+
parser.add_argument('--separate', type=int, default=1)
89+
parser.add_argument('--decoder_hidden_size', type=check_int_positive, default=64)
90+
parser.add_argument('--decoder_dropout_rate', type=float, default=0.0)
91+
parser.add_argument('--decoder_loss', default=nn.CrossEntropyLoss(reduction='none'))
92+
parser.add_argument('--feature_mask', type=int, default=0)
93+
parser.add_argument('--custom_mask', type=int, default=0)
94+
parser.add_argument('--num_heads', type=check_int_positive, default=2)
95+
parser.add_argument('--activation_function', type=str, default='relu')
96+
parser.add_argument('--loss_function', type=str, default='mse')
97+
parser.add_argument('--one_class', type=int, default=1)
98+
parser.add_argument('--one_class_threshold', type=int, default=3) # binarize threshold
99+
parser.add_argument('--threshold', type=int, default=-1) # for nce
100+
parser.add_argument('--root', type=check_float_positive, default=1)
101+
parser.add_argument('--mode', type=str, default='joint')
102+
parser.add_argument('--nce_loss_positive_only', type=int, default=0)
103+
parser.add_argument('--predict_loss_positive_only', type=int, default=0)
104+
parser.add_argument('--sample_strategy', type=str, default='random_batch')
105+
parser.add_argument('--distance_dropout_prob', type=float, default=0.1)
106+
parser.add_argument('--alpha', type=float, default=0.2)
107+
parser.add_argument('--beta', type=float, default=1.0)
108+
parser.add_argument('--cml_embedding_dim', type=check_int_positive, default=5)
109+
parser.add_argument('--norm_factor', type=float, default=1.0)
110+
parser.add_argument('--momentum', type=float, default=0.0)
111+
parser.add_argument('--weight_decay', type=float, default=0.0)
112+
parser.add_argument('--max_len', type=check_int_positive, default=302)
113+
parser.add_argument('--rec_batch_size', type=int, default=100)
114+
parser.add_argument('--lang_feature_batch_size', type=int, default=32)
115+
parser.add_argument('--max_lang_iterations', type=int, default=128)
116+
parser.add_argument('--gradient_clipping', type=float, default=10.0)
117+
parser.add_argument('--rec_epoch', type=int, default=-1)
118+
parser.add_argument('--fix_encoder', type=int, default=0)
119+
parser.add_argument('--criteria', type=str, default='NDCG')
120+
parser.add_argument('--top_k', type=check_int_positive, default=50)
121+
args = parser.parse_args()
122+
main(args)

0 commit comments

Comments
 (0)