Skip to content

Commit

Permalink
porting to nowplaying
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII committed Mar 18, 2020
1 parent 9c10283 commit 43cbc10
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 37 deletions.
23 changes: 14 additions & 9 deletions examples/pytorch/pinsage/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import scipy.sparse as ssp

def train_test_split_by_time(g, column):
def train_test_split_by_time(g, column, etype, itype):
n_edges = g.number_of_edges(etype)
with g.local_scope():
def splits(edges):
num_edges, count = edges.data['train_mask'].shape
Expand All @@ -27,15 +28,15 @@ def splits(edges):
val_mask[x, sorted_idx[:, -2]] = True
return {'train_mask': train_mask, 'val_mask': val_mask, 'test_mask': test_mask}

g.edges['watched'].data['train_mask'] = torch.ones(n_edges, dtype=torch.bool)
g.edges['watched'].data['val_mask'] = torch.zeros(n_edges, dtype=torch.bool)
g.edges['watched'].data['test_mask'] = torch.zeros(n_edges, dtype=torch.bool)
g.nodes['movie'].data['count'] = g.in_degrees(etype='watched')
g.group_apply_edges('src', splits, etype='watched')
g.edges[etype].data['train_mask'] = torch.ones(n_edges, dtype=torch.bool)
g.edges[etype].data['val_mask'] = torch.zeros(n_edges, dtype=torch.bool)
g.edges[etype].data['test_mask'] = torch.zeros(n_edges, dtype=torch.bool)
g.nodes[itype].data['count'] = g.in_degrees(etype=etype)
g.group_apply_edges('src', splits, etype=etype)

train_indices = g.filter_edges(lambda edges: edges.data['train_mask'], etype='watched')
val_indices = g.filter_edges(lambda edges: edges.data['val_mask'], etype='watched')
test_indices = g.filter_edges(lambda edges: edges.data['test_mask'], etype='watched')
train_indices = g.filter_edges(lambda edges: edges.data['train_mask'], etype=etype)
val_indices = g.filter_edges(lambda edges: edges.data['val_mask'], etype=etype)
test_indices = g.filter_edges(lambda edges: edges.data['test_mask'], etype=etype)

return train_indices, val_indices, test_indices

Expand Down Expand Up @@ -70,3 +71,7 @@ def build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):
test_matrix = ssp.coo_matrix((np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items))

return val_matrix, test_matrix

def linear_normalize(values):
return (values - values.min(0, keepdims=True)) / \
(values.max(0, keepdims=True) - values.min(0, keepdims=True))
2 changes: 1 addition & 1 deletion examples/pytorch/pinsage/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(self, hidden_dims, n_layers):

def forward(self, blocks, h):
for layer, block in zip(self.convs, blocks):
h_dst = h[:block.number_of_nodes(block.dsttype)]
h_dst = h[:block.number_of_nodes('DST/' + block.ntypes[0])]
h = layer(block, (h, h_dst), block.edata['weights'])
return h

Expand Down
24 changes: 7 additions & 17 deletions examples/pytorch/pinsage/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
import dgl
import tqdm
Expand Down Expand Up @@ -98,11 +99,13 @@ def train(dataset, args):
# Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size)
batch_sampler_it = iter(batch_sampler)
neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
dataloader = DataLoader(batch_sampler, collate_fn=collator, num_workers=args.num_workers)
dataloader_it = iter(dataloader)

# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device)
Expand All @@ -116,18 +119,7 @@ def train(dataset, args):
for epoch_id in range(args.num_epochs):
model.train()
for batch_id in tqdm.trange(args.batches_per_epoch):
heads, tails, neg_tails = next(batch_sampler_it)
# Train
# Construct multilayer neighborhood via PinSAGE...
pos_graph, neg_graph, blocks = \
neighbor_sampler.sample_from_item_pairs(heads, tails, neg_tails)
# For the first block (which is closest to the input), copy the features from
# the original graph as well as the texts.
sampler_module.assign_simple_node_features(blocks[0].srcdata, g, item_ntype)
sampler_module.assign_textual_node_features(blocks[0].srcdata, textset, item_ntype)
sampler_module.assign_simple_node_features(blocks[-1].dstdata, g, item_ntype)
sampler_module.assign_textual_node_features(blocks[-1].dstdata, textset, item_ntype)

pos_graph, neg_graph, blocks = next(dataloader_it)
# Copy to GPU
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
Expand All @@ -145,10 +137,7 @@ def train(dataset, args):
h_item_batches = []
for item_batch in item_batches:
blocks = neighbor_sampler.sample_blocks(item_batch)
sampler_module.assign_simple_node_features(blocks[0].srcdata, g, item_ntype)
sampler_module.assign_textual_node_features(blocks[0].srcdata, textset, item_ntype)
sampler_module.assign_simple_node_features(blocks[-1].dstdata, g, item_ntype)
sampler_module.assign_textual_node_features(blocks[-1].dstdata, textset, item_ntype)
sampler_module.assign_features_to_blocks(blocks, g, textset, item_ntype)

for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
Expand All @@ -173,6 +162,7 @@ def train(dataset, args):
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0"
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--batches-per-epoch', type=int, default=20000)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('-k', type=int, default=10)
args = parser.parse_args()
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/pinsage/process_movielens1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@
# Train-validation-test split
# This is a little bit tricky as we want to select the last interaction for test, and the
# second-to-last interaction for validation.
n_edges = g.number_of_edges('watched')
train_indices, val_indices, test_indices = train_test_split_by_time(g, 'timestamp')
train_indices, val_indices, test_indices = train_test_split_by_time(g, 'timestamp', 'watched', 'movie')

# Build the graph with training interactions only.
train_g = build_train_graph(g, train_indices, 'user', 'movie', 'watched', 'watched-by')
Expand Down
21 changes: 14 additions & 7 deletions examples/pytorch/pinsage/process_nowplaying_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import pandas as pd
import scipy.sparse as ssp
import pickle
from data_utils import *
from builder import PandasGraphBuilder

Expand All @@ -18,13 +19,13 @@
output_path = args.output_path

data = pd.read_csv(os.path.join(directory, 'context_content_features.csv'))
track_feature_cols = list(x.columns[1:13])
track_feature_cols = list(data.columns[1:13])
data = data[['user_id', 'track_id', 'created_at'] + track_feature_cols].dropna()

users = data[['user_id']].drop_duplicates()
tracks = data[['track_id'] + list(x.columns[1:13])].drop_duplicates()
tracks = data[['track_id'] + track_feature_cols].drop_duplicates()
assert tracks['track_id'].value_counts().max() == 1
tracks = tracks.astype({'mode': 'int64', 'key': 'int64'})
tracks = tracks.astype({'mode': 'int64', 'key': 'int64', 'artist_id': 'category'})
events = data[['user_id', 'track_id', 'created_at']]
events['created_at'] = events['created_at'].values.astype('datetime64[s]').astype('int64')

Expand All @@ -36,16 +37,22 @@

g = graph_builder.build()

float_cols = []
for col in tracks.columns:
if col == 'artist_id':
if col == 'track_id':
continue
elif col == 'artist_id':
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].cat.codes.values)
elif tracks.dtypes[col] == 'float64':
g.nodes['track'].data[col] = torch.FloatTensor(tracks[col].values)
float_cols.append(col)
else:
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].values)
g.nodes['track'].data['song_features'] = torch.FloatTensor(linear_normalize(tracks[float_cols].values))
g.edges['listened'].data['created_at'] = torch.LongTensor(events['created_at'].values)
g.edges['listened-by'].data['created_at'] = torch.LongTensor(events['created_at'].values)

n_edges = g.number_of_edges('listened')
train_indices, val_indices, test_indices = train_test_split_by_time(g, 'created_at')
train_indices, val_indices, test_indices = train_test_split_by_time(g, 'created_at', 'listened', 'track')
train_g = build_train_graph(g, train_indices, 'user', 'track', 'listened', 'listened-by')
val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, 'user', 'track', 'listened')
Expand All @@ -54,7 +61,7 @@
'train-graph': train_g,
'val-matrix': val_matrix,
'test-matrix': test_matrix,
'item-texts': None,
'item-texts': {},
'item-images': None,
'user-type': 'user',
'item-type': 'track',
Expand Down
26 changes: 25 additions & 1 deletion examples/pytorch/pinsage/sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np
import dgl
import torch
from torch.utils.data import IterableDataset, DataLoader

def compact_and_copy(frontier, seeds):
block = dgl.to_block(frontier, seeds)
for col, data in frontier.edata.items():
block.edata[col] = data
return block

class ItemToItemBatchSampler(object):
class ItemToItemBatchSampler(IterableDataset):
def __init__(self, g, user_type, item_type, batch_size):
self.g = g
self.user_type = user_type
Expand Down Expand Up @@ -117,3 +118,26 @@ def assign_textual_node_features(ndata, textset, ntype):

ndata[field_name] = tokens
ndata[field_name + '__len'] = lengths

def assign_features_to_blocks(blocks, g, textset, ntype):
# For the first block (which is closest to the input), copy the features from
# the original graph as well as the texts.
assign_simple_node_features(blocks[0].srcdata, g, ntype)
assign_textual_node_features(blocks[0].srcdata, textset, ntype)
assign_simple_node_features(blocks[-1].dstdata, g, ntype)
assign_textual_node_features(blocks[-1].dstdata, textset, ntype)

class PinSAGECollator(object):
def __init__(self, sampler, g, ntype, textset):
self.sampler = sampler
self.ntype = ntype
self.g = g
self.textset = textset

def __call__(self, batches):
heads, tails, neg_tails = batches[0]
# Construct multilayer neighborhood via PinSAGE...
pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(heads, tails, neg_tails)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)

return pos_graph, neg_graph, blocks

0 comments on commit 43cbc10

Please sign in to comment.