diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 0c8d9646a2b4..286d9c04f5b9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.k8s.Config.KUBERNETES_FILE_UPLOAD_PATH +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkLauncher import org.apache.spark.resource.ResourceUtils @@ -381,4 +382,32 @@ object KubernetesUtils extends Logging { } } } + + @Since("3.3.0") + def loadFeatureStep(conf: KubernetesConf, className: String): KubernetesFeatureConfigStep = { + val constructors = Utils.classForName(className).getConstructors + // Try to find constructor with only type matched conf or only KubernetesConf conf + val confConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && + (constructor.getParameterTypes()(0) == conf.getClass || + constructor.getParameterTypes()(0) == classOf[KubernetesConf]) + } + // Try to find no param constructor + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + // Throw SparkException if no expected constructor found + if (noParamConstructor.isEmpty && confConstructor.isEmpty) { + throw new SparkException(s"Failed to load feature step: $className, " + + s"the constructor of the feature step should be no param or with only " + + s"${conf.getClass.getSimpleName}/KubernetesConf param.") + } + + val constructor = confConstructor.map { confConstructor => + (conf: KubernetesConf) => confConstructor.newInstance(conf) + }.getOrElse { + (_: KubernetesConf) => noParamConstructor.get.newInstance() + } + constructor(conf).asInstanceOf[KubernetesFeatureConfigStep] + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index f0c78f371d6d..98fbdb175894 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -20,7 +20,6 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.util.Utils private[spark] class KubernetesDriverBuilder { @@ -39,7 +38,7 @@ private[spark] class KubernetesDriverBuilder { val userFeatures = conf.get(Config.KUBERNETES_DRIVER_POD_FEATURE_STEPS) .map { className => - Utils.classForName(className).newInstance().asInstanceOf[KubernetesFeatureConfigStep] + KubernetesUtils.loadFeatureStep(conf, className) } val features = Seq( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 1a62d08a7b41..6bd7ead9e615 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -22,7 +22,6 @@ import org.apache.spark.SecurityManager import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ import org.apache.spark.resource.ResourceProfile -import org.apache.spark.util.Utils private[spark] class KubernetesExecutorBuilder { @@ -43,7 +42,7 @@ private[spark] class KubernetesExecutorBuilder { val userFeatures = conf.get(Config.KUBERNETES_EXECUTOR_POD_FEATURE_STEPS) .map { className => - Utils.classForName(className).newInstance().asInstanceOf[KubernetesFeatureConfigStep] + KubernetesUtils.loadFeatureStep(conf, className) } val features = Seq( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala index ef57a4b86150..283b4a206c7e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala @@ -21,7 +21,8 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep class KubernetesUtilsSuite extends SparkFunSuite { private val HOST = "test-host" @@ -65,4 +66,42 @@ class KubernetesUtilsSuite extends SparkFunSuite { assert(sparkPodWithNoContainerName.pod.getSpec.getHostname == HOST) assert(sparkPodWithNoContainerName.container.getName == null) } + + test("SPARK-37145: feature step load test") { + val execConf: KubernetesExecutorConf = KubernetesTestConf.createExecutorConf() + val driverConf: KubernetesDriverConf = KubernetesTestConf.createDriverConf() + val basicFeatureNames = Seq( + "org.apache.spark.deploy.k8s.TestStep", + "org.apache.spark.deploy.k8s.TestStepWithConf" + ) + val driverFeatureNames = "org.apache.spark.deploy.k8s.TestStepWithDrvConf" + val execFeatureNames = "org.apache.spark.deploy.k8s.TestStepWithExecConf" + + (basicFeatureNames :+ driverFeatureNames).foreach { featureName => + val drvFeatureStep = KubernetesUtils.loadFeatureStep(driverConf, featureName) + assert(drvFeatureStep.isInstanceOf[KubernetesFeatureConfigStep]) + } + + (basicFeatureNames :+ execFeatureNames).foreach { featureName => + val execFeatureStep = KubernetesUtils.loadFeatureStep(execConf, featureName) + assert(execFeatureStep.isInstanceOf[KubernetesFeatureConfigStep]) + } + + val e1 = intercept[SparkException] { + KubernetesUtils.loadFeatureStep(execConf, driverFeatureNames) + } + assert(e1.getMessage.contains(s"Failed to load feature step: $driverFeatureNames")) + assert(e1.getMessage.contains("with only KubernetesExecutorConf/KubernetesConf param")) + + val e2 = intercept[SparkException] { + KubernetesUtils.loadFeatureStep(driverConf, execFeatureNames) + } + assert(e2.getMessage.contains(s"Failed to load feature step: $execFeatureNames")) + assert(e2.getMessage.contains("with only KubernetesDriverConf/KubernetesConf param")) + + val e3 = intercept[ClassNotFoundException] { + KubernetesUtils.loadFeatureStep(execConf, "unknow.class") + } + assert(e3.getMessage.contains("unknow.class")) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala index a8a3ca4eea96..3123211d79a8 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala @@ -66,6 +66,40 @@ abstract class PodBuilderSuite extends SparkFunSuite { assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "so_long_two")) } + test("SPARK-37145: configure a custom test step with base config") { + val client = mockKubernetesClient() + val sparkConf = baseConf.clone() + .set(userFeatureStepsConf.key, + "org.apache.spark.deploy.k8s.TestStepWithConf") + .set(templateFileConf.key, "template-file.yaml") + .set("test-features-key", "test-features-value") + val pod = buildPod(sparkConf, client) + verifyPod(pod) + val metadata = pod.pod.getMetadata + assert(metadata.getAnnotations.containsKey("test-user-feature-annotation")) + } + + test("SPARK-37145: configure a custom test step with driver or executor config") { + val client = mockKubernetesClient() + val (featureSteps, annotationKey) = this.getClass.getSimpleName match { + case "KubernetesDriverBuilderSuite" => + ("org.apache.spark.deploy.k8s.TestStepWithDrvConf", + "test-drv-user-feature-annotation") + case "KubernetesExecutorBuilderSuite" => + ("org.apache.spark.deploy.k8s.TestStepWithExecConf", + "test-exec-user-feature-annotation") + } + val sparkConf = baseConf.clone() + .set(templateFileConf.key, "template-file.yaml") + .set(userFeatureStepsConf.key, featureSteps) + .set("test-features-key", "test-features-value") + val pod = buildPod(sparkConf, client) + verifyPod(pod) + val metadata = pod.pod.getMetadata + assert(metadata.getAnnotations.containsKey(annotationKey)) + assert(metadata.getAnnotations.get(annotationKey) === "test-features-value") + } + test("complain about misconfigured pod template") { val client = mockKubernetesClient( new PodBuilder() @@ -249,3 +283,51 @@ class TestStepTwo extends KubernetesFeatureConfigStep { SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) } } + +/** + * A test user feature step. + */ +class TestStepWithConf(conf: KubernetesConf) extends KubernetesFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + .addToAnnotations("test-user-feature-annotation", conf.get("test-features-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +} + +/** + * A test driver user feature step. + */ +class TestStepWithDrvConf(conf: KubernetesDriverConf) extends KubernetesFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + .addToAnnotations("test-drv-user-feature-annotation", conf.get("test-features-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +} + +/** + * A test executor user feature step. + */ +class TestStepWithExecConf(conf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep { + import io.fabric8.kubernetes.api.model._ + + override def configurePod(pod: SparkPod): SparkPod = { + val k8sPodBuilder = new PodBuilder(pod.pod) + .editOrNewMetadata() + .addToAnnotations("test-exec-user-feature-annotation", conf.get("test-features-key")) + .endMetadata() + val k8sPod = k8sPodBuilder.build() + SparkPod(k8sPod, pod.container) + } +}