Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b9e2bed
checkpoint 1
david-zlai Feb 26, 2025
cfd2c7e
add additional confs
david-zlai Feb 26, 2025
aedf7e8
add some logging and bootstrap actions
david-zlai Feb 26, 2025
8b6f853
successful join job
david-zlai Feb 26, 2025
5e65b01
cleaned up code
david-zlai Feb 27, 2025
4907271
made bootstrap an explicit script
david-zlai Feb 27, 2025
d698de8
Add shell script
david-zlai Feb 27, 2025
24f96d5
ignore test
david-zlai Feb 27, 2025
9219d42
tests
david-zlai Feb 27, 2025
a36db4a
fixes
david-zlai Feb 27, 2025
0640477
remove unneeded file
david-zlai Feb 27, 2025
236cb7c
Support submission to existing cluster.
david-zlai Mar 3, 2025
457e81a
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 3, 2025
7922adf
Fix test
david-zlai Mar 3, 2025
cab5877
ignore test
david-zlai Mar 3, 2025
f0a8fb5
couple of more tests
david-zlai Mar 3, 2025
057ccc3
support files arg and createcluster and jobflowid
david-zlai Mar 6, 2025
cb0ea83
jobflowid should come from env var
david-zlai Mar 6, 2025
dbbbcaa
reduce idle timeout to 1h
david-zlai Mar 6, 2025
0049d0d
update role name
david-zlai Mar 6, 2025
f514245
rename to cluster id
david-zlai Mar 6, 2025
ba46125
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 6, 2025
2fdff7b
resync maven
david-zlai Mar 6, 2025
a791589
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 7, 2025
0a1f8b7
rerun maven
david-zlai Mar 7, 2025
1977ce3
todos
david-zlai Mar 7, 2025
006645c
pr comments
david-zlai Mar 7, 2025
f677231
fix test
david-zlai Mar 7, 2025
2caa711
fmt
david-zlai Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cloud_aws/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ scala_library(
maven_artifact("software.amazon.awssdk:aws-core"),
maven_artifact("software.amazon.awssdk:sdk-core"),
maven_artifact("software.amazon.awssdk:utils"),
maven_artifact("software.amazon.awssdk:emr"),
maven_artifact("com.google.guava:guava"),
maven_artifact("org.slf4j:slf4j-api"),
scala_artifact_with_suffix("org.scala-lang.modules:scala-collection-compat"),
Expand All @@ -26,16 +27,21 @@ test_deps = [
":cloud_aws_lib",
"//api:lib",
"//online:lib",
"//spark:lib",
maven_artifact("software.amazon.awssdk:dynamodb"),
maven_artifact("software.amazon.awssdk:regions"),
maven_artifact("software.amazon.awssdk:aws-core"),
maven_artifact("software.amazon.awssdk:sdk-core"),
maven_artifact("software.amazon.awssdk:utils"),
maven_artifact("software.amazon.awssdk:auth"),
maven_artifact("software.amazon.awssdk:emr"),
maven_artifact("software.amazon.awssdk:identity-spi"),
scala_artifact_with_suffix("org.typelevel:cats-core"),
maven_artifact("com.amazonaws:DynamoDBLocal"),
scala_artifact_with_suffix("com.chuusai:shapeless"),
maven_artifact("junit:junit"),
scala_artifact_with_suffix("org.mockito:mockito-scala"),
maven_artifact("org.mockito:mockito-core"),
] + _CIRCE_DEPS + _SCALA_TEST_DEPS

scala_library(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
package ai.chronon.integrations.aws

import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterIdleTimeout
import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterInstanceCount
import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterInstanceType
import ai.chronon.spark.JobSubmitter
import ai.chronon.spark.JobSubmitterConstants._
import ai.chronon.spark.JobType
import ai.chronon.spark.SparkJob
import ai.chronon.spark.{SparkJob => TypeSparkJob}
import software.amazon.awssdk.services.emr.EmrClient
import software.amazon.awssdk.services.emr.model.Application
import software.amazon.awssdk.services.emr.model.AutoTerminationPolicy
import software.amazon.awssdk.services.emr.model.BootstrapActionConfig
import software.amazon.awssdk.services.emr.model.CancelStepsRequest
import software.amazon.awssdk.services.emr.model.Configuration
import software.amazon.awssdk.services.emr.model.DescribeStepRequest
import software.amazon.awssdk.services.emr.model.HadoopJarStepConfig
import software.amazon.awssdk.services.emr.model.JobFlowInstancesConfig
import software.amazon.awssdk.services.emr.model.RunJobFlowRequest
import software.amazon.awssdk.services.emr.model.ScriptBootstrapActionConfig
import software.amazon.awssdk.services.emr.model.StepConfig
import scala.collection.JavaConverters._

class EmrSubmitter(customerId: String, emrClient: EmrClient) extends JobSubmitter {

private val ClusterApplications = List(
"Flink",
"Zeppelin",
"JupyterEnterpriseGateway",
"Hive",
"Hadoop",
"Livy",
"Spark"
)

private val EmrReleaseLabel = "emr-7.2.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something to flag - with Emr 7.2.0 we'll be on Flink 1.20.0 - this is different from our GCP flink version. So we'll need to either build jars with 1.17 and 1.20 / downgrade EMR / install flink 1.20 manually on our GCP clusters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging. I'm leaning towards keeping the emr release label hardcoded and set for users. Don't want users to accidentally set the wrong emr release and run into weird issues.


// Customer specific infra configurations
private val CustomerToSubnetIdMap = Map(
"canary" -> "subnet-085b2af531b50db44"
)
private val CustomerToSecurityGroupIdMap = Map(
"canary" -> "sg-04fb79b5932a41298"
)

private val CopyS3FilesToMntScript = "copy_s3_files.sh"

override def submit(jobType: JobType,
jobProperties: Map[String, String],
files: List[String],
args: String*): String = {

val runJobFlowRequestBuilder = RunJobFlowRequest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we split out the cluster creation into its own utility class / methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can but it's pretty piecemeal at the moment. Like for a single RunJobFlowRequest, the cluster creation configurations include:

  • autoTerminationPolicy
  • configurations
  • applications
  • instances
  • releaseLabel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broke out to a new function. can do more cleanup though if we want

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually we will want this as a separate class / utility object, and make that be a separate verb in run.py or the control plane. That way we can separate the job submission from the cluster creation. But this is fine for now!

.builder()
.name(s"job-${java.util.UUID.randomUUID.toString}")

// Cluster infra configurations:
val customerSecurityGroupId = CustomerToSecurityGroupIdMap.getOrElse(
customerId,
throw new RuntimeException(s"No security group id found for $customerId"))
runJobFlowRequestBuilder
.autoTerminationPolicy(
AutoTerminationPolicy
.builder()
.idleTimeout(jobProperties.getOrElse(ClusterIdleTimeout, s"$DefaultClusterIdleTimeout").toLong)
.build())
.configurations(
Configuration.builder
.classification("spark-hive-site")
.properties(Map(
"hive.metastore.client.factory.class" -> "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory").asJava)
.build()
)
.applications(ClusterApplications.map(app => Application.builder().name(app).build()): _*)
// TODO: Could make this generalizable. Have logs saved where users want it
.logUri(s"s3://zipline-warehouse-${customerId}/emr/")
.instances(
JobFlowInstancesConfig
.builder()
// We may want to make master and slave instance types different in the future
.masterInstanceType(jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType))
.slaveInstanceType(jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType))
// Hack: We hardcode the subnet ID and sg id for each customer of Zipline. The subnet gets created from
// Terraform so we'll need to be careful that these don't get accidentally destroyed.
.ec2SubnetId(
CustomerToSubnetIdMap.getOrElse(customerId,
throw new RuntimeException(s"No subnet id found for $customerId")))
.emrManagedMasterSecurityGroup(customerSecurityGroupId)
.emrManagedSlaveSecurityGroup(customerSecurityGroupId)
.instanceCount(jobProperties.getOrElse(ClusterInstanceCount, DefaultClusterInstanceCount).toInt)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: This is creating static clusters at the moment.

I should change this to use instance groups like the Terraform does https://github.com/zipline-ai/infrastructure/blob/main/base-aws/emr.tf#L29-L34

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.keepJobFlowAliveWhenNoSteps(true) // Keep the cluster alive after the job is done
.build())
// TODO: need to double check that this is how we want our role names to be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still a todo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not anymore. removing

.serviceRole(s"zipline_${customerId}_emr_service_role")
.jobFlowRole(s"zipline_${customerId}_emr_profile")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chewy-zlai @tchow-zlai - thinking ahead but for each individual customer, would this be suitable to expect iam roles in this format?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is how they should be named via the terraform.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that should be zipline_${customerId}_emr_profile_role

.releaseLabel(EmrReleaseLabel)

// Add single step (spark job) to run:
val sparkSubmitArgs =
Seq("spark-submit",
"--class",
jobProperties(MainClass),
jobProperties(JarURI)) ++ args // For EMR, we explicitly spark-submit the job
val stepConfig = StepConfig
.builder()
.name("Zipline Job")
.actionOnFailure("CANCEL_AND_WAIT") // want the cluster to not terminate if the step fails
.hadoopJarStep(
jobType match {
case SparkJob =>
HadoopJarStepConfig
.builder()
// Using command-runner.jar from AWS:
// https://docs.aws.amazon.com/en_us/emr/latest/ReleaseGuide/emr-spark-submit-step.html
.jar("command-runner.jar")
.args(sparkSubmitArgs: _*)
.build()
// TODO: add flink
case _ => throw new IllegalArgumentException("Unsupported job type")
}
)
.build()
runJobFlowRequestBuilder.steps(stepConfig)

// Add bootstrap actions if any
if (files.nonEmpty) {
val artifactsBucket = s"s3://zipline-artifacts-${customerId}/"
val bootstrapActionConfig = BootstrapActionConfig
.builder()
.name("EMR Submitter: Copy S3 Files")
.scriptBootstrapAction(
ScriptBootstrapActionConfig
.builder()
.path(artifactsBucket + CopyS3FilesToMntScript)
.args(files: _*)
.build())
.build()
runJobFlowRequestBuilder.bootstrapActions(bootstrapActionConfig)
Copy link
Contributor Author

@david-zlai david-zlai Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried not doing the bootstrap and just passing the S3 uri's but our code in Chronon expects the files to be locally present: https://docs.google.com/document/d/1bQnTOK8P3Spga2sm9Y1eacQY8H3uwv4asrCnuV_jOws/edit?tab=t.0#heading=h.3urfz1in6chi

@tchow-zlai

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about additional-confs.yaml ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that file was coded into the copy_s3_files.sh script. but i'm making some changes now to explicitly pass that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed that shell script and instead just running the aws s3 cp command as part of the direct args to command-runner.jar

}

val jobFlowResponse = emrClient.runJobFlow(
runJobFlowRequestBuilder.build()
)

jobFlowResponse.jobFlowId()
}

override def status(jobId: String): Unit = {
val describeStepResponse = emrClient.describeStep(DescribeStepRequest.builder().stepId(jobId).build())
val status = describeStepResponse.step().status()
println(status)
}

override def kill(jobId: String): Unit = {
emrClient.cancelSteps(CancelStepsRequest.builder().stepIds(jobId).build())
}
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Status/kill methods ignore cluster IDs.
AWS requires both clusterId and stepId. This might fail if you pass a cluster ID here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

job is an overloaded term huh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai Yes, exactly! The term "job" is definitely overloaded in this context. In AWS EMR:

  • "job" can refer to a Step (a single Spark execution)
  • "job" can also refer to a JobFlow/Cluster
  • The JobSubmitter interface expects a single jobId, but EMR operations need both clusterId and stepId

This terminology mismatch creates the potential issue I flagged. You could address this by either:

  1. Modifying the interface to support both IDs
  2. Using a composite ID format like "clusterId:stepId"
  3. Maintaining an internal mapping in the EmrSubmitter

Which approach would work best for your implementation?

             (\\ 
             \\'\\
              \\'\\     __________
             / '|\\   /         /\
             | '| \\ /         /  \
             | '|  /         /    \  
             | '| /         /      \
             | '/         /        \
             | |         /   ___    \
              \|        /   /  /\    \
               \       /   /  /  \    \
                \     /   /  /    \    \
                 \   /   /  /      \    \
                  \ /___/__/        \____\

✏️ Learnings added
Learnt from: david-zlai
PR: zipline-ai/chronon#439
File: cloud_aws/src/main/scala/ai/chronon/integrations/aws/EmrSubmitter.scala:198-206
Timestamp: 2025-03-07T20:41:11.439Z
Learning: In AWS EMR, the term "job" is ambiguous and can refer to either a Step (single Spark execution) or a JobFlow/Cluster. EMR operations typically require both a clusterId and a stepId, while the JobSubmitter interface expects a single jobId parameter.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

}

object EmrSubmitter {
def apply(): EmrSubmitter = {
val customerId = sys.env.getOrElse("CUSTOMER_ID", throw new Exception("CUSTOMER_ID not set")).toLowerCase

new EmrSubmitter(customerId,
EmrClient
.builder()
.build())
}

private val ClusterInstanceTypeArgKeyword = "--cluster-instance-type"
private val ClusterInstanceCountArgKeyword = "--cluster-instance-count"
private val ClusterIdleTimeoutArgKeyword = "--cluster-idle-timeout"

private val DefaultClusterInstanceType = "m5.xlarge"
private val DefaultClusterInstanceCount = "3"
private val DefaultClusterIdleTimeout = 60 * 60 * 24 * 2 // 2 days in seconds
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 days of idleness then the cluster terminates itself. is that too long for a default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is on the high end..is this setting required to keep the cluster alive for a bit for us to debug any failed jobs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah


def main(args: Array[String]): Unit = {

// List of args that are not application args
val internalArgs = Set(
JarUriArgKeyword,
JobTypeArgKeyword,
MainClassKeyword,
FlinkMainJarUriArgKeyword,
FlinkSavepointUriArgKeyword,
ClusterInstanceTypeArgKeyword,
ClusterInstanceCountArgKeyword,
ClusterIdleTimeoutArgKeyword
)

val userArgs = args.filter(arg => !internalArgs.exists(arg.startsWith))

val jarUri =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's sync up around this - I think there's a way we can cleanly do this consistently without needing to do a bunch of parsing ourselvess (and save ourselves headache).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: follow up pr to use the spark arg parser

args.find(_.startsWith(JarUriArgKeyword)).map(_.split("=")(1)).getOrElse(throw new Exception("Jar URI not found"))
val mainClass = args
.find(_.startsWith(MainClassKeyword))
.map(_.split("=")(1))
.getOrElse(throw new Exception("Main class not found"))
val jobTypeValue = args
.find(_.startsWith(JobTypeArgKeyword))
.map(_.split("=")(1))
.getOrElse(throw new Exception("Job type not found"))
val clusterInstanceType =
args.find(_.startsWith(ClusterInstanceTypeArgKeyword)).map(_.split("=")(1)).getOrElse(DefaultClusterInstanceType)
val clusterInstanceCount = args
.find(_.startsWith(ClusterInstanceCountArgKeyword))
.map(_.split("=")(1))
.getOrElse(DefaultClusterInstanceCount)
val clusterIdleTimeout = args
.find(_.startsWith(ClusterIdleTimeoutArgKeyword))
.map(_.split("=")(1))
.getOrElse(DefaultClusterIdleTimeout.toString)

val (jobType, jobProps) = jobTypeValue.toLowerCase match {
case "spark" => {
val baseProps = Map(
MainClass -> mainClass,
JarURI -> jarUri,
ClusterInstanceType -> clusterInstanceType,
ClusterInstanceCount -> clusterInstanceCount,
ClusterIdleTimeout -> clusterIdleTimeout
)
(TypeSparkJob, baseProps)
}
// TODO: add flink
case _ => throw new Exception("Invalid job type")
}

val finalArgs = userArgs

val emrSubmitter = EmrSubmitter()
val jobId = emrSubmitter.submit(
jobType,
jobProps,
List.empty,
finalArgs: _*
)

println("EMR job id: " + jobId)
println(s"Safe to exit. Follow the job status at: https://console.aws.amazon.com/emr/home#/clusterDetails/$jobId")

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -euxo pipefail
pwd

# Loop through all provided arguments (files). Copies files from S3 to /mnt/zipline/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to copy jars?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to copy files we need to run the job onto the cluster such as additional-conf.yaml and the conf path

for s3_file in "$@"; do
aws s3 cp $s3_file /mnt/zipline/
done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential command injection vulnerability.

Unquoted variable could cause issues with spaces or special characters.

-    aws s3 cp $s3_file /mnt/zipline/
+    aws s3 cp "$s3_file" /mnt/zipline/
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for s3_file in "$@"; do
aws s3 cp $s3_file /mnt/zipline/
done
for s3_file in "$@"; do
aws s3 cp "$s3_file" /mnt/zipline/
done

Loading