|
| 1 | +from utility.argument_check import * |
| 2 | +import torch.nn as nn |
| 3 | +import git |
| 4 | +from models.TAFA import tafa |
| 5 | +from utility.data import * |
| 6 | +from scipy.sparse import csr_matrix |
| 7 | +from utility.model_helper import binarize_dataset, convert_to_rating_matrix_from_lists |
| 8 | +from utility.progress import WorkSplitter |
| 9 | +from utility.predictor import predict |
| 10 | +from utility.metrics import evaluate |
| 11 | + |
| 12 | + |
| 13 | +def main(args): |
| 14 | + progress = WorkSplitter() |
| 15 | + args_dictionary = vars(args) |
| 16 | + repo = git.Repo(search_parent_directories=True) |
| 17 | + sha = repo.head.object.hexsha |
| 18 | + print('current git hash is {0}'.format(sha)) |
| 19 | + print('loading data directory: {0}'.format(args.data_directory)) |
| 20 | + print("Algorithm: {0}".format(args.model)) |
| 21 | + progress.section("Loading Data") |
| 22 | + if 'yelp' in args.data_directory: |
| 23 | + train_users, train_items, train_ratings, val_users, val_items, val_ratings, user_documents, item_documents, word_dict, word_embeddings = load_dataset_yelp(args) |
| 24 | + elif 'amazon' in args.data_directory: |
| 25 | + train_users, train_items, train_ratings, val_users, val_items, val_ratings, user_documents, item_documents, word_dict, word_embeddings = load_dataset_amazon(args) |
| 26 | + else: |
| 27 | + raise NotImplementedError |
| 28 | + |
| 29 | + if args.one_class: |
| 30 | + train_users, train_items, train_ratings = binarize_dataset(args.one_class_threshold, train_users, train_items, |
| 31 | + train_ratings) |
| 32 | + val_users, val_items, val_ratings = binarize_dataset(args.one_class_threshold, val_users, val_items, |
| 33 | + val_ratings) |
| 34 | + |
| 35 | + train = (train_users, train_items, train_ratings) |
| 36 | + validation = (val_users, val_items, val_ratings) |
| 37 | + document_data = (user_documents, item_documents, word_dict, word_embeddings) |
| 38 | + |
| 39 | + args_dictionary['train'] = train |
| 40 | + args_dictionary['val'] = validation |
| 41 | + args_dictionary['document_data'] = document_data |
| 42 | + training_result = tafa(**args_dictionary) |
| 43 | + |
| 44 | + progress.section("Predict") |
| 45 | + # generate train and validation matrices |
| 46 | + num_users = len(user_documents) |
| 47 | + num_items = len(item_documents) |
| 48 | + matrix_train = convert_to_rating_matrix_from_lists(num_users, num_items, train_users, train_items, |
| 49 | + train_ratings, True) |
| 50 | + |
| 51 | + matrix_val = convert_to_rating_matrix_from_lists(num_users, num_items, val_users, val_items, val_ratings, True) |
| 52 | + matrix_train_csr = csr_matrix(matrix_train) |
| 53 | + matrix_val_csr = csr_matrix(matrix_val) |
| 54 | + |
| 55 | + prediction = predict(training_result['best_prediction'], None, None, args.top_k, matrix_train_csr) |
| 56 | + # np.save('{2}/R_{0}_{1}.npy'.format(args.model, args.rank, 'latent'), prediction) |
| 57 | + progress.subsection("Evaluation") |
| 58 | + metric_names = ['R-Precision', 'NDCG', 'Clicks', 'Recall', 'Precision', 'MAP'] |
| 59 | + result = evaluate(prediction, matrix_val_csr, metric_names, [5, 10, 15, 20, 30, 40, 50]) |
| 60 | + for metric in result.keys(): |
| 61 | + print("{0}:{1}".format(metric, result[metric])) |
| 62 | + print('best iteration is {0}'.format(training_result['best_iteration'])) |
| 63 | + |
| 64 | + |
| 65 | +if __name__ == "__main__": |
| 66 | + # Commandline arguments |
| 67 | + parser = argparse.ArgumentParser('main function') |
| 68 | + parser.add_argument('--data_directory', type=str, default="/home/joey/Documents/tafa/amazon_music/") |
| 69 | + parser.add_argument('--model', type=str, default="rnn_nceautorec_ee") |
| 70 | + parser.add_argument('--iteration', type=check_int_positive, default=200) |
| 71 | + parser.add_argument('--lam', type=check_float_positive, default=1) |
| 72 | + parser.add_argument('--rank', type=check_int_positive, default=500) |
| 73 | + parser.add_argument('--optimizer', type=str, default='Adam') |
| 74 | + parser.add_argument('--rec_learning_rate', type=check_float_positive, default=1e-4) |
| 75 | + parser.add_argument('--lang_learning_rate', type=check_float_positive, default=1e-4) |
| 76 | + parser.add_argument('--glove_embedding_size', type=check_int_positive, default=300) |
| 77 | + parser.add_argument('--elmo_embedding_size', type=int, default=None) |
| 78 | + parser.add_argument('--encoder_hidden_size', type=check_int_positive, default=64) |
| 79 | + parser.add_argument('--attention_size', type=check_int_positive, default=256) |
| 80 | + parser.add_argument('--dropout_p', type=float, default=0.5) |
| 81 | + parser.add_argument('--encoder_num_layers', type=check_int_positive, default=1) |
| 82 | + parser.add_argument('--encoder_dropout_rate', type=float, default=0.0) |
| 83 | + parser.add_argument('--encoder_rnn_type', default=nn.LSTM) |
| 84 | + parser.add_argument('--encoder_concat_layers', type=int, default=0) |
| 85 | + parser.add_argument('--attention_hidden_size', type=check_int_positive, default=256) |
| 86 | + parser.add_argument('--attention_dropout_rate', type=float, default=0.0) |
| 87 | + parser.add_argument('--mc_times', type=int, default=5) |
| 88 | + parser.add_argument('--separate', type=int, default=1) |
| 89 | + parser.add_argument('--decoder_hidden_size', type=check_int_positive, default=64) |
| 90 | + parser.add_argument('--decoder_dropout_rate', type=float, default=0.0) |
| 91 | + parser.add_argument('--decoder_loss', default=nn.CrossEntropyLoss(reduction='none')) |
| 92 | + parser.add_argument('--feature_mask', type=int, default=0) |
| 93 | + parser.add_argument('--custom_mask', type=int, default=0) |
| 94 | + parser.add_argument('--num_heads', type=check_int_positive, default=2) |
| 95 | + parser.add_argument('--activation_function', type=str, default='relu') |
| 96 | + parser.add_argument('--loss_function', type=str, default='mse') |
| 97 | + parser.add_argument('--one_class', type=int, default=1) |
| 98 | + parser.add_argument('--one_class_threshold', type=int, default=3) # binarize threshold |
| 99 | + parser.add_argument('--threshold', type=int, default=-1) # for nce |
| 100 | + parser.add_argument('--root', type=check_float_positive, default=1) |
| 101 | + parser.add_argument('--mode', type=str, default='joint') |
| 102 | + parser.add_argument('--nce_loss_positive_only', type=int, default=0) |
| 103 | + parser.add_argument('--predict_loss_positive_only', type=int, default=0) |
| 104 | + parser.add_argument('--sample_strategy', type=str, default='random_batch') |
| 105 | + parser.add_argument('--distance_dropout_prob', type=float, default=0.1) |
| 106 | + parser.add_argument('--alpha', type=float, default=0.2) |
| 107 | + parser.add_argument('--beta', type=float, default=1.0) |
| 108 | + parser.add_argument('--cml_embedding_dim', type=check_int_positive, default=5) |
| 109 | + parser.add_argument('--norm_factor', type=float, default=1.0) |
| 110 | + parser.add_argument('--momentum', type=float, default=0.0) |
| 111 | + parser.add_argument('--weight_decay', type=float, default=0.0) |
| 112 | + parser.add_argument('--max_len', type=check_int_positive, default=302) |
| 113 | + parser.add_argument('--rec_batch_size', type=int, default=100) |
| 114 | + parser.add_argument('--lang_feature_batch_size', type=int, default=32) |
| 115 | + parser.add_argument('--max_lang_iterations', type=int, default=128) |
| 116 | + parser.add_argument('--gradient_clipping', type=float, default=10.0) |
| 117 | + parser.add_argument('--rec_epoch', type=int, default=-1) |
| 118 | + parser.add_argument('--fix_encoder', type=int, default=0) |
| 119 | + parser.add_argument('--criteria', type=str, default='NDCG') |
| 120 | + parser.add_argument('--top_k', type=check_int_positive, default=50) |
| 121 | + args = parser.parse_args() |
| 122 | + main(args) |
0 commit comments