Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b9e2bed
checkpoint 1
david-zlai Feb 26, 2025
cfd2c7e
add additional confs
david-zlai Feb 26, 2025
aedf7e8
add some logging and bootstrap actions
david-zlai Feb 26, 2025
8b6f853
successful join job
david-zlai Feb 26, 2025
5e65b01
cleaned up code
david-zlai Feb 27, 2025
4907271
made bootstrap an explicit script
david-zlai Feb 27, 2025
d698de8
Add shell script
david-zlai Feb 27, 2025
24f96d5
ignore test
david-zlai Feb 27, 2025
9219d42
tests
david-zlai Feb 27, 2025
a36db4a
fixes
david-zlai Feb 27, 2025
0640477
remove unneeded file
david-zlai Feb 27, 2025
236cb7c
Support submission to existing cluster.
david-zlai Mar 3, 2025
457e81a
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 3, 2025
7922adf
Fix test
david-zlai Mar 3, 2025
cab5877
ignore test
david-zlai Mar 3, 2025
f0a8fb5
couple of more tests
david-zlai Mar 3, 2025
057ccc3
support files arg and createcluster and jobflowid
david-zlai Mar 6, 2025
cb0ea83
jobflowid should come from env var
david-zlai Mar 6, 2025
dbbbcaa
reduce idle timeout to 1h
david-zlai Mar 6, 2025
0049d0d
update role name
david-zlai Mar 6, 2025
f514245
rename to cluster id
david-zlai Mar 6, 2025
ba46125
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 6, 2025
2fdff7b
resync maven
david-zlai Mar 6, 2025
a791589
Merge branch 'main' into davidhan/emr_submitter
david-zlai Mar 7, 2025
0a1f8b7
rerun maven
david-zlai Mar 7, 2025
1977ce3
todos
david-zlai Mar 7, 2025
006645c
pr comments
david-zlai Mar 7, 2025
f677231
fix test
david-zlai Mar 7, 2025
2caa711
fmt
david-zlai Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cloud_aws/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ shared_libs = [
maven_artifact("software.amazon.awssdk:aws-core"),
maven_artifact("software.amazon.awssdk:sdk-core"),
maven_artifact("software.amazon.awssdk:utils"),
maven_artifact("software.amazon.awssdk:emr"),
maven_artifact("com.google.guava:guava"),
maven_artifact("org.slf4j:slf4j-api"),
maven_artifact("org.apache.hudi:hudi-aws-bundle"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
package ai.chronon.integrations.aws

import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterIdleTimeout
import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterInstanceCount
import ai.chronon.integrations.aws.EmrSubmitter.DefaultClusterInstanceType
import ai.chronon.spark.JobSubmitter
import ai.chronon.spark.JobSubmitterConstants._
import ai.chronon.spark.JobType
import ai.chronon.spark.{SparkJob => TypeSparkJob}
import software.amazon.awssdk.services.emr.EmrClient
import software.amazon.awssdk.services.emr.model.ActionOnFailure
import software.amazon.awssdk.services.emr.model.AddJobFlowStepsRequest
import software.amazon.awssdk.services.emr.model.Application
import software.amazon.awssdk.services.emr.model.AutoTerminationPolicy
import software.amazon.awssdk.services.emr.model.CancelStepsRequest
import software.amazon.awssdk.services.emr.model.ComputeLimits
import software.amazon.awssdk.services.emr.model.ComputeLimitsUnitType
import software.amazon.awssdk.services.emr.model.Configuration
import software.amazon.awssdk.services.emr.model.DescribeStepRequest
import software.amazon.awssdk.services.emr.model.HadoopJarStepConfig
import software.amazon.awssdk.services.emr.model.InstanceGroupConfig
import software.amazon.awssdk.services.emr.model.InstanceRoleType
import software.amazon.awssdk.services.emr.model.JobFlowInstancesConfig
import software.amazon.awssdk.services.emr.model.ManagedScalingPolicy
import software.amazon.awssdk.services.emr.model.RunJobFlowRequest
import software.amazon.awssdk.services.emr.model.StepConfig

import scala.collection.JavaConverters._

class EmrSubmitter(customerId: String, emrClient: EmrClient) extends JobSubmitter {

private val ClusterApplications = List(
"Flink",
"Zeppelin",
"JupyterEnterpriseGateway",
"Hive",
"Hadoop",
"Livy",
"Spark"
)

// TODO: test if this works for Flink
private val DefaultEmrReleaseLabel = "emr-7.2.0"

// Customer specific infra configurations
private val CustomerToSubnetIdMap = Map(
"canary" -> "subnet-085b2af531b50db44"
)
private val CustomerToSecurityGroupIdMap = Map(
"canary" -> "sg-04fb79b5932a41298"
)

private def createClusterRequestBuilder(emrReleaseLabel: String = DefaultEmrReleaseLabel,
clusterIdleTimeout: Int = DefaultClusterIdleTimeout,
masterInstanceType: String = DefaultClusterInstanceType,
slaveInstanceType: String = DefaultClusterInstanceType,
instanceCount: Int = DefaultClusterInstanceCount) = {
Comment on lines +53 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Consider separate parameters for core/master instance types.
The slaveInstanceType parameter is never used; the core group is also using masterInstanceType.

Apply this diff to fix the mismatched instance type:

-                                                  slaveInstanceType: String = DefaultClusterInstanceType,
...
.instanceType(masterInstanceType)
...
.instanceType(masterInstanceType)
+                                                  slaveInstanceType: String = DefaultClusterInstanceType,
...
.instanceType(masterInstanceType)
...
.instanceType(slaveInstanceType)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private def createClusterRequestBuilder(emrReleaseLabel: String = DefaultEmrReleaseLabel,
clusterIdleTimeout: Int = DefaultClusterIdleTimeout,
masterInstanceType: String = DefaultClusterInstanceType,
slaveInstanceType: String = DefaultClusterInstanceType,
instanceCount: Int = DefaultClusterInstanceCount) = {
private def createClusterRequestBuilder(emrReleaseLabel: String = DefaultEmrReleaseLabel,
clusterIdleTimeout: Int = DefaultClusterIdleTimeout,
masterInstanceType: String = DefaultClusterInstanceType,
slaveInstanceType: String = DefaultClusterInstanceType,
instanceCount: Int = DefaultClusterInstanceCount) = {
val builder = ClusterRequestBuilder.create()
.withReleaseLabel(emrReleaseLabel)
.withMasterInstances(
InstanceGroup.builder()
.instanceType(masterInstanceType)
.instanceCount(1)
.build()
)
.withCoreInstances(
InstanceGroup.builder()
// Previously using masterInstanceType, now fixed to use slaveInstanceType
.instanceType(slaveInstanceType)
.instanceCount(instanceCount - 1)
.build()
)
builder
}

val runJobFlowRequestBuilder = RunJobFlowRequest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we split out the cluster creation into its own utility class / methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can but it's pretty piecemeal at the moment. Like for a single RunJobFlowRequest, the cluster creation configurations include:

  • autoTerminationPolicy
  • configurations
  • applications
  • instances
  • releaseLabel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broke out to a new function. can do more cleanup though if we want

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually we will want this as a separate class / utility object, and make that be a separate verb in run.py or the control plane. That way we can separate the job submission from the cluster creation. But this is fine for now!

.builder()
.name(s"job-${java.util.UUID.randomUUID.toString}")

// Cluster infra configurations:
val customerSecurityGroupId = CustomerToSecurityGroupIdMap.getOrElse(
customerId,
throw new RuntimeException(s"No security group id found for $customerId"))
runJobFlowRequestBuilder
.autoTerminationPolicy(
AutoTerminationPolicy
.builder()
.idleTimeout(clusterIdleTimeout.toLong)
.build())
.configurations(
Configuration.builder
.classification("spark-hive-site")
.properties(Map(
"hive.metastore.client.factory.class" -> "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory").asJava)
.build()
)
.applications(ClusterApplications.map(app => Application.builder().name(app).build()): _*)
// TODO: Could make this generalizable. Have logs saved where users want it
.logUri(s"s3://zipline-warehouse-${customerId}/emr/")
.instances(
JobFlowInstancesConfig
.builder()
// Hack: We hardcode the subnet ID and sg id for each customer of Zipline. The subnet gets created from
// Terraform so we'll need to be careful that these don't get accidentally destroyed.
.ec2SubnetId(
CustomerToSubnetIdMap.getOrElse(customerId,
throw new RuntimeException(s"No subnet id found for $customerId")))
.emrManagedMasterSecurityGroup(customerSecurityGroupId)
.emrManagedSlaveSecurityGroup(customerSecurityGroupId)
.instanceGroups(
InstanceGroupConfig
.builder()
.instanceRole(InstanceRoleType.MASTER)
.instanceType(masterInstanceType)
.instanceCount(1)
.build(),
InstanceGroupConfig
.builder()
.instanceRole(InstanceRoleType.CORE)
.instanceType(masterInstanceType)
.instanceCount(1)
.build()
)
.keepJobFlowAliveWhenNoSteps(true) // Keep the cluster alive after the job is done
.build())
.managedScalingPolicy(
ManagedScalingPolicy
.builder()
.computeLimits(
ComputeLimits
.builder()
.maximumCapacityUnits(instanceCount)
.minimumCapacityUnits(1)
.unitType(ComputeLimitsUnitType.INSTANCES)
.build()
)
.build()
Comment on lines +112 to +119
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Align instance groups with compute limits.
You set each group to 1 instance but set maximumCapacityUnits to instanceCount. This can confuse EMR scaling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guess we'll see in practice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai Understood! That's a pragmatic approach. When testing, you might want to monitor the EMR scaling events to see how the cluster responds when jobs require more resources. Let me know if you encounter any unexpected scaling behavior during your testing.

)
// TODO: need to double check that this is how we want our role names to be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still a todo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not anymore. removing

.serviceRole(s"zipline_${customerId}_emr_service_role")
.jobFlowRole(s"zipline_${customerId}_emr_profile_role")
.releaseLabel(emrReleaseLabel)

}

private def createStepConfig(filesToMount: List[String],
mainClass: String,
jarUri: String,
args: String*): StepConfig = {
// Copy files from s3 to cluster
val awsS3CpArgs = filesToMount.map(file => s"aws s3 cp $file /mnt/zipline/")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a todo here?

val sparkSubmitArgs =
List(s"spark-submit --class $mainClass $jarUri ${args.mkString(" ")}")
val finalArgs = List(
"bash",
"-c",
(awsS3CpArgs ++ sparkSubmitArgs).mkString("; \n")
)
println(finalArgs)
StepConfig
.builder()
.name("Run Zipline Job")
.actionOnFailure(ActionOnFailure.CANCEL_AND_WAIT)
.hadoopJarStep(
HadoopJarStepConfig
.builder()
// Using command-runner.jar from AWS:
// https://docs.aws.amazon.com/en_us/emr/latest/ReleaseGuide/emr-spark-submit-step.html
.jar("command-runner.jar")
.args(finalArgs: _*)
.build()
)
.build()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Shell injection risk.
User input is baked into shell commands; sanitize or validate these args to prevent security issues.


override def submit(jobType: JobType,
jobProperties: Map[String, String],
files: List[String],
args: String*): String = {
if (jobProperties.get(ShouldCreateCluster).exists(_.toBoolean)) {
// create cluster
val runJobFlowBuilder = createClusterRequestBuilder(
emrReleaseLabel = jobProperties.getOrElse(EmrReleaseLabel, DefaultEmrReleaseLabel),
clusterIdleTimeout = jobProperties.getOrElse(ClusterIdleTimeout, DefaultClusterIdleTimeout.toString).toInt,
masterInstanceType = jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
slaveInstanceType = jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
instanceCount = jobProperties.getOrElse(ClusterInstanceCount, DefaultClusterInstanceCount.toString).toInt
)

runJobFlowBuilder.steps(createStepConfig(files, jobProperties(MainClass), jobProperties(JarURI), args: _*))

val responseJobId = emrClient.runJobFlow(runJobFlowBuilder.build()).jobFlowId()
println("EMR job id: " + responseJobId)
println(
s"Safe to exit. Follow the job status at: https://console.aws.amazon.com/emr/home#/clusterDetails/$responseJobId")
responseJobId

} else {
// use existing cluster
val existingJobId = jobProperties.getOrElse(ClusterId, throw new RuntimeException("JobFlowId not found"))
val request = AddJobFlowStepsRequest
.builder()
.jobFlowId(existingJobId)
.steps(createStepConfig(files, jobProperties(MainClass), jobProperties(JarURI), args: _*))
.build()

val responseStepId = emrClient.addJobFlowSteps(request).stepIds().get(0)

println("EMR step id: " + responseStepId)
println(
s"Safe to exit. Follow the job status at: https://console.aws.amazon.com/emr/home#/clusterDetails/$existingJobId")
responseStepId
}
}
Comment on lines +158 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Clarify returned IDs.
This returns a cluster ID or step ID depending on ShouldCreateCluster. This ambiguity can confuse callers. Consider unifying or documenting ID usage.


override def status(jobId: String): Unit = {
val describeStepResponse = emrClient.describeStep(DescribeStepRequest.builder().stepId(jobId).build())
val status = describeStepResponse.step().status()
println(status)
}

override def kill(jobId: String): Unit = {
emrClient.cancelSteps(CancelStepsRequest.builder().stepIds(jobId).build())
}
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Status/kill methods ignore cluster IDs.
AWS requires both clusterId and stepId. This might fail if you pass a cluster ID here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

job is an overloaded term huh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-zlai Yes, exactly! The term "job" is definitely overloaded in this context. In AWS EMR:

  • "job" can refer to a Step (a single Spark execution)
  • "job" can also refer to a JobFlow/Cluster
  • The JobSubmitter interface expects a single jobId, but EMR operations need both clusterId and stepId

This terminology mismatch creates the potential issue I flagged. You could address this by either:

  1. Modifying the interface to support both IDs
  2. Using a composite ID format like "clusterId:stepId"
  3. Maintaining an internal mapping in the EmrSubmitter

Which approach would work best for your implementation?

             (\\ 
             \\'\\
              \\'\\     __________
             / '|\\   /         /\
             | '| \\ /         /  \
             | '|  /         /    \  
             | '| /         /      \
             | '/         /        \
             | |         /   ___    \
              \|        /   /  /\    \
               \       /   /  /  \    \
                \     /   /  /    \    \
                 \   /   /  /      \    \
                  \ /___/__/        \____\

✏️ Learnings added
Learnt from: david-zlai
PR: zipline-ai/chronon#439
File: cloud_aws/src/main/scala/ai/chronon/integrations/aws/EmrSubmitter.scala:198-206
Timestamp: 2025-03-07T20:41:11.439Z
Learning: In AWS EMR, the term "job" is ambiguous and can refer to either a Step (single Spark execution) or a JobFlow/Cluster. EMR operations typically require both a clusterId and a stepId, while the JobSubmitter interface expects a single jobId parameter.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

}

object EmrSubmitter {
def apply(): EmrSubmitter = {
val customerId = sys.env.getOrElse("CUSTOMER_ID", throw new Exception("CUSTOMER_ID not set")).toLowerCase

new EmrSubmitter(customerId,
EmrClient
.builder()
.build())
}

private val ClusterInstanceTypeArgKeyword = "--cluster-instance-type"
private val ClusterInstanceCountArgKeyword = "--cluster-instance-count"
private val ClusterIdleTimeoutArgKeyword = "--cluster-idle-timeout"
private val CreateClusterArgKeyword = "--create-cluster"

private val DefaultClusterInstanceType = "m5.xlarge"
private val DefaultClusterInstanceCount = 3
private val DefaultClusterIdleTimeout = 60 * 60 * 1 // 1h in seconds

def main(args: Array[String]): Unit = {
// List of args that are not application args
val internalArgs = Set(
JarUriArgKeyword,
JobTypeArgKeyword,
MainClassKeyword,
FlinkMainJarUriArgKeyword,
FlinkSavepointUriArgKeyword,
ClusterInstanceTypeArgKeyword,
ClusterInstanceCountArgKeyword,
ClusterIdleTimeoutArgKeyword,
FilesArgKeyword,
CreateClusterArgKeyword
)

val userArgs = args.filter(arg => !internalArgs.exists(arg.startsWith))

val jarUri =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's sync up around this - I think there's a way we can cleanly do this consistently without needing to do a bunch of parsing ourselvess (and save ourselves headache).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: follow up pr to use the spark arg parser

args.find(_.startsWith(JarUriArgKeyword)).map(_.split("=")(1)).getOrElse(throw new Exception("Jar URI not found"))
val mainClass = args
.find(_.startsWith(MainClassKeyword))
.map(_.split("=")(1))
.getOrElse(throw new Exception("Main class not found"))
val jobTypeValue = args
.find(_.startsWith(JobTypeArgKeyword))
.map(_.split("=")(1))
.getOrElse(throw new Exception("Job type not found"))
val clusterInstanceType =
args.find(_.startsWith(ClusterInstanceTypeArgKeyword)).map(_.split("=")(1)).getOrElse(DefaultClusterInstanceType)
val clusterInstanceCount = args
.find(_.startsWith(ClusterInstanceCountArgKeyword))
.map(_.split("=")(1))
.getOrElse(DefaultClusterInstanceCount.toString)
val clusterIdleTimeout = args
.find(_.startsWith(ClusterIdleTimeoutArgKeyword))
.map(_.split("=")(1))
.getOrElse(DefaultClusterIdleTimeout.toString)
val createCluster = args.exists(_.startsWith(CreateClusterArgKeyword))

val clusterId = sys.env.get("EMR_CLUSTER_ID")

// search args array for prefix `--gcs_files`
val filesArgs = args.filter(_.startsWith(FilesArgKeyword))
assert(filesArgs.length == 0 || filesArgs.length == 1)

val files = if (filesArgs.isEmpty) {
Array.empty[String]
} else {
filesArgs(0).split("=")(1).split(",")
}

val (jobType, jobProps) = jobTypeValue.toLowerCase match {
case "spark" => {
val baseProps = Map(
MainClass -> mainClass,
JarURI -> jarUri,
ClusterInstanceType -> clusterInstanceType,
ClusterInstanceCount -> clusterInstanceCount,
ClusterIdleTimeout -> clusterIdleTimeout,
ShouldCreateCluster -> createCluster.toString
)

if (!createCluster && clusterId.isDefined) {
(TypeSparkJob, baseProps + (ClusterId -> clusterId.get))
} else {
(TypeSparkJob, baseProps)
}
}
// TODO: add flink
case _ => throw new Exception("Invalid job type")
}

val finalArgs = userArgs

val emrSubmitter = EmrSubmitter()
emrSubmitter.submit(
jobType,
jobProps,
files.toList,
finalArgs: _*
)
}
}
Loading