From 789045c1df85325cc8e40a7a5ff3055e4d87993b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 24 Sep 2025 20:40:56 +0000 Subject: [PATCH 01/33] Add VertexAiMultiPoolConfig to support multiple worker pools --- .../research/gbml/gigl_resource_config.proto | 39 +++++ .../research/gbml/gigl_resource_config_pb2.py | 44 +++-- .../gbml/gigl_resource_config_pb2.pyi | 73 ++++++++- .../DistributedTrainerConfig.scala | 8 +- .../GiglResourceConfig.scala | 8 +- .../GiglResourceConfigProto.scala | 41 +++-- .../InferencerResourceConfig.scala | 37 ++++- .../SharedResourceConfig.scala | 40 ++--- .../TrainerResourceConfig.scala | 37 ++++- .../VertexAiMultiPoolConfig.scala | 154 ++++++++++++++++++ .../VertexAiMultiPoolTrainerConfig.scala | 154 ++++++++++++++++++ .../VertexAiResourceConfig.scala | 3 +- .../DistributedTrainerConfig.scala | 8 +- .../GiglResourceConfig.scala | 8 +- .../GiglResourceConfigProto.scala | 41 +++-- .../InferencerResourceConfig.scala | 37 ++++- .../SharedResourceConfig.scala | 40 ++--- .../TrainerResourceConfig.scala | 37 ++++- .../VertexAiMultiPoolConfig.scala | 154 ++++++++++++++++++ .../VertexAiMultiPoolTrainerConfig.scala | 154 ++++++++++++++++++ .../VertexAiResourceConfig.scala | 3 +- 21 files changed, 988 insertions(+), 132 deletions(-) create mode 100644 scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala create mode 100644 scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala create mode 100644 scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala create mode 100644 scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala diff --git a/proto/snapchat/research/gbml/gigl_resource_config.proto b/proto/snapchat/research/gbml/gigl_resource_config.proto index bab99bd27..9aecb4d78 100644 --- a/proto/snapchat/research/gbml/gigl_resource_config.proto +++ b/proto/snapchat/research/gbml/gigl_resource_config.proto @@ -116,6 +116,43 @@ message VertexAiResourceConfig { uint32 num_workers = 1; } + // Configuration for Mutlipool Vertex AI jobs. + // See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + // NOTE: The first worker pool will be split into the primary replica and "Workers". + // For example: + // pools = [ + // { + // "machine_type": "n1-standard-8", + // "num_replicas": 16 + // }, + // { + // "machine_type": "n1-standard-8", + // "gpu_type": "nvidia-tesla-v100", + // "gpu_limit": 1, + // "num_replicas": 16 + // } + // ] + // Will have the Primary be: {} + // { + // "machine_type": "n1-standard-8", + // "num_replicas": 1 + // } + // And the Workers be: + // { + // "machine_type": "n1-standard-8", + // "num_replicas": 15 + // } + // And the parameter servers be: + // { + // "machine_type": "n1-standard-8", + // "gpu_type": "nvidia-tesla-v100", + // "gpu_limit": 1, + // "num_replicas": 16 + // } + message VertexAiMultiPoolConfig { + repeated VertexAiResourceConfig pools = 1; + } + // (deprecated) // Configuration for distributed training resources message DistributedTrainerConfig { @@ -132,6 +169,7 @@ message TrainerResourceConfig { VertexAiResourceConfig vertex_ai_trainer_config = 1; KFPResourceConfig kfp_trainer_config = 2; LocalResourceConfig local_trainer_config = 3; + VertexAiMultiPoolConfig vertex_ai_multi_pool_trainer_config = 4; } } @@ -141,6 +179,7 @@ message InferencerResourceConfig { VertexAiResourceConfig vertex_ai_inferencer_config = 1; DataflowResourceConfig dataflow_inferencer_config = 2; LocalResourceConfig local_inferencer_config = 3; + VertexAiMultiPoolConfig vertex_ai_multi_pool_inferencer_config = 4; } } diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.py b/python/snapchat/research/gbml/gigl_resource_config_pb2.py index bef644bfb..0ce7bd9e0 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.py +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.py @@ -15,7 +15,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x97\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\x93\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x42\x10\n\x0etrainer_config\"\xac\x02\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x97\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"X\n\x17VertexAiMultiPoolConfig\x12=\n\x05pools\x18\x01 \x03(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf3\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12^\n#vertex_ai_multi_pool_trainer_config\x18\x04 \x01(\x0b\x32/.snapchat.research.gbml.VertexAiMultiPoolConfigH\x00\x42\x10\n\x0etrainer_config\"\x8f\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x61\n&vertex_ai_multi_pool_inferencer_config\x18\x04 \x01(\x0b\x32/.snapchat.research.gbml.VertexAiMultiPoolConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') _COMPONENT = DESCRIPTOR.enum_types_by_name['Component'] Component = enum_type_wrapper.EnumTypeWrapper(_COMPONENT) @@ -38,6 +38,7 @@ _VERTEXAIRESOURCECONFIG = DESCRIPTOR.message_types_by_name['VertexAiResourceConfig'] _KFPRESOURCECONFIG = DESCRIPTOR.message_types_by_name['KFPResourceConfig'] _LOCALRESOURCECONFIG = DESCRIPTOR.message_types_by_name['LocalResourceConfig'] +_VERTEXAIMULTIPOOLCONFIG = DESCRIPTOR.message_types_by_name['VertexAiMultiPoolConfig'] _DISTRIBUTEDTRAINERCONFIG = DESCRIPTOR.message_types_by_name['DistributedTrainerConfig'] _TRAINERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['TrainerResourceConfig'] _INFERENCERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['InferencerResourceConfig'] @@ -108,6 +109,13 @@ }) _sym_db.RegisterMessage(LocalResourceConfig) +VertexAiMultiPoolConfig = _reflection.GeneratedProtocolMessageType('VertexAiMultiPoolConfig', (_message.Message,), { + 'DESCRIPTOR' : _VERTEXAIMULTIPOOLCONFIG, + '__module__' : 'snapchat.research.gbml.gigl_resource_config_pb2' + # @@protoc_insertion_point(class_scope:snapchat.research.gbml.VertexAiMultiPoolConfig) + }) +_sym_db.RegisterMessage(VertexAiMultiPoolConfig) + DistributedTrainerConfig = _reflection.GeneratedProtocolMessageType('DistributedTrainerConfig', (_message.Message,), { 'DESCRIPTOR' : _DISTRIBUTEDTRAINERCONFIG, '__module__' : 'snapchat.research.gbml.gigl_resource_config_pb2' @@ -168,8 +176,8 @@ _GIGLRESOURCECONFIG.fields_by_name['trainer_config']._serialized_options = b'\030\001' _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._options = None _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._serialized_options = b'\030\001' - _COMPONENT._serialized_start=3196 - _COMPONENT._serialized_end=3439 + _COMPONENT._serialized_start=3481 + _COMPONENT._serialized_end=3724 _SPARKRESOURCECONFIG._serialized_start=77 _SPARKRESOURCECONFIG._serialized_end=166 _DATAFLOWRESOURCECONFIG._serialized_start=168 @@ -188,18 +196,20 @@ _KFPRESOURCECONFIG._serialized_end=1025 _LOCALRESOURCECONFIG._serialized_start=1027 _LOCALRESOURCECONFIG._serialized_end=1069 - _DISTRIBUTEDTRAINERCONFIG._serialized_start=1072 - _DISTRIBUTEDTRAINERCONFIG._serialized_end=1347 - _TRAINERRESOURCECONFIG._serialized_start=1350 - _TRAINERRESOURCECONFIG._serialized_end=1625 - _INFERENCERRESOURCECONFIG._serialized_start=1628 - _INFERENCERRESOURCECONFIG._serialized_end=1928 - _SHAREDRESOURCECONFIG._serialized_start=1931 - _SHAREDRESOURCECONFIG._serialized_end=2478 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2144 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2423 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2425 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2478 - _GIGLRESOURCECONFIG._serialized_start=2481 - _GIGLRESOURCECONFIG._serialized_end=3193 + _VERTEXAIMULTIPOOLCONFIG._serialized_start=1071 + _VERTEXAIMULTIPOOLCONFIG._serialized_end=1159 + _DISTRIBUTEDTRAINERCONFIG._serialized_start=1162 + _DISTRIBUTEDTRAINERCONFIG._serialized_end=1437 + _TRAINERRESOURCECONFIG._serialized_start=1440 + _TRAINERRESOURCECONFIG._serialized_end=1811 + _INFERENCERRESOURCECONFIG._serialized_start=1814 + _INFERENCERRESOURCECONFIG._serialized_end=2213 + _SHAREDRESOURCECONFIG._serialized_start=2216 + _SHAREDRESOURCECONFIG._serialized_end=2763 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2429 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2708 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2710 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2763 + _GIGLRESOURCECONFIG._serialized_start=2766 + _GIGLRESOURCECONFIG._serialized_end=3478 # @@protoc_insertion_point(module_scope) diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi index 00ae6817f..1f540b611 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi @@ -230,8 +230,9 @@ class VertexAiResourceConfig(google.protobuf.message.Message): https://github.com/googleapis/python-aiplatform/blob/58fbabdeeefd1ccf1a9d0c22eeb5606aeb9c2266/google/cloud/aiplatform/jobs.py#L2252-L2253 """ gcp_region_override: builtins.str - """Region override. + """Region override If provided, then the Vertex AI Job will be launched in the provided region. + Otherwise, will launch jobs in the region specified at CommonComputeConfig.region ex: "us-west1" NOTE: If set, then there may be data egress costs from CommonComputeConfig.region -> gcp_region_override """ @@ -298,6 +299,56 @@ class LocalResourceConfig(google.protobuf.message.Message): global___LocalResourceConfig = LocalResourceConfig +class VertexAiMultiPoolConfig(google.protobuf.message.Message): + """Configuration for Mutlipool Vertex AI jobs. + See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + NOTE: The first worker pool will be split into the primary replica and "Workers". + For example: + pools = [ + { + "machine_type": "n1-standard-8", + "num_replicas": 16 + }, + { + "machine_type": "n1-standard-8", + "gpu_type": "nvidia-tesla-v100", + "gpu_limit": 1, + "num_replicas": 16 + } + ] + Will have the Primary be: {} + { + "machine_type": "n1-standard-8", + "num_replicas": 1 + } + And the Workers be: + { + "machine_type": "n1-standard-8", + "num_replicas": 15 + } + And the parameter servers be: + { + "machine_type": "n1-standard-8", + "gpu_type": "nvidia-tesla-v100", + "gpu_limit": 1, + "num_replicas": 16 + } + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + POOLS_FIELD_NUMBER: builtins.int + @property + def pools(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___VertexAiResourceConfig]: ... + def __init__( + self, + *, + pools: collections.abc.Iterable[global___VertexAiResourceConfig] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["pools", b"pools"]) -> None: ... + +global___VertexAiMultiPoolConfig = VertexAiMultiPoolConfig + class DistributedTrainerConfig(google.protobuf.message.Message): """(deprecated) Configuration for distributed training resources @@ -335,22 +386,26 @@ class TrainerResourceConfig(google.protobuf.message.Message): VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER: builtins.int KFP_TRAINER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_TRAINER_CONFIG_FIELD_NUMBER: builtins.int + VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_trainer_config(self) -> global___VertexAiResourceConfig: ... @property def kfp_trainer_config(self) -> global___KFPResourceConfig: ... @property def local_trainer_config(self) -> global___LocalResourceConfig: ... + @property + def vertex_ai_multi_pool_trainer_config(self) -> global___VertexAiMultiPoolConfig: ... def __init__( self, *, vertex_ai_trainer_config: global___VertexAiResourceConfig | None = ..., kfp_trainer_config: global___KFPResourceConfig | None = ..., local_trainer_config: global___LocalResourceConfig | None = ..., + vertex_ai_multi_pool_trainer_config: global___VertexAiMultiPoolConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_multi_pool_trainer_config", b"vertex_ai_multi_pool_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_multi_pool_trainer_config", b"vertex_ai_multi_pool_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config", "vertex_ai_multi_pool_trainer_config"] | None: ... global___TrainerResourceConfig = TrainerResourceConfig @@ -362,22 +417,26 @@ class InferencerResourceConfig(google.protobuf.message.Message): VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int + VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_inferencer_config(self) -> global___VertexAiResourceConfig: ... @property def dataflow_inferencer_config(self) -> global___DataflowResourceConfig: ... @property def local_inferencer_config(self) -> global___LocalResourceConfig: ... + @property + def vertex_ai_multi_pool_inferencer_config(self) -> global___VertexAiMultiPoolConfig: ... def __init__( self, *, vertex_ai_inferencer_config: global___VertexAiResourceConfig | None = ..., dataflow_inferencer_config: global___DataflowResourceConfig | None = ..., local_inferencer_config: global___LocalResourceConfig | None = ..., + vertex_ai_multi_pool_inferencer_config: global___VertexAiMultiPoolConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config", "vertex_ai_multi_pool_inferencer_config", b"vertex_ai_multi_pool_inferencer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config", "vertex_ai_multi_pool_inferencer_config", b"vertex_ai_multi_pool_inferencer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config", "vertex_ai_multi_pool_inferencer_config"] | None: ... global___InferencerResourceConfig = InferencerResourceConfig diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala index 88a76bd40..403157363 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala @@ -39,7 +39,7 @@ final case class DistributedTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiTrainerConfig.foreach { __v => @@ -131,8 +131,8 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(10) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(10) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -166,7 +166,7 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiTrainerConfig) extends snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiTrainerConfig diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala index de92ed17e..4d672733a 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala @@ -86,7 +86,7 @@ final case class GiglResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { sharedResource.sharedResourceConfigUri.foreach { __v => @@ -275,8 +275,8 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -320,7 +320,7 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SharedResourceConfigUri(value: _root_.scala.Predef.String) extends snapchat.research.gbml.gigl_resource_config.GiglResourceConfig.SharedResource { type ValueType = _root_.scala.Predef.String diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index 55b0dfcf7..c5e49c887 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -18,6 +18,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -53,23 +54,29 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { VJlcXVlc3RSDW1lbW9yeVJlcXVlc3QSJwoIZ3B1X3R5cGUYAyABKAlCDOI/CRIHZ3B1VHlwZVIHZ3B1VHlwZRIqCglncHVfbGlta XQYBCABKA1CDeI/ChIIZ3B1TGltaXRSCGdwdUxpbWl0EjMKDG51bV9yZXBsaWNhcxgFIAEoDUIQ4j8NEgtudW1SZXBsaWNhc1ILb nVtUmVwbGljYXMiRwoTTG9jYWxSZXNvdXJjZUNvbmZpZxIwCgtudW1fd29ya2VycxgBIAEoDUIP4j8MEgpudW1Xb3JrZXJzUgpud - W1Xb3JrZXJzIp0DChhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzItL - nNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAU - hV2ZXJ0ZXhBaVRyYWluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5LRlBUcmFpbmVyQ29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ3ChRsb2NhbF90cmFpb - mVyX2NvbmZpZxgDIAEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQhfiPxQSEmxvY2FsVHJha - W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcinQMKFVRyYWluZXJSZXNvdXJjZUNvbmZpZ - xKFAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291c - mNlQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcScAoSa2ZwX3RyYWluZ - XJfY29uZmlnGAIgASgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ - 29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJja - C5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZ0IQC - g50cmFpbmVyX2NvbmZpZyLUAwoYSW5mZXJlbmNlclJlc291cmNlQ29uZmlnEo4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25ma - WcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlc - mVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZxKNAQoaZGF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABK - AsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb3VyY2VDb25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ - 29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXbG9jYWxfaW5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY - 2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCGuI/FxIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsS + W1Xb3JrZXJzImsKF1ZlcnRleEFpTXVsdGlQb29sQ29uZmlnElAKBXBvb2xzGAEgAygLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb + C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQgriPwcSBXBvb2xzUgVwb29scyKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBC + hh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvb + mZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2Nvb + mZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnS + ABSEGtmcFRyYWluZXJDb25maWcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sL + kxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZ + XJfY29uZmlnIsMEChVUcmFpbmVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVd + mVydGV4QWlUcmFpbmVyQ29uZmlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS + 0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZ + XJfY29uZmlnGAMgASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJha + W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWcSowEKI3ZlcnRleF9haV9tdWx0aV9wb29sX3RyYWluZXJfY29uZmlnGAQgA + SgLMi8uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaU11bHRpUG9vbENvbmZpZ0Ij4j8gEh52ZXJ0ZXhBaU11bHRpUG9vb + FRyYWluZXJDb25maWdIAFIedmVydGV4QWlNdWx0aVBvb2xUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIoMFChhJbmZlc + mVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2Vhc + mNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJb + mZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY + 2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZ + XNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnEqwBCiZ2ZXJ0Z + XhfYWlfbXVsdGlfcG9vbF9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIvLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlNd + Wx0aVBvb2xDb25maWdCJuI/IxIhdmVydGV4QWlNdWx0aVBvb2xJbmZlcmVuY2VyQ29uZmlnSABSIXZlcnRleEFpTXVsdGlQb29sS W5mZXJlbmNlckNvbmZpZ0ITChFpbmZlcmVuY2VyX2NvbmZpZyKXCAoUU2hhcmVkUmVzb3VyY2VDb25maWcSfgoPcmVzb3VyY2Vfb GFiZWxzGAEgAygLMkAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5SZXNvdXJjZUxhYmVsc0Vud HJ5QhPiPxASDnJlc291cmNlTGFiZWxzUg5yZXNvdXJjZUxhYmVscxKOAQoVY29tbW9uX2NvbXB1dGVfY29uZmlnGAIgASgLMkAuc diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index 33df88e8d..e222970eb 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -28,6 +28,10 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.localInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (inferencerConfig.vertexAiMultiPoolInferencerConfig.isDefined) { + val __value = inferencerConfig.vertexAiMultiPoolInferencerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -38,7 +42,7 @@ final case class InferencerResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { inferencerConfig.vertexAiInferencerConfig.foreach { __v => @@ -59,6 +63,12 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + inferencerConfig.vertexAiMultiPoolInferencerConfig.foreach { __v => + val __m = __v + _output__.writeTag(4, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = inferencerConfig.vertexAiInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -67,6 +77,8 @@ final case class InferencerResourceConfig( def withDataflowInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__v)) def getLocalInferencerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = inferencerConfig.localInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) + def getVertexAiMultiPoolInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = inferencerConfig.vertexAiMultiPoolInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) + def withVertexAiMultiPoolInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -76,6 +88,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.orNull case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull + case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -84,6 +97,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -107,6 +121,8 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__inferencerConfig.dataflowInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 26 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 34 => + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__inferencerConfig.vertexAiMultiPoolInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -126,18 +142,20 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch inferencerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig } __out } @@ -152,9 +170,11 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isVertexAiInferencerConfig: _root_.scala.Boolean = false def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false + def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None + def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -165,7 +185,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig @@ -187,16 +207,25 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = Some(value) override def number: _root_.scala.Int = 3 } + @SerialVersionUID(0L) + final case class VertexAiMultiPoolInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + override def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = true + override def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + override def number: _root_.scala.Int = 4 + } } implicit class InferencerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig](_l) { def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) + def vertexAiMultiPoolInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 + final val VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER = 4 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala index ba4299fbf..48bf65537 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala @@ -35,7 +35,7 @@ final case class SharedResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { resourceLabels.foreach { __v => @@ -116,8 +116,8 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -172,63 +172,63 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = project if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = region if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = tempAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = tempRegionalAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = permAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(5, __value) } }; - + { val __value = tempAssetsBqDatasetName if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(6, __value) } }; - + { val __value = embeddingBqDatasetName if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(7, __value) } }; - + { val __value = gcpServiceAccountEmail if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(8, __value) } }; - + { val __value = dataflowRunner if (!__value.isEmpty) { @@ -245,7 +245,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -373,7 +373,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig.type = snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig]) } - + object CommonComputeConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig = { @@ -505,7 +505,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig]) } - + @SerialVersionUID(0L) final case class ResourceLabelsEntry( key: _root_.scala.Predef.String = "", @@ -516,14 +516,14 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -540,7 +540,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -584,7 +584,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry.type = snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry]) } - + object ResourceLabelsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry = { @@ -649,7 +649,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry]) } - + implicit class SharedResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig](_l) { def resourceLabels: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.resourceLabels)((c_, f_) => c_.copy(resourceLabels = f_)) def commonComputeConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] = field(_.getCommonComputeConfig)((c_, f_) => c_.copy(commonComputeConfig = Option(f_))) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 9c84a78b6..8a72b5e87 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -28,6 +28,10 @@ final case class TrainerResourceConfig( val __value = trainerConfig.localTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (trainerConfig.vertexAiMultiPoolTrainerConfig.isDefined) { + val __value = trainerConfig.vertexAiMultiPoolTrainerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -38,7 +42,7 @@ final case class TrainerResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiTrainerConfig.foreach { __v => @@ -59,6 +63,12 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + trainerConfig.vertexAiMultiPoolTrainerConfig.foreach { __v => + val __m = __v + _output__.writeTag(4, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = trainerConfig.vertexAiTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -67,6 +77,8 @@ final case class TrainerResourceConfig( def withKfpTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.KFPResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__v)) def getLocalTrainerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = trainerConfig.localTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) + def getVertexAiMultiPoolTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = trainerConfig.vertexAiMultiPoolTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) + def withVertexAiMultiPoolTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -76,6 +88,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.orNull case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull + case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -84,6 +97,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -107,6 +121,8 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__trainerConfig.kfpTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 26 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 34 => + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__trainerConfig.vertexAiMultiPoolTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -126,18 +142,20 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. trainerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(10) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(10) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig } __out } @@ -152,9 +170,11 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isVertexAiTrainerConfig: _root_.scala.Boolean = false def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false + def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None + def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -165,7 +185,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig @@ -187,16 +207,25 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = Some(value) override def number: _root_.scala.Int = 3 } + @SerialVersionUID(0L) + final case class VertexAiMultiPoolTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + override def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = true + override def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + override def number: _root_.scala.Int = 4 + } } implicit class TrainerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig](_l) { def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) + def vertexAiMultiPoolTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 + final val VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER = 4 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala new file mode 100644 index 000000000..dac501e3e --- /dev/null +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala @@ -0,0 +1,154 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for Mutlipool Vertex AI jobs. + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * NOTE: The first worker pool will be split into the primary replica and "Workers". + * For example: + * pools = [ + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 16 + * }, + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + * ] + * Will have the Primary be: {} + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 1 + * } + * And the Workers be: + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 15 + * } + * And the parameter servers be: + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + */ +@SerialVersionUID(0L) +final case class VertexAiMultiPoolConfig( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiMultiPoolConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + pools.foreach { __item => + val __value = __item + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + pools.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def clearPools = copy(pools = _root_.scala.Seq.empty) + def addPools(__vs: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig *): VertexAiMultiPoolConfig = addAllPools(__vs) + def addAllPools(__vs: Iterable[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolConfig = copy(pools = pools ++ __vs) + def withPools(__v: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolConfig = copy(pools = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => pools + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PRepeated(pools.iterator.map(_.toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiMultiPoolConfig]) +} + +object VertexAiMultiPoolConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = { + val __pools: _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = new _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __pools += _root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = __pools.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = _root_.scala.Seq.empty + ) + implicit class VertexAiMultiPoolConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_l) { + def pools: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.pools)((c_, f_) => c_.copy(pools = f_)) + } + final val POOLS_FIELD_NUMBER = 1 + def of( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiMultiPoolConfig]) +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala new file mode 100644 index 000000000..fb43c370b --- /dev/null +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala @@ -0,0 +1,154 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for Mutlipool Vertex AI jobs. + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * NOTE: The first worker pool will be split into the primary replica and "Workers". + * For example: + * pools = [ + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 16 + * }, + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + * ] + * Will have the Primary be: {} + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 1 + * } + * And the Workers be: + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 15 + * } + * And the parameter servers be: + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + */ +@SerialVersionUID(0L) +final case class VertexAiMultiPoolTrainerConfig( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiMultiPoolTrainerConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + pools.foreach { __item => + val __value = __item + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + pools.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def clearPools = copy(pools = _root_.scala.Seq.empty) + def addPools(__vs: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig *): VertexAiMultiPoolTrainerConfig = addAllPools(__vs) + def addAllPools(__vs: Iterable[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolTrainerConfig = copy(pools = pools ++ __vs) + def withPools(__v: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolTrainerConfig = copy(pools = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => pools + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PRepeated(pools.iterator.map(_.toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiMultiPoolTrainerConfig]) +} + +object VertexAiMultiPoolTrainerConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig = { + val __pools: _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = new _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __pools += _root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = __pools.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = _root_.scala.Seq.empty + ) + implicit class VertexAiMultiPoolTrainerConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig](_l) { + def pools: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.pools)((c_, f_) => c_.copy(pools = f_)) + } + final val POOLS_FIELD_NUMBER = 1 + def of( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiMultiPoolTrainerConfig]) +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala index 152f96abb..0cac6898d 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala @@ -19,8 +19,9 @@ package snapchat.research.gbml.gigl_resource_config * Timeout in seconds for the job. If unset or zero, will use the default @ google.cloud.aiplatform.CustomJob, which is 7 days: * https://github.com/googleapis/python-aiplatform/blob/58fbabdeeefd1ccf1a9d0c22eeb5606aeb9c2266/google/cloud/aiplatform/jobs.py#L2252-L2253 * @param gcpRegionOverride - * Region override. + * Region override * If provided, then the Vertex AI Job will be launched in the provided region. + * Otherwise, will launch jobs in the region specified at CommonComputeConfig.region * ex: "us-west1" * NOTE: If set, then there may be data egress costs from CommonComputeConfig.region -> gcp_region_override */ diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala index 88a76bd40..403157363 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala @@ -39,7 +39,7 @@ final case class DistributedTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiTrainerConfig.foreach { __v => @@ -131,8 +131,8 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(10) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(10) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -166,7 +166,7 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiTrainerConfig) extends snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiTrainerConfig diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala index de92ed17e..4d672733a 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala @@ -86,7 +86,7 @@ final case class GiglResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { sharedResource.sharedResourceConfigUri.foreach { __v => @@ -275,8 +275,8 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -320,7 +320,7 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SharedResourceConfigUri(value: _root_.scala.Predef.String) extends snapchat.research.gbml.gigl_resource_config.GiglResourceConfig.SharedResource { type ValueType = _root_.scala.Predef.String diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index 55b0dfcf7..c5e49c887 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -18,6 +18,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -53,23 +54,29 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { VJlcXVlc3RSDW1lbW9yeVJlcXVlc3QSJwoIZ3B1X3R5cGUYAyABKAlCDOI/CRIHZ3B1VHlwZVIHZ3B1VHlwZRIqCglncHVfbGlta XQYBCABKA1CDeI/ChIIZ3B1TGltaXRSCGdwdUxpbWl0EjMKDG51bV9yZXBsaWNhcxgFIAEoDUIQ4j8NEgtudW1SZXBsaWNhc1ILb nVtUmVwbGljYXMiRwoTTG9jYWxSZXNvdXJjZUNvbmZpZxIwCgtudW1fd29ya2VycxgBIAEoDUIP4j8MEgpudW1Xb3JrZXJzUgpud - W1Xb3JrZXJzIp0DChhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzItL - nNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAU - hV2ZXJ0ZXhBaVRyYWluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5LRlBUcmFpbmVyQ29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ3ChRsb2NhbF90cmFpb - mVyX2NvbmZpZxgDIAEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQhfiPxQSEmxvY2FsVHJha - W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcinQMKFVRyYWluZXJSZXNvdXJjZUNvbmZpZ - xKFAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291c - mNlQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcScAoSa2ZwX3RyYWluZ - XJfY29uZmlnGAIgASgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ - 29uZmlnSABSEGtmcFRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJja - C5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZ0IQC - g50cmFpbmVyX2NvbmZpZyLUAwoYSW5mZXJlbmNlclJlc291cmNlQ29uZmlnEo4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25ma - WcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpUmVzb3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlc - mVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZxKNAQoaZGF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABK - AsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb3VyY2VDb25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ - 29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXbG9jYWxfaW5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY - 2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCGuI/FxIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsS + W1Xb3JrZXJzImsKF1ZlcnRleEFpTXVsdGlQb29sQ29uZmlnElAKBXBvb2xzGAEgAygLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb + C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQgriPwcSBXBvb2xzUgVwb29scyKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBC + hh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvb + mZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2Nvb + mZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnS + ABSEGtmcFRyYWluZXJDb25maWcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sL + kxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZ + XJfY29uZmlnIsMEChVUcmFpbmVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVd + mVydGV4QWlUcmFpbmVyQ29uZmlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS + 0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZ + XJfY29uZmlnGAMgASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJha + W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWcSowEKI3ZlcnRleF9haV9tdWx0aV9wb29sX3RyYWluZXJfY29uZmlnGAQgA + SgLMi8uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaU11bHRpUG9vbENvbmZpZ0Ij4j8gEh52ZXJ0ZXhBaU11bHRpUG9vb + FRyYWluZXJDb25maWdIAFIedmVydGV4QWlNdWx0aVBvb2xUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIoMFChhJbmZlc + mVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2Vhc + mNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJb + mZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY + 2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZ + XNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnEqwBCiZ2ZXJ0Z + XhfYWlfbXVsdGlfcG9vbF9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIvLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlNd + Wx0aVBvb2xDb25maWdCJuI/IxIhdmVydGV4QWlNdWx0aVBvb2xJbmZlcmVuY2VyQ29uZmlnSABSIXZlcnRleEFpTXVsdGlQb29sS W5mZXJlbmNlckNvbmZpZ0ITChFpbmZlcmVuY2VyX2NvbmZpZyKXCAoUU2hhcmVkUmVzb3VyY2VDb25maWcSfgoPcmVzb3VyY2Vfb GFiZWxzGAEgAygLMkAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5SZXNvdXJjZUxhYmVsc0Vud HJ5QhPiPxASDnJlc291cmNlTGFiZWxzUg5yZXNvdXJjZUxhYmVscxKOAQoVY29tbW9uX2NvbXB1dGVfY29uZmlnGAIgASgLMkAuc diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index 33df88e8d..e222970eb 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -28,6 +28,10 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.localInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (inferencerConfig.vertexAiMultiPoolInferencerConfig.isDefined) { + val __value = inferencerConfig.vertexAiMultiPoolInferencerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -38,7 +42,7 @@ final case class InferencerResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { inferencerConfig.vertexAiInferencerConfig.foreach { __v => @@ -59,6 +63,12 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + inferencerConfig.vertexAiMultiPoolInferencerConfig.foreach { __v => + val __m = __v + _output__.writeTag(4, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = inferencerConfig.vertexAiInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -67,6 +77,8 @@ final case class InferencerResourceConfig( def withDataflowInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__v)) def getLocalInferencerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = inferencerConfig.localInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) + def getVertexAiMultiPoolInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = inferencerConfig.vertexAiMultiPoolInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) + def withVertexAiMultiPoolInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -76,6 +88,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.orNull case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull + case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -84,6 +97,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -107,6 +121,8 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__inferencerConfig.dataflowInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 26 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 34 => + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__inferencerConfig.vertexAiMultiPoolInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -126,18 +142,20 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch inferencerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig } __out } @@ -152,9 +170,11 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isVertexAiInferencerConfig: _root_.scala.Boolean = false def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false + def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None + def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -165,7 +185,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig @@ -187,16 +207,25 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = Some(value) override def number: _root_.scala.Int = 3 } + @SerialVersionUID(0L) + final case class VertexAiMultiPoolInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + override def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = true + override def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + override def number: _root_.scala.Int = 4 + } } implicit class InferencerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig](_l) { def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) + def vertexAiMultiPoolInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 + final val VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER = 4 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala index ba4299fbf..48bf65537 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala @@ -35,7 +35,7 @@ final case class SharedResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { resourceLabels.foreach { __v => @@ -116,8 +116,8 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -172,63 +172,63 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = project if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = region if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = tempAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = tempRegionalAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = permAssetsBucket if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(5, __value) } }; - + { val __value = tempAssetsBqDatasetName if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(6, __value) } }; - + { val __value = embeddingBqDatasetName if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(7, __value) } }; - + { val __value = gcpServiceAccountEmail if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(8, __value) } }; - + { val __value = dataflowRunner if (!__value.isEmpty) { @@ -245,7 +245,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -373,7 +373,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig.type = snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig]) } - + object CommonComputeConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig = { @@ -505,7 +505,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig]) } - + @SerialVersionUID(0L) final case class ResourceLabelsEntry( key: _root_.scala.Predef.String = "", @@ -516,14 +516,14 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -540,7 +540,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -584,7 +584,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry.type = snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry]) } - + object ResourceLabelsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.ResourceLabelsEntry = { @@ -649,7 +649,7 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry]) } - + implicit class SharedResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig](_l) { def resourceLabels: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.resourceLabels)((c_, f_) => c_.copy(resourceLabels = f_)) def commonComputeConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.SharedResourceConfig.CommonComputeConfig] = field(_.getCommonComputeConfig)((c_, f_) => c_.copy(commonComputeConfig = Option(f_))) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 9c84a78b6..8a72b5e87 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -28,6 +28,10 @@ final case class TrainerResourceConfig( val __value = trainerConfig.localTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (trainerConfig.vertexAiMultiPoolTrainerConfig.isDefined) { + val __value = trainerConfig.vertexAiMultiPoolTrainerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -38,7 +42,7 @@ final case class TrainerResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiTrainerConfig.foreach { __v => @@ -59,6 +63,12 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + trainerConfig.vertexAiMultiPoolTrainerConfig.foreach { __v => + val __m = __v + _output__.writeTag(4, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = trainerConfig.vertexAiTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -67,6 +77,8 @@ final case class TrainerResourceConfig( def withKfpTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.KFPResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__v)) def getLocalTrainerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = trainerConfig.localTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) + def getVertexAiMultiPoolTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = trainerConfig.vertexAiMultiPoolTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) + def withVertexAiMultiPoolTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -76,6 +88,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.orNull case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull + case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -84,6 +97,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -107,6 +121,8 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__trainerConfig.kfpTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 26 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 34 => + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__trainerConfig.vertexAiMultiPoolTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -126,18 +142,20 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. trainerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(10) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(10) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig } __out } @@ -152,9 +170,11 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isVertexAiTrainerConfig: _root_.scala.Boolean = false def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false + def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None + def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -165,7 +185,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig @@ -187,16 +207,25 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = Some(value) override def number: _root_.scala.Int = 3 } + @SerialVersionUID(0L) + final case class VertexAiMultiPoolTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + override def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = true + override def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + override def number: _root_.scala.Int = 4 + } } implicit class TrainerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig](_l) { def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) + def vertexAiMultiPoolTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 + final val VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER = 4 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala new file mode 100644 index 000000000..dac501e3e --- /dev/null +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala @@ -0,0 +1,154 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for Mutlipool Vertex AI jobs. + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * NOTE: The first worker pool will be split into the primary replica and "Workers". + * For example: + * pools = [ + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 16 + * }, + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + * ] + * Will have the Primary be: {} + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 1 + * } + * And the Workers be: + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 15 + * } + * And the parameter servers be: + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + */ +@SerialVersionUID(0L) +final case class VertexAiMultiPoolConfig( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiMultiPoolConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + pools.foreach { __item => + val __value = __item + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + pools.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def clearPools = copy(pools = _root_.scala.Seq.empty) + def addPools(__vs: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig *): VertexAiMultiPoolConfig = addAllPools(__vs) + def addAllPools(__vs: Iterable[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolConfig = copy(pools = pools ++ __vs) + def withPools(__v: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolConfig = copy(pools = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => pools + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PRepeated(pools.iterator.map(_.toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiMultiPoolConfig]) +} + +object VertexAiMultiPoolConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = { + val __pools: _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = new _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __pools += _root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = __pools.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools = _root_.scala.Seq.empty + ) + implicit class VertexAiMultiPoolConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_l) { + def pools: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.pools)((c_, f_) => c_.copy(pools = f_)) + } + final val POOLS_FIELD_NUMBER = 1 + def of( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig( + pools + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiMultiPoolConfig]) +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala new file mode 100644 index 000000000..fb43c370b --- /dev/null +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolTrainerConfig.scala @@ -0,0 +1,154 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for Mutlipool Vertex AI jobs. + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * NOTE: The first worker pool will be split into the primary replica and "Workers". + * For example: + * pools = [ + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 16 + * }, + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + * ] + * Will have the Primary be: {} + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 1 + * } + * And the Workers be: + * { + * "machine_type": "n1-standard-8", + * "num_replicas": 15 + * } + * And the parameter servers be: + * { + * "machine_type": "n1-standard-8", + * "gpu_type": "nvidia-tesla-v100", + * "gpu_limit": 1, + * "num_replicas": 16 + * } + */ +@SerialVersionUID(0L) +final case class VertexAiMultiPoolTrainerConfig( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiMultiPoolTrainerConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + pools.foreach { __item => + val __value = __item + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + pools.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def clearPools = copy(pools = _root_.scala.Seq.empty) + def addPools(__vs: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig *): VertexAiMultiPoolTrainerConfig = addAllPools(__vs) + def addAllPools(__vs: Iterable[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolTrainerConfig = copy(pools = pools ++ __vs) + def withPools(__v: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]): VertexAiMultiPoolTrainerConfig = copy(pools = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => pools + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PRepeated(pools.iterator.map(_.toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiMultiPoolTrainerConfig]) +} + +object VertexAiMultiPoolTrainerConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig = { + val __pools: _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = new _root_.scala.collection.immutable.VectorBuilder[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __pools += _root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = __pools.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools = _root_.scala.Seq.empty + ) + implicit class VertexAiMultiPoolTrainerConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig](_l) { + def pools: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.pools)((c_, f_) => c_.copy(pools = f_)) + } + final val POOLS_FIELD_NUMBER = 1 + def of( + pools: _root_.scala.Seq[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolTrainerConfig( + pools + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiMultiPoolTrainerConfig]) +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala index 152f96abb..0cac6898d 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiResourceConfig.scala @@ -19,8 +19,9 @@ package snapchat.research.gbml.gigl_resource_config * Timeout in seconds for the job. If unset or zero, will use the default @ google.cloud.aiplatform.CustomJob, which is 7 days: * https://github.com/googleapis/python-aiplatform/blob/58fbabdeeefd1ccf1a9d0c22eeb5606aeb9c2266/google/cloud/aiplatform/jobs.py#L2252-L2253 * @param gcpRegionOverride - * Region override. + * Region override * If provided, then the Vertex AI Job will be launched in the provided region. + * Otherwise, will launch jobs in the region specified at CommonComputeConfig.region * ex: "us-west1" * NOTE: If set, then there may be data egress costs from CommonComputeConfig.region -> gcp_region_override */ From d810e45c9dad1744613df059a9a2810dcca8badf Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 26 Sep 2025 18:43:58 +0000 Subject: [PATCH 02/33] typo --- proto/snapchat/research/gbml/gigl_resource_config.proto | 2 +- python/snapchat/research/gbml/gigl_resource_config_pb2.pyi | 2 +- .../gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala | 2 +- .../gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/proto/snapchat/research/gbml/gigl_resource_config.proto b/proto/snapchat/research/gbml/gigl_resource_config.proto index 9aecb4d78..90bd933b7 100644 --- a/proto/snapchat/research/gbml/gigl_resource_config.proto +++ b/proto/snapchat/research/gbml/gigl_resource_config.proto @@ -116,7 +116,7 @@ message VertexAiResourceConfig { uint32 num_workers = 1; } - // Configuration for Mutlipool Vertex AI jobs. + // Configuration for Multipool Vertex AI jobs. // See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. // NOTE: The first worker pool will be split into the primary replica and "Workers". // For example: diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi index 1f540b611..bd5622a76 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi @@ -300,7 +300,7 @@ class LocalResourceConfig(google.protobuf.message.Message): global___LocalResourceConfig = LocalResourceConfig class VertexAiMultiPoolConfig(google.protobuf.message.Message): - """Configuration for Mutlipool Vertex AI jobs. + """Configuration for Multipool Vertex AI jobs. See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. NOTE: The first worker pool will be split into the primary replica and "Workers". For example: diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala index dac501e3e..94cc89fce 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala @@ -5,7 +5,7 @@ package snapchat.research.gbml.gigl_resource_config -/** Configuration for Mutlipool Vertex AI jobs. +/** Configuration for Multipool Vertex AI jobs. * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. * NOTE: The first worker pool will be split into the primary replica and "Workers". * For example: diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala index dac501e3e..94cc89fce 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiMultiPoolConfig.scala @@ -5,7 +5,7 @@ package snapchat.research.gbml.gigl_resource_config -/** Configuration for Mutlipool Vertex AI jobs. +/** Configuration for Multipool Vertex AI jobs. * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. * NOTE: The first worker pool will be split into the primary replica and "Workers". * For example: From 5ff13b26d4e6670629a1e73a7f65f0269ed23fb5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 26 Sep 2025 21:00:12 +0000 Subject: [PATCH 03/33] to more explicit configs --- .../research/gbml/gigl_resource_config.proto | 47 ++---- .../research/gbml/gigl_resource_config_pb2.py | 48 +++--- .../gbml/gigl_resource_config_pb2.pyi | 78 ++++----- .../GiglResourceConfigProto.scala | 124 +++++++------- .../InferencerResourceConfig.scala | 36 ++-- .../TrainerResourceConfig.scala | 36 ++-- .../VertexAiGraphStoreConfig.scala | 155 ++++++++++++++++++ .../GiglResourceConfigProto.scala | 124 +++++++------- .../InferencerResourceConfig.scala | 36 ++-- .../TrainerResourceConfig.scala | 36 ++-- .../VertexAiGraphStoreConfig.scala | 155 ++++++++++++++++++ 11 files changed, 570 insertions(+), 305 deletions(-) create mode 100644 scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala create mode 100644 scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala diff --git a/proto/snapchat/research/gbml/gigl_resource_config.proto b/proto/snapchat/research/gbml/gigl_resource_config.proto index 90bd933b7..1712ef61b 100644 --- a/proto/snapchat/research/gbml/gigl_resource_config.proto +++ b/proto/snapchat/research/gbml/gigl_resource_config.proto @@ -116,43 +116,16 @@ message VertexAiResourceConfig { uint32 num_workers = 1; } - // Configuration for Multipool Vertex AI jobs. + // Configuration for lauching Vertex AI clusters with both graph store and compute pools + // Under the hood, this uses Vertex AI Multi-Pool Training // See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. - // NOTE: The first worker pool will be split into the primary replica and "Workers". - // For example: - // pools = [ - // { - // "machine_type": "n1-standard-8", - // "num_replicas": 16 - // }, - // { - // "machine_type": "n1-standard-8", - // "gpu_type": "nvidia-tesla-v100", - // "gpu_limit": 1, - // "num_replicas": 16 - // } - // ] - // Will have the Primary be: {} - // { - // "machine_type": "n1-standard-8", - // "num_replicas": 1 - // } - // And the Workers be: - // { - // "machine_type": "n1-standard-8", - // "num_replicas": 15 - // } - // And the parameter servers be: - // { - // "machine_type": "n1-standard-8", - // "gpu_type": "nvidia-tesla-v100", - // "gpu_limit": 1, - // "num_replicas": 16 - // } - message VertexAiMultiPoolConfig { - repeated VertexAiResourceConfig pools = 1; + // This cluster setup should be used when you want store your graph on separate machines from the compute machines + // e.g. you can get lots of big memory machines and separate gpu machines individually, + // but getting lots of gpu machines with lots of memory is challenging. + message VertexAiGraphStoreConfig { + VertexAiResourceConfig graph_store_pool = 1; + VertexAiResourceConfig compute_pool = 2; } - // (deprecated) // Configuration for distributed training resources message DistributedTrainerConfig { @@ -169,7 +142,7 @@ message TrainerResourceConfig { VertexAiResourceConfig vertex_ai_trainer_config = 1; KFPResourceConfig kfp_trainer_config = 2; LocalResourceConfig local_trainer_config = 3; - VertexAiMultiPoolConfig vertex_ai_multi_pool_trainer_config = 4; + VertexAiGraphStoreConfig vertex_ai_graph_store_trainer_config = 4; } } @@ -179,7 +152,7 @@ message InferencerResourceConfig { VertexAiResourceConfig vertex_ai_inferencer_config = 1; DataflowResourceConfig dataflow_inferencer_config = 2; LocalResourceConfig local_inferencer_config = 3; - VertexAiMultiPoolConfig vertex_ai_multi_pool_inferencer_config = 4; + VertexAiGraphStoreConfig vertex_ai_graph_store_inferencer_config = 4; } } diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.py b/python/snapchat/research/gbml/gigl_resource_config_pb2.py index 0ce7bd9e0..55fe245e8 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.py +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.py @@ -15,7 +15,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x97\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"X\n\x17VertexAiMultiPoolConfig\x12=\n\x05pools\x18\x01 \x03(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf3\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12^\n#vertex_ai_multi_pool_trainer_config\x18\x04 \x01(\x0b\x32/.snapchat.research.gbml.VertexAiMultiPoolConfigH\x00\x42\x10\n\x0etrainer_config\"\x8f\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x61\n&vertex_ai_multi_pool_inferencer_config\x18\x04 \x01(\x0b\x32/.snapchat.research.gbml.VertexAiMultiPoolConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"r\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\x97\x01\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xaa\x01\n\x18VertexAiGraphStoreConfig\x12H\n\x10graph_store_pool\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12\x44\n\x0c\x63ompute_pool\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf5\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12`\n$vertex_ai_graph_store_trainer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x10\n\x0etrainer_config\"\x91\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x63\n\'vertex_ai_graph_store_inferencer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') _COMPONENT = DESCRIPTOR.enum_types_by_name['Component'] Component = enum_type_wrapper.EnumTypeWrapper(_COMPONENT) @@ -38,7 +38,7 @@ _VERTEXAIRESOURCECONFIG = DESCRIPTOR.message_types_by_name['VertexAiResourceConfig'] _KFPRESOURCECONFIG = DESCRIPTOR.message_types_by_name['KFPResourceConfig'] _LOCALRESOURCECONFIG = DESCRIPTOR.message_types_by_name['LocalResourceConfig'] -_VERTEXAIMULTIPOOLCONFIG = DESCRIPTOR.message_types_by_name['VertexAiMultiPoolConfig'] +_VERTEXAIGRAPHSTORECONFIG = DESCRIPTOR.message_types_by_name['VertexAiGraphStoreConfig'] _DISTRIBUTEDTRAINERCONFIG = DESCRIPTOR.message_types_by_name['DistributedTrainerConfig'] _TRAINERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['TrainerResourceConfig'] _INFERENCERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['InferencerResourceConfig'] @@ -109,12 +109,12 @@ }) _sym_db.RegisterMessage(LocalResourceConfig) -VertexAiMultiPoolConfig = _reflection.GeneratedProtocolMessageType('VertexAiMultiPoolConfig', (_message.Message,), { - 'DESCRIPTOR' : _VERTEXAIMULTIPOOLCONFIG, +VertexAiGraphStoreConfig = _reflection.GeneratedProtocolMessageType('VertexAiGraphStoreConfig', (_message.Message,), { + 'DESCRIPTOR' : _VERTEXAIGRAPHSTORECONFIG, '__module__' : 'snapchat.research.gbml.gigl_resource_config_pb2' - # @@protoc_insertion_point(class_scope:snapchat.research.gbml.VertexAiMultiPoolConfig) + # @@protoc_insertion_point(class_scope:snapchat.research.gbml.VertexAiGraphStoreConfig) }) -_sym_db.RegisterMessage(VertexAiMultiPoolConfig) +_sym_db.RegisterMessage(VertexAiGraphStoreConfig) DistributedTrainerConfig = _reflection.GeneratedProtocolMessageType('DistributedTrainerConfig', (_message.Message,), { 'DESCRIPTOR' : _DISTRIBUTEDTRAINERCONFIG, @@ -176,8 +176,8 @@ _GIGLRESOURCECONFIG.fields_by_name['trainer_config']._serialized_options = b'\030\001' _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._options = None _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._serialized_options = b'\030\001' - _COMPONENT._serialized_start=3481 - _COMPONENT._serialized_end=3724 + _COMPONENT._serialized_start=3568 + _COMPONENT._serialized_end=3811 _SPARKRESOURCECONFIG._serialized_start=77 _SPARKRESOURCECONFIG._serialized_end=166 _DATAFLOWRESOURCECONFIG._serialized_start=168 @@ -196,20 +196,20 @@ _KFPRESOURCECONFIG._serialized_end=1025 _LOCALRESOURCECONFIG._serialized_start=1027 _LOCALRESOURCECONFIG._serialized_end=1069 - _VERTEXAIMULTIPOOLCONFIG._serialized_start=1071 - _VERTEXAIMULTIPOOLCONFIG._serialized_end=1159 - _DISTRIBUTEDTRAINERCONFIG._serialized_start=1162 - _DISTRIBUTEDTRAINERCONFIG._serialized_end=1437 - _TRAINERRESOURCECONFIG._serialized_start=1440 - _TRAINERRESOURCECONFIG._serialized_end=1811 - _INFERENCERRESOURCECONFIG._serialized_start=1814 - _INFERENCERRESOURCECONFIG._serialized_end=2213 - _SHAREDRESOURCECONFIG._serialized_start=2216 - _SHAREDRESOURCECONFIG._serialized_end=2763 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2429 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2708 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2710 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2763 - _GIGLRESOURCECONFIG._serialized_start=2766 - _GIGLRESOURCECONFIG._serialized_end=3478 + _VERTEXAIGRAPHSTORECONFIG._serialized_start=1072 + _VERTEXAIGRAPHSTORECONFIG._serialized_end=1242 + _DISTRIBUTEDTRAINERCONFIG._serialized_start=1245 + _DISTRIBUTEDTRAINERCONFIG._serialized_end=1520 + _TRAINERRESOURCECONFIG._serialized_start=1523 + _TRAINERRESOURCECONFIG._serialized_end=1896 + _INFERENCERRESOURCECONFIG._serialized_start=1899 + _INFERENCERRESOURCECONFIG._serialized_end=2300 + _SHAREDRESOURCECONFIG._serialized_start=2303 + _SHAREDRESOURCECONFIG._serialized_end=2850 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2516 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=2795 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=2797 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=2850 + _GIGLRESOURCECONFIG._serialized_start=2853 + _GIGLRESOURCECONFIG._serialized_end=3565 # @@protoc_insertion_point(module_scope) diff --git a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi index bd5622a76..c5bb6e842 100644 --- a/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi +++ b/python/snapchat/research/gbml/gigl_resource_config_pb2.pyi @@ -299,55 +299,33 @@ class LocalResourceConfig(google.protobuf.message.Message): global___LocalResourceConfig = LocalResourceConfig -class VertexAiMultiPoolConfig(google.protobuf.message.Message): - """Configuration for Multipool Vertex AI jobs. +class VertexAiGraphStoreConfig(google.protobuf.message.Message): + """Configuration for lauching Vertex AI clusters with both graph store and compute pools + Under the hood, this uses Vertex AI Multi-Pool Training See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. - NOTE: The first worker pool will be split into the primary replica and "Workers". - For example: - pools = [ - { - "machine_type": "n1-standard-8", - "num_replicas": 16 - }, - { - "machine_type": "n1-standard-8", - "gpu_type": "nvidia-tesla-v100", - "gpu_limit": 1, - "num_replicas": 16 - } - ] - Will have the Primary be: {} - { - "machine_type": "n1-standard-8", - "num_replicas": 1 - } - And the Workers be: - { - "machine_type": "n1-standard-8", - "num_replicas": 15 - } - And the parameter servers be: - { - "machine_type": "n1-standard-8", - "gpu_type": "nvidia-tesla-v100", - "gpu_limit": 1, - "num_replicas": 16 - } + This cluster setup should be used when you want store your graph on separate machines from the compute machines + e.g. you can get lots of big memory machines and separate gpu machines individually, + but getting lots of gpu machines with lots of memory is challenging. """ DESCRIPTOR: google.protobuf.descriptor.Descriptor - POOLS_FIELD_NUMBER: builtins.int + GRAPH_STORE_POOL_FIELD_NUMBER: builtins.int + COMPUTE_POOL_FIELD_NUMBER: builtins.int @property - def pools(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___VertexAiResourceConfig]: ... + def graph_store_pool(self) -> global___VertexAiResourceConfig: ... + @property + def compute_pool(self) -> global___VertexAiResourceConfig: ... def __init__( self, *, - pools: collections.abc.Iterable[global___VertexAiResourceConfig] | None = ..., + graph_store_pool: global___VertexAiResourceConfig | None = ..., + compute_pool: global___VertexAiResourceConfig | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["pools", b"pools"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["compute_pool", b"compute_pool", "graph_store_pool", b"graph_store_pool"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["compute_pool", b"compute_pool", "graph_store_pool", b"graph_store_pool"]) -> None: ... -global___VertexAiMultiPoolConfig = VertexAiMultiPoolConfig +global___VertexAiGraphStoreConfig = VertexAiGraphStoreConfig class DistributedTrainerConfig(google.protobuf.message.Message): """(deprecated) @@ -386,7 +364,7 @@ class TrainerResourceConfig(google.protobuf.message.Message): VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER: builtins.int KFP_TRAINER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_TRAINER_CONFIG_FIELD_NUMBER: builtins.int - VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER: builtins.int + VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_trainer_config(self) -> global___VertexAiResourceConfig: ... @property @@ -394,18 +372,18 @@ class TrainerResourceConfig(google.protobuf.message.Message): @property def local_trainer_config(self) -> global___LocalResourceConfig: ... @property - def vertex_ai_multi_pool_trainer_config(self) -> global___VertexAiMultiPoolConfig: ... + def vertex_ai_graph_store_trainer_config(self) -> global___VertexAiGraphStoreConfig: ... def __init__( self, *, vertex_ai_trainer_config: global___VertexAiResourceConfig | None = ..., kfp_trainer_config: global___KFPResourceConfig | None = ..., local_trainer_config: global___LocalResourceConfig | None = ..., - vertex_ai_multi_pool_trainer_config: global___VertexAiMultiPoolConfig | None = ..., + vertex_ai_graph_store_trainer_config: global___VertexAiGraphStoreConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_multi_pool_trainer_config", b"vertex_ai_multi_pool_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_multi_pool_trainer_config", b"vertex_ai_multi_pool_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config", "vertex_ai_multi_pool_trainer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config", "vertex_ai_graph_store_trainer_config"] | None: ... global___TrainerResourceConfig = TrainerResourceConfig @@ -417,7 +395,7 @@ class InferencerResourceConfig(google.protobuf.message.Message): VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int - VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int + VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_inferencer_config(self) -> global___VertexAiResourceConfig: ... @property @@ -425,18 +403,18 @@ class InferencerResourceConfig(google.protobuf.message.Message): @property def local_inferencer_config(self) -> global___LocalResourceConfig: ... @property - def vertex_ai_multi_pool_inferencer_config(self) -> global___VertexAiMultiPoolConfig: ... + def vertex_ai_graph_store_inferencer_config(self) -> global___VertexAiGraphStoreConfig: ... def __init__( self, *, vertex_ai_inferencer_config: global___VertexAiResourceConfig | None = ..., dataflow_inferencer_config: global___DataflowResourceConfig | None = ..., local_inferencer_config: global___LocalResourceConfig | None = ..., - vertex_ai_multi_pool_inferencer_config: global___VertexAiMultiPoolConfig | None = ..., + vertex_ai_graph_store_inferencer_config: global___VertexAiGraphStoreConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config", "vertex_ai_multi_pool_inferencer_config", b"vertex_ai_multi_pool_inferencer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config", "vertex_ai_multi_pool_inferencer_config", b"vertex_ai_multi_pool_inferencer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config", "vertex_ai_multi_pool_inferencer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config", "vertex_ai_graph_store_inferencer_config"] | None: ... global___InferencerResourceConfig = InferencerResourceConfig diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index c5e49c887..24306c158 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -18,7 +18,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, - snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig, + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -54,66 +54,68 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { VJlcXVlc3RSDW1lbW9yeVJlcXVlc3QSJwoIZ3B1X3R5cGUYAyABKAlCDOI/CRIHZ3B1VHlwZVIHZ3B1VHlwZRIqCglncHVfbGlta XQYBCABKA1CDeI/ChIIZ3B1TGltaXRSCGdwdUxpbWl0EjMKDG51bV9yZXBsaWNhcxgFIAEoDUIQ4j8NEgtudW1SZXBsaWNhc1ILb nVtUmVwbGljYXMiRwoTTG9jYWxSZXNvdXJjZUNvbmZpZxIwCgtudW1fd29ya2VycxgBIAEoDUIP4j8MEgpudW1Xb3JrZXJzUgpud - W1Xb3JrZXJzImsKF1ZlcnRleEFpTXVsdGlQb29sQ29uZmlnElAKBXBvb2xzGAEgAygLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQgriPwcSBXBvb2xzUgVwb29scyKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBC - hh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvb - mZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2Nvb - mZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnS - ABSEGtmcFRyYWluZXJDb25maWcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sL - kxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZ - XJfY29uZmlnIsMEChVUcmFpbmVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuY - XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVd - mVydGV4QWlUcmFpbmVyQ29uZmlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS - 0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZ - XJfY29uZmlnGAMgASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJha - W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWcSowEKI3ZlcnRleF9haV9tdWx0aV9wb29sX3RyYWluZXJfY29uZmlnGAQgA - SgLMi8uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaU11bHRpUG9vbENvbmZpZ0Ij4j8gEh52ZXJ0ZXhBaU11bHRpUG9vb - FRyYWluZXJDb25maWdIAFIedmVydGV4QWlNdWx0aVBvb2xUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIoMFChhJbmZlc - mVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2Vhc - mNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJb - mZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib - WwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY - 2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZ - XNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnEqwBCiZ2ZXJ0Z - XhfYWlfbXVsdGlfcG9vbF9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIvLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlNd - Wx0aVBvb2xDb25maWdCJuI/IxIhdmVydGV4QWlNdWx0aVBvb2xJbmZlcmVuY2VyQ29uZmlnSABSIXZlcnRleEFpTXVsdGlQb29sS - W5mZXJlbmNlckNvbmZpZ0ITChFpbmZlcmVuY2VyX2NvbmZpZyKXCAoUU2hhcmVkUmVzb3VyY2VDb25maWcSfgoPcmVzb3VyY2Vfb - GFiZWxzGAEgAygLMkAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5SZXNvdXJjZUxhYmVsc0Vud - HJ5QhPiPxASDnJlc291cmNlTGFiZWxzUg5yZXNvdXJjZUxhYmVscxKOAQoVY29tbW9uX2NvbXB1dGVfY29uZmlnGAIgASgLMkAuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5Db21tb25Db21wdXRlQ29uZmlnQhjiPxUSE2NvbW1vb - kNvbXB1dGVDb25maWdSE2NvbW1vbkNvbXB1dGVDb25maWcalAUKE0NvbW1vbkNvbXB1dGVDb25maWcSJgoHcHJvamVjdBgBIAEoC - UIM4j8JEgdwcm9qZWN0Ugdwcm9qZWN0EiMKBnJlZ2lvbhgCIAEoCUIL4j8IEgZyZWdpb25SBnJlZ2lvbhJDChJ0ZW1wX2Fzc2V0c - 19idWNrZXQYAyABKAlCFeI/EhIQdGVtcEFzc2V0c0J1Y2tldFIQdGVtcEFzc2V0c0J1Y2tldBJcCht0ZW1wX3JlZ2lvbmFsX2Fzc - 2V0c19idWNrZXQYBCABKAlCHeI/GhIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0Uhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXQSQ - woScGVybV9hc3NldHNfYnVja2V0GAUgASgJQhXiPxISEHBlcm1Bc3NldHNCdWNrZXRSEHBlcm1Bc3NldHNCdWNrZXQSWgobdGVtc - F9hc3NldHNfYnFfZGF0YXNldF9uYW1lGAYgASgJQhziPxkSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lUhd0ZW1wQXNzZXRzQnFEY - XRhc2V0TmFtZRJWChllbWJlZGRpbmdfYnFfZGF0YXNldF9uYW1lGAcgASgJQhviPxgSFmVtYmVkZGluZ0JxRGF0YXNldE5hbWVSF - mVtYmVkZGluZ0JxRGF0YXNldE5hbWUSVgoZZ2NwX3NlcnZpY2VfYWNjb3VudF9lbWFpbBgIIAEoCUIb4j8YEhZnY3BTZXJ2aWNlQ - WNjb3VudEVtYWlsUhZnY3BTZXJ2aWNlQWNjb3VudEVtYWlsEjwKD2RhdGFmbG93X3J1bm5lchgLIAEoCUIT4j8QEg5kYXRhZmxvd - 1J1bm5lclIOZGF0YWZsb3dSdW5uZXIaVwoTUmVzb3VyY2VMYWJlbHNFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSI - AoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4ASL3CAoSR2lnbFJlc291cmNlQ29uZmlnElsKGnNoYXJlZF9yZXNvd - XJjZV9jb25maWdfdXJpGAEgASgJQhziPxkSF3NoYXJlZFJlc291cmNlQ29uZmlnVXJpSABSF3NoYXJlZFJlc291cmNlQ29uZmlnV - XJpEn8KFnNoYXJlZF9yZXNvdXJjZV9jb25maWcYAiABKAsyLC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ - 29uZmlnQhniPxYSFHNoYXJlZFJlc291cmNlQ29uZmlnSABSFHNoYXJlZFJlc291cmNlQ29uZmlnEngKE3ByZXByb2Nlc3Nvcl9jb - 25maWcYDCABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFQcmVwcm9jZXNzb3JDb25maWdCF+I/FBIScHJlcHJvY2Vzc - 29yQ29uZmlnUhJwcmVwcm9jZXNzb3JDb25maWcSfwoXc3ViZ3JhcGhfc2FtcGxlcl9jb25maWcYDSABKAsyKy5zbmFwY2hhdC5yZ - XNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGuI/FxIVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnUhVzdWJncmFwaFNhbXBsZ - XJDb25maWcSfAoWc3BsaXRfZ2VuZXJhdG9yX2NvbmZpZxgOIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU3BhcmtSZXNvd - XJjZUNvbmZpZ0IZ4j8WEhRzcGxpdEdlbmVyYXRvckNvbmZpZ1IUc3BsaXRHZW5lcmF0b3JDb25maWcSbQoOdHJhaW5lcl9jb25ma - WcYDyABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRpc3RyaWJ1dGVkVHJhaW5lckNvbmZpZ0IUGAHiPw8SDXRyYWluZXJDb - 25maWdSDXRyYWluZXJDb25maWcSdAoRaW5mZXJlbmNlcl9jb25maWcYECABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhd - GFmbG93UmVzb3VyY2VDb25maWdCFxgB4j8SEhBpbmZlcmVuY2VyQ29uZmlnUhBpbmZlcmVuY2VyQ29uZmlnEoEBChd0cmFpbmVyX - 3Jlc291cmNlX2NvbmZpZxgRIAEoCzItLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVHJhaW5lclJlc291cmNlQ29uZmlnQhriPxcSF - XRyYWluZXJSZXNvdXJjZUNvbmZpZ1IVdHJhaW5lclJlc291cmNlQ29uZmlnEo0BChppbmZlcmVuY2VyX3Jlc291cmNlX2NvbmZpZ - xgSIAEoCzIwLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuSW5mZXJlbmNlclJlc291cmNlQ29uZmlnQh3iPxoSGGluZmVyZW5jZXJSZ - XNvdXJjZUNvbmZpZ1IYaW5mZXJlbmNlclJlc291cmNlQ29uZmlnQhEKD3NoYXJlZF9yZXNvdXJjZSrjAwoJQ29tcG9uZW50Ei0KE - UNvbXBvbmVudF9Vbmtub3duEAAaFuI/ExIRQ29tcG9uZW50X1Vua25vd24SPwoaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3IQA - Rof4j8cEhpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchI/ChpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchACGh/iPxwSGkNvb - XBvbmVudF9Db25maWdfUG9wdWxhdG9yEkEKG0NvbXBvbmVudF9EYXRhX1ByZXByb2Nlc3NvchADGiDiPx0SG0NvbXBvbmVudF9EY - XRhX1ByZXByb2Nlc3NvchI/ChpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchAEGh/iPxwSGkNvbXBvbmVudF9TdWJncmFwaF9TY - W1wbGVyEj0KGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3IQBRoe4j8bEhlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEi0KEUNvb - XBvbmVudF9UcmFpbmVyEAYaFuI/ExIRQ29tcG9uZW50X1RyYWluZXISMwoUQ29tcG9uZW50X0luZmVyZW5jZXIQBxoZ4j8WEhRDb - 21wb25lbnRfSW5mZXJlbmNlcmIGcHJvdG8z""" + W1Xb3JrZXJzIu4BChhWZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWcSbQoQZ3JhcGhfc3RvcmVfcG9vbBgBIAEoCzIuLnNuYXBjaGF0L + nJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0IT4j8QEg5ncmFwaFN0b3JlUG9vbFIOZ3JhcGhTdG9yZVBvb2wSY + woMY29tcHV0ZV9wb29sGAIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhDiPw0SC + 2NvbXB1dGVQb29sUgtjb21wdXRlUG9vbCKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBChh2ZXJ0ZXhfYWlfdHJhaW5lc + l9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBa + VRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIoLnNuYXBja + GF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25ma + WcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsVHJhaW5lckNvbmZpZ + 0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIscEChVUcmFpb + mVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZ + mlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQUmVzb3VyY2VDb25maWdCF + eI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZXJfY29uZmlnGAMgASgLMisuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2Nhb + FRyYWluZXJDb25maWcSpwEKJHZlcnRleF9haV9ncmFwaF9zdG9yZV90cmFpbmVyX2NvbmZpZxgEIAEoCzIwLnNuYXBjaGF0LnJlc + 2VhcmNoLmdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQiTiPyESH3ZlcnRleEFpR3JhcGhTdG9yZVRyYWluZXJDb25maWdIA + FIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyKHBQoYSW5mZXJlbmNlclJlc291cmNlQ + 29uZmlnEo4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRle + EFpUmVzb3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZ + xKNAQoaZGF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb + 3VyY2VDb25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXb + G9jYWxfaW5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCG + uI/FxIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsSW5mZXJlbmNlckNvbmZpZxKwAQondmVydGV4X2FpX2dyYXBoX3N0b + 3JlX2luZmVyZW5jZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25ma + WdCJ+I/JBIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZ0gAUiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ + 29uZmlnQhMKEWluZmVyZW5jZXJfY29uZmlnIpcIChRTaGFyZWRSZXNvdXJjZUNvbmZpZxJ+Cg9yZXNvdXJjZV9sYWJlbHMYASADK + AsyQC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLlJlc291cmNlTGFiZWxzRW50cnlCE+I/EBIOc + mVzb3VyY2VMYWJlbHNSDnJlc291cmNlTGFiZWxzEo4BChVjb21tb25fY29tcHV0ZV9jb25maWcYAiABKAsyQC5zbmFwY2hhdC5yZ + XNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLkNvbW1vbkNvbXB1dGVDb25maWdCGOI/FRITY29tbW9uQ29tcHV0ZUNvb + mZpZ1ITY29tbW9uQ29tcHV0ZUNvbmZpZxqUBQoTQ29tbW9uQ29tcHV0ZUNvbmZpZxImCgdwcm9qZWN0GAEgASgJQgziPwkSB3Byb + 2plY3RSB3Byb2plY3QSIwoGcmVnaW9uGAIgASgJQgviPwgSBnJlZ2lvblIGcmVnaW9uEkMKEnRlbXBfYXNzZXRzX2J1Y2tldBgDI + AEoCUIV4j8SEhB0ZW1wQXNzZXRzQnVja2V0UhB0ZW1wQXNzZXRzQnVja2V0ElwKG3RlbXBfcmVnaW9uYWxfYXNzZXRzX2J1Y2tld + BgEIAEoCUId4j8aEhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXRSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldBJDChJwZXJtX2Fzc + 2V0c19idWNrZXQYBSABKAlCFeI/EhIQcGVybUFzc2V0c0J1Y2tldFIQcGVybUFzc2V0c0J1Y2tldBJaCht0ZW1wX2Fzc2V0c19ic + V9kYXRhc2V0X25hbWUYBiABKAlCHOI/GRIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWVSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lE + lYKGWVtYmVkZGluZ19icV9kYXRhc2V0X25hbWUYByABKAlCG+I/GBIWZW1iZWRkaW5nQnFEYXRhc2V0TmFtZVIWZW1iZWRkaW5nQ + nFEYXRhc2V0TmFtZRJWChlnY3Bfc2VydmljZV9hY2NvdW50X2VtYWlsGAggASgJQhviPxgSFmdjcFNlcnZpY2VBY2NvdW50RW1ha + WxSFmdjcFNlcnZpY2VBY2NvdW50RW1haWwSPAoPZGF0YWZsb3dfcnVubmVyGAsgASgJQhPiPxASDmRhdGFmbG93UnVubmVyUg5kY + XRhZmxvd1J1bm5lchpXChNSZXNvdXJjZUxhYmVsc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCI + AEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6AjgBIvcIChJHaWdsUmVzb3VyY2VDb25maWcSWwoac2hhcmVkX3Jlc291cmNlX2NvbmZpZ + 191cmkYASABKAlCHOI/GRIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmlIAFIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmkSfwoWc2hhc + mVkX3Jlc291cmNlX2NvbmZpZxgCIAEoCzIsLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWdCGeI/F + hIUc2hhcmVkUmVzb3VyY2VDb25maWdIAFIUc2hhcmVkUmVzb3VyY2VDb25maWcSeAoTcHJlcHJvY2Vzc29yX2NvbmZpZxgMIAEoC + zIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YVByZXByb2Nlc3NvckNvbmZpZ0IX4j8UEhJwcmVwcm9jZXNzb3JDb25maWdSE + nByZXByb2Nlc3NvckNvbmZpZxJ/ChdzdWJncmFwaF9zYW1wbGVyX2NvbmZpZxgNIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuU3BhcmtSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVzdWJncmFwaFNhbXBsZXJDb25maWdSFXN1YmdyYXBoU2FtcGxlckNvbmZpZxJ8C + hZzcGxpdF9nZW5lcmF0b3JfY29uZmlnGA4gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQ + hniPxYSFHNwbGl0R2VuZXJhdG9yQ29uZmlnUhRzcGxpdEdlbmVyYXRvckNvbmZpZxJtCg50cmFpbmVyX2NvbmZpZxgPIAEoCzIwL + nNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnQhQYAeI/DxINdHJhaW5lckNvbmZpZ1INdHJha + W5lckNvbmZpZxJ0ChFpbmZlcmVuY2VyX2NvbmZpZxgQIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvd + XJjZUNvbmZpZ0IXGAHiPxISEGluZmVyZW5jZXJDb25maWdSEGluZmVyZW5jZXJDb25maWcSgQEKF3RyYWluZXJfcmVzb3VyY2VfY + 29uZmlnGBEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5UcmFpbmVyUmVzb3VyY2VDb25maWdCGuI/FxIVdHJhaW5lclJlc + 291cmNlQ29uZmlnUhV0cmFpbmVyUmVzb3VyY2VDb25maWcSjQEKGmluZmVyZW5jZXJfcmVzb3VyY2VfY29uZmlnGBIgASgLMjAuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5JbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCHeI/GhIYaW5mZXJlbmNlclJlc291cmNlQ29uZ + mlnUhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCEQoPc2hhcmVkX3Jlc291cmNlKuMDCglDb21wb25lbnQSLQoRQ29tcG9uZW50X + 1Vua25vd24QABoW4j8TEhFDb21wb25lbnRfVW5rbm93bhI/ChpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchABGh/iPxwSGkNvb + XBvbmVudF9Db25maWdfVmFsaWRhdG9yEj8KGkNvbXBvbmVudF9Db25maWdfUG9wdWxhdG9yEAIaH+I/HBIaQ29tcG9uZW50X0Nvb + mZpZ19Qb3B1bGF0b3ISQQobQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yEAMaIOI/HRIbQ29tcG9uZW50X0RhdGFfUHJlcHJvY + 2Vzc29yEj8KGkNvbXBvbmVudF9TdWJncmFwaF9TYW1wbGVyEAQaH+I/HBIaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXISPQoZQ + 29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchAFGh7iPxsSGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3ISLQoRQ29tcG9uZW50X1RyY + WluZXIQBhoW4j8TEhFDb21wb25lbnRfVHJhaW5lchIzChRDb21wb25lbnRfSW5mZXJlbmNlchAHGhniPxYSFENvbXBvbmVudF9Jb + mZlcmVuY2VyYgZwcm90bzM=""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index e222970eb..60315338f 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -28,8 +28,8 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.localInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - if (inferencerConfig.vertexAiMultiPoolInferencerConfig.isDefined) { - val __value = inferencerConfig.vertexAiMultiPoolInferencerConfig.get + if (inferencerConfig.vertexAiGraphStoreInferencerConfig.isDefined) { + val __value = inferencerConfig.vertexAiGraphStoreInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; __size += unknownFields.serializedSize @@ -63,7 +63,7 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; - inferencerConfig.vertexAiMultiPoolInferencerConfig.foreach { __v => + inferencerConfig.vertexAiGraphStoreInferencerConfig.foreach { __v => val __m = __v _output__.writeTag(4, 2) _output__.writeUInt32NoTag(__m.serializedSize) @@ -77,8 +77,8 @@ final case class InferencerResourceConfig( def withDataflowInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__v)) def getLocalInferencerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = inferencerConfig.localInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) - def getVertexAiMultiPoolInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = inferencerConfig.vertexAiMultiPoolInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) - def withVertexAiMultiPoolInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__v)) + def getVertexAiGraphStoreInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = inferencerConfig.vertexAiGraphStoreInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) + def withVertexAiGraphStoreInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -88,7 +88,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.orNull case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull - case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.orNull + case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -97,7 +97,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) - case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -122,7 +122,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 26 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => - __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__inferencerConfig.vertexAiMultiPoolInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__inferencerConfig.vertexAiGraphStoreInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -142,7 +142,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch inferencerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) - .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") @@ -155,7 +155,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig - case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig } __out } @@ -170,11 +170,11 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isVertexAiInferencerConfig: _root_.scala.Boolean = false def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false - def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = false + def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None - def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None + def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -208,10 +208,10 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 3 } @SerialVersionUID(0L) - final case class VertexAiMultiPoolInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { - type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig - override def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = true - override def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + final case class VertexAiGraphStoreInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + override def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = true + override def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } } @@ -219,13 +219,13 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) - def vertexAiMultiPoolInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(f_))) + def vertexAiGraphStoreInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 - final val VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER = 4 + final val VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER = 4 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 8a72b5e87..4d6dbeaf5 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -28,8 +28,8 @@ final case class TrainerResourceConfig( val __value = trainerConfig.localTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - if (trainerConfig.vertexAiMultiPoolTrainerConfig.isDefined) { - val __value = trainerConfig.vertexAiMultiPoolTrainerConfig.get + if (trainerConfig.vertexAiGraphStoreTrainerConfig.isDefined) { + val __value = trainerConfig.vertexAiGraphStoreTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; __size += unknownFields.serializedSize @@ -63,7 +63,7 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; - trainerConfig.vertexAiMultiPoolTrainerConfig.foreach { __v => + trainerConfig.vertexAiGraphStoreTrainerConfig.foreach { __v => val __m = __v _output__.writeTag(4, 2) _output__.writeUInt32NoTag(__m.serializedSize) @@ -77,8 +77,8 @@ final case class TrainerResourceConfig( def withKfpTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.KFPResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__v)) def getLocalTrainerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = trainerConfig.localTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) - def getVertexAiMultiPoolTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = trainerConfig.vertexAiMultiPoolTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) - def withVertexAiMultiPoolTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__v)) + def getVertexAiGraphStoreTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = trainerConfig.vertexAiGraphStoreTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) + def withVertexAiGraphStoreTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -88,7 +88,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.orNull case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull - case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.orNull + case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -97,7 +97,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) - case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -122,7 +122,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 26 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => - __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__trainerConfig.vertexAiMultiPoolTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__trainerConfig.vertexAiGraphStoreTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -142,7 +142,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. trainerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) - .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") @@ -155,7 +155,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig - case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig } __out } @@ -170,11 +170,11 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isVertexAiTrainerConfig: _root_.scala.Boolean = false def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false - def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = false + def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None - def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None + def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -208,10 +208,10 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def number: _root_.scala.Int = 3 } @SerialVersionUID(0L) - final case class VertexAiMultiPoolTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { - type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig - override def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = true - override def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + final case class VertexAiGraphStoreTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + override def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = true + override def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } } @@ -219,13 +219,13 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) - def vertexAiMultiPoolTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(f_))) + def vertexAiGraphStoreTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 - final val VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER = 4 + final val VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER = 4 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala new file mode 100644 index 000000000..c07f1a3cb --- /dev/null +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala @@ -0,0 +1,155 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for lauching Vertex AI clusters with both graph store and compute pools + * Under the hood, this uses Vertex AI Multi-Pool Training + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * This cluster setup should be used when you want store your graph on separate machines from the compute machines + * e.g. you can get lots of big memory machines and separate gpu machines individually, + * but getting lots of gpu machines with lots of memory is challenging. + */ +@SerialVersionUID(0L) +final case class VertexAiGraphStoreConfig( + graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiGraphStoreConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + if (graphStorePool.isDefined) { + val __value = graphStorePool.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; + if (computePool.isDefined) { + val __value = computePool.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + graphStorePool.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + computePool.foreach { __v => + val __m = __v + _output__.writeTag(2, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def getGraphStorePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = graphStorePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) + def clearGraphStorePool: VertexAiGraphStoreConfig = copy(graphStorePool = _root_.scala.None) + def withGraphStorePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(graphStorePool = Option(__v)) + def getComputePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = computePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) + def clearComputePool: VertexAiGraphStoreConfig = copy(computePool = _root_.scala.None) + def withComputePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(computePool = Option(__v)) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => graphStorePool.orNull + case 2 => computePool.orNull + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => graphStorePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 2 => computePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiGraphStoreConfig]) +} + +object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = { + var __graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var __computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __graphStorePool = Option(__graphStorePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 18 => + __computePool = Option(__computePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = __graphStorePool, + computePool = __computePool, + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), + computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + case 2 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = _root_.scala.None, + computePool = _root_.scala.None + ) + implicit class VertexAiGraphStoreConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_l) { + def graphStorePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getGraphStorePool)((c_, f_) => c_.copy(graphStorePool = Option(f_))) + def optionalGraphStorePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.graphStorePool)((c_, f_) => c_.copy(graphStorePool = f_)) + def computePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getComputePool)((c_, f_) => c_.copy(computePool = Option(f_))) + def optionalComputePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.computePool)((c_, f_) => c_.copy(computePool = f_)) + } + final val GRAPH_STORE_POOL_FIELD_NUMBER = 1 + final val COMPUTE_POOL_FIELD_NUMBER = 2 + def of( + graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool, + computePool + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiGraphStoreConfig]) +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index c5e49c887..24306c158 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -18,7 +18,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, - snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig, + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -54,66 +54,68 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { VJlcXVlc3RSDW1lbW9yeVJlcXVlc3QSJwoIZ3B1X3R5cGUYAyABKAlCDOI/CRIHZ3B1VHlwZVIHZ3B1VHlwZRIqCglncHVfbGlta XQYBCABKA1CDeI/ChIIZ3B1TGltaXRSCGdwdUxpbWl0EjMKDG51bV9yZXBsaWNhcxgFIAEoDUIQ4j8NEgtudW1SZXBsaWNhc1ILb nVtUmVwbGljYXMiRwoTTG9jYWxSZXNvdXJjZUNvbmZpZxIwCgtudW1fd29ya2VycxgBIAEoDUIP4j8MEgpudW1Xb3JrZXJzUgpud - W1Xb3JrZXJzImsKF1ZlcnRleEFpTXVsdGlQb29sQ29uZmlnElAKBXBvb2xzGAEgAygLMi4uc25hcGNoYXQucmVzZWFyY2guZ2Jtb - C5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQgriPwcSBXBvb2xzUgVwb29scyKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBC - hh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvb - mZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2Nvb - mZpZxgCIAEoCzIoLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnS - ABSEGtmcFRyYWluZXJDb25maWcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sL - kxvY2FsVHJhaW5lckNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZ - XJfY29uZmlnIsMEChVUcmFpbmVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuY - XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVd - mVydGV4QWlUcmFpbmVyQ29uZmlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS - 0ZQUmVzb3VyY2VDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZ - XJfY29uZmlnGAMgASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJha - W5lckNvbmZpZ0gAUhJsb2NhbFRyYWluZXJDb25maWcSowEKI3ZlcnRleF9haV9tdWx0aV9wb29sX3RyYWluZXJfY29uZmlnGAQgA - SgLMi8uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaU11bHRpUG9vbENvbmZpZ0Ij4j8gEh52ZXJ0ZXhBaU11bHRpUG9vb - FRyYWluZXJDb25maWdIAFIedmVydGV4QWlNdWx0aVBvb2xUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIoMFChhJbmZlc - mVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2Vhc - mNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJb - mZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2NvbmZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib - WwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY - 2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZ - XNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIAFIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnEqwBCiZ2ZXJ0Z - XhfYWlfbXVsdGlfcG9vbF9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIvLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlNd - Wx0aVBvb2xDb25maWdCJuI/IxIhdmVydGV4QWlNdWx0aVBvb2xJbmZlcmVuY2VyQ29uZmlnSABSIXZlcnRleEFpTXVsdGlQb29sS - W5mZXJlbmNlckNvbmZpZ0ITChFpbmZlcmVuY2VyX2NvbmZpZyKXCAoUU2hhcmVkUmVzb3VyY2VDb25maWcSfgoPcmVzb3VyY2Vfb - GFiZWxzGAEgAygLMkAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5SZXNvdXJjZUxhYmVsc0Vud - HJ5QhPiPxASDnJlc291cmNlTGFiZWxzUg5yZXNvdXJjZUxhYmVscxKOAQoVY29tbW9uX2NvbXB1dGVfY29uZmlnGAIgASgLMkAuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZy5Db21tb25Db21wdXRlQ29uZmlnQhjiPxUSE2NvbW1vb - kNvbXB1dGVDb25maWdSE2NvbW1vbkNvbXB1dGVDb25maWcalAUKE0NvbW1vbkNvbXB1dGVDb25maWcSJgoHcHJvamVjdBgBIAEoC - UIM4j8JEgdwcm9qZWN0Ugdwcm9qZWN0EiMKBnJlZ2lvbhgCIAEoCUIL4j8IEgZyZWdpb25SBnJlZ2lvbhJDChJ0ZW1wX2Fzc2V0c - 19idWNrZXQYAyABKAlCFeI/EhIQdGVtcEFzc2V0c0J1Y2tldFIQdGVtcEFzc2V0c0J1Y2tldBJcCht0ZW1wX3JlZ2lvbmFsX2Fzc - 2V0c19idWNrZXQYBCABKAlCHeI/GhIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0Uhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXQSQ - woScGVybV9hc3NldHNfYnVja2V0GAUgASgJQhXiPxISEHBlcm1Bc3NldHNCdWNrZXRSEHBlcm1Bc3NldHNCdWNrZXQSWgobdGVtc - F9hc3NldHNfYnFfZGF0YXNldF9uYW1lGAYgASgJQhziPxkSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lUhd0ZW1wQXNzZXRzQnFEY - XRhc2V0TmFtZRJWChllbWJlZGRpbmdfYnFfZGF0YXNldF9uYW1lGAcgASgJQhviPxgSFmVtYmVkZGluZ0JxRGF0YXNldE5hbWVSF - mVtYmVkZGluZ0JxRGF0YXNldE5hbWUSVgoZZ2NwX3NlcnZpY2VfYWNjb3VudF9lbWFpbBgIIAEoCUIb4j8YEhZnY3BTZXJ2aWNlQ - WNjb3VudEVtYWlsUhZnY3BTZXJ2aWNlQWNjb3VudEVtYWlsEjwKD2RhdGFmbG93X3J1bm5lchgLIAEoCUIT4j8QEg5kYXRhZmxvd - 1J1bm5lclIOZGF0YWZsb3dSdW5uZXIaVwoTUmVzb3VyY2VMYWJlbHNFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSI - AoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4ASL3CAoSR2lnbFJlc291cmNlQ29uZmlnElsKGnNoYXJlZF9yZXNvd - XJjZV9jb25maWdfdXJpGAEgASgJQhziPxkSF3NoYXJlZFJlc291cmNlQ29uZmlnVXJpSABSF3NoYXJlZFJlc291cmNlQ29uZmlnV - XJpEn8KFnNoYXJlZF9yZXNvdXJjZV9jb25maWcYAiABKAsyLC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ - 29uZmlnQhniPxYSFHNoYXJlZFJlc291cmNlQ29uZmlnSABSFHNoYXJlZFJlc291cmNlQ29uZmlnEngKE3ByZXByb2Nlc3Nvcl9jb - 25maWcYDCABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFQcmVwcm9jZXNzb3JDb25maWdCF+I/FBIScHJlcHJvY2Vzc - 29yQ29uZmlnUhJwcmVwcm9jZXNzb3JDb25maWcSfwoXc3ViZ3JhcGhfc2FtcGxlcl9jb25maWcYDSABKAsyKy5zbmFwY2hhdC5yZ - XNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGuI/FxIVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnUhVzdWJncmFwaFNhbXBsZ - XJDb25maWcSfAoWc3BsaXRfZ2VuZXJhdG9yX2NvbmZpZxgOIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU3BhcmtSZXNvd - XJjZUNvbmZpZ0IZ4j8WEhRzcGxpdEdlbmVyYXRvckNvbmZpZ1IUc3BsaXRHZW5lcmF0b3JDb25maWcSbQoOdHJhaW5lcl9jb25ma - WcYDyABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRpc3RyaWJ1dGVkVHJhaW5lckNvbmZpZ0IUGAHiPw8SDXRyYWluZXJDb - 25maWdSDXRyYWluZXJDb25maWcSdAoRaW5mZXJlbmNlcl9jb25maWcYECABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhd - GFmbG93UmVzb3VyY2VDb25maWdCFxgB4j8SEhBpbmZlcmVuY2VyQ29uZmlnUhBpbmZlcmVuY2VyQ29uZmlnEoEBChd0cmFpbmVyX - 3Jlc291cmNlX2NvbmZpZxgRIAEoCzItLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVHJhaW5lclJlc291cmNlQ29uZmlnQhriPxcSF - XRyYWluZXJSZXNvdXJjZUNvbmZpZ1IVdHJhaW5lclJlc291cmNlQ29uZmlnEo0BChppbmZlcmVuY2VyX3Jlc291cmNlX2NvbmZpZ - xgSIAEoCzIwLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuSW5mZXJlbmNlclJlc291cmNlQ29uZmlnQh3iPxoSGGluZmVyZW5jZXJSZ - XNvdXJjZUNvbmZpZ1IYaW5mZXJlbmNlclJlc291cmNlQ29uZmlnQhEKD3NoYXJlZF9yZXNvdXJjZSrjAwoJQ29tcG9uZW50Ei0KE - UNvbXBvbmVudF9Vbmtub3duEAAaFuI/ExIRQ29tcG9uZW50X1Vua25vd24SPwoaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3IQA - Rof4j8cEhpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchI/ChpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchACGh/iPxwSGkNvb - XBvbmVudF9Db25maWdfUG9wdWxhdG9yEkEKG0NvbXBvbmVudF9EYXRhX1ByZXByb2Nlc3NvchADGiDiPx0SG0NvbXBvbmVudF9EY - XRhX1ByZXByb2Nlc3NvchI/ChpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchAEGh/iPxwSGkNvbXBvbmVudF9TdWJncmFwaF9TY - W1wbGVyEj0KGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3IQBRoe4j8bEhlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEi0KEUNvb - XBvbmVudF9UcmFpbmVyEAYaFuI/ExIRQ29tcG9uZW50X1RyYWluZXISMwoUQ29tcG9uZW50X0luZmVyZW5jZXIQBxoZ4j8WEhRDb - 21wb25lbnRfSW5mZXJlbmNlcmIGcHJvdG8z""" + W1Xb3JrZXJzIu4BChhWZXJ0ZXhBaUdyYXBoU3RvcmVDb25maWcSbQoQZ3JhcGhfc3RvcmVfcG9vbBgBIAEoCzIuLnNuYXBjaGF0L + nJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0IT4j8QEg5ncmFwaFN0b3JlUG9vbFIOZ3JhcGhTdG9yZVBvb2wSY + woMY29tcHV0ZV9wb29sGAIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhDiPw0SC + 2NvbXB1dGVQb29sUgtjb21wdXRlUG9vbCKdAwoYRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnEoQBChh2ZXJ0ZXhfYWlfdHJhaW5lc + l9jb25maWcYASABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpVHJhaW5lckNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBa + VRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZmlnEm8KEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIoLnNuYXBja + GF0LnJlc2VhcmNoLmdibWwuS0ZQVHJhaW5lckNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmcFRyYWluZXJDb25ma + WcSdwoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsVHJhaW5lckNvbmZpZ + 0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ29uZmlnQhAKDnRyYWluZXJfY29uZmlnIscEChVUcmFpb + mVyUmVzb3VyY2VDb25maWcShQEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV2ZXJ0ZXhBaVRyYWluZXJDb25maWdIAFIVdmVydGV4QWlUcmFpbmVyQ29uZ + mlnEnAKEmtmcF90cmFpbmVyX2NvbmZpZxgCIAEoCzIpLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuS0ZQUmVzb3VyY2VDb25maWdCF + eI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEngKFGxvY2FsX3RyYWluZXJfY29uZmlnGAMgASgLMisuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAUhJsb2Nhb + FRyYWluZXJDb25maWcSpwEKJHZlcnRleF9haV9ncmFwaF9zdG9yZV90cmFpbmVyX2NvbmZpZxgEIAEoCzIwLnNuYXBjaGF0LnJlc + 2VhcmNoLmdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQiTiPyESH3ZlcnRleEFpR3JhcGhTdG9yZVRyYWluZXJDb25maWdIA + FIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyKHBQoYSW5mZXJlbmNlclJlc291cmNlQ + 29uZmlnEo4BCht2ZXJ0ZXhfYWlfaW5mZXJlbmNlcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRle + EFpUmVzb3VyY2VDb25maWdCHeI/GhIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnSABSGHZlcnRleEFpSW5mZXJlbmNlckNvbmZpZ + xKNAQoaZGF0YWZsb3dfaW5mZXJlbmNlcl9jb25maWcYAiABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkRhdGFmbG93UmVzb + 3VyY2VDb25maWdCHeI/GhIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnSABSGGRhdGFmbG93SW5mZXJlbmNlckNvbmZpZxKBAQoXb + G9jYWxfaW5mZXJlbmNlcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsUmVzb3VyY2VDb25maWdCG + uI/FxIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnSABSFWxvY2FsSW5mZXJlbmNlckNvbmZpZxKwAQondmVydGV4X2FpX2dyYXBoX3N0b + 3JlX2luZmVyZW5jZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3RvcmVDb25ma + WdCJ+I/JBIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZ0gAUiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJbmZlcmVuY2VyQ + 29uZmlnQhMKEWluZmVyZW5jZXJfY29uZmlnIpcIChRTaGFyZWRSZXNvdXJjZUNvbmZpZxJ+Cg9yZXNvdXJjZV9sYWJlbHMYASADK + AsyQC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLlJlc291cmNlTGFiZWxzRW50cnlCE+I/EBIOc + mVzb3VyY2VMYWJlbHNSDnJlc291cmNlTGFiZWxzEo4BChVjb21tb25fY29tcHV0ZV9jb25maWcYAiABKAsyQC5zbmFwY2hhdC5yZ + XNlYXJjaC5nYm1sLlNoYXJlZFJlc291cmNlQ29uZmlnLkNvbW1vbkNvbXB1dGVDb25maWdCGOI/FRITY29tbW9uQ29tcHV0ZUNvb + mZpZ1ITY29tbW9uQ29tcHV0ZUNvbmZpZxqUBQoTQ29tbW9uQ29tcHV0ZUNvbmZpZxImCgdwcm9qZWN0GAEgASgJQgziPwkSB3Byb + 2plY3RSB3Byb2plY3QSIwoGcmVnaW9uGAIgASgJQgviPwgSBnJlZ2lvblIGcmVnaW9uEkMKEnRlbXBfYXNzZXRzX2J1Y2tldBgDI + AEoCUIV4j8SEhB0ZW1wQXNzZXRzQnVja2V0UhB0ZW1wQXNzZXRzQnVja2V0ElwKG3RlbXBfcmVnaW9uYWxfYXNzZXRzX2J1Y2tld + BgEIAEoCUId4j8aEhh0ZW1wUmVnaW9uYWxBc3NldHNCdWNrZXRSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldBJDChJwZXJtX2Fzc + 2V0c19idWNrZXQYBSABKAlCFeI/EhIQcGVybUFzc2V0c0J1Y2tldFIQcGVybUFzc2V0c0J1Y2tldBJaCht0ZW1wX2Fzc2V0c19ic + V9kYXRhc2V0X25hbWUYBiABKAlCHOI/GRIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWVSF3RlbXBBc3NldHNCcURhdGFzZXROYW1lE + lYKGWVtYmVkZGluZ19icV9kYXRhc2V0X25hbWUYByABKAlCG+I/GBIWZW1iZWRkaW5nQnFEYXRhc2V0TmFtZVIWZW1iZWRkaW5nQ + nFEYXRhc2V0TmFtZRJWChlnY3Bfc2VydmljZV9hY2NvdW50X2VtYWlsGAggASgJQhviPxgSFmdjcFNlcnZpY2VBY2NvdW50RW1ha + WxSFmdjcFNlcnZpY2VBY2NvdW50RW1haWwSPAoPZGF0YWZsb3dfcnVubmVyGAsgASgJQhPiPxASDmRhdGFmbG93UnVubmVyUg5kY + XRhZmxvd1J1bm5lchpXChNSZXNvdXJjZUxhYmVsc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCI + AEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6AjgBIvcIChJHaWdsUmVzb3VyY2VDb25maWcSWwoac2hhcmVkX3Jlc291cmNlX2NvbmZpZ + 191cmkYASABKAlCHOI/GRIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmlIAFIXc2hhcmVkUmVzb3VyY2VDb25maWdVcmkSfwoWc2hhc + mVkX3Jlc291cmNlX2NvbmZpZxgCIAEoCzIsLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWdCGeI/F + hIUc2hhcmVkUmVzb3VyY2VDb25maWdIAFIUc2hhcmVkUmVzb3VyY2VDb25maWcSeAoTcHJlcHJvY2Vzc29yX2NvbmZpZxgMIAEoC + zIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YVByZXByb2Nlc3NvckNvbmZpZ0IX4j8UEhJwcmVwcm9jZXNzb3JDb25maWdSE + nByZXByb2Nlc3NvckNvbmZpZxJ/ChdzdWJncmFwaF9zYW1wbGVyX2NvbmZpZxgNIAEoCzIrLnNuYXBjaGF0LnJlc2VhcmNoLmdib + WwuU3BhcmtSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVzdWJncmFwaFNhbXBsZXJDb25maWdSFXN1YmdyYXBoU2FtcGxlckNvbmZpZxJ8C + hZzcGxpdF9nZW5lcmF0b3JfY29uZmlnGA4gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQ + hniPxYSFHNwbGl0R2VuZXJhdG9yQ29uZmlnUhRzcGxpdEdlbmVyYXRvckNvbmZpZxJtCg50cmFpbmVyX2NvbmZpZxgPIAEoCzIwL + nNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGlzdHJpYnV0ZWRUcmFpbmVyQ29uZmlnQhQYAeI/DxINdHJhaW5lckNvbmZpZ1INdHJha + W5lckNvbmZpZxJ0ChFpbmZlcmVuY2VyX2NvbmZpZxgQIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvd + XJjZUNvbmZpZ0IXGAHiPxISEGluZmVyZW5jZXJDb25maWdSEGluZmVyZW5jZXJDb25maWcSgQEKF3RyYWluZXJfcmVzb3VyY2VfY + 29uZmlnGBEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5UcmFpbmVyUmVzb3VyY2VDb25maWdCGuI/FxIVdHJhaW5lclJlc + 291cmNlQ29uZmlnUhV0cmFpbmVyUmVzb3VyY2VDb25maWcSjQEKGmluZmVyZW5jZXJfcmVzb3VyY2VfY29uZmlnGBIgASgLMjAuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5JbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCHeI/GhIYaW5mZXJlbmNlclJlc291cmNlQ29uZ + mlnUhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdCEQoPc2hhcmVkX3Jlc291cmNlKuMDCglDb21wb25lbnQSLQoRQ29tcG9uZW50X + 1Vua25vd24QABoW4j8TEhFDb21wb25lbnRfVW5rbm93bhI/ChpDb21wb25lbnRfQ29uZmlnX1ZhbGlkYXRvchABGh/iPxwSGkNvb + XBvbmVudF9Db25maWdfVmFsaWRhdG9yEj8KGkNvbXBvbmVudF9Db25maWdfUG9wdWxhdG9yEAIaH+I/HBIaQ29tcG9uZW50X0Nvb + mZpZ19Qb3B1bGF0b3ISQQobQ29tcG9uZW50X0RhdGFfUHJlcHJvY2Vzc29yEAMaIOI/HRIbQ29tcG9uZW50X0RhdGFfUHJlcHJvY + 2Vzc29yEj8KGkNvbXBvbmVudF9TdWJncmFwaF9TYW1wbGVyEAQaH+I/HBIaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXISPQoZQ + 29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchAFGh7iPxsSGUNvbXBvbmVudF9TcGxpdF9HZW5lcmF0b3ISLQoRQ29tcG9uZW50X1RyY + WluZXIQBhoW4j8TEhFDb21wb25lbnRfVHJhaW5lchIzChRDb21wb25lbnRfSW5mZXJlbmNlchAHGhniPxYSFENvbXBvbmVudF9Jb + mZlcmVuY2VyYgZwcm90bzM=""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index e222970eb..60315338f 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -28,8 +28,8 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.localInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - if (inferencerConfig.vertexAiMultiPoolInferencerConfig.isDefined) { - val __value = inferencerConfig.vertexAiMultiPoolInferencerConfig.get + if (inferencerConfig.vertexAiGraphStoreInferencerConfig.isDefined) { + val __value = inferencerConfig.vertexAiGraphStoreInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; __size += unknownFields.serializedSize @@ -63,7 +63,7 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; - inferencerConfig.vertexAiMultiPoolInferencerConfig.foreach { __v => + inferencerConfig.vertexAiGraphStoreInferencerConfig.foreach { __v => val __m = __v _output__.writeTag(4, 2) _output__.writeUInt32NoTag(__m.serializedSize) @@ -77,8 +77,8 @@ final case class InferencerResourceConfig( def withDataflowInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(__v)) def getLocalInferencerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = inferencerConfig.localInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) - def getVertexAiMultiPoolInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = inferencerConfig.vertexAiMultiPoolInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) - def withVertexAiMultiPoolInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__v)) + def getVertexAiGraphStoreInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = inferencerConfig.vertexAiGraphStoreInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) + def withVertexAiGraphStoreInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -88,7 +88,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.orNull case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull - case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.orNull + case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -97,7 +97,7 @@ final case class InferencerResourceConfig( case 1 => inferencerConfig.vertexAiInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) - case 4 => inferencerConfig.vertexAiMultiPoolInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -122,7 +122,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 26 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => - __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(__inferencerConfig.vertexAiMultiPoolInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__inferencerConfig.vertexAiGraphStoreInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -142,7 +142,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch inferencerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) - .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") @@ -155,7 +155,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig - case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig } __out } @@ -170,11 +170,11 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isVertexAiInferencerConfig: _root_.scala.Boolean = false def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false - def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = false + def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None - def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None + def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -208,10 +208,10 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 3 } @SerialVersionUID(0L) - final case class VertexAiMultiPoolInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { - type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig - override def isVertexAiMultiPoolInferencerConfig: _root_.scala.Boolean = true - override def vertexAiMultiPoolInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + final case class VertexAiGraphStoreInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + override def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = true + override def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } } @@ -219,13 +219,13 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) - def vertexAiMultiPoolInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiMultiPoolInferencerConfig(f_))) + def vertexAiGraphStoreInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 - final val VERTEX_AI_MULTI_POOL_INFERENCER_CONFIG_FIELD_NUMBER = 4 + final val VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER = 4 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 8a72b5e87..4d6dbeaf5 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -28,8 +28,8 @@ final case class TrainerResourceConfig( val __value = trainerConfig.localTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - if (trainerConfig.vertexAiMultiPoolTrainerConfig.isDefined) { - val __value = trainerConfig.vertexAiMultiPoolTrainerConfig.get + if (trainerConfig.vertexAiGraphStoreTrainerConfig.isDefined) { + val __value = trainerConfig.vertexAiGraphStoreTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; __size += unknownFields.serializedSize @@ -63,7 +63,7 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; - trainerConfig.vertexAiMultiPoolTrainerConfig.foreach { __v => + trainerConfig.vertexAiGraphStoreTrainerConfig.foreach { __v => val __m = __v _output__.writeTag(4, 2) _output__.writeUInt32NoTag(__m.serializedSize) @@ -77,8 +77,8 @@ final case class TrainerResourceConfig( def withKfpTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.KFPResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(__v)) def getLocalTrainerConfig: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig = trainerConfig.localTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.LocalResourceConfig.defaultInstance) def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) - def getVertexAiMultiPoolTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig = trainerConfig.vertexAiMultiPoolTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig.defaultInstance) - def withVertexAiMultiPoolTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__v)) + def getVertexAiGraphStoreTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = trainerConfig.vertexAiGraphStoreTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) + def withVertexAiGraphStoreTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -88,7 +88,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.orNull case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull - case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.orNull + case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -97,7 +97,7 @@ final case class TrainerResourceConfig( case 1 => trainerConfig.vertexAiTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) - case 4 => trainerConfig.vertexAiMultiPoolTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -122,7 +122,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 26 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => - __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(__trainerConfig.vertexAiMultiPoolTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__trainerConfig.vertexAiGraphStoreTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -142,7 +142,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. trainerConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(_)) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) - .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") @@ -155,7 +155,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig - case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig + case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig } __out } @@ -170,11 +170,11 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isVertexAiTrainerConfig: _root_.scala.Boolean = false def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false - def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = false + def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None - def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = _root_.scala.None + def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -208,10 +208,10 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def number: _root_.scala.Int = 3 } @SerialVersionUID(0L) - final case class VertexAiMultiPoolTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { - type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig - override def isVertexAiMultiPoolTrainerConfig: _root_.scala.Boolean = true - override def vertexAiMultiPoolTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = Some(value) + final case class VertexAiGraphStoreTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + override def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = true + override def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } } @@ -219,13 +219,13 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) - def vertexAiMultiPoolTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiMultiPoolConfig] = field(_.getVertexAiMultiPoolTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiMultiPoolTrainerConfig(f_))) + def vertexAiGraphStoreTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 - final val VERTEX_AI_MULTI_POOL_TRAINER_CONFIG_FIELD_NUMBER = 4 + final val VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER = 4 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala new file mode 100644 index 000000000..c07f1a3cb --- /dev/null +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiGraphStoreConfig.scala @@ -0,0 +1,155 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Configuration for lauching Vertex AI clusters with both graph store and compute pools + * Under the hood, this uses Vertex AI Multi-Pool Training + * See https://cloud.google.com/vertex-ai/docs/training/distributed-training for more info. + * This cluster setup should be used when you want store your graph on separate machines from the compute machines + * e.g. you can get lots of big memory machines and separate gpu machines individually, + * but getting lots of gpu machines with lots of memory is challenging. + */ +@SerialVersionUID(0L) +final case class VertexAiGraphStoreConfig( + graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[VertexAiGraphStoreConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + if (graphStorePool.isDefined) { + val __value = graphStorePool.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; + if (computePool.isDefined) { + val __value = computePool.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + graphStorePool.foreach { __v => + val __m = __v + _output__.writeTag(1, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + computePool.foreach { __v => + val __m = __v + _output__.writeTag(2, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def getGraphStorePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = graphStorePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) + def clearGraphStorePool: VertexAiGraphStoreConfig = copy(graphStorePool = _root_.scala.None) + def withGraphStorePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(graphStorePool = Option(__v)) + def getComputePool: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = computePool.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) + def clearComputePool: VertexAiGraphStoreConfig = copy(computePool = _root_.scala.None) + def withComputePool(__v: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig): VertexAiGraphStoreConfig = copy(computePool = Option(__v)) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => graphStorePool.orNull + case 2 => computePool.orNull + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => graphStorePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 2 => computePool.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.type = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.VertexAiGraphStoreConfig]) +} + +object VertexAiGraphStoreConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = { + var __graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var __computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __graphStorePool = Option(__graphStorePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 18 => + __computePool = Option(__computePool.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = __graphStorePool, + computePool = __computePool, + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]), + computePool = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]]) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(9) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 1 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + case 2 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool = _root_.scala.None, + computePool = _root_.scala.None + ) + implicit class VertexAiGraphStoreConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_l) { + def graphStorePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getGraphStorePool)((c_, f_) => c_.copy(graphStorePool = Option(f_))) + def optionalGraphStorePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.graphStorePool)((c_, f_) => c_.copy(graphStorePool = f_)) + def computePool: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getComputePool)((c_, f_) => c_.copy(computePool = Option(f_))) + def optionalComputePool: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig]] = field(_.computePool)((c_, f_) => c_.copy(computePool = f_)) + } + final val GRAPH_STORE_POOL_FIELD_NUMBER = 1 + final val COMPUTE_POOL_FIELD_NUMBER = 2 + def of( + graphStorePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig], + computePool: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] + ): _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = _root_.snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig( + graphStorePool, + computePool + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.VertexAiGraphStoreConfig]) +} From b4d35efd849b487f7a8dba7a204de158ec4d58a0 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 30 Sep 2025 17:24:08 +0000 Subject: [PATCH 04/33] wip --- python/gigl/common/services/vertex_ai.py | 136 ++++++++++++++++-- .../src/training/v1/lib/training_process.py | 6 +- .../common/services/vertex_ai_test.py | 81 ++++++----- 3 files changed, 172 insertions(+), 51 deletions(-) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 72d974246..38293d16e 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -82,6 +82,13 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name str ] = "LEADER_WORKER_INTERNAL_IP_FILE_PATH" +STORAGE_CLUSTER_MASTER_KEY: Final[ + str +] = "GIGL_STORAGE_CLUSTER_MASTER_KEY" +COMPUTE_CLUSTER_MASTER_KEY: Final[ + str +] = "GIGL_COMPUTE_CLUSTER_MASTER_KEY" + DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36 # 36 hours DEFAULT_CUSTOM_JOB_TIMEOUT_S: Final[int] = 60 * 60 * 24 # 24 hours @@ -151,11 +158,7 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: """ logger.info(f"Running Vertex AI job: {job_config.job_name}") - machine_spec = MachineSpec( - machine_type=job_config.machine_type, - accelerator_type=job_config.accelerator_type, - accelerator_count=job_config.accelerator_count, - ) + machine_spec = _get_machine_spec(job_config) # This file is used to store the leader worker's internal IP address. # Whenever `connect_worker_pool()` is called, the leader worker will @@ -175,17 +178,9 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: ) ] - container_spec = ContainerSpec( - image_uri=job_config.container_uri, - command=job_config.command, - args=job_config.args, - env=env_vars, - ) + container_spec = _get_container_spec(job_config, env_vars) - disk_spec = DiskSpec( - boot_disk_type=job_config.boot_disk_type, - boot_disk_size_gb=job_config.boot_disk_size_gb, - ) + disk_spec = _get_disk_spec(job_config) assert ( job_config.replica_count >= 1 @@ -246,6 +241,90 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: ) job.wait_for_completion() + def launch_graph_store_job(self, storage_cluster: VertexAiJobConfig, compute_cluster: VertexAiJobConfig) -> None: + """Launch a Vertex AI Graph Store job.""" + storage_machine_spec = _get_machine_spec(storage_cluster) + compute_machine_spec = _get_machine_spec(compute_cluster) + storage_disk_spec = _get_disk_spec(storage_cluster) + compute_disk_spec = _get_disk_spec(compute_cluster) + + # This file is used to store the leader worker's internal IP address. + # Whenever `connect_worker_pool()` is called, the leader worker will + # write its internal IP address to this file. The other workers will + # read this file to get the leader worker's internal IP address. + # See connect_worker_pool() implementation for more details. + leader_worker_internal_ip_file_path = GcsUri.join( + self._staging_bucket, + storage_cluster.job_name, + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), + "leader_worker_internal_ip.txt", + ) + env_vars: list[env_var.EnvVar] = [ + env_var.EnvVar( + name=LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY, + value=leader_worker_internal_ip_file_path.uri, + ), + env_var.EnvVar( + name=STORAGE_CLUSTER_MASTER_KEY, + value="0", + ), + env_var.EnvVar( + name=COMPUTE_CLUSTER_MASTER_KEY, + value=str(storage_cluster.replica_count) + ), + ] + + storage_container_spec = _get_container_spec(storage_cluster, env_vars) + compute_container_spec = _get_container_spec(compute_cluster, env_vars) + + worker_pool_specs: list[WorkerPoolSpec] = [] + + leader_worker_spec = WorkerPoolSpec( + machine_spec=storage_machine_spec, + container_spec=storage_container_spec, + disk_spec=storage_disk_spec, + replica_count=1, + ) + worker_pool_specs.append(leader_worker_spec) + if storage_cluster.replica_count > 1: + worker_spec = WorkerPoolSpec( + machine_spec=storage_machine_spec, + container_spec=storage_container_spec, + disk_spec=storage_disk_spec, + replica_count=storage_cluster.replica_count - 1, + ) + worker_pool_specs.append(worker_spec) + worker_spec = WorkerPoolSpec( + machine_spec=compute_machine_spec, + container_spec=compute_container_spec, + disk_spec=compute_disk_spec, + replica_count=compute_cluster.replica_count, + ) + worker_pool_specs.append(worker_spec) + + job = aiplatform.CustomJob( + display_name=storage_cluster.job_name, + worker_pool_specs=worker_pool_specs, + project=self._project, + location=self._location, + labels=storage_cluster.labels, + staging_bucket=self._staging_bucket, + ) + job.submit( + service_account=self._service_account, + timeout=storage_cluster.timeout_s, + enable_web_access=storage_cluster.enable_web_access, + ) + job.wait_for_resource_creation() + logger.info(f"Created job: {job.resource_name}") + # Copying https://github.com/googleapis/python-aiplatform/blob/v1.48.0/google/cloud/aiplatform/jobs.py#L207-L215 + # Since for some reason upgrading from VertexAI v1.27.1 to v1.48.0 + # caused the logs to occasionally not be printed. + logger.info( + f"See job logs at: https://console.cloud.google.com/ai/platform/locations/{self._location}/training/{job.name}?project={self._project}" + ) + job.wait_for_completion() + def run_pipeline( self, display_name: str, @@ -355,3 +434,30 @@ def wait_for_run_completion( f"Vertex AI run stopped with status: {run.state}. " f"Please check the Vertex AI page to trace down the error." ) + +def _get_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec: + """Get the machine spec for a job config.""" + machine_spec = MachineSpec( + machine_type=job_config.machine_type, + accelerator_type=job_config.accelerator_type, + accelerator_count=job_config.accelerator_count, + ) + return machine_spec + +def _get_container_spec(job_config: VertexAiJobConfig, env_vars: list[env_var.EnvVar]) -> ContainerSpec: + """Get the container spec for a job config.""" + container_spec = ContainerSpec( + image_uri=job_config.container_uri, + command=job_config.command, + args=job_config.args, + env=env_vars, + ) + return container_spec + +def _get_disk_spec(job_config: VertexAiJobConfig) -> DiskSpec: + """Get the disk spec for a job config.""" + disk_spec = DiskSpec( + boot_disk_type=job_config.boot_disk_type, + boot_disk_size_gb=job_config.boot_disk_size_gb, + ) + return disk_spec diff --git a/python/gigl/src/training/v1/lib/training_process.py b/python/gigl/src/training/v1/lib/training_process.py index 25a31091a..cda91a5b2 100644 --- a/python/gigl/src/training/v1/lib/training_process.py +++ b/python/gigl/src/training/v1/lib/training_process.py @@ -1,5 +1,6 @@ import argparse import contextlib +import datetime import multiprocessing as mp import sys import tempfile @@ -354,8 +355,9 @@ def __setup_training_env(self, device: torch.device): use_cuda = device.type != "cpu" if should_distribute(): distributed_backend = get_distributed_backend(use_cuda=use_cuda) - logger.info(f"Using distributed PyTorch with {distributed_backend}") - torch.distributed.init_process_group(backend=distributed_backend) + timeout = datetime.timedelta(minutes=45) + logger.info(f"Using distributed PyTorch with {distributed_backend} and timeout {timeout}") + torch.distributed.init_process_group(backend=distributed_backend, timeout=timeout) logger.info("Successfully initiated distributed backend!") @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance) diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index b044eb051..8b42ed420 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -6,7 +6,7 @@ import kfp from gigl.common import UriFactory -from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService, STORAGE_CLUSTER_MASTER_KEY, COMPUTE_CLUSTER_MASTER_KEY from gigl.env.pipelines_config import get_resource_config @@ -46,12 +46,24 @@ def get_pipeline_that_fails() -> float: class VertexAIPipelineIntegrationTest(unittest.TestCase): - def test_launch_job(self): - resource_config = get_resource_config() - project = resource_config.project - location = resource_config.region - service_account = resource_config.service_account_email - staging_bucket = resource_config.temp_assets_regional_bucket_path.uri + def setUp(self): + self.resource_config = get_resource_config() + self.project = self.resource_config.project + self.location = self.resource_config.region + self.service_account = self.resource_config.service_account_email + self.staging_bucket = self.resource_config.temp_assets_regional_bucket_path.uri + self.vertex_ai_service = VertexAIService( + project=self.project, + location=self.location, + service_account=self.service_account, + staging_bucket=self.staging_bucket, + ) + super().setUp() + + def tearDown(self): + super().tearDown() + + def _test_launch_job(self): job_name = f"GiGL-Integration-Test-{uuid.uuid4()}" container_uri = "condaforge/miniforge3:25.3.0-1" command = ["python", "-c", "import logging; logging.info('Hello, World!')"] @@ -60,27 +72,35 @@ def test_launch_job(self): job_name=job_name, container_uri=container_uri, command=command ) - vertex_ai_service = VertexAIService( - project=project, - location=location, - service_account=service_account, - staging_bucket=staging_bucket, + self.vertex_ai_service.launch_job(job_config) + + def test_launch_graph_store_job(self): + command = ["python", "-c", f"import os; import logging; logging.info(f'Graph cluster master: {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']}}, compute cluster master: {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']}}')"] + job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" + storage_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri="condaforge/miniforge3:25.3.0-1", + replica_count=2, + #machine_type="n1-standard-4", + command=command, ) + compute_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri="condaforge/miniforge3:25.3.0-1", + replica_count=2, + command=command, + machine_type="n1-standard-32", + accelerator_type="NVIDIA_TESLA_T4", + accelerator_count=2, + ) + self.vertex_ai_service.launch_graph_store_job(storage_cluster_config, compute_cluster_config) - vertex_ai_service.launch_job(job_config) - def test_run_pipeline(self): + def _test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline.yaml") kfp.compiler.Compiler().compile(get_pipeline, pipeline_def) - resource_config = get_resource_config() - ps = VertexAIService( - project=resource_config.project, - location=resource_config.region, - service_account=resource_config.service_account_email, - staging_bucket=resource_config.temp_assets_regional_bucket_path.uri, - ) - job = ps.run_pipeline( + job = self.vertex_ai_service.run_pipeline( display_name="integration-test-pipeline", template_path=UriFactory.create_uri(pipeline_def), run_keyword_args={}, @@ -89,27 +109,20 @@ def test_run_pipeline(self): ) # Wait for the run to complete, 30 minutes is probably too long but # we don't want this test to be flaky. - ps.wait_for_run_completion( + self.vertex_ai_service.wait_for_run_completion( job.resource_name, timeout=60 * 30, polling_period_s=10 ) # Also verify that we can fetch a pipeline. - run = ps.get_pipeline_job_from_job_name(job.name) + run = self.vertex_ai_service.get_pipeline_job_from_job_name(job.name) self.assertEqual(run.resource_name, job.resource_name) self.assertEqual(run.labels["gigl-integration-test"], "true") - def test_run_pipeline_fails(self): + def _test_run_pipeline_fails(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline_that_fails.yaml") kfp.compiler.Compiler().compile(get_pipeline_that_fails, pipeline_def) - resource_config = get_resource_config() - ps = VertexAIService( - project=resource_config.project, - location=resource_config.region, - service_account=resource_config.service_account_email, - staging_bucket=resource_config.temp_assets_regional_bucket_path.uri, - ) - job = ps.run_pipeline( + job = self.vertex_ai_service.run_pipeline( display_name="integration-test-pipeline-that-fails", template_path=UriFactory.create_uri(pipeline_def), run_keyword_args={}, @@ -117,7 +130,7 @@ def test_run_pipeline_fails(self): labels={"gigl-integration-test": "true"}, ) with self.assertRaises(RuntimeError): - ps.wait_for_run_completion( + self.vertex_ai_service.wait_for_run_completion( job.resource_name, timeout=60 * 30, polling_period_s=10 ) From 5a270575f50261e66787a647ba1052588c8db80a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 30 Sep 2025 18:38:35 +0000 Subject: [PATCH 05/33] wip --- parse.py | 61 +++++++++++++++++++ python/gigl/common/services/vertex_ai.py | 43 +++++++++---- .../src/training/v1/lib/training_process.py | 8 ++- .../common/services/vertex_ai_test.py | 44 +++++++++---- 4 files changed, 129 insertions(+), 27 deletions(-) create mode 100644 parse.py diff --git a/parse.py b/parse.py new file mode 100644 index 000000000..06942449d --- /dev/null +++ b/parse.py @@ -0,0 +1,61 @@ +import sqlglot + +import sqlparse + +query = """ +CREATE OR REPLACE TABLE `{edge_table}` AS +WITH agg AS ( + SELECT + product_id, + ARRAY_AGG(STRUCT(said)) AS users + FROM + `{engagement_table}` + GROUP BY + product_id +), +user_pairs AS ( + SELECT + from_user.said AS from_user_said, + to_user.said AS to_user_said + FROM + agg, + UNNEST(users) AS from_user WITH OFFSET AS i, + UNNEST(users) AS to_user WITH OFFSET AS j + WHERE + i < j +) +SELECT + from_user_said, + to_user_said +FROM + user_pairs +GROUP BY + from_user_said, + to_user_said +""" +res = sqlparse.parse(query) +print(f"Parsed query: {res[0].tokens}") +for token in res[0].tokens: + print(f"Token: {token.ttype} {token.value}") +exit() + +res = sqlglot.parse(query) +print(f"Parsed query: {res}") +for res in res: + select_expr = res.find(sqlglot.exp.Select) + print(f"select_expr: {select_expr}") + if select_expr: + for expr in select_expr.expressions: + print(f"expr: {expr}") + else: + print(f"No select expression found in {res}") + +res_one = sqlglot.parse_one(query).find_all(sqlglot.exp.Select) +print(f"Parsed query: {res_one}") +for res_one in res_one: + print(f"select_expr: {select_expr}") + if select_expr: + for expr in select_expr.expressions: + print(f"expr: {expr}") + else: + print(f"No select expression found in {res_one}") diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 38293d16e..cd0644dbe 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -82,12 +82,8 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name str ] = "LEADER_WORKER_INTERNAL_IP_FILE_PATH" -STORAGE_CLUSTER_MASTER_KEY: Final[ - str -] = "GIGL_STORAGE_CLUSTER_MASTER_KEY" -COMPUTE_CLUSTER_MASTER_KEY: Final[ - str -] = "GIGL_COMPUTE_CLUSTER_MASTER_KEY" +STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_MASTER_KEY" +COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_MASTER_KEY" DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36 # 36 hours @@ -147,7 +143,7 @@ def project(self) -> str: """The GCP project that is being used for this service.""" return self._project - def launch_job(self, job_config: VertexAiJobConfig) -> None: + def launch_job(self, job_config: VertexAiJobConfig) -> aiplatform.CustomJob: """ Launch a Vertex AI CustomJob. See the docs for more info. @@ -155,6 +151,9 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: Args: job_config (VertexAiJobConfig): The configuration for the job. + + Returns: + The completed CustomJob. """ logger.info(f"Running Vertex AI job: {job_config.job_name}") @@ -240,9 +239,20 @@ def launch_job(self, job_config: VertexAiJobConfig) -> None: f"See job logs at: https://console.cloud.google.com/ai/platform/locations/{self._location}/training/{job.name}?project={self._project}" ) job.wait_for_completion() + return job + + def launch_graph_store_job( + self, storage_cluster: VertexAiJobConfig, compute_cluster: VertexAiJobConfig + ) -> aiplatform.CustomJob: + """Launch a Vertex AI Graph Store job. + + Args: + storage_cluster (VertexAiJobConfig): The configuration for the storage cluster. + compute_cluster (VertexAiJobConfig): The configuration for the compute cluster. - def launch_graph_store_job(self, storage_cluster: VertexAiJobConfig, compute_cluster: VertexAiJobConfig) -> None: - """Launch a Vertex AI Graph Store job.""" + Returns: + The completed CustomJob. + """ storage_machine_spec = _get_machine_spec(storage_cluster) compute_machine_spec = _get_machine_spec(compute_cluster) storage_disk_spec = _get_disk_spec(storage_cluster) @@ -270,7 +280,7 @@ def launch_graph_store_job(self, storage_cluster: VertexAiJobConfig, compute_clu ), env_var.EnvVar( name=COMPUTE_CLUSTER_MASTER_KEY, - value=str(storage_cluster.replica_count) + value=str(storage_cluster.replica_count), ), ] @@ -294,6 +304,9 @@ def launch_graph_store_job(self, storage_cluster: VertexAiJobConfig, compute_clu replica_count=storage_cluster.replica_count - 1, ) worker_pool_specs.append(worker_spec) + else: + worker_pool_specs.append({}) + worker_pool_specs.append({}) worker_spec = WorkerPoolSpec( machine_spec=compute_machine_spec, container_spec=compute_container_spec, @@ -324,6 +337,7 @@ def launch_graph_store_job(self, storage_cluster: VertexAiJobConfig, compute_clu f"See job logs at: https://console.cloud.google.com/ai/platform/locations/{self._location}/training/{job.name}?project={self._project}" ) job.wait_for_completion() + return job def run_pipeline( self, @@ -435,16 +449,20 @@ def wait_for_run_completion( f"Please check the Vertex AI page to trace down the error." ) + def _get_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec: """Get the machine spec for a job config.""" machine_spec = MachineSpec( - machine_type=job_config.machine_type, + machine_type=job_config.machine_type, accelerator_type=job_config.accelerator_type, accelerator_count=job_config.accelerator_count, ) return machine_spec -def _get_container_spec(job_config: VertexAiJobConfig, env_vars: list[env_var.EnvVar]) -> ContainerSpec: + +def _get_container_spec( + job_config: VertexAiJobConfig, env_vars: list[env_var.EnvVar] +) -> ContainerSpec: """Get the container spec for a job config.""" container_spec = ContainerSpec( image_uri=job_config.container_uri, @@ -454,6 +472,7 @@ def _get_container_spec(job_config: VertexAiJobConfig, env_vars: list[env_var.En ) return container_spec + def _get_disk_spec(job_config: VertexAiJobConfig) -> DiskSpec: """Get the disk spec for a job config.""" disk_spec = DiskSpec( diff --git a/python/gigl/src/training/v1/lib/training_process.py b/python/gigl/src/training/v1/lib/training_process.py index cda91a5b2..48649c697 100644 --- a/python/gigl/src/training/v1/lib/training_process.py +++ b/python/gigl/src/training/v1/lib/training_process.py @@ -356,8 +356,12 @@ def __setup_training_env(self, device: torch.device): if should_distribute(): distributed_backend = get_distributed_backend(use_cuda=use_cuda) timeout = datetime.timedelta(minutes=45) - logger.info(f"Using distributed PyTorch with {distributed_backend} and timeout {timeout}") - torch.distributed.init_process_group(backend=distributed_backend, timeout=timeout) + logger.info( + f"Using distributed PyTorch with {distributed_backend} and timeout {timeout}" + ) + torch.distributed.init_process_group( + backend=distributed_backend, timeout=timeout + ) logger.info("Successfully initiated distributed backend!") @flushes_metrics(get_metrics_service_instance_fn=get_metrics_service_instance) diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 8b42ed420..ae8b3f08c 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -4,9 +4,16 @@ import uuid import kfp +from parameterized import param, parameterized from gigl.common import UriFactory -from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService, STORAGE_CLUSTER_MASTER_KEY, COMPUTE_CLUSTER_MASTER_KEY +from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU +from gigl.common.services.vertex_ai import ( + COMPUTE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_MASTER_KEY, + VertexAiJobConfig, + VertexAIService, +) from gigl.env.pipelines_config import get_resource_config @@ -74,27 +81,38 @@ def _test_launch_job(self): self.vertex_ai_service.launch_job(job_config) - def test_launch_graph_store_job(self): - command = ["python", "-c", f"import os; import logging; logging.info(f'Graph cluster master: {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']}}, compute cluster master: {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']}}')"] + @parameterized.expand( + [ + param("one server, one client", num_servers=1, num_clients=1), + param("one server, two clients", num_servers=1, num_clients=2), + param("two servers, one client", num_servers=2, num_clients=1), + param("two servers, two clients", num_servers=2, num_clients=2), + ] + ) + def test_launch_graph_store_job(self, _, num_servers, num_clients): + command = [ + "python", + "-c", + f"import os; import logging; logging.info(f'Graph cluster master: {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']}}, compute cluster master: {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']}}')", + ] job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" storage_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri="condaforge/miniforge3:25.3.0-1", - replica_count=2, - #machine_type="n1-standard-4", + container_uri="condaforge/miniforge3:25.3.0-1", # different images for storage and compute + replica_count=num_servers, + machine_type="n1-standard-4", # Different machine shapes - ideally we would test with GPU too but want to save on costs command=command, ) compute_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri="condaforge/miniforge3:25.3.0-1", - replica_count=2, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, # different image for storage and compute + replica_count=num_clients, command=command, - machine_type="n1-standard-32", - accelerator_type="NVIDIA_TESLA_T4", - accelerator_count=2, + machine_type="n2-standard-8", # Different machine shapes - ideally we would test with GPU too but want to save on costs + ) + self.vertex_ai_service.launch_graph_store_job( + storage_cluster_config, compute_cluster_config ) - self.vertex_ai_service.launch_graph_store_job(storage_cluster_config, compute_cluster_config) - def _test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: From f8c4ab7e319e8b4660459262e58a137c68b6987d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 1 Oct 2025 01:32:41 +0000 Subject: [PATCH 06/33] works --- .../common/services/vertex_ai_test.py | 117 ++++++++++++++++-- 1 file changed, 110 insertions(+), 7 deletions(-) diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index ae8b3f08c..d939722ea 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -83,17 +83,100 @@ def _test_launch_job(self): @parameterized.expand( [ - param("one server, one client", num_servers=1, num_clients=1), - param("one server, two clients", num_servers=1, num_clients=2), - param("two servers, one client", num_servers=2, num_clients=1), - param("two servers, two clients", num_servers=2, num_clients=2), + param( + "one server, one client", + num_servers=1, + num_clients=1, + expected_worker_pool_specs=[ + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + {}, + {}, + { + "machine_type": "n2-standard-8", + "num_replicas": 1, + "image_uri": DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + }, + ], + ), + param( + "one server, two clients", + num_servers=1, + num_clients=2, + expected_worker_pool_specs=[ + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + {}, + {}, + { + "machine_type": "n2-standard-8", + "num_replicas": 2, + "image_uri": DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + }, + ], + ), + param( + "two servers, one client", + num_servers=2, + num_clients=1, + expected_worker_pool_specs=[ + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + {}, + { + "machine_type": "n2-standard-8", + "num_replicas": 1, + "image_uri": DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + }, + ], + ), + param( + "two servers, two clients", + num_servers=2, + num_clients=2, + expected_worker_pool_specs=[ + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + { + "machine_type": "n1-standard-4", + "num_replicas": 1, + "image_uri": "condaforge/miniforge3:25.3.0-1", + }, + {}, + { + "machine_type": "n2-standard-8", + "num_replicas": 2, + "image_uri": DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + }, + ], + ), ] ) - def test_launch_graph_store_job(self, _, num_servers, num_clients): + def test_launch_graph_store_job( + self, _, num_servers, num_clients, expected_worker_pool_specs + ): + env_checks = f"logging.info(f'Graph cluster master: {{os.environ[\"{STORAGE_CLUSTER_MASTER_KEY}\"]}}, compute cluster master: {{os.environ[\"{COMPUTE_CLUSTER_MASTER_KEY}\"]}}')" command = [ "python", "-c", - f"import os; import logging; logging.info(f'Graph cluster master: {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']}}, compute cluster master: {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']}}')", + f"import os; import logging; {env_checks}", ] job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" storage_cluster_config = VertexAiJobConfig( @@ -110,10 +193,30 @@ def test_launch_graph_store_job(self, _, num_servers, num_clients): command=command, machine_type="n2-standard-8", # Different machine shapes - ideally we would test with GPU too but want to save on costs ) - self.vertex_ai_service.launch_graph_store_job( + + job = self.vertex_ai_service.launch_graph_store_job( storage_cluster_config, compute_cluster_config ) + self.assertEqual( + len(job.job_spec.worker_pool_specs), len(expected_worker_pool_specs) + ) + for i, worker_pool_spec in enumerate(job.job_spec.worker_pool_specs): + expected_worker_pool_spec = expected_worker_pool_specs[i] + if expected_worker_pool_spec: + self.assertEqual( + worker_pool_spec.machine_spec.machine_type, + expected_worker_pool_spec["machine_type"], + ) + self.assertEqual( + worker_pool_spec.replica_count, + expected_worker_pool_spec["num_replicas"], + ) + self.assertEqual( + worker_pool_spec.container_spec.image_uri, + expected_worker_pool_spec["image_uri"], + ) + def _test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline.yaml") From 6ea5b3f28b4bace9533d9aabf0d18ebb4fd7bfa1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 1 Oct 2025 17:20:11 +0000 Subject: [PATCH 07/33] tests --- python/gigl/common/services/vertex_ai.py | 20 +++++ .../common/services/vertex_ai_test.py | 75 +++++++++++++------ 2 files changed, 72 insertions(+), 23 deletions(-) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index cd0644dbe..01f403709 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -246,6 +246,23 @@ def launch_graph_store_job( ) -> aiplatform.CustomJob: """Launch a Vertex AI Graph Store job. + This launches one Vertex AI CustomJob with two worker pools, see + https://cloud.google.com/vertex-ai/docs/training/distributed-training + for more details. + + These jobs will have the follow env variables set + - GIGL_STORAGE_CLUSTER_MASTER_KEY + - GIGL_COMPUTE_CLUSTER_MASTER_KEY + Whose values are the cluster ranks of the leaders for the storage and compute clusters respectively. + For example, if if there are 2 nodes in the storage cluster, and 3 nodes in the compute cluster, + Then, + - GIGL_STORAGE_CLUSTER_MASTER_KEY = 0 + - GIGL_COMPUTE_CLUSTER_MASTER_KEY = 2 # e.g. the "first" worker in the computer cluster pool is the leader. + + NOTE: + We use the job_name, timeout, and enable_web_access from the storage cluster. + These fields, if set on the compute cluster, will be ignored. + Args: storage_cluster (VertexAiJobConfig): The configuration for the storage cluster. compute_cluster (VertexAiJobConfig): The configuration for the compute cluster. @@ -306,6 +323,9 @@ def launch_graph_store_job( worker_pool_specs.append(worker_spec) else: worker_pool_specs.append({}) + # For whatever reason, VAI errors out (indescriptly) if we put the computer cluster as anything other + # than the fourth worker pool. + # So we need to pad. worker_pool_specs.append({}) worker_spec = WorkerPoolSpec( machine_spec=compute_machine_spec, diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index d939722ea..b1c46f79e 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -54,23 +54,25 @@ def get_pipeline_that_fails() -> float: class VertexAIPipelineIntegrationTest(unittest.TestCase): def setUp(self): - self.resource_config = get_resource_config() - self.project = self.resource_config.project - self.location = self.resource_config.region - self.service_account = self.resource_config.service_account_email - self.staging_bucket = self.resource_config.temp_assets_regional_bucket_path.uri - self.vertex_ai_service = VertexAIService( - project=self.project, - location=self.location, - service_account=self.service_account, - staging_bucket=self.staging_bucket, + self._resource_config = get_resource_config() + self._project = self._resource_config.project + self._location = self._resource_config.region + self._service_account = self._resource_config.service_account_email + self._staging_bucket = ( + self._resource_config.temp_assets_regional_bucket_path.uri + ) + self._vertex_ai_service = VertexAIService( + project=self._project, + location=self._location, + service_account=self._service_account, + staging_bucket=self._staging_bucket, ) super().setUp() def tearDown(self): super().tearDown() - def _test_launch_job(self): + def test_launch_job(self): job_name = f"GiGL-Integration-Test-{uuid.uuid4()}" container_uri = "condaforge/miniforge3:25.3.0-1" command = ["python", "-c", "import logging; logging.info('Hello, World!')"] @@ -79,7 +81,7 @@ def _test_launch_job(self): job_name=job_name, container_uri=container_uri, command=command ) - self.vertex_ai_service.launch_job(job_config) + self._vertex_ai_service.launch_job(job_config) @parameterized.expand( [ @@ -87,6 +89,11 @@ def _test_launch_job(self): "one server, one client", num_servers=1, num_clients=1, + env_var_checks=[ + "import os", + f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", + f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", + ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -106,6 +113,11 @@ def _test_launch_job(self): "one server, two clients", num_servers=1, num_clients=2, + env_var_checks=[ + "import os", + f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", + f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", + ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -125,6 +137,11 @@ def _test_launch_job(self): "two servers, one client", num_servers=2, num_clients=1, + env_var_checks=[ + "import os", + f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", + f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", + ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -148,6 +165,11 @@ def _test_launch_job(self): "two servers, two clients", num_servers=2, num_clients=2, + env_var_checks=[ + "import os", + f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", + f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", + ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -170,14 +192,21 @@ def _test_launch_job(self): ] ) def test_launch_graph_store_job( - self, _, num_servers, num_clients, expected_worker_pool_specs + self, + _, + num_servers, + num_clients, + env_var_checks, + expected_worker_pool_specs, ): - env_checks = f"logging.info(f'Graph cluster master: {{os.environ[\"{STORAGE_CLUSTER_MASTER_KEY}\"]}}, compute cluster master: {{os.environ[\"{COMPUTE_CLUSTER_MASTER_KEY}\"]}}')" + # Tests that the env variables are set correctly. + # If they are not populated, then the job will fail. + env_checks = f'logging.info(f\'Graph cluster master: {{os.environ["{STORAGE_CLUSTER_MASTER_KEY}"]}}, compute cluster master: {{os.environ["{COMPUTE_CLUSTER_MASTER_KEY}"]}}\')' command = [ "python", "-c", f"import os; import logging; {env_checks}", - ] + ] + env_var_checks job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" storage_cluster_config = VertexAiJobConfig( job_name=job_name, @@ -194,7 +223,7 @@ def test_launch_graph_store_job( machine_type="n2-standard-8", # Different machine shapes - ideally we would test with GPU too but want to save on costs ) - job = self.vertex_ai_service.launch_graph_store_job( + job = self._vertex_ai_service.launch_graph_store_job( storage_cluster_config, compute_cluster_config ) @@ -217,11 +246,11 @@ def test_launch_graph_store_job( expected_worker_pool_spec["image_uri"], ) - def _test_run_pipeline(self): + def test_run_pipeline(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline.yaml") kfp.compiler.Compiler().compile(get_pipeline, pipeline_def) - job = self.vertex_ai_service.run_pipeline( + job = self._vertex_ai_service.run_pipeline( display_name="integration-test-pipeline", template_path=UriFactory.create_uri(pipeline_def), run_keyword_args={}, @@ -230,20 +259,20 @@ def _test_run_pipeline(self): ) # Wait for the run to complete, 30 minutes is probably too long but # we don't want this test to be flaky. - self.vertex_ai_service.wait_for_run_completion( + self._vertex_ai_service.wait_for_run_completion( job.resource_name, timeout=60 * 30, polling_period_s=10 ) # Also verify that we can fetch a pipeline. - run = self.vertex_ai_service.get_pipeline_job_from_job_name(job.name) + run = self._vertex_ai_service.get_pipeline_job_from_job_name(job.name) self.assertEqual(run.resource_name, job.resource_name) self.assertEqual(run.labels["gigl-integration-test"], "true") - def _test_run_pipeline_fails(self): + def test_run_pipeline_fails(self): with tempfile.TemporaryDirectory() as tmpdir: pipeline_def = os.path.join(tmpdir, "pipeline_that_fails.yaml") kfp.compiler.Compiler().compile(get_pipeline_that_fails, pipeline_def) - job = self.vertex_ai_service.run_pipeline( + job = self._vertex_ai_service.run_pipeline( display_name="integration-test-pipeline-that-fails", template_path=UriFactory.create_uri(pipeline_def), run_keyword_args={}, @@ -251,7 +280,7 @@ def _test_run_pipeline_fails(self): labels={"gigl-integration-test": "true"}, ) with self.assertRaises(RuntimeError): - self.vertex_ai_service.wait_for_run_completion( + self._vertex_ai_service.wait_for_run_completion( job.resource_name, timeout=60 * 30, polling_period_s=10 ) From 90d651cda24a8269334f114b643572d45bc76cdd Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 1 Oct 2025 17:20:28 +0000 Subject: [PATCH 08/33] remove --- parse.py | 61 -------------------------------------------------------- 1 file changed, 61 deletions(-) delete mode 100644 parse.py diff --git a/parse.py b/parse.py deleted file mode 100644 index 06942449d..000000000 --- a/parse.py +++ /dev/null @@ -1,61 +0,0 @@ -import sqlglot - -import sqlparse - -query = """ -CREATE OR REPLACE TABLE `{edge_table}` AS -WITH agg AS ( - SELECT - product_id, - ARRAY_AGG(STRUCT(said)) AS users - FROM - `{engagement_table}` - GROUP BY - product_id -), -user_pairs AS ( - SELECT - from_user.said AS from_user_said, - to_user.said AS to_user_said - FROM - agg, - UNNEST(users) AS from_user WITH OFFSET AS i, - UNNEST(users) AS to_user WITH OFFSET AS j - WHERE - i < j -) -SELECT - from_user_said, - to_user_said -FROM - user_pairs -GROUP BY - from_user_said, - to_user_said -""" -res = sqlparse.parse(query) -print(f"Parsed query: {res[0].tokens}") -for token in res[0].tokens: - print(f"Token: {token.ttype} {token.value}") -exit() - -res = sqlglot.parse(query) -print(f"Parsed query: {res}") -for res in res: - select_expr = res.find(sqlglot.exp.Select) - print(f"select_expr: {select_expr}") - if select_expr: - for expr in select_expr.expressions: - print(f"expr: {expr}") - else: - print(f"No select expression found in {res}") - -res_one = sqlglot.parse_one(query).find_all(sqlglot.exp.Select) -print(f"Parsed query: {res_one}") -for res_one in res_one: - print(f"select_expr: {select_expr}") - if select_expr: - for expr in select_expr.expressions: - print(f"expr: {expr}") - else: - print(f"No select expression found in {res_one}") From 94037af63e7eb81ca63efa07adf8a45b7486b5b8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 1 Oct 2025 18:54:04 +0000 Subject: [PATCH 09/33] fix typecheck --- python/gigl/common/services/vertex_ai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 01f403709..afe46139a 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -62,7 +62,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name import datetime import time from dataclasses import dataclass -from typing import Final, Optional +from typing import Final, Optional, Union from google.cloud import aiplatform from google.cloud.aiplatform_v1.types import ( @@ -304,7 +304,7 @@ def launch_graph_store_job( storage_container_spec = _get_container_spec(storage_cluster, env_vars) compute_container_spec = _get_container_spec(compute_cluster, env_vars) - worker_pool_specs: list[WorkerPoolSpec] = [] + worker_pool_specs: list[Union[WorkerPoolSpec, dict]] = [] leader_worker_spec = WorkerPoolSpec( machine_spec=storage_machine_spec, From 298d19d290a8409a1615bd87aa3c419fbcd838c8 Mon Sep 17 00:00:00 2001 From: Kyle Montemayor Date: Fri, 3 Oct 2025 16:30:04 +0000 Subject: [PATCH 10/33] comments --- python/gigl/common/services/vertex_ai.py | 39 ++++++++---------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index afe46139a..0734b34cc 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -157,7 +157,7 @@ def launch_job(self, job_config: VertexAiJobConfig) -> aiplatform.CustomJob: """ logger.info(f"Running Vertex AI job: {job_config.job_name}") - machine_spec = _get_machine_spec(job_config) + machine_spec = _create_machine_spec(job_config) # This file is used to store the leader worker's internal IP address. # Whenever `connect_worker_pool()` is called, the leader worker will @@ -177,9 +177,9 @@ def launch_job(self, job_config: VertexAiJobConfig) -> aiplatform.CustomJob: ) ] - container_spec = _get_container_spec(job_config, env_vars) + container_spec = _create_container_spec(job_config, env_vars) - disk_spec = _get_disk_spec(job_config) + disk_spec = _create_disk_spec(job_config) assert ( job_config.replica_count >= 1 @@ -270,27 +270,12 @@ def launch_graph_store_job( Returns: The completed CustomJob. """ - storage_machine_spec = _get_machine_spec(storage_cluster) - compute_machine_spec = _get_machine_spec(compute_cluster) - storage_disk_spec = _get_disk_spec(storage_cluster) - compute_disk_spec = _get_disk_spec(compute_cluster) + storage_machine_spec = _create_machine_spec(storage_cluster) + compute_machine_spec = _create_machine_spec(compute_cluster) + storage_disk_spec = _create_disk_spec(storage_cluster) + compute_disk_spec = _create_disk_spec(compute_cluster) - # This file is used to store the leader worker's internal IP address. - # Whenever `connect_worker_pool()` is called, the leader worker will - # write its internal IP address to this file. The other workers will - # read this file to get the leader worker's internal IP address. - # See connect_worker_pool() implementation for more details. - leader_worker_internal_ip_file_path = GcsUri.join( - self._staging_bucket, - storage_cluster.job_name, - datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), - "leader_worker_internal_ip.txt", - ) env_vars: list[env_var.EnvVar] = [ - env_var.EnvVar( - name=LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY, - value=leader_worker_internal_ip_file_path.uri, - ), env_var.EnvVar( name=STORAGE_CLUSTER_MASTER_KEY, value="0", @@ -301,8 +286,8 @@ def launch_graph_store_job( ), ] - storage_container_spec = _get_container_spec(storage_cluster, env_vars) - compute_container_spec = _get_container_spec(compute_cluster, env_vars) + storage_container_spec = _create_container_spec(storage_cluster, env_vars) + compute_container_spec = _create_container_spec(compute_cluster, env_vars) worker_pool_specs: list[Union[WorkerPoolSpec, dict]] = [] @@ -470,7 +455,7 @@ def wait_for_run_completion( ) -def _get_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec: +def _create_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec: """Get the machine spec for a job config.""" machine_spec = MachineSpec( machine_type=job_config.machine_type, @@ -480,7 +465,7 @@ def _get_machine_spec(job_config: VertexAiJobConfig) -> MachineSpec: return machine_spec -def _get_container_spec( +def _create_container_spec( job_config: VertexAiJobConfig, env_vars: list[env_var.EnvVar] ) -> ContainerSpec: """Get the container spec for a job config.""" @@ -493,7 +478,7 @@ def _get_container_spec( return container_spec -def _get_disk_spec(job_config: VertexAiJobConfig) -> DiskSpec: +def _create_disk_spec(job_config: VertexAiJobConfig) -> DiskSpec: """Get the disk spec for a job config.""" disk_spec = DiskSpec( boot_disk_type=job_config.boot_disk_type, From 1752d918aaf812cf9a73de599cb3027efbe95f8b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 6 Oct 2025 17:35:00 +0000 Subject: [PATCH 11/33] Add get_graph_store_info to setup graph store clusters --- python/gigl/common/services/vertex_ai.py | 7 +- python/gigl/distributed/utils/__init__.py | 6 +- python/gigl/distributed/utils/networking.py | 82 ++++++ .../gigl/src/common/constants/distributed.py | 7 + .../common/services/vertex_ai_test.py | 7 +- .../unit/distributed/utils/networking_test.py | 250 ++++++++++++++++++ 6 files changed, 351 insertions(+), 8 deletions(-) create mode 100644 python/gigl/src/common/constants/distributed.py diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 0734b34cc..76bc1aa82 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -75,6 +75,10 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name from gigl.common import GcsUri, Uri from gigl.common.logger import Logger +from gigl.src.common.constants.distributed import ( + COMPUTE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_MASTER_KEY, +) logger = Logger() @@ -82,9 +86,6 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name str ] = "LEADER_WORKER_INTERNAL_IP_FILE_PATH" -STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_MASTER_KEY" -COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_MASTER_KEY" - DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36 # 36 hours DEFAULT_CUSTOM_JOB_TIMEOUT_S: Final[int] = 60 * 60 * 24 # 24 hours diff --git a/python/gigl/distributed/utils/__init__.py b/python/gigl/distributed/utils/__init__.py index b73d00b8b..363fb1470 100644 --- a/python/gigl/distributed/utils/__init__.py +++ b/python/gigl/distributed/utils/__init__.py @@ -3,10 +3,12 @@ """ __all__ = [ + "GraphStoreInfo", "get_available_device", + "get_free_port", "get_free_ports_from_master_node", "get_free_ports_from_node", - "get_free_port", + "get_graph_store_info", "get_internal_ip_from_all_ranks", "get_internal_ip_from_master_node", "get_internal_ip_from_node", @@ -20,9 +22,11 @@ init_neighbor_loader_worker, ) from .networking import ( + GraphStoreInfo, get_free_port, get_free_ports_from_master_node, get_free_ports_from_node, + get_graph_store_info, get_internal_ip_from_all_ranks, get_internal_ip_from_master_node, get_internal_ip_from_node, diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 0ef68dbeb..825fad498 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -1,9 +1,15 @@ +import os import socket +from dataclasses import dataclass from typing import Optional import torch from gigl.common.logger import Logger +from gigl.src.common.constants.distributed import ( + COMPUTE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_MASTER_KEY, +) logger = Logger() @@ -179,3 +185,79 @@ def get_internal_ip_from_all_ranks() -> list[str]: assert all(ip for ip in ip_list), "Could not retrieve all ranks' internal IPs" return ip_list + + +@dataclass(frozen=True) +class GraphStoreInfo: + """Information about a graph store cluster.""" + + # Number of nodes in the whole cluster + num_cluster_nodes: int + # Number of nodes in the storage cluster + num_storage_nodes: int + # Number of nodes in the compute cluster + num_compute_nodes: int + + # IP address of the master node for the whole cluster + cluster_master_ip: str + # IP address of the master node for the storage cluster + storage_cluster_master_ip: str + # IP address of the master node for the compute cluster + compute_cluster_master_ip: str + + # Port of the master node for the whole cluster + cluster_master_port: int + # Port of the master node for the storage cluster + storage_cluster_master_port: int + # Port of the master node for the compute cluster + compute_cluster_master_port: int + + +def get_graph_store_info() -> GraphStoreInfo: + """ + Get the information about the graph store cluster. + + Returns: + GraphStoreInfo: The information about the graph store cluster. + + Raises: + ValueError: If a torch distributed environment is not initialized. + ValueError: If the storage cluster master key or the compute cluster master key is not set as an environment variable. + """ + if not torch.distributed.is_initialized(): + raise ValueError("Distributed environment must be initialized") + + if not STORAGE_CLUSTER_MASTER_KEY in os.environ: + raise ValueError( + f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable" + ) + if not COMPUTE_CLUSTER_MASTER_KEY in os.environ: + raise ValueError( + f"{COMPUTE_CLUSTER_MASTER_KEY} must be set as an environment variable" + ) + + num_storage_nodes = int(os.environ[STORAGE_CLUSTER_MASTER_KEY]) + num_compute_nodes = int(os.environ[COMPUTE_CLUSTER_MASTER_KEY]) + + cluster_master_ip = get_internal_ip_from_master_node() + # We assume that the storage cluster nodes come first. + storage_cluster_master_ip = get_internal_ip_from_node(node_rank=0) + compute_cluster_master_ip = get_internal_ip_from_node(node_rank=num_storage_nodes) + + cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + storage_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + compute_cluster_master_port = get_free_ports_from_node( + num_ports=1, node_rank=num_storage_nodes + )[0] + + return GraphStoreInfo( + num_cluster_nodes=num_storage_nodes + num_compute_nodes, + num_storage_nodes=num_storage_nodes, + num_compute_nodes=num_compute_nodes, + cluster_master_ip=cluster_master_ip, + storage_cluster_master_ip=storage_cluster_master_ip, + compute_cluster_master_ip=compute_cluster_master_ip, + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + ) diff --git a/python/gigl/src/common/constants/distributed.py b/python/gigl/src/common/constants/distributed.py new file mode 100644 index 000000000..2032dc759 --- /dev/null +++ b/python/gigl/src/common/constants/distributed.py @@ -0,0 +1,7 @@ +"""Constants for distributed workloads.""" +from typing import Final + +# The env vars where the ranks of the leader workers are stored for the storage and compute clusters +# Only applicable in multipool workloads. +STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_NUM_NODES" +COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_NUM_NODES" diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index b1c46f79e..67e660e5c 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -8,13 +8,12 @@ from gigl.common import UriFactory from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU -from gigl.common.services.vertex_ai import ( +from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.constants.distributed import ( COMPUTE_CLUSTER_MASTER_KEY, STORAGE_CLUSTER_MASTER_KEY, - VertexAiJobConfig, - VertexAIService, ) -from gigl.env.pipelines_config import get_resource_config @kfp.dsl.component diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index 0a2fcf6be..4f9271e5a 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -1,5 +1,7 @@ +import os import subprocess import unittest +from typing import Optional from unittest.mock import patch import torch @@ -8,11 +10,17 @@ from parameterized import param, parameterized from gigl.distributed.utils import ( + GraphStoreInfo, get_free_ports_from_master_node, get_free_ports_from_node, + get_graph_store_info, get_internal_ip_from_master_node, get_internal_ip_from_node, ) +from gigl.src.common.constants.distributed import ( + COMPUTE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_MASTER_KEY, +) from tests.test_assets.distributed.utils import get_process_group_init_method @@ -289,3 +297,245 @@ def test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized msg="An error should be raised since the `dist.init_process_group` is not initialized", ): get_internal_ip_from_master_node() + + +def _test_get_graph_store_info_in_dist_context( + rank: int, + world_size: int, + init_process_group_init_method: str, + storage_nodes: int, + compute_nodes: int, +): + """Test get_graph_store_info in a real distributed context.""" + # Initialize distributed process group + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=world_size, + rank=rank, + ) + try: + # Call get_graph_store_info + graph_store_info = get_graph_store_info() + + # Verify the result is a GraphStoreInfo instance + assert isinstance( + graph_store_info, GraphStoreInfo + ), "Result should be a GraphStoreInfo instance" + + # Verify cluster sizes + assert ( + graph_store_info.num_storage_nodes == storage_nodes + ), f"Expected {storage_nodes} storage nodes" + assert ( + graph_store_info.num_compute_nodes == compute_nodes + ), f"Expected {compute_nodes} compute nodes" + assert ( + graph_store_info.num_cluster_nodes == storage_nodes + compute_nodes + ), "Total nodes should be sum of storage and compute nodes" + + # Verify IP addresses are strings and not empty + assert isinstance( + graph_store_info.cluster_master_ip, str + ), "Cluster master IP should be a string" + assert ( + len(graph_store_info.cluster_master_ip) > 0 + ), "Cluster master IP should not be empty" + assert isinstance( + graph_store_info.storage_cluster_master_ip, str + ), "Storage cluster master IP should be a string" + assert ( + len(graph_store_info.storage_cluster_master_ip) > 0 + ), "Storage cluster master IP should not be empty" + assert isinstance( + graph_store_info.compute_cluster_master_ip, str + ), "Compute cluster master IP should be a string" + assert ( + len(graph_store_info.compute_cluster_master_ip) > 0 + ), "Compute cluster master IP should not be empty" + + # Verify ports are positive integers + assert isinstance( + graph_store_info.cluster_master_port, int + ), "Cluster master port should be an integer" + assert ( + graph_store_info.cluster_master_port > 0 + ), "Cluster master port should be positive" + assert isinstance( + graph_store_info.storage_cluster_master_port, int + ), "Storage cluster master port should be an integer" + assert ( + graph_store_info.storage_cluster_master_port > 0 + ), "Storage cluster master port should be positive" + assert isinstance( + graph_store_info.compute_cluster_master_port, int + ), "Compute cluster master port should be an integer" + assert ( + graph_store_info.compute_cluster_master_port > 0 + ), "Compute cluster master port should be positive" + + # Verify all ranks get the same result (since they should all get the same broadcasted values) + gathered_info: list[Optional[GraphStoreInfo]] = [None] * world_size + dist.all_gather_object(gathered_info, graph_store_info) + + # All ranks should have the same GraphStoreInfo + for i, info in enumerate(gathered_info): + assert info is not None + assert ( + info.num_cluster_nodes == graph_store_info.num_cluster_nodes + ), f"Rank {i} should have same cluster nodes" + assert ( + info.num_storage_nodes == graph_store_info.num_storage_nodes + ), f"Rank {i} should have same storage nodes" + assert ( + info.num_compute_nodes == graph_store_info.num_compute_nodes + ), f"Rank {i} should have same compute nodes" + assert ( + info.cluster_master_ip == graph_store_info.cluster_master_ip + ), f"Rank {i} should have same cluster master IP" + assert ( + info.storage_cluster_master_ip + == graph_store_info.storage_cluster_master_ip + ), f"Rank {i} should have same storage master IP" + assert ( + info.compute_cluster_master_ip + == graph_store_info.compute_cluster_master_ip + ), f"Rank {i} should have same compute master IP" + assert ( + info.cluster_master_port == graph_store_info.cluster_master_port + ), f"Rank {i} should have same cluster master port" + assert ( + info.storage_cluster_master_port + == graph_store_info.storage_cluster_master_port + ), f"Rank {i} should have same storage master port" + assert ( + info.compute_cluster_master_port + == graph_store_info.compute_cluster_master_port + ), f"Rank {i} should have same compute master port" + + finally: + dist.destroy_process_group() + + +class TestGetGraphStoreInfo(unittest.TestCase): + """Test suite for get_graph_store_info function.""" + + def tearDown(self): + """Clean up after each test.""" + if dist.is_initialized(): + dist.destroy_process_group() + + def test_get_graph_store_info_fails_when_distributed_not_initialized(self): + """Test that get_graph_store_info fails when distributed environment is not initialized.""" + with patch.dict( + os.environ, + {STORAGE_CLUSTER_MASTER_KEY: "2", COMPUTE_CLUSTER_MASTER_KEY: "3"}, + ): + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + self.assertIn( + "Distributed environment must be initialized", str(context.exception) + ) + + def test_get_graph_store_info_fails_when_storage_cluster_key_missing(self): + """Test that get_graph_store_info fails when STORAGE_CLUSTER_MASTER_KEY is not set.""" + with patch.dict(os.environ, {COMPUTE_CLUSTER_MASTER_KEY: "3"}, clear=False): + init_process_group_init_method = get_process_group_init_method() + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=1, + rank=0, + ) + + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + self.assertIn( + f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable", + str(context.exception), + ) + + def test_get_graph_store_info_fails_when_compute_cluster_key_missing(self): + """Test that get_graph_store_info fails when COMPUTE_CLUSTER_MASTER_KEY is not set.""" + with patch.dict(os.environ, {STORAGE_CLUSTER_MASTER_KEY: "2"}, clear=False): + init_process_group_init_method = get_process_group_init_method() + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=1, + rank=0, + ) + + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + self.assertIn( + f"{COMPUTE_CLUSTER_MASTER_KEY} must be set as an environment variable", + str(context.exception), + ) + + def test_get_graph_store_info_fails_when_both_cluster_keys_missing(self): + """Test that get_graph_store_info fails when both cluster keys are not set.""" + with patch.dict(os.environ, {}, clear=True): + init_process_group_init_method = get_process_group_init_method() + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=1, + rank=0, + ) + + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + # Should fail on the first missing key (storage cluster key) + self.assertIn( + f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable", + str(context.exception), + ) + + @parameterized.expand( + [ + param( + "Test with 1 storage node and 1 compute node", + storage_nodes=1, + compute_nodes=1, + ), + param( + "Test with 2 storage nodes and 3 compute nodes", + storage_nodes=2, + compute_nodes=3, + ), + param( + "Test with 3 storage nodes and 2 compute nodes", + storage_nodes=3, + compute_nodes=2, + ), + ] + ) + def test_get_graph_store_info_success_in_distributed_context( + self, _name, storage_nodes, compute_nodes + ): + """Test successful execution of get_graph_store_info in a real distributed context.""" + init_process_group_init_method = get_process_group_init_method() + world_size = storage_nodes + compute_nodes + with patch.dict( + os.environ, + { + STORAGE_CLUSTER_MASTER_KEY: str(storage_nodes), + COMPUTE_CLUSTER_MASTER_KEY: str(compute_nodes), + }, + clear=False, + ): + mp.spawn( + fn=_test_get_graph_store_info_in_dist_context, + args=( + world_size, + init_process_group_init_method, + storage_nodes, + compute_nodes, + ), + nprocs=world_size, + ) From 64296177d7a8214cc5077dc9fddd9696adfdaaf2 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 6 Oct 2025 17:58:38 +0000 Subject: [PATCH 12/33] add intergration tests for get_graph_store_info --- python/gigl/common/services/vertex_ai.py | 10 +++ python/gigl/distributed/utils/networking.py | 31 ++++++-- .../gigl/src/common/constants/distributed.py | 7 +- .../common/services/vertex_ai_test.py | 10 +++ .../integration/distributed/utils/__init__.py | 0 .../distributed/utils/networking_test.py | 76 +++++++++++++++++++ 6 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 python/tests/integration/distributed/utils/__init__.py create mode 100644 python/tests/integration/distributed/utils/networking_test.py diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 76bc1aa82..834a67389 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -77,7 +77,9 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name from gigl.common.logger import Logger from gigl.src.common.constants.distributed import ( COMPUTE_CLUSTER_MASTER_KEY, + COMPUTE_CLUSTER_NUM_NODES_KEY, STORAGE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_NUM_NODES_KEY, ) logger = Logger() @@ -285,6 +287,14 @@ def launch_graph_store_job( name=COMPUTE_CLUSTER_MASTER_KEY, value=str(storage_cluster.replica_count), ), + env_var.EnvVar( + name=STORAGE_CLUSTER_NUM_NODES_KEY, + value=str(storage_cluster.replica_count), + ), + env_var.EnvVar( + name=COMPUTE_CLUSTER_NUM_NODES_KEY, + value=str(compute_cluster.replica_count), + ), ] storage_container_spec = _create_container_spec(storage_cluster, env_vars) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 825fad498..39a1228a9 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -8,7 +8,9 @@ from gigl.common.logger import Logger from gigl.src.common.constants.distributed import ( COMPUTE_CLUSTER_MASTER_KEY, + COMPUTE_CLUSTER_NUM_NODES_KEY, STORAGE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_NUM_NODES_KEY, ) logger = Logger() @@ -235,21 +237,38 @@ def get_graph_store_info() -> GraphStoreInfo: raise ValueError( f"{COMPUTE_CLUSTER_MASTER_KEY} must be set as an environment variable" ) + if not STORAGE_CLUSTER_NUM_NODES_KEY in os.environ: + raise ValueError( + f"{STORAGE_CLUSTER_NUM_NODES_KEY} must be set as an environment variable" + ) + if not COMPUTE_CLUSTER_NUM_NODES_KEY in os.environ: + raise ValueError( + f"{COMPUTE_CLUSTER_NUM_NODES_KEY} must be set as an environment variable" + ) - num_storage_nodes = int(os.environ[STORAGE_CLUSTER_MASTER_KEY]) - num_compute_nodes = int(os.environ[COMPUTE_CLUSTER_MASTER_KEY]) + storage_cluster_master_rank = int(os.environ[STORAGE_CLUSTER_MASTER_KEY]) + compute_cluster_master_rank = int(os.environ[COMPUTE_CLUSTER_MASTER_KEY]) cluster_master_ip = get_internal_ip_from_master_node() # We assume that the storage cluster nodes come first. - storage_cluster_master_ip = get_internal_ip_from_node(node_rank=0) - compute_cluster_master_ip = get_internal_ip_from_node(node_rank=num_storage_nodes) + storage_cluster_master_ip = get_internal_ip_from_node( + node_rank=storage_cluster_master_rank + ) + compute_cluster_master_ip = get_internal_ip_from_node( + node_rank=compute_cluster_master_rank + ) cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] - storage_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + storage_cluster_master_port = get_free_ports_from_node( + num_ports=1, node_rank=storage_cluster_master_rank + )[0] compute_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=num_storage_nodes + num_ports=1, node_rank=compute_cluster_master_rank )[0] + num_storage_nodes = int(os.environ[STORAGE_CLUSTER_NUM_NODES_KEY]) + num_compute_nodes = int(os.environ[COMPUTE_CLUSTER_NUM_NODES_KEY]) + return GraphStoreInfo( num_cluster_nodes=num_storage_nodes + num_compute_nodes, num_storage_nodes=num_storage_nodes, diff --git a/python/gigl/src/common/constants/distributed.py b/python/gigl/src/common/constants/distributed.py index 2032dc759..0b4951564 100644 --- a/python/gigl/src/common/constants/distributed.py +++ b/python/gigl/src/common/constants/distributed.py @@ -3,5 +3,8 @@ # The env vars where the ranks of the leader workers are stored for the storage and compute clusters # Only applicable in multipool workloads. -STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_NUM_NODES" -COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_NUM_NODES" +STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_MASTER_RANK" +COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_MASTER_RANK" + +STORAGE_CLUSTER_NUM_NODES_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_NUM_NODES" +COMPUTE_CLUSTER_NUM_NODES_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_NUM_NODES" diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index 67e660e5c..d5d68433f 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -12,7 +12,9 @@ from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.distributed import ( COMPUTE_CLUSTER_MASTER_KEY, + COMPUTE_CLUSTER_NUM_NODES_KEY, STORAGE_CLUSTER_MASTER_KEY, + STORAGE_CLUSTER_NUM_NODES_KEY, ) @@ -92,6 +94,8 @@ def test_launch_job(self): "import os", f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", + f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", + f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", ], expected_worker_pool_specs=[ { @@ -116,6 +120,8 @@ def test_launch_job(self): "import os", f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", + f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", + f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", ], expected_worker_pool_specs=[ { @@ -140,6 +146,8 @@ def test_launch_job(self): "import os", f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", + f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", + f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", ], expected_worker_pool_specs=[ { @@ -168,6 +176,8 @@ def test_launch_job(self): "import os", f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", + f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", + f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", ], expected_worker_pool_specs=[ { diff --git a/python/tests/integration/distributed/utils/__init__.py b/python/tests/integration/distributed/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py new file mode 100644 index 000000000..0d8873be1 --- /dev/null +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -0,0 +1,76 @@ +import unittest +import uuid + +from parameterized import param, parameterized + +from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU +from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.env.pipelines_config import get_resource_config + + +class NetworkingUtlsIntegrationTest(unittest.TestCase): + def setUp(self): + self._resource_config = get_resource_config() + self._project = self._resource_config.project + self._location = self._resource_config.region + self._service_account = self._resource_config.service_account_email + self._staging_bucket = ( + self._resource_config.temp_assets_regional_bucket_path.uri + ) + self._vertex_ai_service = VertexAIService( + project=self._project, + location=self._location, + service_account=self._service_account, + staging_bucket=self._staging_bucket, + ) + super().setUp() + + @parameterized.expand( + [ + param( + "Test with 1 storage node and 1 compute node", + storage_nodes=1, + compute_nodes=1, + ), + param( + "Test with 2 storage nodes and 1 compute nodes", + storage_nodes=2, + compute_nodes=1, + ), + param( + "Test with 1 storage nodes and 2 compute nodes", + storage_nodes=1, + compute_nodes=2, + ), + param( + "Test with 2 storage nodes and 2 compute nodes", + storage_nodes=2, + compute_nodes=2, + ), + ] + ) + def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): + job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" + command = [ + "python", + "-c", + "from gigl.distributed.utils import get_graph_store_info; get_graph_store_info()", + ] + storage_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + replica_count=storage_nodes, + machine_type="n1-standard-4", + command=command, + ) + compute_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, # different image for storage and compute + replica_count=compute_nodes, + command=command, + machine_type="n2-standard-8", + ) + + self._vertex_ai_service.launch_graph_store_job( + storage_cluster_config, compute_cluster_config + ) From 603ca6ac518fe71f28db124665fcbdd91bc87ca8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 6 Oct 2025 18:24:37 +0000 Subject: [PATCH 13/33] [AUTOMATED] Update dep.vars, and other relevant files with new image names --- .github/cloud_builder/run_command_on_active_checkout.yaml | 2 +- dep_vars.env | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/cloud_builder/run_command_on_active_checkout.yaml b/.github/cloud_builder/run_command_on_active_checkout.yaml index eb905fb99..5e898d388 100644 --- a/.github/cloud_builder/run_command_on_active_checkout.yaml +++ b/.github/cloud_builder/run_command_on_active_checkout.yaml @@ -3,7 +3,7 @@ substitutions: options: logging: CLOUD_LOGGING_ONLY steps: - - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 + - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 entrypoint: /bin/bash args: - -c diff --git a/dep_vars.env b/dep_vars.env index 0b6ca38ef..31915b440 100644 --- a/dep_vars.env +++ b/dep_vars.env @@ -1,7 +1,7 @@ # Note this file only supports static key value pairs so it can be loaded by make, bash, python, and sbt without any additional parsing. -DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 -DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 -DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 +DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 +DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 +DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9 DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9 From d78e938dfc597fd4b92215905d643f822bf616ba Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 6 Oct 2025 21:16:00 +0000 Subject: [PATCH 14/33] bleg --- python/gigl/common/services/vertex_ai.py | 37 +------- python/gigl/common/utils/vertex_ai_context.py | 94 +++++++++++++++++++ python/gigl/distributed/utils/networking.py | 69 +++----------- python/gigl/env/distributed.py | 29 ++++++ .../gigl/src/common/constants/distributed.py | 10 -- .../common/services/vertex_ai_test.py | 41 +------- .../unit/distributed/utils/networking_test.py | 4 - 7 files changed, 138 insertions(+), 146 deletions(-) create mode 100644 python/gigl/env/distributed.py delete mode 100644 python/gigl/src/common/constants/distributed.py diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 834a67389..e13f281ee 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -75,12 +75,6 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name from gigl.common import GcsUri, Uri from gigl.common.logger import Logger -from gigl.src.common.constants.distributed import ( - COMPUTE_CLUSTER_MASTER_KEY, - COMPUTE_CLUSTER_NUM_NODES_KEY, - STORAGE_CLUSTER_MASTER_KEY, - STORAGE_CLUSTER_NUM_NODES_KEY, -) logger = Logger() @@ -253,14 +247,6 @@ def launch_graph_store_job( https://cloud.google.com/vertex-ai/docs/training/distributed-training for more details. - These jobs will have the follow env variables set - - GIGL_STORAGE_CLUSTER_MASTER_KEY - - GIGL_COMPUTE_CLUSTER_MASTER_KEY - Whose values are the cluster ranks of the leaders for the storage and compute clusters respectively. - For example, if if there are 2 nodes in the storage cluster, and 3 nodes in the compute cluster, - Then, - - GIGL_STORAGE_CLUSTER_MASTER_KEY = 0 - - GIGL_COMPUTE_CLUSTER_MASTER_KEY = 2 # e.g. the "first" worker in the computer cluster pool is the leader. NOTE: We use the job_name, timeout, and enable_web_access from the storage cluster. @@ -278,27 +264,8 @@ def launch_graph_store_job( storage_disk_spec = _create_disk_spec(storage_cluster) compute_disk_spec = _create_disk_spec(compute_cluster) - env_vars: list[env_var.EnvVar] = [ - env_var.EnvVar( - name=STORAGE_CLUSTER_MASTER_KEY, - value="0", - ), - env_var.EnvVar( - name=COMPUTE_CLUSTER_MASTER_KEY, - value=str(storage_cluster.replica_count), - ), - env_var.EnvVar( - name=STORAGE_CLUSTER_NUM_NODES_KEY, - value=str(storage_cluster.replica_count), - ), - env_var.EnvVar( - name=COMPUTE_CLUSTER_NUM_NODES_KEY, - value=str(compute_cluster.replica_count), - ), - ] - - storage_container_spec = _create_container_spec(storage_cluster, env_vars) - compute_container_spec = _create_container_spec(compute_cluster, env_vars) + storage_container_spec = _create_container_spec(storage_cluster, []) + compute_container_spec = _create_container_spec(compute_cluster, []) worker_pool_specs: list[Union[WorkerPoolSpec, dict]] = [] diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index dfdf569c1..3e36ee7a6 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -1,14 +1,20 @@ """Utility functions to be used by machines running on Vertex AI.""" +import json import os import subprocess import time +from dataclasses import dataclass +from typing import Dict, List, Optional + +from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils from gigl.distributed import DistributedContext +from gigl.env.distributed import GraphStoreInfo logger = Logger() @@ -155,6 +161,94 @@ def connect_worker_pool() -> DistributedContext: global_world_size=global_world_size, ) +def get_num_storage_and_compute_nodes() -> tuple[int, int]: + """ + Returns the number of storage and compute nodes for a Vertex AI job. + + Raises: + ValueError: If not running in a Vertex AI job. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec = _parse_cluster_spec() + if len(cluster_spec.cluster) != 4: + raise ValueError(f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools.") + num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len(cluster_spec.cluster["workerpool1"]) + num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) + + return num_storage_nodes, num_compute_nodes + +@dataclass +class TaskInfo: + """Information about the current task running on this node.""" + type: str # The type of worker pool this task is running in (e.g., "workerpool0") + index: int # The zero-based index of the task + trial: Optional[str] = None # Hyperparameter tuning trial identifier (if applicable) + + +@dataclass +class ClusterSpec: + """Represents the cluster specification for a Vertex AI custom job.""" + cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists + environment: str # The environment string (e.g., "cloud") + task: TaskInfo # Information about the current task + job: Optional[CustomJobSpec] = None # The CustomJobSpec for the current job + + +def _parse_cluster_spec() -> ClusterSpec: + """ + Parse the cluster specification from the CLUSTER_SPEC environment variable. + + Returns: + ClusterSpec: Parsed cluster specification data. + + Raises: + ValueError: If not running in a Vertex AI job or CLUSTER_SPEC is not found. + json.JSONDecodeError: If CLUSTER_SPEC contains invalid JSON. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec_json = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_json: + raise ValueError("CLUSTER_SPEC not found in environment variables.") + + try: + cluster_spec_data = json.loads(cluster_spec_json) + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos) + + # Parse the task information + task_data = cluster_spec_data.get("task", {}) + task_info = TaskInfo( + type=task_data.get("type", ""), + index=task_data.get("index", 0), + trial=task_data.get("trial") + ) + + # Parse the cluster specification + cluster_data = cluster_spec_data.get("cluster", {}) + + # Parse the environment + environment = cluster_spec_data.get("environment", "cloud") + + + # Parse the job specification (optional) + job_data = cluster_spec_data.get("job") + job_spec = None + if job_data: + # Convert the dictionary to CustomJobSpec + # Note: This assumes the job_data is already in the correct format + # You may need to adjust this based on the actual structure + job_spec = CustomJobSpec(**job_data) + + return ClusterSpec( + cluster=cluster_data, + environment=environment, + task=task_info, + job=job_spec + ) def _get_leader_worker_internal_ip_file_path() -> str: """ diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 39a1228a9..c165bca8b 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -6,12 +6,8 @@ import torch from gigl.common.logger import Logger -from gigl.src.common.constants.distributed import ( - COMPUTE_CLUSTER_MASTER_KEY, - COMPUTE_CLUSTER_NUM_NODES_KEY, - STORAGE_CLUSTER_MASTER_KEY, - STORAGE_CLUSTER_NUM_NODES_KEY, -) +from gigl.env.distributed import GraphStoreInfo +from gigl.common.utils.vertex_ai_context import is_currently_running_in_vertex_ai_job, get_num_storage_and_compute_nodes logger = Logger() @@ -189,31 +185,6 @@ def get_internal_ip_from_all_ranks() -> list[str]: return ip_list -@dataclass(frozen=True) -class GraphStoreInfo: - """Information about a graph store cluster.""" - - # Number of nodes in the whole cluster - num_cluster_nodes: int - # Number of nodes in the storage cluster - num_storage_nodes: int - # Number of nodes in the compute cluster - num_compute_nodes: int - - # IP address of the master node for the whole cluster - cluster_master_ip: str - # IP address of the master node for the storage cluster - storage_cluster_master_ip: str - # IP address of the master node for the compute cluster - compute_cluster_master_ip: str - - # Port of the master node for the whole cluster - cluster_master_port: int - # Port of the master node for the storage cluster - storage_cluster_master_port: int - # Port of the master node for the compute cluster - compute_cluster_master_port: int - def get_graph_store_info() -> GraphStoreInfo: """ @@ -224,50 +195,32 @@ def get_graph_store_info() -> GraphStoreInfo: Raises: ValueError: If a torch distributed environment is not initialized. - ValueError: If the storage cluster master key or the compute cluster master key is not set as an environment variable. + ValueError: If not running running in a supported environment. """ if not torch.distributed.is_initialized(): raise ValueError("Distributed environment must be initialized") - - if not STORAGE_CLUSTER_MASTER_KEY in os.environ: - raise ValueError( - f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable" - ) - if not COMPUTE_CLUSTER_MASTER_KEY in os.environ: - raise ValueError( - f"{COMPUTE_CLUSTER_MASTER_KEY} must be set as an environment variable" - ) - if not STORAGE_CLUSTER_NUM_NODES_KEY in os.environ: - raise ValueError( - f"{STORAGE_CLUSTER_NUM_NODES_KEY} must be set as an environment variable" - ) - if not COMPUTE_CLUSTER_NUM_NODES_KEY in os.environ: - raise ValueError( - f"{COMPUTE_CLUSTER_NUM_NODES_KEY} must be set as an environment variable" - ) - - storage_cluster_master_rank = int(os.environ[STORAGE_CLUSTER_MASTER_KEY]) - compute_cluster_master_rank = int(os.environ[COMPUTE_CLUSTER_MASTER_KEY]) + if is_currently_running_in_vertex_ai_job(): + num_storage_nodes, num_compute_nodes = get_num_storage_and_compute_nodes() + else: + raise ValueError("Must be running on a vertex AI job to get graph store cluster info!") cluster_master_ip = get_internal_ip_from_master_node() # We assume that the storage cluster nodes come first. storage_cluster_master_ip = get_internal_ip_from_node( - node_rank=storage_cluster_master_rank + node_rank=0 ) compute_cluster_master_ip = get_internal_ip_from_node( - node_rank=compute_cluster_master_rank + node_rank=num_storage_nodes ) cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] storage_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=storage_cluster_master_rank + num_ports=1, node_rank=0 )[0] compute_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=compute_cluster_master_rank + num_ports=1, node_rank=num_storage_nodes )[0] - num_storage_nodes = int(os.environ[STORAGE_CLUSTER_NUM_NODES_KEY]) - num_compute_nodes = int(os.environ[COMPUTE_CLUSTER_NUM_NODES_KEY]) return GraphStoreInfo( num_cluster_nodes=num_storage_nodes + num_compute_nodes, diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py new file mode 100644 index 000000000..722fe7778 --- /dev/null +++ b/python/gigl/env/distributed.py @@ -0,0 +1,29 @@ +"""Information about the distributed setup.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class GraphStoreInfo: + """Information about a graph store cluster.""" + + # Number of nodes in the whole cluster + num_cluster_nodes: int + # Number of nodes in the storage cluster + num_storage_nodes: int + # Number of nodes in the compute cluster + num_compute_nodes: int + + # IP address of the master node for the whole cluster + cluster_master_ip: str + # IP address of the master node for the storage cluster + storage_cluster_master_ip: str + # IP address of the master node for the compute cluster + compute_cluster_master_ip: str + + # Port of the master node for the whole cluster + cluster_master_port: int + # Port of the master node for the storage cluster + storage_cluster_master_port: int + # Port of the master node for the compute cluster + compute_cluster_master_port: int diff --git a/python/gigl/src/common/constants/distributed.py b/python/gigl/src/common/constants/distributed.py deleted file mode 100644 index 0b4951564..000000000 --- a/python/gigl/src/common/constants/distributed.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Constants for distributed workloads.""" -from typing import Final - -# The env vars where the ranks of the leader workers are stored for the storage and compute clusters -# Only applicable in multipool workloads. -STORAGE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_MASTER_RANK" -COMPUTE_CLUSTER_MASTER_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_MASTER_RANK" - -STORAGE_CLUSTER_NUM_NODES_KEY: Final[str] = "GIGL_STORAGE_CLUSTER_NUM_NODES" -COMPUTE_CLUSTER_NUM_NODES_KEY: Final[str] = "GIGL_COMPUTE_CLUSTER_NUM_NODES" diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index d5d68433f..f7c07df26 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -10,13 +10,6 @@ from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService from gigl.env.pipelines_config import get_resource_config -from gigl.src.common.constants.distributed import ( - COMPUTE_CLUSTER_MASTER_KEY, - COMPUTE_CLUSTER_NUM_NODES_KEY, - STORAGE_CLUSTER_MASTER_KEY, - STORAGE_CLUSTER_NUM_NODES_KEY, -) - @kfp.dsl.component def source() -> int: @@ -90,13 +83,6 @@ def test_launch_job(self): "one server, one client", num_servers=1, num_clients=1, - env_var_checks=[ - "import os", - f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", - f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", - f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", - f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", - ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -116,13 +102,6 @@ def test_launch_job(self): "one server, two clients", num_servers=1, num_clients=2, - env_var_checks=[ - "import os", - f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", - f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '1'", - f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", - f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", - ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -142,13 +121,6 @@ def test_launch_job(self): "two servers, one client", num_servers=2, num_clients=1, - env_var_checks=[ - "import os", - f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", - f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", - f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", - f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '1', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '1'", - ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -172,13 +144,6 @@ def test_launch_job(self): "two servers, two clients", num_servers=2, num_clients=2, - env_var_checks=[ - "import os", - f"assert os.environ['{STORAGE_CLUSTER_MASTER_KEY}'] == '0', Expected {{os.environ['{STORAGE_CLUSTER_MASTER_KEY}']=}} to be '0'", - f"assert os.environ['{COMPUTE_CLUSTER_MASTER_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_MASTER_KEY}']=}} to be '2'", - f"assert os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{STORAGE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", - f"assert os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}'] == '2', Expected {{os.environ['{COMPUTE_CLUSTER_NUM_NODES_KEY}']=}} to be '2'", - ], expected_worker_pool_specs=[ { "machine_type": "n1-standard-4", @@ -205,17 +170,15 @@ def test_launch_graph_store_job( _, num_servers, num_clients, - env_var_checks, expected_worker_pool_specs, ): # Tests that the env variables are set correctly. # If they are not populated, then the job will fail. - env_checks = f'logging.info(f\'Graph cluster master: {{os.environ["{STORAGE_CLUSTER_MASTER_KEY}"]}}, compute cluster master: {{os.environ["{COMPUTE_CLUSTER_MASTER_KEY}"]}}\')' command = [ "python", "-c", - f"import os; import logging; {env_checks}", - ] + env_var_checks + f"import os; import logging; logging.info('Hello, World!')", + ] job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" storage_cluster_config = VertexAiJobConfig( job_name=job_name, diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index 4f9271e5a..10d253055 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -17,10 +17,6 @@ get_internal_ip_from_master_node, get_internal_ip_from_node, ) -from gigl.src.common.constants.distributed import ( - COMPUTE_CLUSTER_MASTER_KEY, - STORAGE_CLUSTER_MASTER_KEY, -) from tests.test_assets.distributed.utils import get_process_group_init_method From fc9d0d058d6127d4c3d33606cb7d0211496ca2b4 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 6 Oct 2025 22:53:52 +0000 Subject: [PATCH 15/33] wip --- python/gigl/common/utils/vertex_ai_context.py | 2 + .../unit/distributed/utils/networking_test.py | 100 ++++++------------ 2 files changed, 35 insertions(+), 67 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 3e36ee7a6..bec1e043b 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -199,6 +199,8 @@ class ClusterSpec: def _parse_cluster_spec() -> ClusterSpec: """ Parse the cluster specification from the CLUSTER_SPEC environment variable. + Based on the spec given at: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables Returns: ClusterSpec: Parsed cluster specification data. diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index 10d253055..cdc2752d4 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -2,6 +2,7 @@ import subprocess import unittest from typing import Optional +import json from unittest.mock import patch import torch @@ -412,6 +413,18 @@ def _test_get_graph_store_info_in_dist_context( finally: dist.destroy_process_group() +def _get_cluster_spec_for_test(worker_pool_sizes: list[int]): + cluster_spec: dict = { + "environment": "cloud", + "task": { + "type": "workerpool0", + "index": 0, + }, + "cluster": {}, + } + for i,worker_pool_size in enumerate(worker_pool_sizes): + cluster_spec["cluster"][f"workerpool{worker_pool_size}"] = [f"workerpool{i}-{j}:2222" for j in range(worker_pool_size)] + return cluster_spec class TestGetGraphStoreInfo(unittest.TestCase): """Test suite for get_graph_store_info function.""" @@ -423,74 +436,28 @@ def tearDown(self): def test_get_graph_store_info_fails_when_distributed_not_initialized(self): """Test that get_graph_store_info fails when distributed environment is not initialized.""" - with patch.dict( - os.environ, - {STORAGE_CLUSTER_MASTER_KEY: "2", COMPUTE_CLUSTER_MASTER_KEY: "3"}, - ): - with self.assertRaises(ValueError) as context: - get_graph_store_info() - - self.assertIn( - "Distributed environment must be initialized", str(context.exception) - ) - - def test_get_graph_store_info_fails_when_storage_cluster_key_missing(self): - """Test that get_graph_store_info fails when STORAGE_CLUSTER_MASTER_KEY is not set.""" - with patch.dict(os.environ, {COMPUTE_CLUSTER_MASTER_KEY: "3"}, clear=False): - init_process_group_init_method = get_process_group_init_method() - dist.init_process_group( - backend="gloo", - init_method=init_process_group_init_method, - world_size=1, - rank=0, - ) - - with self.assertRaises(ValueError) as context: - get_graph_store_info() - - self.assertIn( - f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable", - str(context.exception), - ) - - def test_get_graph_store_info_fails_when_compute_cluster_key_missing(self): - """Test that get_graph_store_info fails when COMPUTE_CLUSTER_MASTER_KEY is not set.""" - with patch.dict(os.environ, {STORAGE_CLUSTER_MASTER_KEY: "2"}, clear=False): - init_process_group_init_method = get_process_group_init_method() - dist.init_process_group( - backend="gloo", - init_method=init_process_group_init_method, - world_size=1, - rank=0, - ) - - with self.assertRaises(ValueError) as context: - get_graph_store_info() + with self.assertRaises(ValueError) as context: + get_graph_store_info() - self.assertIn( - f"{COMPUTE_CLUSTER_MASTER_KEY} must be set as an environment variable", - str(context.exception), - ) - - def test_get_graph_store_info_fails_when_both_cluster_keys_missing(self): - """Test that get_graph_store_info fails when both cluster keys are not set.""" - with patch.dict(os.environ, {}, clear=True): - init_process_group_init_method = get_process_group_init_method() - dist.init_process_group( - backend="gloo", - init_method=init_process_group_init_method, - world_size=1, - rank=0, - ) + self.assertIn( + "Distributed environment must be initialized", str(context.exception) + ) - with self.assertRaises(ValueError) as context: - get_graph_store_info() + def test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): + """Test that get_graph_store_info fails when not running in a Vertex AI job.""" + init_process_group_init_method = get_process_group_init_method() + torch.distributed.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=1, + rank=0, + ) + with self.assertRaises(ValueError) as context: + get_graph_store_info() - # Should fail on the first missing key (storage cluster key) - self.assertIn( - f"{STORAGE_CLUSTER_MASTER_KEY} must be set as an environment variable", - str(context.exception), - ) + self.assertIn( + "Must be running on a vertex AI job to get graph store cluster info!", str(context.exception) + ) @parameterized.expand( [ @@ -520,8 +487,7 @@ def test_get_graph_store_info_success_in_distributed_context( with patch.dict( os.environ, { - STORAGE_CLUSTER_MASTER_KEY: str(storage_nodes), - COMPUTE_CLUSTER_MASTER_KEY: str(compute_nodes), + "CLUSTER_SPEC": json.dumps(_get_cluster_spec_for_test([min(1, storage_nodes - 1), min(0, storage_nodes - 1), 0, compute_nodes])), }, clear=False, ): From fb91f1a8878696da9c630de5ec6e6a29571ffaf3 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 02:45:41 +0000 Subject: [PATCH 16/33] bleh --- python/gigl/common/utils/vertex_ai_context.py | 34 ++++++++++------ python/gigl/distributed/utils/networking.py | 25 +++++------- .../common/services/vertex_ai_test.py | 1 + .../distributed/utils/networking_test.py | 17 +++++++- .../unit/distributed/utils/networking_test.py | 40 ++++++++++++------- 5 files changed, 74 insertions(+), 43 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index bec1e043b..a5ff2558e 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -5,7 +5,7 @@ import subprocess import time from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Optional from google.cloud.aiplatform_v1.types import CustomJobSpec @@ -13,8 +13,7 @@ from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils -from gigl.distributed import DistributedContext -from gigl.env.distributed import GraphStoreInfo +from gigl.distributed.dist_context import DistributedContext logger = Logger() @@ -161,6 +160,7 @@ def connect_worker_pool() -> DistributedContext: global_world_size=global_world_size, ) + def get_num_storage_and_compute_nodes() -> tuple[int, int]: """ Returns the number of storage and compute nodes for a Vertex AI job. @@ -173,23 +173,32 @@ def get_num_storage_and_compute_nodes() -> tuple[int, int]: cluster_spec = _parse_cluster_spec() if len(cluster_spec.cluster) != 4: - raise ValueError(f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools.") - num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len(cluster_spec.cluster["workerpool1"]) + raise ValueError( + f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools." + ) + num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) return num_storage_nodes, num_compute_nodes + @dataclass class TaskInfo: """Information about the current task running on this node.""" + type: str # The type of worker pool this task is running in (e.g., "workerpool0") index: int # The zero-based index of the task - trial: Optional[str] = None # Hyperparameter tuning trial identifier (if applicable) + trial: Optional[ + str + ] = None # Hyperparameter tuning trial identifier (if applicable) @dataclass class ClusterSpec: """Represents the cluster specification for a Vertex AI custom job.""" + cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists environment: str # The environment string (e.g., "cloud") task: TaskInfo # Information about the current task @@ -219,14 +228,16 @@ def _parse_cluster_spec() -> ClusterSpec: try: cluster_spec_data = json.loads(cluster_spec_json) except json.JSONDecodeError as e: - raise json.JSONDecodeError(f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos) + raise json.JSONDecodeError( + f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos + ) # Parse the task information task_data = cluster_spec_data.get("task", {}) task_info = TaskInfo( type=task_data.get("type", ""), index=task_data.get("index", 0), - trial=task_data.get("trial") + trial=task_data.get("trial"), ) # Parse the cluster specification @@ -235,7 +246,6 @@ def _parse_cluster_spec() -> ClusterSpec: # Parse the environment environment = cluster_spec_data.get("environment", "cloud") - # Parse the job specification (optional) job_data = cluster_spec_data.get("job") job_spec = None @@ -246,12 +256,10 @@ def _parse_cluster_spec() -> ClusterSpec: job_spec = CustomJobSpec(**job_data) return ClusterSpec( - cluster=cluster_data, - environment=environment, - task=task_info, - job=job_spec + cluster=cluster_data, environment=environment, task=task_info, job=job_spec ) + def _get_leader_worker_internal_ip_file_path() -> str: """ Get the file path to the leader worker's internal IP address. diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index c165bca8b..9241ee400 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -1,13 +1,14 @@ -import os import socket -from dataclasses import dataclass from typing import Optional import torch from gigl.common.logger import Logger +from gigl.common.utils.vertex_ai_context import ( + get_num_storage_and_compute_nodes, + is_currently_running_in_vertex_ai_job, +) from gigl.env.distributed import GraphStoreInfo -from gigl.common.utils.vertex_ai_context import is_currently_running_in_vertex_ai_job, get_num_storage_and_compute_nodes logger = Logger() @@ -185,7 +186,6 @@ def get_internal_ip_from_all_ranks() -> list[str]: return ip_list - def get_graph_store_info() -> GraphStoreInfo: """ Get the information about the graph store cluster. @@ -202,26 +202,21 @@ def get_graph_store_info() -> GraphStoreInfo: if is_currently_running_in_vertex_ai_job(): num_storage_nodes, num_compute_nodes = get_num_storage_and_compute_nodes() else: - raise ValueError("Must be running on a vertex AI job to get graph store cluster info!") + raise ValueError( + "Must be running on a vertex AI job to get graph store cluster info!" + ) cluster_master_ip = get_internal_ip_from_master_node() # We assume that the storage cluster nodes come first. - storage_cluster_master_ip = get_internal_ip_from_node( - node_rank=0 - ) - compute_cluster_master_ip = get_internal_ip_from_node( - node_rank=num_storage_nodes - ) + storage_cluster_master_ip = get_internal_ip_from_node(node_rank=0) + compute_cluster_master_ip = get_internal_ip_from_node(node_rank=num_storage_nodes) cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] - storage_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=0 - )[0] + storage_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] compute_cluster_master_port = get_free_ports_from_node( num_ports=1, node_rank=num_storage_nodes )[0] - return GraphStoreInfo( num_cluster_nodes=num_storage_nodes + num_compute_nodes, num_storage_nodes=num_storage_nodes, diff --git a/python/tests/integration/common/services/vertex_ai_test.py b/python/tests/integration/common/services/vertex_ai_test.py index f7c07df26..371bac424 100644 --- a/python/tests/integration/common/services/vertex_ai_test.py +++ b/python/tests/integration/common/services/vertex_ai_test.py @@ -11,6 +11,7 @@ from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService from gigl.env.pipelines_config import get_resource_config + @kfp.dsl.component def source() -> int: return 42 diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py index 0d8873be1..f4dbd04c2 100644 --- a/python/tests/integration/distributed/utils/networking_test.py +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -1,5 +1,6 @@ import unittest import uuid +from textwrap import dedent from parameterized import param, parameterized @@ -54,7 +55,21 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): command = [ "python", "-c", - "from gigl.distributed.utils import get_graph_store_info; get_graph_store_info()", + dedent( + f""" + from gigl.distributed.utils import get_graph_store_info + info = get_graph_store_info() + assert info.num_storage_nodes == {storage_nodes} + assert info.num_compute_nodes == {compute_nodes} + assert info.num_cluster_nodes == {storage_nodes + compute_nodes} + assert info.cluster_master_ip is not None + assert info.storage_cluster_master_ip is not None + assert info.compute_cluster_master_ip is not None + assert info.cluster_master_port is not None + assert info.storage_cluster_master_port is not None + assert info.compute_cluster_master_port is not None + """ + ), ] storage_cluster_config = VertexAiJobConfig( job_name=job_name, diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index cdc2752d4..825a587e9 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -1,8 +1,8 @@ +import json import os import subprocess import unittest from typing import Optional -import json from unittest.mock import patch import torch @@ -195,7 +195,7 @@ def tearDown(self): ), ] ) - def test_get_free_ports_from_master_node_two_ranks( + def _test_get_free_ports_from_master_node_two_ranks( self, _name, num_ports, world_size ): init_process_group_init_method = get_process_group_init_method() @@ -223,7 +223,7 @@ def test_get_free_ports_from_master_node_two_ranks( ), ] ) - def test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( + def _test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( self, _name, num_ports, world_size, master_node_rank, ports ): init_process_group_init_method = get_process_group_init_method() @@ -239,14 +239,14 @@ def test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( nprocs=world_size, ) - def test_get_free_ports_from_master_fails_if_process_group_not_initialized(self): + def _test_get_free_ports_from_master_fails_if_process_group_not_initialized(self): with self.assertRaises( AssertionError, msg="An error should be raised since the `dist.init_process_group` is not initialized", ): get_free_ports_from_master_node(num_ports=1) - def test_get_internal_ip_from_master_node(self): + def _test_get_internal_ip_from_master_node(self): init_process_group_init_method = get_process_group_init_method() expected_host_ip = subprocess.check_output(["hostname", "-i"]).decode().strip() world_size = 2 @@ -270,7 +270,7 @@ def test_get_internal_ip_from_master_node(self): ), ] ) - def test_get_internal_ip_from_master_node_with_master_node_rank( + def _test_get_internal_ip_from_master_node_with_master_node_rank( self, _, world_size, master_node_rank ): init_process_group_init_method = get_process_group_init_method() @@ -286,7 +286,7 @@ def test_get_internal_ip_from_master_node_with_master_node_rank( nprocs=world_size, ) - def test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized( + def _test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized( self, ): with self.assertRaises( @@ -413,7 +413,8 @@ def _test_get_graph_store_info_in_dist_context( finally: dist.destroy_process_group() -def _get_cluster_spec_for_test(worker_pool_sizes: list[int]): + +def _get_cluster_spec_for_test(worker_pool_sizes: list[int]) -> dict: cluster_spec: dict = { "environment": "cloud", "task": { @@ -422,10 +423,13 @@ def _get_cluster_spec_for_test(worker_pool_sizes: list[int]): }, "cluster": {}, } - for i,worker_pool_size in enumerate(worker_pool_sizes): - cluster_spec["cluster"][f"workerpool{worker_pool_size}"] = [f"workerpool{i}-{j}:2222" for j in range(worker_pool_size)] + for i, worker_pool_size in enumerate(worker_pool_sizes): + cluster_spec["cluster"][f"workerpool{i}"] = [ + f"workerpool{i}-{j}:2222" for j in range(worker_pool_size) + ] return cluster_spec + class TestGetGraphStoreInfo(unittest.TestCase): """Test suite for get_graph_store_info function.""" @@ -434,7 +438,7 @@ def tearDown(self): if dist.is_initialized(): dist.destroy_process_group() - def test_get_graph_store_info_fails_when_distributed_not_initialized(self): + def _test_get_graph_store_info_fails_when_distributed_not_initialized(self): """Test that get_graph_store_info fails when distributed environment is not initialized.""" with self.assertRaises(ValueError) as context: get_graph_store_info() @@ -443,7 +447,7 @@ def test_get_graph_store_info_fails_when_distributed_not_initialized(self): "Distributed environment must be initialized", str(context.exception) ) - def test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): + def _test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): """Test that get_graph_store_info fails when not running in a Vertex AI job.""" init_process_group_init_method = get_process_group_init_method() torch.distributed.init_process_group( @@ -456,7 +460,8 @@ def test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): get_graph_store_info() self.assertIn( - "Must be running on a vertex AI job to get graph store cluster info!", str(context.exception) + "Must be running on a vertex AI job to get graph store cluster info!", + str(context.exception), ) @parameterized.expand( @@ -484,10 +489,17 @@ def test_get_graph_store_info_success_in_distributed_context( """Test successful execution of get_graph_store_info in a real distributed context.""" init_process_group_init_method = get_process_group_init_method() world_size = storage_nodes + compute_nodes + if storage_nodes == 1: + worker_pool_sizes = [1, 0, 0, compute_nodes] + else: + worker_pool_sizes = [1, storage_nodes - 1, 0, compute_nodes] with patch.dict( os.environ, { - "CLUSTER_SPEC": json.dumps(_get_cluster_spec_for_test([min(1, storage_nodes - 1), min(0, storage_nodes - 1), 0, compute_nodes])), + "CLUSTER_SPEC": json.dumps( + _get_cluster_spec_for_test(worker_pool_sizes) + ), + "CLOUD_ML_JOB_ID": "test_job_id", }, clear=False, ): From c2b5607e9a65172ab3e07adc15f9a59efdcaed80 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 03:42:32 +0000 Subject: [PATCH 17/33] fix --- python/gigl/common/utils/vertex_ai_context.py | 2 +- python/gigl/distributed/dist_context.py | 20 +-- python/gigl/env/distributed.py | 18 ++ .../common/utils/vertex_ai_context_test.py | 159 +++++++++++++++++- 4 files changed, 176 insertions(+), 23 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index a5ff2558e..d155e29ac 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -13,7 +13,7 @@ from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils -from gigl.distributed.dist_context import DistributedContext +from gigl.env.distributed import DistributedContext logger = Logger() diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index da513ab87..d9efa883d 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -1,19 +1,3 @@ -from dataclasses import dataclass +from gigl.env.distributed import DistributedContext - -@dataclass(frozen=True) -class DistributedContext: - """ - GiGL Distributed Context - """ - - # TODO (mkolodner-sc): Investigate adding local rank and local world size - - # Main Worker's IP Address for RPC communication - main_worker_ip_address: str - - # Rank of machine - global_rank: int - - # Total number of machines - global_world_size: int +__all__ = ["DistributedContext"] diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index 722fe7778..2ed94cba0 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -27,3 +27,21 @@ class GraphStoreInfo: storage_cluster_master_port: int # Port of the master node for the compute cluster compute_cluster_master_port: int + + +@dataclass(frozen=True) +class DistributedContext: + """ + GiGL Distributed Context + """ + + # TODO (mkolodner-sc): Investigate adding local rank and local world size + + # Main Worker's IP Address for RPC communication + main_worker_ip_address: str + + # Rank of machine + global_rank: int + + # Total number of machines + global_world_size: int diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index d3a5132b2..7c4649e14 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -1,3 +1,4 @@ +import json import os import unittest from unittest.mock import call, patch @@ -5,17 +6,17 @@ from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - DistributedContext, + _parse_cluster_spec, connect_worker_pool, get_host_name, get_leader_hostname, get_leader_port, + get_num_storage_and_compute_nodes, get_rank, get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, ) -from gigl.distributed import DistributedContext class TestVertexAIContext(unittest.TestCase): @@ -76,7 +77,7 @@ def test_throws_if_not_on_vai(self): }, ) def test_connect_worker_pool_leader(self, mock_upload, mock_sleep, mock_subprocess): - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 0) self.assertEqual(distributed_context.global_world_size, 2) @@ -102,7 +103,7 @@ def test_connect_worker_pool_worker( self, mock_upload, mock_read, mock_sleep, mock_subprocess, mock_ping_host ): mock_ping_host.side_effect = [False, True] - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 1) self.assertEqual(distributed_context.global_world_size, 2) @@ -113,6 +114,156 @@ def test_connect_worker_pool_worker( ] ) + def test_get_num_storage_and_compute_nodes_success(self): + """Test successful retrieval of storage and compute node counts.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1", "replica-2"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + num_storage, num_compute = get_num_storage_and_compute_nodes() + self.assertEqual(num_storage, 3) # workerpool0 (2) + workerpool1 (1) + self.assertEqual(num_compute, 3) # workerpool3 (3) + + def test_get_num_storage_and_compute_nodes_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_get_num_storage_and_compute_nodes_invalid_worker_pools(self): + """Test that function raises ValueError when cluster doesn't have 4 worker pools.""" + cluster_spec_json = json.dumps( + { + "cluster": {"workerpool0": ["replica-0"], "workerpool1": ["replica-0"]}, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn( + "Cluster specification must have 4 worker pools", str(context.exception) + ) + self.assertIn("Found 2 worker pools", str(context.exception)) + + def test_parse_cluster_spec_success(self): + """Test successful parsing of cluster specification.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1"], + }, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, + "environment": "cloud", + "job": { + "worker_pool_specs": [ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + }, + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + self.assertEqual( + cluster_spec.cluster["workerpool0"], ["replica-0", "replica-1"] + ) + self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica-0"]) + + # Test task info + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 1) + self.assertEqual(cluster_spec.task.trial, "trial-123") + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec + self.assertIsNotNone(cluster_spec.job) + + def test_parse_cluster_spec_minimal(self): + """Test parsing of minimal cluster specification without optional fields.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + + # Test task info with defaults + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 0) + self.assertIsNone(cluster_spec.task.trial) + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec (should be None when not provided) + self.assertIsNone(cluster_spec.job) + + def test_parse_cluster_spec_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_parse_cluster_spec_missing_cluster_spec(self): + """Test that function raises ValueError when CLUSTER_SPEC is missing.""" + with patch.dict(os.environ, self.VAI_JOB_ENV): + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn( + "CLUSTER_SPEC not found in environment variables", + str(context.exception), + ) + + def test_parse_cluster_spec_invalid_json(self): + """Test that function raises JSONDecodeError for invalid JSON.""" + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} + ): + with self.assertRaises(json.JSONDecodeError) as context: + _parse_cluster_spec() + self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) + if __name__ == "__main__": unittest.main() From 694d72b5c029584ae5d000e873d710aa43c9a021 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 03:45:17 +0000 Subject: [PATCH 18/33] Nightly --- dep_vars.env | 10 +++++----- python/gigl/__init__.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dep_vars.env b/dep_vars.env index 31915b440..81c35dd17 100644 --- a/dep_vars.env +++ b/dep_vars.env @@ -3,11 +3,11 @@ DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external- DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.9 -DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.9 -DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.9.yaml +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9.dev20251007030821 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9.dev20251007030821 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.9.dev20251007030821 +DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.9.dev20251007030821 +DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.9.dev20251007030821.yaml SPARK_31_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark-custom-tfrecord_2.12-0.5.0.jar SPARK_35_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark_3.5.0-custom-tfrecord_2.12-0.6.1.jar diff --git a/python/gigl/__init__.py b/python/gigl/__init__.py index 00ec2dcdb..f1b2287b4 100644 --- a/python/gigl/__init__.py +++ b/python/gigl/__init__.py @@ -1 +1 @@ -__version__ = "0.0.9" +__version__ = "0.0.9.dev20251007030821" From 74d8df17b7e8f8301a45cbd3957e45bc06dece95 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:30:13 +0000 Subject: [PATCH 19/33] Add utils to parse VAI CLUSTER_SPEC --- python/gigl/common/utils/vertex_ai_context.py | 106 +++++++++++- python/gigl/env/distributed.py | 21 +++ .../common/utils/vertex_ai_context_test.py | 159 +++++++++++++++++- 3 files changed, 281 insertions(+), 5 deletions(-) create mode 100644 python/gigl/env/distributed.py diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index dfdf569c1..d155e29ac 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -1,14 +1,19 @@ """Utility functions to be used by machines running on Vertex AI.""" +import json import os import subprocess import time +from dataclasses import dataclass +from typing import Optional + +from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.gcs import GcsUtils -from gigl.distributed import DistributedContext +from gigl.env.distributed import DistributedContext logger = Logger() @@ -156,6 +161,105 @@ def connect_worker_pool() -> DistributedContext: ) +def get_num_storage_and_compute_nodes() -> tuple[int, int]: + """ + Returns the number of storage and compute nodes for a Vertex AI job. + + Raises: + ValueError: If not running in a Vertex AI job. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec = _parse_cluster_spec() + if len(cluster_spec.cluster) != 4: + raise ValueError( + f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools." + ) + num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) + num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) + + return num_storage_nodes, num_compute_nodes + + +@dataclass +class TaskInfo: + """Information about the current task running on this node.""" + + type: str # The type of worker pool this task is running in (e.g., "workerpool0") + index: int # The zero-based index of the task + trial: Optional[ + str + ] = None # Hyperparameter tuning trial identifier (if applicable) + + +@dataclass +class ClusterSpec: + """Represents the cluster specification for a Vertex AI custom job.""" + + cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists + environment: str # The environment string (e.g., "cloud") + task: TaskInfo # Information about the current task + job: Optional[CustomJobSpec] = None # The CustomJobSpec for the current job + + +def _parse_cluster_spec() -> ClusterSpec: + """ + Parse the cluster specification from the CLUSTER_SPEC environment variable. + Based on the spec given at: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + + Returns: + ClusterSpec: Parsed cluster specification data. + + Raises: + ValueError: If not running in a Vertex AI job or CLUSTER_SPEC is not found. + json.JSONDecodeError: If CLUSTER_SPEC contains invalid JSON. + """ + if not is_currently_running_in_vertex_ai_job(): + raise ValueError("Not running in a Vertex AI job.") + + cluster_spec_json = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_json: + raise ValueError("CLUSTER_SPEC not found in environment variables.") + + try: + cluster_spec_data = json.loads(cluster_spec_json) + except json.JSONDecodeError as e: + raise json.JSONDecodeError( + f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos + ) + + # Parse the task information + task_data = cluster_spec_data.get("task", {}) + task_info = TaskInfo( + type=task_data.get("type", ""), + index=task_data.get("index", 0), + trial=task_data.get("trial"), + ) + + # Parse the cluster specification + cluster_data = cluster_spec_data.get("cluster", {}) + + # Parse the environment + environment = cluster_spec_data.get("environment", "cloud") + + # Parse the job specification (optional) + job_data = cluster_spec_data.get("job") + job_spec = None + if job_data: + # Convert the dictionary to CustomJobSpec + # Note: This assumes the job_data is already in the correct format + # You may need to adjust this based on the actual structure + job_spec = CustomJobSpec(**job_data) + + return ClusterSpec( + cluster=cluster_data, environment=environment, task=task_info, job=job_spec + ) + + def _get_leader_worker_internal_ip_file_path() -> str: """ Get the file path to the leader worker's internal IP address. diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py new file mode 100644 index 000000000..84466dde4 --- /dev/null +++ b/python/gigl/env/distributed.py @@ -0,0 +1,21 @@ +"""Information about distributed environments.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DistributedContext: + """ + GiGL Distributed Context + """ + + # TODO (mkolodner-sc): Investigate adding local rank and local world size + + # Main Worker's IP Address for RPC communication + main_worker_ip_address: str + + # Rank of machine + global_rank: int + + # Total number of machines + global_world_size: int diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index d3a5132b2..7c4649e14 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -1,3 +1,4 @@ +import json import os import unittest from unittest.mock import call, patch @@ -5,17 +6,17 @@ from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - DistributedContext, + _parse_cluster_spec, connect_worker_pool, get_host_name, get_leader_hostname, get_leader_port, + get_num_storage_and_compute_nodes, get_rank, get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, ) -from gigl.distributed import DistributedContext class TestVertexAIContext(unittest.TestCase): @@ -76,7 +77,7 @@ def test_throws_if_not_on_vai(self): }, ) def test_connect_worker_pool_leader(self, mock_upload, mock_sleep, mock_subprocess): - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 0) self.assertEqual(distributed_context.global_world_size, 2) @@ -102,7 +103,7 @@ def test_connect_worker_pool_worker( self, mock_upload, mock_read, mock_sleep, mock_subprocess, mock_ping_host ): mock_ping_host.side_effect = [False, True] - distributed_context: DistributedContext = connect_worker_pool() + distributed_context = connect_worker_pool() self.assertEqual(distributed_context.main_worker_ip_address, "127.0.0.1") self.assertEqual(distributed_context.global_rank, 1) self.assertEqual(distributed_context.global_world_size, 2) @@ -113,6 +114,156 @@ def test_connect_worker_pool_worker( ] ) + def test_get_num_storage_and_compute_nodes_success(self): + """Test successful retrieval of storage and compute node counts.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1", "replica-2"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + num_storage, num_compute = get_num_storage_and_compute_nodes() + self.assertEqual(num_storage, 3) # workerpool0 (2) + workerpool1 (1) + self.assertEqual(num_compute, 3) # workerpool3 (3) + + def test_get_num_storage_and_compute_nodes_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_get_num_storage_and_compute_nodes_invalid_worker_pools(self): + """Test that function raises ValueError when cluster doesn't have 4 worker pools.""" + cluster_spec_json = json.dumps( + { + "cluster": {"workerpool0": ["replica-0"], "workerpool1": ["replica-0"]}, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + with self.assertRaises(ValueError) as context: + get_num_storage_and_compute_nodes() + self.assertIn( + "Cluster specification must have 4 worker pools", str(context.exception) + ) + self.assertIn("Found 2 worker pools", str(context.exception)) + + def test_parse_cluster_spec_success(self): + """Test successful parsing of cluster specification.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0", "replica-1"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0", "replica-1"], + }, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, + "environment": "cloud", + "job": { + "worker_pool_specs": [ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + }, + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + self.assertEqual( + cluster_spec.cluster["workerpool0"], ["replica-0", "replica-1"] + ) + self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica-0"]) + + # Test task info + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 1) + self.assertEqual(cluster_spec.task.trial, "trial-123") + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec + self.assertIsNotNone(cluster_spec.job) + + def test_parse_cluster_spec_minimal(self): + """Test parsing of minimal cluster specification without optional fields.""" + cluster_spec_json = json.dumps( + { + "cluster": { + "workerpool0": ["replica-0"], + "workerpool1": ["replica-0"], + "workerpool2": ["replica-0"], + "workerpool3": ["replica-0"], + }, + "task": {"type": "workerpool0", "index": 0}, + "environment": "cloud", + } + ) + + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} + ): + cluster_spec = _parse_cluster_spec() + + # Test cluster data + self.assertEqual(len(cluster_spec.cluster), 4) + + # Test task info with defaults + self.assertEqual(cluster_spec.task.type, "workerpool0") + self.assertEqual(cluster_spec.task.index, 0) + self.assertIsNone(cluster_spec.task.trial) + + # Test environment + self.assertEqual(cluster_spec.environment, "cloud") + + # Test job spec (should be None when not provided) + self.assertIsNone(cluster_spec.job) + + def test_parse_cluster_spec_not_on_vai(self): + """Test that function raises ValueError when not running in Vertex AI.""" + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn("Not running in a Vertex AI job", str(context.exception)) + + def test_parse_cluster_spec_missing_cluster_spec(self): + """Test that function raises ValueError when CLUSTER_SPEC is missing.""" + with patch.dict(os.environ, self.VAI_JOB_ENV): + with self.assertRaises(ValueError) as context: + _parse_cluster_spec() + self.assertIn( + "CLUSTER_SPEC not found in environment variables", + str(context.exception), + ) + + def test_parse_cluster_spec_invalid_json(self): + """Test that function raises JSONDecodeError for invalid JSON.""" + with patch.dict( + os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} + ): + with self.assertRaises(json.JSONDecodeError) as context: + _parse_cluster_spec() + self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) + if __name__ == "__main__": unittest.main() From de1de6ac470a05634cb65ec5af439c936566b69d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:33:35 +0000 Subject: [PATCH 20/33] comments --- python/gigl/common/utils/vertex_ai_context.py | 38 +++--------- .../common/utils/vertex_ai_context_test.py | 61 ++----------------- 2 files changed, 15 insertions(+), 84 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index d155e29ac..f8572906f 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -161,29 +161,6 @@ def connect_worker_pool() -> DistributedContext: ) -def get_num_storage_and_compute_nodes() -> tuple[int, int]: - """ - Returns the number of storage and compute nodes for a Vertex AI job. - - Raises: - ValueError: If not running in a Vertex AI job. - """ - if not is_currently_running_in_vertex_ai_job(): - raise ValueError("Not running in a Vertex AI job.") - - cluster_spec = _parse_cluster_spec() - if len(cluster_spec.cluster) != 4: - raise ValueError( - f"Cluster specification must have 4 worker pools to fetch the number of storage and compute nodes. Found {len(cluster_spec.cluster)} worker pools." - ) - num_storage_nodes = len(cluster_spec.cluster["workerpool0"]) + len( - cluster_spec.cluster["workerpool1"] - ) - num_compute_nodes = len(cluster_spec.cluster["workerpool3"]) - - return num_storage_nodes, num_compute_nodes - - @dataclass class TaskInfo: """Information about the current task running on this node.""" @@ -197,15 +174,21 @@ class TaskInfo: @dataclass class ClusterSpec: - """Represents the cluster specification for a Vertex AI custom job.""" + """Represents the cluster specification for a Vertex AI custom job. + See the docs for more info: + https://cloud.google.com/vertex-ai/docs/training/distributed-training#cluster-variables + """ cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists environment: str # The environment string (e.g., "cloud") task: TaskInfo # Information about the current task - job: Optional[CustomJobSpec] = None # The CustomJobSpec for the current job + # The CustomJobSpec for the current job + # See the docs for more info: + # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec + job: Optional[CustomJobSpec] = None -def _parse_cluster_spec() -> ClusterSpec: +def parse_cluster_spec() -> ClusterSpec: """ Parse the cluster specification from the CLUSTER_SPEC environment variable. Based on the spec given at: @@ -250,9 +233,6 @@ def _parse_cluster_spec() -> ClusterSpec: job_data = cluster_spec_data.get("job") job_spec = None if job_data: - # Convert the dictionary to CustomJobSpec - # Note: This assumes the job_data is already in the correct format - # You may need to adjust this based on the actual structure job_spec = CustomJobSpec(**job_data) return ClusterSpec( diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 7c4649e14..bcd3a4ca4 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -6,16 +6,15 @@ from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( - _parse_cluster_spec, connect_worker_pool, get_host_name, get_leader_hostname, get_leader_port, - get_num_storage_and_compute_nodes, get_rank, get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, + parse_cluster_spec, ) @@ -114,54 +113,6 @@ def test_connect_worker_pool_worker( ] ) - def test_get_num_storage_and_compute_nodes_success(self): - """Test successful retrieval of storage and compute node counts.""" - cluster_spec_json = json.dumps( - { - "cluster": { - "workerpool0": ["replica-0", "replica-1"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0", "replica-1", "replica-2"], - }, - "task": {"type": "workerpool0", "index": 0}, - "environment": "cloud", - } - ) - - with patch.dict( - os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} - ): - num_storage, num_compute = get_num_storage_and_compute_nodes() - self.assertEqual(num_storage, 3) # workerpool0 (2) + workerpool1 (1) - self.assertEqual(num_compute, 3) # workerpool3 (3) - - def test_get_num_storage_and_compute_nodes_not_on_vai(self): - """Test that function raises ValueError when not running in Vertex AI.""" - with self.assertRaises(ValueError) as context: - get_num_storage_and_compute_nodes() - self.assertIn("Not running in a Vertex AI job", str(context.exception)) - - def test_get_num_storage_and_compute_nodes_invalid_worker_pools(self): - """Test that function raises ValueError when cluster doesn't have 4 worker pools.""" - cluster_spec_json = json.dumps( - { - "cluster": {"workerpool0": ["replica-0"], "workerpool1": ["replica-0"]}, - "task": {"type": "workerpool0", "index": 0}, - "environment": "cloud", - } - ) - - with patch.dict( - os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} - ): - with self.assertRaises(ValueError) as context: - get_num_storage_and_compute_nodes() - self.assertIn( - "Cluster specification must have 4 worker pools", str(context.exception) - ) - self.assertIn("Found 2 worker pools", str(context.exception)) - def test_parse_cluster_spec_success(self): """Test successful parsing of cluster specification.""" cluster_spec_json = json.dumps( @@ -185,7 +136,7 @@ def test_parse_cluster_spec_success(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = _parse_cluster_spec() + cluster_spec = parse_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -223,7 +174,7 @@ def test_parse_cluster_spec_minimal(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = _parse_cluster_spec() + cluster_spec = parse_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -242,14 +193,14 @@ def test_parse_cluster_spec_minimal(self): def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" with self.assertRaises(ValueError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn("Not running in a Vertex AI job", str(context.exception)) def test_parse_cluster_spec_missing_cluster_spec(self): """Test that function raises ValueError when CLUSTER_SPEC is missing.""" with patch.dict(os.environ, self.VAI_JOB_ENV): with self.assertRaises(ValueError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn( "CLUSTER_SPEC not found in environment variables", str(context.exception), @@ -261,7 +212,7 @@ def test_parse_cluster_spec_invalid_json(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} ): with self.assertRaises(json.JSONDecodeError) as context: - _parse_cluster_spec() + parse_cluster_spec() self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) From fc0dca428596d4c8843aa893ca7c38a34d475c52 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:34:23 +0000 Subject: [PATCH 21/33] rename --- python/gigl/common/utils/vertex_ai_context.py | 2 +- .../unit/common/utils/vertex_ai_context_test.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index f8572906f..0f3f032bc 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -188,7 +188,7 @@ class ClusterSpec: job: Optional[CustomJobSpec] = None -def parse_cluster_spec() -> ClusterSpec: +def get_cluster_spec() -> ClusterSpec: """ Parse the cluster specification from the CLUSTER_SPEC environment variable. Based on the spec given at: diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index bcd3a4ca4..8234725f6 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -14,7 +14,7 @@ get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, - parse_cluster_spec, + get_cluster_spec, ) @@ -136,7 +136,7 @@ def test_parse_cluster_spec_success(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = parse_cluster_spec() + cluster_spec = get_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -174,7 +174,7 @@ def test_parse_cluster_spec_minimal(self): with patch.dict( os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): - cluster_spec = parse_cluster_spec() + cluster_spec = get_cluster_spec() # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) @@ -193,14 +193,14 @@ def test_parse_cluster_spec_minimal(self): def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" with self.assertRaises(ValueError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn("Not running in a Vertex AI job", str(context.exception)) def test_parse_cluster_spec_missing_cluster_spec(self): """Test that function raises ValueError when CLUSTER_SPEC is missing.""" with patch.dict(os.environ, self.VAI_JOB_ENV): with self.assertRaises(ValueError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn( "CLUSTER_SPEC not found in environment variables", str(context.exception), @@ -212,7 +212,7 @@ def test_parse_cluster_spec_invalid_json(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": "invalid json"} ): with self.assertRaises(json.JSONDecodeError) as context: - parse_cluster_spec() + get_cluster_spec() self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) From d3319d61c7148ab467f19c94063bbf5bbec2e513 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:40:46 +0000 Subject: [PATCH 22/33] fixes --- python/gigl/common/utils/vertex_ai_context.py | 4 ++-- python/gigl/distributed/dist_context.py | 22 +++++-------------- .../common/utils/vertex_ai_context_test.py | 2 +- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 0f3f032bc..c21571091 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -161,7 +161,7 @@ def connect_worker_pool() -> DistributedContext: ) -@dataclass +@dataclass(frozen=True) class TaskInfo: """Information about the current task running on this node.""" @@ -172,7 +172,7 @@ class TaskInfo: ] = None # Hyperparameter tuning trial identifier (if applicable) -@dataclass +@dataclass(frozen=True) class ClusterSpec: """Represents the cluster specification for a Vertex AI custom job. See the docs for more info: diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index da513ab87..0f222b956 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -1,19 +1,9 @@ -from dataclasses import dataclass +from gigl.env.distributed import DistributedContext +# TODO (mkolodner-sc): Deprecate this file. +__all__ = [ + "DeprecatedDistributedContext", +] -@dataclass(frozen=True) -class DistributedContext: - """ - GiGL Distributed Context - """ - # TODO (mkolodner-sc): Investigate adding local rank and local world size - - # Main Worker's IP Address for RPC communication - main_worker_ip_address: str - - # Rank of machine - global_rank: int - - # Total number of machines - global_world_size: int +DeprecatedDistributedContext = DistributedContext diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 8234725f6..ee2f39ecd 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -7,6 +7,7 @@ from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( connect_worker_pool, + get_cluster_spec, get_host_name, get_leader_hostname, get_leader_port, @@ -14,7 +15,6 @@ get_vertex_ai_job_id, get_world_size, is_currently_running_in_vertex_ai_job, - get_cluster_spec, ) From 090566427621d031021444c28fb61300560a3167 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:42:28 +0000 Subject: [PATCH 23/33] fixes --- python/gigl/distributed/dist_context.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/gigl/distributed/dist_context.py b/python/gigl/distributed/dist_context.py index 0f222b956..1078bee5a 100644 --- a/python/gigl/distributed/dist_context.py +++ b/python/gigl/distributed/dist_context.py @@ -2,8 +2,5 @@ # TODO (mkolodner-sc): Deprecate this file. __all__ = [ - "DeprecatedDistributedContext", + "DistributedContext", ] - - -DeprecatedDistributedContext = DistributedContext From 112d0ad605ba62b10dc541b70705e3c0d716f9a0 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 7 Oct 2025 15:44:27 +0000 Subject: [PATCH 24/33] fix --- .../common/utils/vertex_ai_context_test.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index ee2f39ecd..28e98b4e6 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -118,10 +118,10 @@ def test_parse_cluster_spec_success(self): cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica-0", "replica-1"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0", "replica-1"], + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], }, "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", @@ -141,9 +141,13 @@ def test_parse_cluster_spec_success(self): # Test cluster data self.assertEqual(len(cluster_spec.cluster), 4) self.assertEqual( - cluster_spec.cluster["workerpool0"], ["replica-0", "replica-1"] + cluster_spec.cluster["workerpool0"], ["replica0-0", "replica0-1"] + ) + self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica1-0"]) + self.assertEqual(cluster_spec.cluster["workerpool2"], ["replica2-0"]) + self.assertEqual( + cluster_spec.cluster["workerpool3"], ["replica3-0", "replica3-1"] ) - self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica-0"]) # Test task info self.assertEqual(cluster_spec.task.type, "workerpool0") @@ -161,10 +165,10 @@ def test_parse_cluster_spec_minimal(self): cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica-0"], - "workerpool1": ["replica-0"], - "workerpool2": ["replica-0"], - "workerpool3": ["replica-0"], + "workerpool0": ["replica0-0"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0"], }, "task": {"type": "workerpool0", "index": 0}, "environment": "cloud", From f621bc7f120fca86d4760856d7c340d8b726d43e Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Oct 2025 16:13:58 +0000 Subject: [PATCH 25/33] address comments --- python/gigl/common/utils/vertex_ai_context.py | 56 ++++++-------- .../common/utils/vertex_ai_context_test.py | 75 +++++++++---------- 2 files changed, 59 insertions(+), 72 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index c21571091..667044435 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Optional +import omegaconf from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri @@ -187,6 +188,25 @@ class ClusterSpec: # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec job: Optional[CustomJobSpec] = None + # We use a custom method for parsing, because we need to handle the DictConfig -> Proto conversion + @classmethod + def from_json(cls, json_str: str) -> "ClusterSpec": + """Instantiates ClusterSpec from an OmegaConf DictConfig.""" + cluster_spec_json = json.loads(json_str) + if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: + job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) + else: + job_spec = None + conf = omegaconf.OmegaConf.create(cluster_spec_json) + if isinstance(conf, omegaconf.ListConfig): + raise ValueError("ListConfig is not supported") + return cls( + cluster=conf.cluster, + environment=conf.environment, + task=conf.task, + job=job_spec, + ) + def get_cluster_spec() -> ClusterSpec: """ @@ -204,40 +224,12 @@ def get_cluster_spec() -> ClusterSpec: if not is_currently_running_in_vertex_ai_job(): raise ValueError("Not running in a Vertex AI job.") - cluster_spec_json = os.environ.get("CLUSTER_SPEC") - if not cluster_spec_json: + cluster_spec_str = os.environ.get("CLUSTER_SPEC") + if not cluster_spec_str: raise ValueError("CLUSTER_SPEC not found in environment variables.") - try: - cluster_spec_data = json.loads(cluster_spec_json) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse CLUSTER_SPEC JSON: {e.msg}", e.doc, e.pos - ) - - # Parse the task information - task_data = cluster_spec_data.get("task", {}) - task_info = TaskInfo( - type=task_data.get("type", ""), - index=task_data.get("index", 0), - trial=task_data.get("trial"), - ) - - # Parse the cluster specification - cluster_data = cluster_spec_data.get("cluster", {}) - - # Parse the environment - environment = cluster_spec_data.get("environment", "cloud") - - # Parse the job specification (optional) - job_data = cluster_spec_data.get("job") - job_spec = None - if job_data: - job_spec = CustomJobSpec(**job_data) - - return ClusterSpec( - cluster=cluster_data, environment=environment, task=task_info, job=job_spec - ) + cluster_spec = ClusterSpec.from_json(cluster_spec_str) + return cluster_spec def _get_leader_worker_internal_ip_file_path() -> str: diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index 28e98b4e6..db6ad7cc7 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -3,9 +3,13 @@ import unittest from unittest.mock import call, patch +from google.cloud.aiplatform_v1.types import CustomJobSpec + from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( + ClusterSpec, + TaskInfo, connect_worker_pool, get_cluster_spec, get_host_name, @@ -137,40 +141,34 @@ def test_parse_cluster_spec_success(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): cluster_spec = get_cluster_spec() - - # Test cluster data - self.assertEqual(len(cluster_spec.cluster), 4) - self.assertEqual( - cluster_spec.cluster["workerpool0"], ["replica0-0", "replica0-1"] - ) - self.assertEqual(cluster_spec.cluster["workerpool1"], ["replica1-0"]) - self.assertEqual(cluster_spec.cluster["workerpool2"], ["replica2-0"]) - self.assertEqual( - cluster_spec.cluster["workerpool3"], ["replica3-0", "replica3-1"] + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + job=CustomJobSpec( + worker_pool_specs=[ + {"machine_spec": {"machine_type": "n1-standard-4"}} + ] + ), ) + self.assertEqual(cluster_spec, expected_cluster_spec) - # Test task info - self.assertEqual(cluster_spec.task.type, "workerpool0") - self.assertEqual(cluster_spec.task.index, 1) - self.assertEqual(cluster_spec.task.trial, "trial-123") - - # Test environment - self.assertEqual(cluster_spec.environment, "cloud") - - # Test job spec - self.assertIsNotNone(cluster_spec.job) - - def test_parse_cluster_spec_minimal(self): - """Test parsing of minimal cluster specification without optional fields.""" + def test_parse_cluster_spec_success_without_job(self): + """Test successful parsing of cluster specification.""" cluster_spec_json = json.dumps( { "cluster": { - "workerpool0": ["replica0-0"], + "workerpool0": ["replica0-0", "replica0-1"], "workerpool1": ["replica1-0"], "workerpool2": ["replica2-0"], - "workerpool3": ["replica3-0"], + "workerpool3": ["replica3-0", "replica3-1"], }, - "task": {"type": "workerpool0", "index": 0}, + "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", } ) @@ -179,20 +177,18 @@ def test_parse_cluster_spec_minimal(self): os.environ, self.VAI_JOB_ENV | {"CLUSTER_SPEC": cluster_spec_json} ): cluster_spec = get_cluster_spec() + expected_cluster_spec = ClusterSpec( + cluster={ + "workerpool0": ["replica0-0", "replica0-1"], + "workerpool1": ["replica1-0"], + "workerpool2": ["replica2-0"], + "workerpool3": ["replica3-0", "replica3-1"], + }, + environment="cloud", + task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), + ) - # Test cluster data - self.assertEqual(len(cluster_spec.cluster), 4) - - # Test task info with defaults - self.assertEqual(cluster_spec.task.type, "workerpool0") - self.assertEqual(cluster_spec.task.index, 0) - self.assertIsNone(cluster_spec.task.trial) - - # Test environment - self.assertEqual(cluster_spec.environment, "cloud") - - # Test job spec (should be None when not provided) - self.assertIsNone(cluster_spec.job) + self.assertEqual(cluster_spec, expected_cluster_spec) def test_parse_cluster_spec_not_on_vai(self): """Test that function raises ValueError when not running in Vertex AI.""" @@ -217,7 +213,6 @@ def test_parse_cluster_spec_invalid_json(self): ): with self.assertRaises(json.JSONDecodeError) as context: get_cluster_spec() - self.assertIn("Failed to parse CLUSTER_SPEC JSON", str(context.exception)) if __name__ == "__main__": From 9b9970638b34097e97459ca8f9a614d8012c575f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Oct 2025 16:17:18 +0000 Subject: [PATCH 26/33] reword --- python/gigl/common/utils/vertex_ai_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 667044435..35d1d90ad 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -188,10 +188,10 @@ class ClusterSpec: # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec job: Optional[CustomJobSpec] = None - # We use a custom method for parsing, because we need to handle the DictConfig -> Proto conversion + # We use a custom method for parsing, because CustomJobSpec is a protobuf message. @classmethod def from_json(cls, json_str: str) -> "ClusterSpec": - """Instantiates ClusterSpec from an OmegaConf DictConfig.""" + """Instantiates ClusterSpec from a JSON string.""" cluster_spec_json = json.loads(json_str) if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) From 86eb8eb82b8859996b30dda93946a13a29231aad Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 8 Oct 2025 22:38:53 +0000 Subject: [PATCH 27/33] merges --- python/gigl/distributed/utils/networking.py | 21 ++++++++++++------- .../unit/distributed/utils/networking_test.py | 11 +++++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 9241ee400..fa0eb9818 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -5,7 +5,7 @@ from gigl.common.logger import Logger from gigl.common.utils.vertex_ai_context import ( - get_num_storage_and_compute_nodes, + get_cluster_spec, is_currently_running_in_vertex_ai_job, ) from gigl.env.distributed import GraphStoreInfo @@ -200,21 +200,26 @@ def get_graph_store_info() -> GraphStoreInfo: if not torch.distributed.is_initialized(): raise ValueError("Distributed environment must be initialized") if is_currently_running_in_vertex_ai_job(): - num_storage_nodes, num_compute_nodes = get_num_storage_and_compute_nodes() + cluster_spec = get_cluster_spec() + # We setup the VAI cluster such that the compute nodes come first, followed by the storage nodes. + num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) + num_storage_nodes = len(cluster_spec.cluster["workerpool2"]) else: raise ValueError( "Must be running on a vertex AI job to get graph store cluster info!" ) cluster_master_ip = get_internal_ip_from_master_node() - # We assume that the storage cluster nodes come first. - storage_cluster_master_ip = get_internal_ip_from_node(node_rank=0) - compute_cluster_master_ip = get_internal_ip_from_node(node_rank=num_storage_nodes) + # We assume that the compute cluster nodes come first, followed by the storage nodes. + compute_cluster_master_ip = get_internal_ip_from_node(node_rank=0) + storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] - storage_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] - compute_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=num_storage_nodes + compute_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + storage_cluster_master_port = get_free_ports_from_node( + num_ports=1, node_rank=num_compute_nodes )[0] return GraphStoreInfo( diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index 825a587e9..b96319e09 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -319,7 +319,6 @@ def _test_get_graph_store_info_in_dist_context( assert isinstance( graph_store_info, GraphStoreInfo ), "Result should be a GraphStoreInfo instance" - # Verify cluster sizes assert ( graph_store_info.num_storage_nodes == storage_nodes @@ -472,9 +471,9 @@ def _test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): compute_nodes=1, ), param( - "Test with 2 storage nodes and 3 compute nodes", + "Test with 2 storage nodes and 1 compute nodes", storage_nodes=2, - compute_nodes=3, + compute_nodes=1, ), param( "Test with 3 storage nodes and 2 compute nodes", @@ -489,10 +488,10 @@ def test_get_graph_store_info_success_in_distributed_context( """Test successful execution of get_graph_store_info in a real distributed context.""" init_process_group_init_method = get_process_group_init_method() world_size = storage_nodes + compute_nodes - if storage_nodes == 1: - worker_pool_sizes = [1, 0, 0, compute_nodes] + if compute_nodes == 1: + worker_pool_sizes = [1, 0, storage_nodes] else: - worker_pool_sizes = [1, storage_nodes - 1, 0, compute_nodes] + worker_pool_sizes = [1, compute_nodes - 1, storage_nodes] with patch.dict( os.environ, { From af12a007d473a5fab48da662e76cc3ec36d89230 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 9 Oct 2025 18:11:42 +0000 Subject: [PATCH 28/33] fix --- python/gigl/env/distributed.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index f21b8a2d1..e8999be67 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -21,11 +21,6 @@ class DistributedContext: global_world_size: int -"""Information about the distributed setup.""" - -from dataclasses import dataclass - - @dataclass(frozen=True) class GraphStoreInfo: """Information about a graph store cluster.""" @@ -50,21 +45,3 @@ class GraphStoreInfo: storage_cluster_master_port: int # Port of the master node for the compute cluster compute_cluster_master_port: int - - -@dataclass(frozen=True) -class DistributedContext: - """ - GiGL Distributed Context - """ - - # TODO (mkolodner-sc): Investigate adding local rank and local world size - - # Main Worker's IP Address for RPC communication - main_worker_ip_address: str - - # Rank of machine - global_rank: int - - # Total number of machines - global_world_size: int From 2c13526695b2ebd36d6da81d2dec33fdf21dab9e Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 9 Oct 2025 22:35:29 +0000 Subject: [PATCH 29/33] test fix --- python/tests/integration/distributed/utils/networking_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py index f4dbd04c2..c577d042c 100644 --- a/python/tests/integration/distributed/utils/networking_test.py +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -57,7 +57,9 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): "-c", dedent( f""" + import torch from gigl.distributed.utils import get_graph_store_info + torch.distributed.init_process_group(backend="gloo") info = get_graph_store_info() assert info.num_storage_nodes == {storage_nodes} assert info.num_compute_nodes == {compute_nodes} From a3dea3101ab59183fda3a1ae6fc2ae3507d72453 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 9 Oct 2025 22:37:04 +0000 Subject: [PATCH 30/33] fix test --- .../distributed/utils/networking_test.py | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py index c577d042c..212ae20e5 100644 --- a/python/tests/integration/distributed/utils/networking_test.py +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -29,24 +29,14 @@ def setUp(self): @parameterized.expand( [ param( - "Test with 1 storage node and 1 compute node", - storage_nodes=1, - compute_nodes=1, - ), - param( - "Test with 2 storage nodes and 1 compute nodes", - storage_nodes=2, + "Test with 1 compute node and 1 storage node", compute_nodes=1, - ), - param( - "Test with 1 storage nodes and 2 compute nodes", storage_nodes=1, - compute_nodes=2, ), param( - "Test with 2 storage nodes and 2 compute nodes", - storage_nodes=2, + "Test with 2 compute nodes and 2 storage nodes", compute_nodes=2, + storage_nodes=2, ), ] ) @@ -73,13 +63,6 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): """ ), ] - storage_cluster_config = VertexAiJobConfig( - job_name=job_name, - container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, - replica_count=storage_nodes, - machine_type="n1-standard-4", - command=command, - ) compute_cluster_config = VertexAiJobConfig( job_name=job_name, container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, # different image for storage and compute @@ -87,7 +70,14 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): command=command, machine_type="n2-standard-8", ) + storage_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + replica_count=storage_nodes, + machine_type="n1-standard-4", + command=command, + ) self._vertex_ai_service.launch_graph_store_job( - storage_cluster_config, compute_cluster_config + compute_cluster_config, storage_cluster_config ) From aaceeee553edf71ed1fe3f52653a3aa030943464 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 10 Oct 2025 16:56:26 +0000 Subject: [PATCH 31/33] fixes --- python/gigl/common/utils/vertex_ai_context.py | 21 +++++++++------- python/gigl/distributed/utils/networking.py | 9 ++++--- .../distributed/utils/networking_test.py | 24 +++++++++---------- .../common/utils/vertex_ai_context_test.py | 14 ++++------- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 35d1d90ad..bc4a93d02 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -8,7 +8,6 @@ from typing import Optional import omegaconf -from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger @@ -183,29 +182,35 @@ class ClusterSpec: cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists environment: str # The environment string (e.g., "cloud") task: TaskInfo # Information about the current task - # The CustomJobSpec for the current job - # See the docs for more info: - # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec - job: Optional[CustomJobSpec] = None - # We use a custom method for parsing, because CustomJobSpec is a protobuf message. + # DESPITE what the docs say, this is *not* a CustomJobSpec. + # It's *sort of* like a PythonPackageSpec, but it's not. + # It has `jobArgs` instead of `args`. + # See an example: + # {"python_module":"","package_uris":[],"job_args":[]} + job: Optional[dict] = None + + # We use a custom method for parsing, the "job" is actually a serialized json string. @classmethod def from_json(cls, json_str: str) -> "ClusterSpec": """Instantiates ClusterSpec from a JSON string.""" cluster_spec_json = json.loads(json_str) if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: - job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) + logger.info(f"Job spec: {cluster_spec_json['job']}") + job_spec = json.loads(cluster_spec_json.pop("job")) else: job_spec = None conf = omegaconf.OmegaConf.create(cluster_spec_json) if isinstance(conf, omegaconf.ListConfig): raise ValueError("ListConfig is not supported") - return cls( + cluster_spec = cls( cluster=conf.cluster, environment=conf.environment, task=conf.task, job=job_spec, ) + logger.info(f"Cluster spec: {cluster_spec}") + return cluster_spec def get_cluster_spec() -> ClusterSpec: diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index fa0eb9818..0f1d8e77a 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -202,9 +202,12 @@ def get_graph_store_info() -> GraphStoreInfo: if is_currently_running_in_vertex_ai_job(): cluster_spec = get_cluster_spec() # We setup the VAI cluster such that the compute nodes come first, followed by the storage nodes. - num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len( - cluster_spec.cluster["workerpool1"] - ) + if "workerpool1" in cluster_spec.cluster: + num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) + else: + num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) num_storage_nodes = len(cluster_spec.cluster["workerpool2"]) else: raise ValueError( diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py index 212ae20e5..bbbf10283 100644 --- a/python/tests/integration/distributed/utils/networking_test.py +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -4,9 +4,9 @@ from parameterized import param, parameterized -from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService from gigl.env.pipelines_config import get_resource_config +from gigl.common.constants import GIGL_RELEASE_SRC_IMAGE_CPU class NetworkingUtlsIntegrationTest(unittest.TestCase): @@ -51,28 +51,28 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): from gigl.distributed.utils import get_graph_store_info torch.distributed.init_process_group(backend="gloo") info = get_graph_store_info() - assert info.num_storage_nodes == {storage_nodes} - assert info.num_compute_nodes == {compute_nodes} - assert info.num_cluster_nodes == {storage_nodes + compute_nodes} - assert info.cluster_master_ip is not None - assert info.storage_cluster_master_ip is not None - assert info.compute_cluster_master_ip is not None - assert info.cluster_master_port is not None - assert info.storage_cluster_master_port is not None - assert info.compute_cluster_master_port is not None + assert info.num_storage_nodes == {storage_nodes}, f"Expected {storage_nodes} storage nodes, but got {{ info.num_storage_nodes }}" + assert info.num_compute_nodes == {compute_nodes}, f"Expected {compute_nodes} compute nodes, but got {{ info.num_compute_nodes }}" + assert info.num_cluster_nodes == {storage_nodes + compute_nodes}, f"Expected {storage_nodes + compute_nodes} cluster nodes, but got {{ info.num_cluster_nodes }}" + assert info.cluster_master_ip is not None, f"Cluster master IP is None" + assert info.storage_cluster_master_ip is not None, f"Storage cluster master IP is None" + assert info.compute_cluster_master_ip is not None, f"Compute cluster master IP is None" + assert info.cluster_master_port is not None, f"Cluster master port is None" + assert info.storage_cluster_master_port is not None, f"Storage cluster master port is None" + assert info.compute_cluster_master_port is not None, f"Compute cluster master port is None" """ ), ] compute_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, # different image for storage and compute + container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, replica_count=compute_nodes, command=command, machine_type="n2-standard-8", ) storage_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, + container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, replica_count=storage_nodes, machine_type="n1-standard-4", command=command, diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index db6ad7cc7..aa7dba6f0 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -3,8 +3,6 @@ import unittest from unittest.mock import call, patch -from google.cloud.aiplatform_v1.types import CustomJobSpec - from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( @@ -129,11 +127,7 @@ def test_parse_cluster_spec_success(self): }, "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", - "job": { - "worker_pool_specs": [ - {"machine_spec": {"machine_type": "n1-standard-4"}} - ] - }, + "job": '{ "worker_pool_specs": [ {"machine_spec": {"machine_type": "n1-standard-4"}}]}', } ) @@ -150,11 +144,11 @@ def test_parse_cluster_spec_success(self): }, environment="cloud", task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), - job=CustomJobSpec( - worker_pool_specs=[ + job={ + "worker_pool_specs": [ {"machine_spec": {"machine_type": "n1-standard-4"}} ] - ), + }, ) self.assertEqual(cluster_spec, expected_cluster_spec) From 1805a8d2d79a2d1ae7adf7ae650feb5c5e6c7a7b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 10 Oct 2025 17:20:39 +0000 Subject: [PATCH 32/33] [AUTOMATED] Bumped version to v0.0.10 --- dep_vars.env | 10 +++++----- python/gigl/__init__.py | 2 +- python/pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dep_vars.env b/dep_vars.env index 81c35dd17..c5aa972b5 100644 --- a/dep_vars.env +++ b/dep_vars.env @@ -3,11 +3,11 @@ DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external- DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9.dev20251007030821 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9.dev20251007030821 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.9.dev20251007030821 -DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.9.dev20251007030821 -DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.9.dev20251007030821.yaml +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.10 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.10 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.10 +DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.10 +DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.10.yaml SPARK_31_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark-custom-tfrecord_2.12-0.5.0.jar SPARK_35_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark_3.5.0-custom-tfrecord_2.12-0.6.1.jar diff --git a/python/gigl/__init__.py b/python/gigl/__init__.py index f1b2287b4..9b36b86cf 100644 --- a/python/gigl/__init__.py +++ b/python/gigl/__init__.py @@ -1 +1 @@ -__version__ = "0.0.9.dev20251007030821" +__version__ = "0.0.10" diff --git a/python/pyproject.toml b/python/pyproject.toml index 5cda3a3cf..e1fde652e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" name = "gigl" description = "GIgantic Graph Learning Library" readme = "README.md" -version = "0.0.9" +version = "0.0.10" requires-python = ">=3.9,<3.10" # Currently we only support python 3.9 as per deps setup below classifiers = [ "Programming Language :: Python", From 11390ea34845a5c77f08810fa3eef80a0ee456e0 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 10 Oct 2025 20:06:38 +0000 Subject: [PATCH 33/33] fix --- .../tests/integration/distributed/utils/networking_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py index bbbf10283..e14eb101b 100644 --- a/python/tests/integration/distributed/utils/networking_test.py +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -4,9 +4,9 @@ from parameterized import param, parameterized +from gigl.common.constants import DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService from gigl.env.pipelines_config import get_resource_config -from gigl.common.constants import GIGL_RELEASE_SRC_IMAGE_CPU class NetworkingUtlsIntegrationTest(unittest.TestCase): @@ -65,14 +65,14 @@ def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): ] compute_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, replica_count=compute_nodes, command=command, machine_type="n2-standard-8", ) storage_cluster_config = VertexAiJobConfig( job_name=job_name, - container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, + container_uri=DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, replica_count=storage_nodes, machine_type="n1-standard-4", command=command,