Skip to content

Commit

Permalink
Merge pull request #374 from PaddlePaddle/graph4rec
Browse files Browse the repository at this point in the history
Graph4Rec: add a demo with slot features
  • Loading branch information
Yelrose authored Mar 22, 2022
2 parents 0121d96 + df71104 commit 229873d
Show file tree
Hide file tree
Showing 8 changed files with 2,125 additions and 17 deletions.
26 changes: 13 additions & 13 deletions apps/Graph4Rec/env_run/src/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pgl.distributed import DistGraphClient, DistGraphServer
from pgl.utils.data import Dataloader, StreamDataset
from pgl.distributed import helper
from pgl.utils.logger import log

from utils.config import prepare_config
from datasets.node import NodeGenerator
Expand Down Expand Up @@ -227,6 +228,11 @@ def _distcpu_train(self):


def index2segment_id(num_count):
"""
num_count: [1, 2, 1]
return:
[0, 1, 1, 2]
"""
index = np.cumsum(num_count, dtype="int64")
index = np.insert(index, 0, 0)

Expand Down Expand Up @@ -256,16 +262,13 @@ def __call__(self, batch_data):

for slot in self.config.slots:
feed_dict[slot].extend(src.feature[slot][0])
# lod_id2segment_id
segments = index2segment_id(src.feature[slot][1])
segments = segments + offset
# print(src.node_id, slot, src.feature[slot][1], segments)
segments = np.array(
src.feature[slot][1], dtype="int64") + offset
feed_dict["%s_info" % slot].append(segments)

feed_dict[slot].extend(pos.feature[slot][0])
# lod_id2segment_id
segments = index2segment_id(pos.feature[slot][1])
segments = segments + offset + 1
segments = np.array(
pos.feature[slot][1], dtype="int64") + offset + 1
feed_dict["%s_info" % slot].append(segments)

offset += 2
Expand Down Expand Up @@ -503,15 +506,12 @@ def __call__(self, batch_data):

for slot in self.config.slots:
feed_dict[slot].extend(src.feature[slot][0])
# lod_id2segment_id
segments = index2segment_id(src.feature[slot][1])
segments = segments + offset
# print(src.node_id, slot, src.feature[slot][1], segments)
segments = np.array(
src.feature[slot][1], dtype="int64") + offset
feed_dict["%s_info" % slot].append(segments)

feed_dict[slot].extend(pos.feature[slot][0])
# lod_id2segment_id
segments = index2segment_id(pos.feature[slot][1])
segments = np.array(pos.feature[slot][1], dtype="int64")
segments = segments + offset + len(
src.node_id) # for pos offset
feed_dict["%s_info" % slot].append(segments)
Expand Down
23 changes: 19 additions & 4 deletions apps/Graph4Rec/env_run/src/datasets/ego_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def ego_graph_sample(graph, node_ids, samples, edge_types):
node_feat={"node_id": np.array(
ego.node_id, dtype="int64"), })
ego.graph = pg

return ego_graph_list, list(unique_nodes)


Expand Down Expand Up @@ -166,35 +165,51 @@ def get_slots_feat(graph, nodes, slots):

def make_slot_feat(node_id, slots, node_feat_dict):
slot_dict = {}
slot_cout = {}
for slot in slots:
slot_dict[slot] = ([], [])
slot_cout[slot] = 0

for n in node_id:
nf = node_feat_dict[n]
for slot in slots:
if slot in nf:
slot_dict[slot][0].extend(nf[slot])
slot_dict[slot][1].append(len(nf[slot]))
seg = np.zeros(len(nf[slot]), dtype="int64") + slot_cout[slot]
slot_dict[slot][1].extend(seg)
else:
slot_dict[slot][0].append(0)
slot_dict[slot][1].append(1)
slot_dict[slot][1].extend([slot_cout[slot]])
slot_cout[slot] += 1

return slot_dict


class EgoGraphGenerator(object):
def __init__(self, config, graph, **kwargs):
self.config = config
self.graph = graph
self.rank = kwargs.get("rank", 0)
self.nrank = kwargs.get("nrank", 1)
self.kwargs = kwargs
self.edge_types = self.graph.get_edge_types()
self.sample_num_list = kwargs.get("sample_list",
self.config.sample_num_list)
log.info("sample_num_list is %s" % repr(self.sample_num_list))

def __call__(self, generator):
self.generator = generator

ego_generator = self.base_ego_generator
ego_generator = AsynchronousGenerator(ego_generator, maxsize=10000)

for data in ego_generator():
yield data

def base_ego_generator(self):
"""Input Batch of Walks
"""
for walks in generator():
for walks in self.generator():
# unique walk
nodes = []
for walk in walks:
Expand Down
9 changes: 9 additions & 0 deletions apps/Graph4Rec/env_run/src/datasets/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from utils.config import prepare_config
from datasets.node import NodeGenerator
from datasets.walk import WalkGenerator
from datasets.helper import stream_shuffle_generator, AsynchronousGenerator


class PairGenerator(object):
Expand All @@ -50,6 +51,14 @@ def __call__(self, generator):
"""
self.generator = generator

pair_generator = self.base_pair_generator
pair_generator = AsynchronousGenerator(pair_generator, maxsize=10000)

for pair in pair_generator():
yield pair

def base_pair_generator(self):

iterval = 20000000 * 24 // self.config.walk_len
pair_count = 0
for walks in self.generator():
Expand Down
Loading

0 comments on commit 229873d

Please sign in to comment.