The PyTorch implementation of DeePRed (@WWW'21), an algorithm for next item prediction in temporal interaction networks.
Zekarias T. Kefato, Sarunas Girdzijauskas, Nasrullah Sheikh, and Alberto Montresor. 2021. Dynamic Embeddings
for Interaction Prediction. In Proceedings of the Web Conference 2021 (WWW’21), April 19–23, 2021, Ljubljana,
Slovenia. ACM, New York, NY, USA, 10 pages. https://doi.org/10.1145/3442381.3450020
- Python 3.6+
- PyTorch 1.4+
- Numpy 1.17.2+
- Networkx 2.3+
The following step trains DeepRed using the wikipedia dataset
$ cd deepred
$ python train.py --name wikipedia
And the trained model can be evaluated as
$ cd deepred
$ python evaluate.py --name wikipedia --epoch 2
This evaluates the model saved after training DeePRed
for 2 epochs on the wikipedia dataset.
To reproduce the results for the three datasets used in the paper, replace the value for the --name
argument with either of these, reddit
, wikipedia
, lastfm
.
You don't need to have a local copy of these datasets, they will be automatically downloaded during the first use.
Otherwise, if you have them locally, please follow Using local datasets
guideline
Alternatively, the following shell scripts can be used to achieve the above
$ bash train.sh wikipedia
$ bash eval.sh wikipedia
$ bash deepred.sh wikipedia
To use DeePRed
for datasets other than the above three, please follow Using local datasets
guideline
The following arguments can be used for training and evaluating DeePRed
.
--root:
A path to the directory where the dataset will be saved.
Default is ../../../data/deepred
.
--name:
The name of the dataset. Default is wikipedia
.
--batch-size:
Batch size. Default is 512.
--workers:
Number of parallel workers. Default is 32.
--verbose:
Logger is verbose. Default is False.
--nbr-size:
The neighborhood size or the value for k (number of short-term events) in the paper. Default is 100.
--dim:
The size of the embedding dimension. Default is 128.
--lr:
Learning rate. Default is 0.0001
--reg-cof:
A regularization coefficient to avoid collapse into a subspace or
--dropout:
A dropout rate to avoid overfitting. Default is 0.5
--epochs:
The number of epochs. Default is 2.
--temporal:
A flag to indicate that the dataset is a temporal interacton network. Default is True.
--static:
A flag to indicate that the dataset is a static interacton network. Default is False.
Support for static networks is temporarily suspended. If you wish to use deepred for
static networks, please use the legacy code inside the legacy
directory.
--k:
k for recall@k. Default is 10
--epoch:
The specific epoch you want to evaluate. Default is 1.
In the following we put the directory structures for both the program and data for ease of navigation.
./
├── deepred/
│ ├── utils/
│ │ ├── datasets.py
│ │ └── helpers.py
│ ├── model/
│ │ └── deepred.py
│ ├── train.py
│ ├── evaluate.py
│ ├── train.sh
│ ├── eval.sh
│ ├── deepred.sh
│ └── README.md
/<root>/
├── <name>/
│ ├── raw/
│ │ └── <name>.csv
│ ├── processed/
│ │ ├── data.pt
│ │ ├── nodes.pt
│ │ └── splits.pt
│ ├── model/
│ │ ├── deepred.epoch.<epoch_num>.pt
│ │ └── deepred.config.<name>.json
│ ├── result/
│ │ └── next.item.prediction.result.epoch.<epoch_num>.txt
If you have a local copy of the three datasets, reddit
, wikipedia
, lastfm
or you would like to use DeePRed
for other datasets, just put <name>.csv
file under the /<root>/<name>/raw/
directory. For example, if you are working with the wikipedia
dataset, you should have the following file
/<root>/wikipedia/raw/wikipedia.csv
and it should be a comma separated temporal interaction network.
Example:
user,item,timestamp
5,16,1
40,6,4
14,5,7
Then, simply run
$ cd deepred
$ python train.py --name wikipedia
for training DeePRed
on the wikipedia
dataset, and run
$ cd deepred
$ python evaluate.py --name wikipedia
to evaluate it.
If you find our work useful, please consider citing as
@misc{kefato2021dynamic,
title={Dynamic Embeddings for Interaction Prediction},
author={Zekarias T. Kefato and Sarunas Girdzijauskas and Nasrullah Sheikh and Alberto Montresor},
year={2021},
eprint={2011.05208},
archivePrefix={arXiv},
primaryClass={cs.LG}
}