Skip to content

Commit 51cf915

Browse files
committed
SPARK-3178 - Validate the memory is greater than zero when set from the SPARK_WORKER_MEMORY environment variable or command line without a g or m label. Added unit tests. If memory is 0 an IllegalStateException is thrown. Updated unit tests to mock environment variables by subclassing SparkConf (tip provided by Josh Rosen). Updated WorkerArguments to use SparkConf.getenv instead of System.getenv for reading the SPARK_WORKER_MEMORY environment variable.
1 parent 110fb8b commit 51cf915

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
4141
if (System.getenv("SPARK_WORKER_CORES") != null) {
4242
cores = System.getenv("SPARK_WORKER_CORES").toInt
4343
}
44-
if (System.getenv("SPARK_WORKER_MEMORY") != null) {
45-
memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
44+
if (conf.getenv("SPARK_WORKER_MEMORY") != null) {
45+
memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY"))
4646
}
4747
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
4848
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
@@ -56,6 +56,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
5656

5757
parse(args.toList)
5858

59+
checkWorkerMemory()
60+
5961
def parse(args: List[String]): Unit = args match {
6062
case ("--ip" | "-i") :: value :: tail =>
6163
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -153,4 +155,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
153155
// Leave out 1 GB for the operating system, but don't return a negative memory size
154156
math.max(totalMb - 1024, 512)
155157
}
158+
159+
def checkWorkerMemory(): Unit = {
160+
if (memory <= 0) {
161+
val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
162+
throw new IllegalStateException(message)
163+
}
164+
}
156165
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
19+
package org.apache.spark.deploy.worker
20+
21+
import org.apache.spark.SparkConf
22+
import org.scalatest.FunSuite
23+
24+
25+
class WorkerArgumentsTest extends FunSuite {
26+
27+
test("Memory can't be set to 0 when cmd line args leave off M or G") {
28+
val conf = new SparkConf
29+
val args = Array("-m", "10000", "spark://localhost:0000 ")
30+
intercept[IllegalStateException] {
31+
new WorkerArguments(args, conf)
32+
}
33+
}
34+
35+
36+
test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") {
37+
val args = Array("spark://localhost:0000 ")
38+
39+
class MySparkConf extends SparkConf(false) {
40+
override def getenv(name: String) = {
41+
if (name == "SPARK_WORKER_MEMORY") "50000"
42+
else super.getenv(name)
43+
}
44+
45+
override def clone: SparkConf = {
46+
new MySparkConf().setAll(settings)
47+
}
48+
}
49+
val conf = new MySparkConf()
50+
intercept[IllegalStateException] {
51+
new WorkerArguments(args, conf)
52+
}
53+
}
54+
55+
test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") {
56+
val args = Array("spark://localhost:0000 ")
57+
58+
class MySparkConf extends SparkConf(false) {
59+
override def getenv(name: String) = {
60+
if (name == "SPARK_WORKER_MEMORY") "5G"
61+
else super.getenv(name)
62+
}
63+
64+
override def clone: SparkConf = {
65+
new MySparkConf().setAll(settings)
66+
}
67+
}
68+
val conf = new MySparkConf()
69+
val workerArgs = new WorkerArguments(args, conf)
70+
assert(workerArgs.memory === 5120)
71+
}
72+
73+
test("Memory correctly set from args with M appended to memory value") {
74+
val conf = new SparkConf
75+
val args = Array("-m", "10000M", "spark://localhost:0000 ")
76+
77+
val workerArgs = new WorkerArguments(args, conf)
78+
assert(workerArgs.memory === 10000)
79+
80+
}
81+
82+
}

0 commit comments

Comments
 (0)