diff --git a/server/src/main/scala/org/apache/livy/server/batch/CreateBatchRequest.scala b/server/src/main/scala/org/apache/livy/server/batch/CreateBatchRequest.scala index 53b5e1b76..8129dde70 100644 --- a/server/src/main/scala/org/apache/livy/server/batch/CreateBatchRequest.scala +++ b/server/src/main/scala/org/apache/livy/server/batch/CreateBatchRequest.scala @@ -27,9 +27,9 @@ class CreateBatchRequest { var pyFiles: List[String] = List() var files: List[String] = List() var driverMemory: Option[String] = None - var driverCores: Option[Int] = None + var driverCores: Option[Double] = None var executorMemory: Option[String] = None - var executorCores: Option[Int] = None + var executorCores: Option[Double] = None var numExecutors: Option[Int] = None var archives: List[String] = List() var queue: Option[String] = None diff --git a/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala b/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala index b2f34b008..a7b4a7931 100644 --- a/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala +++ b/server/src/main/scala/org/apache/livy/server/interactive/CreateInteractiveRequest.scala @@ -26,9 +26,9 @@ class CreateInteractiveRequest { var pyFiles: List[String] = List() var files: List[String] = List() var driverMemory: Option[String] = None - var driverCores: Option[Int] = None + var driverCores: Option[Double] = None var executorMemory: Option[String] = None - var executorCores: Option[Int] = None + var executorCores: Option[Double] = None var numExecutors: Option[Int] = None var archives: List[String] = List() var queue: Option[String] = None diff --git a/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala b/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala index 01cbb4c3c..e57361af6 100644 --- a/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala +++ b/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala @@ -92,7 +92,7 @@ class SparkProcessBuilder(livyConf: LivyConf) extends Logging { this } - def driverCores(driverCores: Int): SparkProcessBuilder = { + def driverCores(driverCores: Double): SparkProcessBuilder = { this.driverCores(driverCores.toString) } @@ -104,7 +104,7 @@ class SparkProcessBuilder(livyConf: LivyConf) extends Logging { conf("spark.driver.cores", driverCores) } - def executorCores(executorCores: Int): SparkProcessBuilder = { + def executorCores(executorCores: Double): SparkProcessBuilder = { this.executorCores(executorCores.toString) } diff --git a/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala index 7fef3c2ff..239d2cb32 100644 --- a/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala @@ -50,6 +50,31 @@ class CreateBatchRequestSpec extends FunSpec with LivyBaseUnitTestSuite { assert(req.conf === Map()) } + it("should support integer cores") { + val json = """{ "driverCores" : 1, "executorCores": 2 }""" + val req = mapper.readValue(json, classOf[CreateBatchRequest]) + assert(req.driverCores.get === 1) + assert(req.executorCores.get === 2) + } + + it("should support float cores") { + val json = """{ "driverCores" : 0.1, "executorCores": 0.2 }""" + val req = mapper.readValue(json, classOf[CreateBatchRequest]) + assert(req.driverCores.get === 0.1) + assert(req.executorCores.get === 0.2) + } + + it("should not support string cores") { + val json = """{ "driverCores" : "asdf", "executorCores": "0.2" }""" + val req = mapper.readValue(json, classOf[CreateBatchRequest]) + intercept[ClassCastException] { + req.driverCores.get + } + intercept[ClassCastException] { + req.executorCores.get + } + } + } } diff --git a/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala index b84d98a9c..5ee78c5b9 100644 --- a/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala +++ b/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala @@ -50,6 +50,31 @@ class CreateInteractiveRequestSpec extends FunSpec with LivyBaseUnitTestSuite { assert(req.conf === Map()) } + it("should support integer cores") { + val json = """{ "driverCores" : 1, "executorCores": 2 }""" + val req = mapper.readValue(json, classOf[CreateInteractiveRequest]) + assert(req.driverCores.get === 1) + assert(req.executorCores.get === 2) + } + + it("should support float cores") { + val json = """{ "driverCores" : 0.1, "executorCores": 0.2 }""" + val req = mapper.readValue(json, classOf[CreateInteractiveRequest]) + assert(req.driverCores.get === 0.1) + assert(req.executorCores.get === 0.2) + } + + it("should not support string cores") { + val json = """{ "driverCores" : "asdf", "executorCores": "0.2" }""" + val req = mapper.readValue(json, classOf[CreateInteractiveRequest]) + intercept[ClassCastException] { + req.driverCores.get + } + intercept[ClassCastException] { + req.executorCores.get + } + } + } }