Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def __init__(
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
logger.info(f"local rank: {local_rank}, node rank: {node_rank}, anchor node type: {anchor_node_type}, "
f"supervision edge type: {supervision_edge_type}, supervision node type: {supervision_node_type}")

missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys())
if missing_edge_types:
Expand All @@ -296,13 +298,15 @@ def __init__(
self._negative_label_edge_type,
) = select_label_edge_types(supervision_edge_type, dataset.graph.keys())
self._supervision_edge_type = supervision_edge_type
logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, supervision edge type: {supervision_edge_type}")

positive_labels, negative_labels = get_labels_for_anchor_nodes(
dataset=dataset,
node_ids=anchor_node_ids,
positive_label_edge_type=self._positive_label_edge_type,
negative_label_edge_type=self._negative_label_edge_type,
)
logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, Got labels for anchor nodes")

self.to_device = (
pin_memory_device
Expand All @@ -315,12 +319,14 @@ def __init__(
num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)
logger.info(f"Local rank: {local_rank}, node rank: {node_rank}, Number of neighbors: {num_neighbors}")

curr_process_nodes = shard_nodes_by_process(
input_nodes=anchor_node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
logger.info(f"local rank: {local_rank}, node rank: {node_rank}, current process nodes: {curr_process_nodes}`")

self._node_feature_info = dataset.node_feature_info
self._edge_feature_info = dataset.edge_feature_info
Expand Down
34 changes: 33 additions & 1 deletion python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@
message_passing_to_positive_label,
reverse_edge_type,
)
import psutil
import os

logger = Logger()

PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64)

def _debug_memory_usage(prefix: str):
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
logger.info(f"{prefix} Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB (out of {psutil.virtual_memory().total / 1024 / 1024:.2f} MB)")


class NodeAnchorLinkSplitter(Protocol):
"""Protocol that should be satisfied for anything that is used to split on edges.
Expand Down Expand Up @@ -648,6 +655,8 @@ def _get_padded_labels(
# and indices is the COL_INDEX of a CSR matrix.
# See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
# Note that GLT defaults to CSR under the hood, if this changes, we will need to update this.
_debug_memory_usage("Before indptr and indices")

indptr = topo.indptr # [N]
indices = topo.indices # [M]
extra_nodes_to_pad = 0
Expand All @@ -657,27 +666,50 @@ def _get_padded_labels(
anchor_node_ids = anchor_node_ids[valid_ids]
starts = indptr[anchor_node_ids] # [N]
ends = indptr[anchor_node_ids + 1] # [N]
_debug_memory_usage("After starts and ends")

max_range = int(torch.max(ends - starts).item())
logger.info(f"max range {max_range}")
mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1)
max_end_value = ends.max().item()
_debug_memory_usage("After max_end_value")
# del ends
gc.collect()
_debug_memory_usage("After ends gc")


logger.info(f"Local Rank {torch.distributed.get_rank() % torch.distributed.get_world_size()}, "
f"Node Rank {torch.distributed.get_rank() // torch.distributed.get_world_size()}: "
f"Get padded labels")
# Sample all labels based on the CSR start/stop indices.
# Creates "indices" for us to us, e.g [[0, 1], [2, 3]]
ranges = starts.unsqueeze(1) + torch.arange(max_range) # [N, max_range]
_debug_memory_usage("After ranges")
del starts
gc.collect()
_debug_memory_usage("After starts gc")

# Clamp the ranges to be valid indices into `indices`.
ranges.clamp_(min=0, max=ends.max().item() - 1)
_debug_memory_usage("After clamp")
# Mask out the parts of "ranges" that are not applicable to the current label
# filling out the rest with `PADDING_NODE`.
mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1)
# mask = torch.arange(max_range) >= (ends - starts).unsqueeze(1)
labels = torch.where(
mask, torch.full_like(ranges, PADDING_NODE.item()), indices[ranges]
)
_debug_memory_usage("After labels")
del ranges
gc.collect()
_debug_memory_usage("After ranges gc")
labels = torch.cat(
[
labels,
torch.ones(extra_nodes_to_pad, max_range, dtype=torch.int64) * PADDING_NODE,
],
dim=0,
)
_debug_memory_usage("After cat")
return labels


Expand Down