Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/link_prediction/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def _training_process(
remove_accidental_hits=True,
)

training_start_time = time.time() # Initialize before conditional block
if not should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
dataset=dataset,
Expand Down Expand Up @@ -425,7 +426,6 @@ def _training_process(
torch.distributed.barrier()

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
Expand Down
2 changes: 1 addition & 1 deletion examples/link_prediction/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def _training_process(
remove_accidental_hits=True,
)

training_start_time = time.time() # Initialize before conditional block
if not should_skip_training:
train_main_loader, train_random_negative_loader = _setup_dataloaders(
dataset=dataset,
Expand Down Expand Up @@ -387,7 +388,6 @@ def _training_process(
torch.distributed.barrier()

# Entering the training loop
training_start_time = time.time()
batch_idx = 0
avg_train_loss = 0.0
last_n_batch_avg_loss: list[float] = []
Expand Down
2 changes: 2 additions & 0 deletions examples/tutorial/KDD_2025/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def train(
)
logger.info(f"Process {process_number} initialized model: {model}")
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
loss = torch.tensor(0.0) # Initialize loss in case loop doesn't execute
for batch_idx, main_data in enumerate(train_loader):
if batch_idx >= max_training_batches:
break
Expand Down Expand Up @@ -249,6 +250,7 @@ def train(
)
assert isinstance(dataset.train_node_ids, Mapping)
process_count = int(args.process_count)
max_training_batches = 0 # Initialize in case loop doesn't execute
for node_type, node_ids in dataset.train_node_ids.items():
logger.info(f"Training node type {node_type} has {node_ids.size(0)} nodes.")
max_training_batches = node_ids.size(0) // (
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Global options:
[mypy]
python_version = 3.9
enable_error_code = possibly-undefined

# Ignore modules that don't have any existing stubs

Expand Down
14 changes: 8 additions & 6 deletions python/gigl/common/data/load_torch_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def load_torch_tensors_from_tf_record(
},
)

positive_label_data_loading_process = None
if serialized_graph_metadata.positive_label_entity_info is not None:
positive_label_data_loading_process = ctx.Process(
target=_data_loading_process,
Expand All @@ -287,6 +288,7 @@ def load_torch_tensors_from_tf_record(
else:
logger.info(f"No positive labels detected from input data")

negative_label_data_loading_process = None
if serialized_graph_metadata.negative_label_entity_info is not None:
negative_label_data_loading_process = ctx.Process(
target=_data_loading_process,
Expand All @@ -307,16 +309,16 @@ def load_torch_tensors_from_tf_record(
logger.info("Loading Serialized TFRecord Data in Parallel ...")
node_data_loading_process.start()
edge_data_loading_process.start()
if serialized_graph_metadata.positive_label_entity_info is not None:
if positive_label_data_loading_process is not None:
positive_label_data_loading_process.start()
if serialized_graph_metadata.negative_label_entity_info is not None:
if negative_label_data_loading_process is not None:
negative_label_data_loading_process.start()

node_data_loading_process.join()
edge_data_loading_process.join()
if serialized_graph_metadata.positive_label_entity_info is not None:
if positive_label_data_loading_process is not None:
positive_label_data_loading_process.join()
if serialized_graph_metadata.negative_label_entity_info is not None:
if negative_label_data_loading_process is not None:
negative_label_data_loading_process.join()
else:
# In this setting, we start and join each process one-at-a-time in order to achieve sequential tensor loading
Expand All @@ -329,10 +331,10 @@ def load_torch_tensors_from_tf_record(
edge_data_loading_process.join()
node_data_loading_process.start()
node_data_loading_process.join()
if serialized_graph_metadata.positive_label_entity_info is not None:
if positive_label_data_loading_process is not None:
positive_label_data_loading_process.start()
positive_label_data_loading_process.join()
if serialized_graph_metadata.negative_label_entity_info is not None:
if negative_label_data_loading_process is not None:
negative_label_data_loading_process.start()
negative_label_data_loading_process.join()

Expand Down
2 changes: 2 additions & 0 deletions python/gigl/common/metrics/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __safely_flush_metrics(
Callable[[], Optional[OpsMetricPublisher]]
]
) -> None:
metrics_instance = None
if get_metrics_service_instance_fn is not None:
metrics_instance = get_metrics_service_instance_fn()
if metrics_instance is not None:
Expand Down Expand Up @@ -81,6 +82,7 @@ def profileit(
def inner(func: F) -> F:
def wrap(*args: Any, **kwargs: Any) -> Any:
raised_exception: Optional[Exception] = None
result = None
started_at = time.time()
try:
result = func(*args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions python/gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def __init__(
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
else:
raise ValueError(f"Unsupported input_nodes type: {type(input_nodes)}")
# These assignments are unreachable but help mypy understand the variables are defined
anchor_node_ids = torch.tensor([])
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE

missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys())
if missing_edge_types:
Expand Down
1 change: 1 addition & 0 deletions python/gigl/distributed/dist_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,7 @@ def partition(
)
else:
partitioned_node_features = None
partitioned_node_labels = None

if self._positive_label_edge_index is not None:
partitioned_positive_edge_index = self.partition_labels(
Expand Down
7 changes: 5 additions & 2 deletions python/gigl/distributed/dist_range_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def edge_partition_fn(rank_indices, _):
dim=0,
)

partitioned_edge_features = None
if edge_feat_dim is not None:
if len(res_list) == 0:
partitioned_edge_features = torch.empty(0, edge_feat_dim)
Expand Down Expand Up @@ -325,8 +326,10 @@ def edge_partition_fn(rank_indices, _):
edge_index=partitioned_edge_index,
edge_ids=partitioned_edge_ids,
)
current_feat_part = FeaturePartitionData(
feats=partitioned_edge_features, ids=None
current_feat_part = (
FeaturePartitionData(feats=partitioned_edge_features, ids=None)
if partitioned_edge_features is not None
else None
)
logger.info(
f"Got edge range-based partition book for edge type {edge_type} on rank {self._rank} with partition bounds: {edge_partition_book.partition_bounds}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import cast
from typing import List, Optional, cast

import torch
from google.protobuf.json_format import ParseDict
Expand All @@ -21,6 +21,10 @@
from gigl.experimental.knowledge_graph_embedding.lib.config.run import RunConfig
from gigl.experimental.knowledge_graph_embedding.lib.config.training import TrainConfig
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.data_preprocessor.lib.ingest.bigquery import (
BigqueryEdgeDataReference,
BigqueryNodeDataReference,
)
from snapchat.research.gbml import graph_schema_pb2

logger = Logger()
Expand Down Expand Up @@ -78,9 +82,11 @@ def from_omegaconf(config: DictConfig) -> HeterogeneousGraphSparseEmbeddingConfi
raw_graph_data = OmegaConf.select(
config, "dataset.raw_graph_data", default=None
)
graph_data: Optional[RawGraphData] = None
if raw_graph_data:
raw_node_data = [instantiate(entry) for entry in raw_graph_data.node_data]
raw_edge_data = [instantiate(entry) for entry in raw_graph_data.edge_data]
graph_data = RawGraphData(node_data=raw_node_data, edge_data=raw_edge_data)

enumerated_graph_data = OmegaConf.select(
config, "dataset.enumerated_graph_data", default=None
Expand All @@ -92,19 +98,14 @@ def from_omegaconf(config: DictConfig) -> HeterogeneousGraphSparseEmbeddingConfi
enumerated_edge_data = [
instantiate(entry) for entry in enumerated_graph_data.edge_data
]
enumerated_graph_data = EnumeratedGraphData(
node_data=enumerated_node_data, edge_data=enumerated_edge_data
)

graph_config = GraphConfig(
metadata=graph_metadata,
raw_graph_data=RawGraphData(
node_data=raw_node_data, edge_data=raw_edge_data
)
if raw_graph_data
else None,
enumerated_graph_data=EnumeratedGraphData(
node_data=enumerated_node_data, edge_data=enumerated_edge_data
)
if enumerated_graph_data
else None,
raw_graph_data=graph_data,
enumerated_graph_data=enumerated_graph_data,
)

# Build the RunConfig
Expand Down
2 changes: 2 additions & 0 deletions python/gigl/src/common/models/layers/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,6 @@ def forward(self, query_embeddings, candidate_embeddings) -> torch.Tensor:
elif self.decoder_type.value == "hadamard_MLP":
hadamard_scores = query_embeddings.unsqueeze(dim=1) * candidate_embeddings
scores = self.mlp_decoder(hadamard_scores).sum(dim=-1)
else:
raise ValueError(f"Unsupported decoder type: {self.decoder_type.value}")
return scores
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def forward(

# For each head, project edge features to out_dim and correct NaNs.
# Output shape: [num_edges, num_heads, edge_in_dim]
edge_emb = None
if edge_feat is not None and self.edge_in_dim is not None:
edge_emb = self.efeat_drop(edge_feat)
edge_emb = torch.matmul(edge_emb, self.W_efeat).view(
Expand All @@ -150,7 +151,7 @@ def forward(

h_efeat_term = (
0
if edge_feat is None or self.edge_in_dim is None
if edge_feat is None or self.edge_in_dim is None or edge_emb is None
else (self.a_efeat * edge_emb).sum(dim=-1)
)
alpha = self.leakyrelu(h_l_term + h_r_term + h_etype_term + h_efeat_term)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def trainer_config(
"""

if not self._trainer_config:
_trainer_config: Union[
VertexAiResourceConfig, KFPResourceConfig, LocalResourceConfig
]
# TODO: (svij) Marked for deprecation
if self.resource_config.HasField("trainer_config"):
logger.warning(
Expand All @@ -286,9 +289,6 @@ def trainer_config(
deprecated_config: DistributedTrainerConfig = (
self.resource_config.trainer_config
)
_trainer_config: Union[
VertexAiResourceConfig, KFPResourceConfig, LocalResourceConfig
]
if deprecated_config.WhichOneof(_TRAINER_CONFIG_FIELD) == _VERTEX_AI_TRAINER_CONFIG: # type: ignore[arg-type]
logger.info(
f"Casting VertexAiTrainerConfig: ({deprecated_config.vertex_ai_trainer_config}) to VertexAiResourceConfig"
Expand Down Expand Up @@ -337,7 +337,8 @@ def trainer_config(
raise ValueError(
f"Trainer config not found in resource config; neither trainer_config nor trainer_resource_config is set: {self.resource_config}"
)
return _trainer_config
self._trainer_config = _trainer_config
return self._trainer_config

@property
def inferencer_config(
Expand Down
1 change: 1 addition & 0 deletions python/gigl/src/common/utils/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def count_number_of_rows_in_bq_table(
SELECT count(1) AS ct FROM `{bq_table}`
"""
result = self.run_query(query=ROW_COUNTING_QUERY, labels=labels)
n_rows = 0
for row in result:
n_rows = row["ct"]
return n_rows
Expand Down
6 changes: 4 additions & 2 deletions python/gigl/src/inference/v1/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def get_inferencer_pipeline_component_for_single_node_type(
)
)
)
predictions = None
embeddings = None
if temp_predictions_gcs_path is not None:
predictions = (
outputs[PREDICTION_TAGGED_OUTPUT_KEY],
Expand All @@ -201,7 +203,7 @@ def get_inferencer_pipeline_component_for_single_node_type(
)
predictions = outputs[PREDICTION_TAGGED_OUTPUT_KEY]
embeddings = outputs[EMBEDDING_TAGGED_OUTPUT_KEY]
if temp_predictions_gcs_path is not None:
if temp_predictions_gcs_path is not None and predictions is not None:
logger.info(
f"Writing node type {node_type} temp predictions to gcs path {temp_predictions_gcs_path.uri}"
)
Expand All @@ -213,7 +215,7 @@ def get_inferencer_pipeline_component_for_single_node_type(
file_name_suffix=".json",
)
)
if temp_embeddings_gcs_path is not None:
if temp_embeddings_gcs_path is not None and embeddings is not None:
logger.info(
f"Writing node type {node_type} temp embeddings to gcs path {temp_embeddings_gcs_path.uri}"
)
Expand Down
15 changes: 8 additions & 7 deletions python/gigl/src/mocking/lib/pyg_to_training_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,14 @@ def build_node_anchor_link_prediction_samples_from_pyg_heterodata(
else None
)

edge_label_index = hetero_data[
(
str(sample_edge_type.src_node_type),
str(sample_edge_type.relation),
str(sample_edge_type.dst_node_type),
)
].edge_label_index

if user_defined_pos_edges is not None:
pos_node_map = sample_hydrate_user_def_edge(
mocked_dataset_info=mocked_dataset_info,
Expand All @@ -350,13 +358,6 @@ def build_node_anchor_link_prediction_samples_from_pyg_heterodata(
else:
pos_node_map = defaultdict(list)
# Create map to track each node's candidate neighbors.
edge_label_index = hetero_data[
(
str(sample_edge_type.src_node_type),
str(sample_edge_type.relation),
str(sample_edge_type.dst_node_type),
)
].edge_label_index
for src, dst in zip(edge_label_index[0].tolist(), edge_label_index[1].tolist()):
pos_node_map[src].append(dst)

Expand Down
7 changes: 5 additions & 2 deletions python/gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,12 @@ def __call__(
# Set device explicitly here so we don't default to CPU.
# TODO(kmonte): We should add tests for this - but we need to enable accelerators on our CI/CD first.
# Also, maybe swap setting device until later?
node_id_count = torch.zeros(
max_node_id, dtype=torch.uint8, device=anchor_nodes.device
device = (
collected_anchor_nodes[0].device
if collected_anchor_nodes
else torch.device("cpu")
)
node_id_count = torch.zeros(max_node_id, dtype=torch.uint8, device=device)
for anchor_nodes in collected_anchor_nodes:
# Clamp here to avoid overflow to 0
# If we don't clamp, and the counts add up to X mod 255,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,12 @@ def __run_and_check_nablp_sgs_on_homogeneous_toy_graph(
supervision_edge_types = (
gbml_config_pb_wrapper.task_metadata_pb_wrapper.task_metadata_pb.node_anchor_based_link_prediction_task_metadata.supervision_edge_types
)
# Initialize src_node_type from the first supervision edge type
src_node_type = (
NodeType(supervision_edge_types[0].src_node_type)
if supervision_edge_types
else NodeType("default")
)
if (
not gbml_config_pb_wrapper.shared_config.should_include_isolated_nodes_in_training
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo)
graph_metadata_pb_wrapper.edge_types,
serialized_positive_label_info_iterable,
):
condensed_edge_type = (
graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[
edge_type
]
)
if preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[
condensed_edge_type
].HasField(
Expand Down Expand Up @@ -384,6 +389,11 @@ def test_translator_correctness(self, _, mocked_dataset_info: MockedDatasetInfo)
graph_metadata_pb_wrapper.edge_types,
serialized_negative_label_info_iterable,
):
condensed_edge_type = (
graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[
edge_type
]
)
if preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb.condensed_edge_type_to_preprocessed_metadata[
condensed_edge_type
].HasField(
Expand Down
Loading