Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
789045c
Add VertexAiMultiPoolConfig to support multiple worker pools
kmonte Sep 24, 2025
13e30c4
Merge branch 'main' into kmonte/add-multipool-vai
kmontemayor2-sc Sep 25, 2025
76f94ce
Merge branch 'main' into kmonte/add-multipool-vai
kmonte Sep 26, 2025
d810e45
typo
kmonte Sep 26, 2025
5ff13b2
to more explicit configs
kmonte Sep 26, 2025
b4d35ef
wip
kmonte Sep 30, 2025
5a27057
wip
kmonte Sep 30, 2025
f8c4ab7
works
kmonte Oct 1, 2025
3a6768f
Merge branch 'main' into kmonte/launch-multipool-vai
kmonte Oct 1, 2025
6ea5b3f
tests
kmonte Oct 1, 2025
90d651c
remove
kmonte Oct 1, 2025
94037af
fix typecheck
kmonte Oct 1, 2025
298d19d
comments
kmontemayor2-sc Oct 3, 2025
7ba700b
Merge branch 'main' into kmonte/launch-multipool-vai
kmonte Oct 6, 2025
1752d91
Add get_graph_store_info to setup graph store clusters
kmonte Oct 6, 2025
6429617
add intergration tests for get_graph_store_info
kmonte Oct 6, 2025
603ca6a
[AUTOMATED] Update dep.vars, and other relevant files with new image …
github-actions[bot] Oct 6, 2025
d78e938
bleg
kmonte Oct 6, 2025
de00fd2
Merge branch 'kmonte/multipool-utils' of https://github.com/Snapchat/…
kmonte Oct 6, 2025
fc9d0d0
wip
kmonte Oct 6, 2025
704cbfd
Merge branch 'main' into kmonte/multipool-utils
kmonte Oct 7, 2025
fb91f1a
bleh
kmonte Oct 7, 2025
c2b5607
fix
kmonte Oct 7, 2025
694d72b
Nightly
kmonte Oct 7, 2025
74d8df1
Add utils to parse VAI CLUSTER_SPEC
kmonte Oct 7, 2025
de1de6a
comments
kmonte Oct 7, 2025
fc0dca4
rename
kmonte Oct 7, 2025
d3319d6
fixes
kmonte Oct 7, 2025
0905664
fixes
kmonte Oct 7, 2025
112d0ad
fix
kmonte Oct 7, 2025
f621bc7
address comments
kmonte Oct 8, 2025
fee17c1
Merge branch 'main' into kmonte/parse-cluster-spec
kmonte Oct 8, 2025
9b99706
reword
kmonte Oct 8, 2025
b9a766c
Merge branch 'main' into kmonte/multipool-utils
kmonte Oct 8, 2025
c4ec660
Merge branch 'kmonte/parse-cluster-spec' into kmonte/multipool-utils
kmonte Oct 8, 2025
86eb8eb
merges
kmonte Oct 8, 2025
026301e
merge
kmonte Oct 9, 2025
af12a00
fix
kmonte Oct 9, 2025
2c13526
test fix
kmonte Oct 9, 2025
a3dea31
fix test
kmonte Oct 9, 2025
aaceeee
fixes
kmonte Oct 10, 2025
1805a8d
[AUTOMATED] Bumped version to v0.0.10
github-actions[bot] Oct 10, 2025
11390ea
fix
kmonte Oct 10, 2025
a30e0b0
Merge branch 'release/v0.0.10' into kmonte/multipool-utils
kmonte Oct 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ substitutions:
options:
logging: CLOUD_LOGGING_ONLY
steps:
- name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1
- name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1
entrypoint: /bin/bash
args:
- -c
Expand Down
16 changes: 8 additions & 8 deletions dep_vars.env
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Note this file only supports static key value pairs so it can be loaded by make, bash, python, and sbt without any additional parsing.
DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1
DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1
DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1
DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1
DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1
DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1

DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9
DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.9
DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.9
DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.9.yaml
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.10
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.10
DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.10
DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.10
DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.10.yaml

SPARK_31_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark-custom-tfrecord_2.12-0.5.0.jar
SPARK_35_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark_3.5.0-custom-tfrecord_2.12-0.6.1.jar
Expand Down
2 changes: 1 addition & 1 deletion python/gigl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.9"
__version__ = "0.0.10"
21 changes: 13 additions & 8 deletions python/gigl/common/utils/vertex_ai_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Optional

import omegaconf
from google.cloud.aiplatform_v1.types import CustomJobSpec

from gigl.common import GcsUri
from gigl.common.logger import Logger
Expand Down Expand Up @@ -183,29 +182,35 @@ class ClusterSpec:
cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists
environment: str # The environment string (e.g., "cloud")
task: TaskInfo # Information about the current task
# The CustomJobSpec for the current job
# See the docs for more info:
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec
job: Optional[CustomJobSpec] = None

# We use a custom method for parsing, because CustomJobSpec is a protobuf message.
# DESPITE what the docs say, this is *not* a CustomJobSpec.
# It's *sort of* like a PythonPackageSpec, but it's not.
# It has `jobArgs` instead of `args`.
# See an example:
# {"python_module":"","package_uris":[],"job_args":[]}
job: Optional[dict] = None

# We use a custom method for parsing, the "job" is actually a serialized json string.
@classmethod
def from_json(cls, json_str: str) -> "ClusterSpec":
"""Instantiates ClusterSpec from a JSON string."""
cluster_spec_json = json.loads(json_str)
if "job" in cluster_spec_json and cluster_spec_json["job"] is not None:
job_spec = CustomJobSpec(**cluster_spec_json.pop("job"))
logger.info(f"Job spec: {cluster_spec_json['job']}")
job_spec = json.loads(cluster_spec_json.pop("job"))
else:
job_spec = None
conf = omegaconf.OmegaConf.create(cluster_spec_json)
if isinstance(conf, omegaconf.ListConfig):
raise ValueError("ListConfig is not supported")
return cls(
cluster_spec = cls(
cluster=conf.cluster,
environment=conf.environment,
task=conf.task,
job=job_spec,
)
logger.info(f"Cluster spec: {cluster_spec}")
return cluster_spec


def get_cluster_spec() -> ClusterSpec:
Expand Down
6 changes: 5 additions & 1 deletion python/gigl/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

__all__ = [
"GraphStoreInfo",
"get_available_device",
"get_free_port",
"get_free_ports_from_master_node",
"get_free_ports_from_node",
"get_free_port",
"get_graph_store_info",
"get_internal_ip_from_all_ranks",
"get_internal_ip_from_master_node",
"get_internal_ip_from_node",
Expand All @@ -20,9 +22,11 @@
init_neighbor_loader_worker,
)
from .networking import (
GraphStoreInfo,
get_free_port,
get_free_ports_from_master_node,
get_free_ports_from_node,
get_graph_store_info,
get_internal_ip_from_all_ranks,
get_internal_ip_from_master_node,
get_internal_ip_from_node,
Expand Down
57 changes: 57 additions & 0 deletions python/gigl/distributed/utils/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch

from gigl.common.logger import Logger
from gigl.common.utils.vertex_ai_context import (
get_cluster_spec,
is_currently_running_in_vertex_ai_job,
)
from gigl.env.distributed import GraphStoreInfo

logger = Logger()

Expand Down Expand Up @@ -179,3 +184,55 @@ def get_internal_ip_from_all_ranks() -> list[str]:
assert all(ip for ip in ip_list), "Could not retrieve all ranks' internal IPs"

return ip_list


def get_graph_store_info() -> GraphStoreInfo:
"""
Get the information about the graph store cluster.

Returns:
GraphStoreInfo: The information about the graph store cluster.

Raises:
ValueError: If a torch distributed environment is not initialized.
ValueError: If not running running in a supported environment.
"""
if not torch.distributed.is_initialized():
raise ValueError("Distributed environment must be initialized")
if is_currently_running_in_vertex_ai_job():
cluster_spec = get_cluster_spec()
# We setup the VAI cluster such that the compute nodes come first, followed by the storage nodes.
if "workerpool1" in cluster_spec.cluster:
num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len(
cluster_spec.cluster["workerpool1"]
)
else:
num_compute_nodes = len(cluster_spec.cluster["workerpool0"])
num_storage_nodes = len(cluster_spec.cluster["workerpool2"])
else:
raise ValueError(
"Must be running on a vertex AI job to get graph store cluster info!"
)

cluster_master_ip = get_internal_ip_from_master_node()
# We assume that the compute cluster nodes come first, followed by the storage nodes.
compute_cluster_master_ip = get_internal_ip_from_node(node_rank=0)
storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes)

cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0]
compute_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0]
storage_cluster_master_port = get_free_ports_from_node(
num_ports=1, node_rank=num_compute_nodes
)[0]

return GraphStoreInfo(
num_cluster_nodes=num_storage_nodes + num_compute_nodes,
num_storage_nodes=num_storage_nodes,
num_compute_nodes=num_compute_nodes,
cluster_master_ip=cluster_master_ip,
storage_cluster_master_ip=storage_cluster_master_ip,
compute_cluster_master_ip=compute_cluster_master_ip,
cluster_master_port=cluster_master_port,
storage_cluster_master_port=storage_cluster_master_port,
compute_cluster_master_port=compute_cluster_master_port,
)
26 changes: 26 additions & 0 deletions python/gigl/env/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,29 @@ class DistributedContext:

# Total number of machines
global_world_size: int


@dataclass(frozen=True)
class GraphStoreInfo:
"""Information about a graph store cluster."""

# Number of nodes in the whole cluster
num_cluster_nodes: int
# Number of nodes in the storage cluster
num_storage_nodes: int
# Number of nodes in the compute cluster
num_compute_nodes: int

# IP address of the master node for the whole cluster
cluster_master_ip: str
# IP address of the master node for the storage cluster
storage_cluster_master_ip: str
# IP address of the master node for the compute cluster
compute_cluster_master_ip: str

# Port of the master node for the whole cluster
cluster_master_port: int
# Port of the master node for the storage cluster
storage_cluster_master_port: int
# Port of the master node for the compute cluster
compute_cluster_master_port: int
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"
name = "gigl"
description = "GIgantic Graph Learning Library"
readme = "README.md"
version = "0.0.9"
version = "0.0.10"
requires-python = ">=3.9,<3.10" # Currently we only support python 3.9 as per deps setup below
classifiers = [
"Programming Language :: Python",
Expand Down
Empty file.
83 changes: 83 additions & 0 deletions python/tests/integration/distributed/utils/networking_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
import uuid
from textwrap import dedent

from parameterized import param, parameterized

from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.env.pipelines_config import get_resource_config


class NetworkingUtlsIntegrationTest(unittest.TestCase):
def setUp(self):
self._resource_config = get_resource_config()
self._project = self._resource_config.project
self._location = self._resource_config.region
self._service_account = self._resource_config.service_account_email
self._staging_bucket = (
self._resource_config.temp_assets_regional_bucket_path.uri
)
self._vertex_ai_service = VertexAIService(
project=self._project,
location=self._location,
service_account=self._service_account,
staging_bucket=self._staging_bucket,
)
super().setUp()

@parameterized.expand(
[
param(
"Test with 1 compute node and 1 storage node",
compute_nodes=1,
storage_nodes=1,
),
param(
"Test with 2 compute nodes and 2 storage nodes",
compute_nodes=2,
storage_nodes=2,
),
]
)
def test_get_graph_store_info(self, _, storage_nodes, compute_nodes):
job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}"
command = [
"python",
"-c",
dedent(
f"""
import torch
from gigl.distributed.utils import get_graph_store_info
torch.distributed.init_process_group(backend="gloo")
info = get_graph_store_info()
assert info.num_storage_nodes == {storage_nodes}, f"Expected {storage_nodes} storage nodes, but got {{ info.num_storage_nodes }}"
assert info.num_compute_nodes == {compute_nodes}, f"Expected {compute_nodes} compute nodes, but got {{ info.num_compute_nodes }}"
assert info.num_cluster_nodes == {storage_nodes + compute_nodes}, f"Expected {storage_nodes + compute_nodes} cluster nodes, but got {{ info.num_cluster_nodes }}"
assert info.cluster_master_ip is not None, f"Cluster master IP is None"
assert info.storage_cluster_master_ip is not None, f"Storage cluster master IP is None"
assert info.compute_cluster_master_ip is not None, f"Compute cluster master IP is None"
assert info.cluster_master_port is not None, f"Cluster master port is None"
assert info.storage_cluster_master_port is not None, f"Storage cluster master port is None"
assert info.compute_cluster_master_port is not None, f"Compute cluster master port is None"
"""
),
]
compute_cluster_config = VertexAiJobConfig(
job_name=job_name,
container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
replica_count=compute_nodes,
command=command,
machine_type="n2-standard-8",
)
storage_cluster_config = VertexAiJobConfig(
job_name=job_name,
container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
replica_count=storage_nodes,
machine_type="n1-standard-4",
command=command,
)

self._vertex_ai_service.launch_graph_store_job(
compute_cluster_config, storage_cluster_config
)
14 changes: 4 additions & 10 deletions python/tests/unit/common/utils/vertex_ai_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import unittest
from unittest.mock import call, patch

from google.cloud.aiplatform_v1.types import CustomJobSpec

from gigl.common import GcsUri
from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY
from gigl.common.utils.vertex_ai_context import (
Expand Down Expand Up @@ -129,11 +127,7 @@ def test_parse_cluster_spec_success(self):
},
"task": {"type": "workerpool0", "index": 1, "trial": "trial-123"},
"environment": "cloud",
"job": {
"worker_pool_specs": [
{"machine_spec": {"machine_type": "n1-standard-4"}}
]
},
"job": '{ "worker_pool_specs": [ {"machine_spec": {"machine_type": "n1-standard-4"}}]}',
}
)

Expand All @@ -150,11 +144,11 @@ def test_parse_cluster_spec_success(self):
},
environment="cloud",
task=TaskInfo(type="workerpool0", index=1, trial="trial-123"),
job=CustomJobSpec(
worker_pool_specs=[
job={
"worker_pool_specs": [
{"machine_spec": {"machine_type": "n1-standard-4"}}
]
),
},
)
self.assertEqual(cluster_spec, expected_cluster_spec)

Expand Down
Loading