diff --git a/api/python/ai/chronon/cli/compile/parse_teams.py b/api/python/ai/chronon/cli/compile/parse_teams.py index 9affd2bba8..5a86fbaee5 100644 --- a/api/python/ai/chronon/cli/compile/parse_teams.py +++ b/api/python/ai/chronon/cli/compile/parse_teams.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, Union from ai.chronon.api.common.ttypes import ( + ClusterConfigProperties, ConfigProperties, EnvironmentVariables, ExecutionInfo, @@ -40,7 +41,6 @@ def import_module_from_file(file_path): def load_teams(conf_root: str, print: bool = True) -> Dict[str, Team]: - teams_file = os.path.join(conf_root, "teams.py") assert os.path.exists( @@ -69,9 +69,7 @@ def load_teams(conf_root: str, print: bool = True) -> Dict[str, Team]: return team_dict - def update_metadata(obj: Any, team_dict: Dict[str, Team]): - assert obj is not None, "Cannot update metadata None object" metadata = obj.metaData @@ -82,15 +80,15 @@ def update_metadata(obj: Any, team_dict: Dict[str, Team]): team = obj.metaData.team assert ( - team is not None + team is not None ), f"Team name is required in metadata for {name}. This usually set by compiler. Internal error." assert ( - team in team_dict + team in team_dict ), f"Team '{team}' not found in teams.py. Please add an entry 🙏" assert ( - _DEFAULT_CONF_TEAM in team_dict + _DEFAULT_CONF_TEAM in team_dict ), f"'{_DEFAULT_CONF_TEAM}' team not found in teams.py, please add an entry 🙏." # Only set the outputNamespace if it hasn't been set already @@ -99,6 +97,7 @@ def update_metadata(obj: Any, team_dict: Dict[str, Team]): if isinstance(obj, Join): join_namespace = obj.metaData.outputNamespace + # set the metadata for each join part and labelParts def set_group_by_metadata(join_part_gb, output_namespace): if join_part_gb is not None: @@ -124,6 +123,7 @@ def set_group_by_metadata(join_part_gb, output_namespace): merge_team_execution_info(metadata, team_dict, team) + def merge_team_execution_info(metadata: MetaData, team_dict: Dict[str, Team], team_name: str): default_team = team_dict.get(_DEFAULT_CONF_TEAM) if not metadata.executionInfo: @@ -143,6 +143,13 @@ def merge_team_execution_info(metadata: MetaData, team_dict: Dict[str, Team], te env_or_config_attribute=EnvOrConfigAttribute.CONFIG, ) + metadata.executionInfo.clusterConf = _merge_mode_maps( + default_team.clusterConf if default_team else {}, + team_dict[team_name].clusterConf, + metadata.executionInfo.clusterConf, + env_or_config_attribute=EnvOrConfigAttribute.CLUSTER_CONFIG, + ) + def _merge_maps(*maps: Optional[Dict[str, str]]): """ @@ -165,10 +172,12 @@ def _merge_maps(*maps: Optional[Dict[str, str]]): class EnvOrConfigAttribute(str, Enum): ENV = "modeEnvironments" CONFIG = "modeConfigs" + CLUSTER_CONFIG = "modeClusterConfigs" + def _merge_mode_maps( - *mode_maps: Optional[Union[EnvironmentVariables, ConfigProperties]], - env_or_config_attribute: EnvOrConfigAttribute, + *mode_maps: Optional[Union[EnvironmentVariables, ConfigProperties, ClusterConfigProperties]], + env_or_config_attribute: EnvOrConfigAttribute, ): """ Merges multiple environment variables into one - with the later maps overriding the earlier ones. @@ -178,14 +187,13 @@ def _merge_mode_maps( def push_common_to_modes(mode_map: Union[EnvironmentVariables, ConfigProperties], mode_key: EnvOrConfigAttribute): final_mode_map = deepcopy(mode_map) common = final_mode_map.common - modes = getattr(final_mode_map, mode_key) + modes = getattr(final_mode_map, mode_key) for _ in modes: modes[_] = _merge_maps( common, modes[_] ) return final_mode_map - filtered_mode_maps = [m for m in mode_maps if m] # Initialize the result with the first mode map @@ -212,7 +220,6 @@ def push_common_to_modes(mode_map: Union[EnvironmentVariables, ConfigProperties] all_modes_keys = list(set(current_modes_keys + incoming_modes_keys)) for mode in all_modes_keys: - current_mode = current_modes.get(mode, {}) # if the incoming_mode is not found, we NEED to default to incoming_common diff --git a/api/python/ai/chronon/repo/cluster.py b/api/python/ai/chronon/repo/cluster.py new file mode 100644 index 0000000000..1efe87993b --- /dev/null +++ b/api/python/ai/chronon/repo/cluster.py @@ -0,0 +1,65 @@ +import json + + +def generate_dataproc_cluster_config(num_workers, project_id, artifact_prefix, master_host_type="n2-highmem-64", + worker_host_type="n2-highmem-16", + subnetwork="default", idle_timeout="7200s", initialization_actions=None, tags=None): + """ + Create a configuration for a Dataproc cluster. + :return: A json string representing the configuration. + """ + if initialization_actions is None: + initialization_actions = [] + return json.dumps({ + "gceClusterConfig": { + "subnetworkUri": subnetwork, + "serviceAccount": "dataproc@" + project_id + ".iam.gserviceaccount.com", + "serviceAccountScopes": [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cloud.useraccounts.readonly", + "https://www.googleapis.com/auth/devstorage.read_write", + "https://www.googleapis.com/auth/logging.write" + ], + "metadata": { + "hive-version": "3.1.2", + "SPARK_BQ_CONNECTOR_URL": "gs://spark-lib/bigquery/spark-3.5-bigquery-0.42.1.jar", + "artifact_prefix": artifact_prefix.rstrip("/"), + }, + "tags": tags or [] + }, + "masterConfig": { + "numInstances": 1, + "machineTypeUri": master_host_type, + "diskConfig": { + "bootDiskType": "pd-standard", + "bootDiskSizeGb": 1024 + } + }, + "workerConfig": { + "numInstances": num_workers, + "machineTypeUri": worker_host_type, + "diskConfig": { + "bootDiskType": "pd-standard", + "bootDiskSizeGb": 64, + "numLocalSsds": 2 + } + }, + "softwareConfig": { + "imageVersion": "2.2.50-debian12", + "optionalComponents": [ + "FLINK", + "JUPYTER", + ], + "properties": { + + } + }, + "initializationActions": [{"executable_file": initialization_action} for initialization_action in ( + (initialization_actions or []) + [artifact_prefix.rstrip("/")+"/scripts/copy_java_security.sh"])], + "endpointConfig": { + "enableHttpPortAccess": True, + }, + "lifecycleConfig": { + "idleDeleteTtl": idle_timeout, + } + }) diff --git a/api/python/ai/chronon/types.py b/api/python/ai/chronon/types.py index 3b974e2e54..04118d1363 100644 --- a/api/python/ai/chronon/types.py +++ b/api/python/ai/chronon/types.py @@ -50,6 +50,7 @@ EnvironmentVariables = common.EnvironmentVariables ConfigProperties = common.ConfigProperties +ClusterConfigProperties = common.ClusterConfigProperties ExecutionInfo = common.ExecutionInfo TableDependency = common.TableDependency diff --git a/api/python/test/canary/teams.py b/api/python/test/canary/teams.py index 2cbf15c45f..2cf816a03b 100644 --- a/api/python/test/canary/teams.py +++ b/api/python/test/canary/teams.py @@ -1,6 +1,7 @@ from ai.chronon.api.ttypes import Team +from ai.chronon.repo.cluster import generate_dataproc_cluster_config from ai.chronon.repo.constants import RunMode -from ai.chronon.types import ConfigProperties, EnvironmentVariables +from ai.chronon.types import ClusterConfigProperties, ConfigProperties, EnvironmentVariables default = Team( description="Default team", @@ -30,7 +31,6 @@ ), ) - test = Team( outputNamespace="test", env=EnvironmentVariables( @@ -64,6 +64,11 @@ "GCP_DATAPROC_CLUSTER_NAME": "zipline-canary-cluster", "GCP_BIGTABLE_INSTANCE_ID": "zipline-canary-instance", }, + modeEnvironments={ + RunMode.UPLOAD: { + "GCP_DATAPROC_CLUSTER_NAME": "zipline-transient-upload-cluster" + } + } ), conf=ConfigProperties( common={ @@ -100,6 +105,15 @@ } } ), + clusterConf=ClusterConfigProperties( + modeClusterConfigs={ + RunMode.UPLOAD: { + "dataproc.config": generate_dataproc_cluster_config(2, "canary-443022", "gs://zipline-artifacts-canary", + worker_host_type="n2-highmem-4", + master_host_type="n2-highmem-8") + } + } + ), ) aws = Team( diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 698f0af733..c670a9f545 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -517,6 +517,7 @@ struct Team { 20: optional common.EnvironmentVariables env 21: optional common.ConfigProperties conf + 22: optional common.ClusterConfigProperties clusterConf } enum DataModel { diff --git a/api/thrift/common.thrift b/api/thrift/common.thrift index 5691af998d..f6caf4bebd 100644 --- a/api/thrift/common.thrift +++ b/api/thrift/common.thrift @@ -57,6 +57,22 @@ struct ConfigProperties { 2: map> modeConfigs = {} } +/** +* Cluster config for different modes of execution as a json string - with "common" applying to all modes +* These are settings for creating a new cluster for running the job +* +* these confs are layered in order of priority +* 1. company file defaults specified in teams.py - in the "common" team +* 2. team wide defaults that apply to all objects in the team folder +* 3. object specific defaults - applies to only the object that are declares them +* +* All the maps from the above three places are merged to create final cluster config +**/ +struct ClusterConfigProperties { + 1: map common = {} + 2: map> modeClusterConfigs = {} +} + struct TableInfo { // fully qualified table name 1: optional string table @@ -122,6 +138,7 @@ struct ExecutionInfo { 2: optional ConfigProperties conf 3: optional i64 dependencyPollIntervalMillis 4: optional i64 healthCheckIntervalMillis + 5: optional ClusterConfigProperties clusterConf # relevant for batch jobs # temporal workflow nodes maintain their own cron schedule diff --git a/cloud_gcp/BUILD.bazel b/cloud_gcp/BUILD.bazel index 28464cd7d5..6b236fa217 100644 --- a/cloud_gcp/BUILD.bazel +++ b/cloud_gcp/BUILD.bazel @@ -30,6 +30,7 @@ shared_deps = [ maven_artifact("com.google.api:gax"), maven_artifact("com.google.guava:guava"), maven_artifact("com.google.protobuf:protobuf-java"), + maven_artifact("com.google.protobuf:protobuf-java-util"), maven_artifact("org.yaml:snakeyaml"), maven_artifact("io.grpc:grpc-netty-shaded"), maven_artifact("org.slf4j:slf4j-api"), diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala index 798d02b58e..26ce1e0c34 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala @@ -3,9 +3,11 @@ import ai.chronon.spark.submission.JobSubmitterConstants._ import ai.chronon.spark.submission.{JobSubmitter, JobType, FlinkJob => TypeFlinkJob, SparkJob => TypeSparkJob} import com.google.api.gax.rpc.ApiException import com.google.cloud.dataproc.v1._ +import com.google.protobuf.util.JsonFormat import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.slf4j.LoggerFactory import org.yaml.snakeyaml.Yaml import scala.io.Source @@ -348,16 +350,19 @@ object DataprocSubmitter { } - private def initializeDataprocSubmitter(): DataprocSubmitter = { + private def initializeDataprocSubmitter(clusterName: String, + maybeClusterConfig: Option[Map[String, String]]): DataprocSubmitter = { val projectId = sys.env.getOrElse(GcpProjectIdEnvVar, throw new Exception(s"$GcpProjectIdEnvVar not set")) val region = sys.env.getOrElse(GcpRegionEnvVar, throw new Exception(s"$GcpRegionEnvVar not set")) - val clusterName = sys.env - .getOrElse(GcpDataprocClusterNameEnvVar, throw new Exception(s"$GcpDataprocClusterNameEnvVar not set")) + val dataprocClient = ClusterControllerClient.create( + ClusterControllerSettings.newBuilder().setEndpoint(s"$region-dataproc.googleapis.com:443").build()) + + val submitterClusterName = getOrCreateCluster(clusterName, maybeClusterConfig, projectId, region, dataprocClient) val submitterConf = SubmitterConf( projectId, region, - clusterName + submitterClusterName ) DataprocSubmitter(submitterConf) } @@ -517,6 +522,137 @@ object DataprocSubmitter { } } + private[cloud_gcp] def getOrCreateCluster(clusterName: String, + maybeClusterConfig: Option[Map[String, String]], + projectId: String, + region: String, + dataprocClient: ClusterControllerClient): String = { + if (clusterName != "") { + try { + val cluster = dataprocClient.getCluster(projectId, region, clusterName) + if (cluster != null && cluster.getStatus.getState == ClusterStatus.State.RUNNING) { + println(s"Dataproc cluster $clusterName already exists and is running.") + clusterName + } else if (maybeClusterConfig.isDefined && maybeClusterConfig.get.contains("dataproc.config")) { + // Print to stderr so that it flushes immediately + System.err.println( + s"Dataproc cluster $clusterName does not exist or is not running. Creating it with the provided config.") + createDataprocCluster(clusterName, + projectId, + region, + dataprocClient, + maybeClusterConfig.get.getOrElse("dataproc.config", "")) + } else { + throw new Exception(s"Dataproc cluster $clusterName does not exist and no cluster config provided.") + } + } catch { + case _: ApiException if maybeClusterConfig.isDefined && maybeClusterConfig.get.contains("dataproc.config") => + // Print to stderr so that it flushes immediately + System.err.println(s"Dataproc cluster $clusterName does not exist. Creating it with the provided config.") + createDataprocCluster(clusterName, + projectId, + region, + dataprocClient, + maybeClusterConfig.get.getOrElse("dataproc.config", "")) + case _: ApiException => + throw new Exception(s"Dataproc cluster $clusterName does not exist and no cluster config provided.") + } + } else if (maybeClusterConfig.isDefined && maybeClusterConfig.get.contains("dataproc.config")) { + // Print to stderr so that it flushes immediately + System.err.println(s"Creating a transient dataproc cluster based on config.") + val transientClusterName = s"zipline-${java.util.UUID.randomUUID()}" + createDataprocCluster(transientClusterName, + projectId, + region, + dataprocClient, + maybeClusterConfig.get.getOrElse("dataproc.config", "")) + } else { + throw new Exception( + s"$GcpDataprocClusterNameEnvVar is not set and no cluster config was provided. " + + s"Please set $GcpDataprocClusterNameEnvVar or provide a cluster config in teams.py.") + } + } + + /** Creates a Dataproc cluster with the given name, project ID, region, and configuration. + * + * @param clusterName The name of the cluster to create. + * @param projectId The GCP project ID. + * @param region The region where the cluster will be created. + * @param dataprocClient The ClusterControllerClient to interact with the Dataproc API. + * @param clusterConfigStr The JSON string representing the cluster configuration. + * @return The name of the created cluster. + */ + private[cloud_gcp] def createDataprocCluster(clusterName: String, + projectId: String, + region: String, + dataprocClient: ClusterControllerClient, + clusterConfigStr: String): String = { + + val builder = ClusterConfig.newBuilder() + val clusterConfig = + try { + JsonFormat.parser().merge(clusterConfigStr, builder) + builder.build() + } catch { + case e: Exception => + throw new IllegalArgumentException(s"Failed to parse JSON: ${e.getMessage}", e) + } + + val cluster: Cluster = Cluster + .newBuilder() + .setClusterName(clusterName) + .setProjectId(projectId) + .setConfig(clusterConfig) + .build() + + val createRequest = CreateClusterRequest + .newBuilder() + .setProjectId(projectId) + .setRegion(region) + .setCluster(cluster) + .build() + + // Asynchronously create the cluster and wait for it to be ready + try { + val operation = dataprocClient + .createClusterAsync(createRequest) + .get(15, java.util.concurrent.TimeUnit.MINUTES) + if (operation == null) { + throw new RuntimeException("Failed to create Dataproc cluster.") + } + println(s"Created Dataproc cluster: $clusterName") + } catch { + case e: java.util.concurrent.TimeoutException => + throw new RuntimeException(s"Timeout waiting for cluster creation: ${e.getMessage}", e) + case e: Exception => + throw new RuntimeException(s"Error creating Dataproc cluster: ${e.getMessage}", e) + } + + // Check status of the cluster creation + var currentStatus = dataprocClient.getCluster(projectId, region, clusterName).getStatus + var currentState = currentStatus.getState + while ( + currentState != ClusterStatus.State.RUNNING && + currentState != ClusterStatus.State.ERROR && + currentState != ClusterStatus.State.STOPPING + ) { + println(s"Waiting for Dataproc cluster $clusterName to be in RUNNING state. Current state: $currentState") + Thread.sleep(30000) // Wait for 30 seconds before checking again + currentStatus = dataprocClient.getCluster(projectId, region, clusterName).getStatus + currentState = currentStatus.getState + } + currentState match { + case ClusterStatus.State.RUNNING => + println(s"Dataproc cluster $clusterName is running.") + clusterName + case ClusterStatus.State.ERROR => + throw new RuntimeException( + s"Failed to create Dataproc cluster $clusterName: ERROR state: ${currentStatus.toString}") + case _ => + throw new RuntimeException(s"Dataproc cluster $clusterName is in unexpected state: $currentState.") + } + } + private[cloud_gcp] def run(args: Array[String], submitter: DataprocSubmitter, envMap: Map[String, Option[String]] = Map.empty): Unit = { @@ -596,7 +732,10 @@ object DataprocSubmitter { } def main(args: Array[String]): Unit = { - val submitter = initializeDataprocSubmitter() + val clusterName = sys.env + .getOrElse(GcpDataprocClusterNameEnvVar, "") + val maybeClusterConfig = JobSubmitter.getClusterConfig(args) + val submitter = initializeDataprocSubmitter(clusterName, maybeClusterConfig) val envMap = Map( GcpBigtableInstanceIdEnvVar -> sys.env.get(GcpBigtableInstanceIdEnvVar), GcpProjectIdEnvVar -> sys.env.get(GcpProjectIdEnvVar) diff --git a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala index d7026f963a..584f85514d 100644 --- a/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala +++ b/cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala @@ -3,6 +3,9 @@ package ai.chronon.integrations.cloud_gcp import ai.chronon.spark import ai.chronon.spark.submission import ai.chronon.spark.submission.JobSubmitterConstants._ +import com.google.api.core.ApiFuture +import com.google.api.gax.longrunning.{OperationFuture, OperationSnapshot} +import com.google.api.gax.retrying.RetryingFuture import com.google.api.gax.rpc.UnaryCallable import com.google.cloud.dataproc.v1.JobControllerClient.ListJobsPagedResponse import com.google.cloud.dataproc.v1._ @@ -15,9 +18,18 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.mockito.MockitoSugar import java.nio.file.Paths +import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar { + def setEnv(key: String, value: String): Unit = { + val env = System.getenv() + val field = env.getClass.getDeclaredField("m") + field.setAccessible(true) + val writableEnv = field.get(env).asInstanceOf[java.util.Map[String, String]] + writableEnv.put(key, value) + } + it should "test buildFlinkJob with the expected flinkStateUri and savepointUri" in { val submitter = new DataprocSubmitter(jobControllerClient = mock[JobControllerClient], conf = SubmitterConf("test-project", "test-region", "test-cluster")) @@ -715,6 +727,92 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar { } + it should "create a Dataproc cluster successfully with a given config" in { + val mockDataprocClient = mock[ClusterControllerClient] + + val mockOperationFuture = mock[OperationFuture[Cluster, ClusterOperationMetadata]] + val mockRetryingFuture = mock[RetryingFuture[OperationSnapshot]] + val mockMetadataFuture = mock[ApiFuture[ClusterOperationMetadata]] + val mockCluster = Cluster + .newBuilder() + .setStatus(ClusterStatus.newBuilder().setState(ClusterStatus.State.RUNNING)) + .build() + + when(mockDataprocClient.createClusterAsync(any[CreateClusterRequest])) + .thenReturn(mockOperationFuture) + when(mockOperationFuture.getPollingFuture).thenReturn(mockRetryingFuture) + when(mockOperationFuture.peekMetadata()).thenReturn(mockMetadataFuture) + when(mockOperationFuture.get(anyLong(), any[TimeUnit])).thenReturn(mockCluster) + + when(mockDataprocClient.createClusterAsync(any[CreateClusterRequest])) + .thenReturn(mockOperationFuture) + + when(mockDataprocClient.getCluster(any[String], any[String], any[String])).thenReturn(mockCluster) + + val region = "test-region" + val projectId = "test-project" + + val clusterConfigStr = """{ + "masterConfig": { + "numInstances": 1, + "machineTypeUri": "n1-standard-4" + }, + "workerConfig": { + "numInstances": 2, + "machineTypeUri": "n1-standard-4" + } + }""" + + val clusterName = + DataprocSubmitter.getOrCreateCluster("", Option(Map("dataproc.config" -> clusterConfigStr)), projectId, region, mockDataprocClient) + + assert(clusterName.startsWith("zipline-")) + verify(mockDataprocClient).createClusterAsync(any()) + } + + it should "not create a new cluster if given name exists" in { + val mockDataprocClient = mock[ClusterControllerClient] + + val mockOperationFuture = mock[OperationFuture[Cluster, ClusterOperationMetadata]] + val mockRetryingFuture = mock[RetryingFuture[OperationSnapshot]] + val mockMetadataFuture = mock[ApiFuture[ClusterOperationMetadata]] + val mockCluster = Cluster + .newBuilder() + .setStatus(ClusterStatus.newBuilder().setState(ClusterStatus.State.RUNNING)) + .build() + + when(mockDataprocClient.createClusterAsync(any[CreateClusterRequest])) + .thenReturn(mockOperationFuture) + when(mockOperationFuture.getPollingFuture).thenReturn(mockRetryingFuture) + when(mockOperationFuture.peekMetadata()).thenReturn(mockMetadataFuture) + when(mockOperationFuture.get(anyLong(), any[TimeUnit])).thenReturn(mockCluster) + + when(mockDataprocClient.createClusterAsync(any[CreateClusterRequest])) + .thenReturn(mockOperationFuture) + + when(mockDataprocClient.getCluster(any[String], any[String], any[String])).thenReturn(mockCluster) + + val region = "test-region" + val projectId = "test-project" + + val clusterConfigStr = """{ + "masterConfig": { + "numInstances": 1, + "machineTypeUri": "n1-standard-4" + }, + "workerConfig": { + "numInstances": 2, + "machineTypeUri": "n1-standard-4" + } + }""" + + val clusterName = + DataprocSubmitter.getOrCreateCluster("test-cluster", Option(Map("dataproc.config" -> clusterConfigStr)), projectId, region, mockDataprocClient) + + assert(clusterName.equals("test-cluster")) + verify(mockDataprocClient, never()).createClusterAsync(any()) + } + it should "test getZiplineVersionOfDataprocJob successfully" in { val jobId = "mock-job-id" val mockJob = mock[Job] diff --git a/spark/src/main/scala/ai/chronon/spark/submission/JobSubmitter.scala b/spark/src/main/scala/ai/chronon/spark/submission/JobSubmitter.scala index 99ebe0cd4d..3c61fbfc36 100644 --- a/spark/src/main/scala/ai/chronon/spark/submission/JobSubmitter.scala +++ b/spark/src/main/scala/ai/chronon/spark/submission/JobSubmitter.scala @@ -98,6 +98,37 @@ object JobSubmitter { modeConfigProperties } + + def getClusterConfig(args: Array[String]): Option[Map[String, String]] = { + val maybeMetadata = getMetadata(args) + val clusterConfig = if (maybeMetadata.isDefined) { + val metadata = maybeMetadata.get + + val executionInfo = Option(metadata.getExecutionInfo) + + if (executionInfo.isEmpty) { + None + } else { + val originalMode = getArgValue(args, OriginalModeArgKeyword) + + (Option(executionInfo.get.clusterConf), originalMode) match { + case (Some(clusterConf), Some(mode)) => + val modeConfig = + if (clusterConf.isSetModeClusterConfigs && clusterConf.getModeClusterConfigs.containsKey(mode)) { + clusterConf.getModeClusterConfigs.get(mode).toScala + } else if (clusterConf.isSetCommon) { + clusterConf.getCommon.toScala + } else { + Map[String, String]() + } + Option(modeConfig) + case _ => None + } + } + } else None + clusterConfig + } + } abstract class JobAuth {