Skip to content

Commit 7215aa7

Browse files
JoshRosenAndrew Or
authored andcommitted
[SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (master branch PR)
ExecutorClassLoader does not ensure proper cleanup of network connections that it opens. If it fails to load a class, it may leak partially-consumed InputStreams that are connected to the REPL's HTTP class server, causing that server to exhaust its thread pool, which can cause the entire job to hang. See [SPARK-6209](https://issues.apache.org/jira/browse/SPARK-6209) for more details, including a bug reproduction. This patch fixes this issue by ensuring proper cleanup of these resources. It also adds logging for unexpected error cases. This PR is an extended version of #4935 and adds a regression test. Author: Josh Rosen <[email protected]> Closes #4944 from JoshRosen/executorclassloader-leak-master-branch and squashes the following commits: e0e3c25 [Josh Rosen] Wrap try block around getReponseCode; re-enable keep-alive by closing error stream 961c284 [Josh Rosen] Roll back changes that were added to get the regression test to fail 7ee2261 [Josh Rosen] Add a failing regression test e2d70a3 [Josh Rosen] Properly clean up after errors in ExecutorClassLoader
1 parent a8f51b8 commit 7215aa7

File tree

3 files changed

+140
-20
lines changed

3 files changed

+140
-20
lines changed

repl/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@
8484
<artifactId>scalacheck_${scala.binary.version}</artifactId>
8585
<scope>test</scope>
8686
</dependency>
87+
<dependency>
88+
<groupId>org.mockito</groupId>
89+
<artifactId>mockito-all</artifactId>
90+
<scope>test</scope>
91+
</dependency>
8792

8893
<!-- Explicit listing of transitive deps that are shaded. Otherwise, odd compiler crashes. -->
8994
<dependency>

repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.repl
1919

20-
import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
21-
import java.net.{URI, URL, URLEncoder}
22-
import java.util.concurrent.{Executors, ExecutorService}
20+
import java.io.{IOException, ByteArrayOutputStream, InputStream}
21+
import java.net.{HttpURLConnection, URI, URL, URLEncoder}
22+
23+
import scala.util.control.NonFatal
2324

2425
import org.apache.hadoop.fs.{FileSystem, Path}
2526

@@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
4344

4445
val parentLoader = new ParentClassLoader(parent)
4546

47+
// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
48+
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1
49+
4650
// Hadoop FileSystem object for our URI, if it isn't using HTTP
4751
var fileSystem: FileSystem = {
4852
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
@@ -71,37 +75,82 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
7175
}
7276
}
7377

78+
private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
79+
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
80+
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
81+
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
82+
newuri.toURL
83+
} else {
84+
new URL(classUri + "/" + urlEncode(pathInDirectory))
85+
}
86+
val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
87+
SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
88+
// Set the connection timeouts (for testing purposes)
89+
if (httpUrlConnectionTimeoutMillis != -1) {
90+
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
91+
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
92+
}
93+
connection.connect()
94+
try {
95+
if (connection.getResponseCode != 200) {
96+
// Close the error stream so that the connection is eligible for re-use
97+
try {
98+
connection.getErrorStream.close()
99+
} catch {
100+
case ioe: IOException =>
101+
logError("Exception while closing error stream", ioe)
102+
}
103+
throw new ClassNotFoundException(s"Class file not found at URL $url")
104+
} else {
105+
connection.getInputStream
106+
}
107+
} catch {
108+
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
109+
connection.disconnect()
110+
throw e
111+
}
112+
}
113+
114+
private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
115+
val path = new Path(directory, pathInDirectory)
116+
if (fileSystem.exists(path)) {
117+
fileSystem.open(path)
118+
} else {
119+
throw new ClassNotFoundException(s"Class file not found at path $path")
120+
}
121+
}
122+
74123
def findClassLocally(name: String): Option[Class[_]] = {
124+
val pathInDirectory = name.replace('.', '/') + ".class"
125+
var inputStream: InputStream = null
75126
try {
76-
val pathInDirectory = name.replace('.', '/') + ".class"
77-
val inputStream = {
127+
inputStream = {
78128
if (fileSystem != null) {
79-
fileSystem.open(new Path(directory, pathInDirectory))
129+
getClassFileInputStreamFromFileSystem(pathInDirectory)
80130
} else {
81-
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
82-
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
83-
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
84-
newuri.toURL
85-
} else {
86-
new URL(classUri + "/" + urlEncode(pathInDirectory))
87-
}
88-
89-
Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
90-
.getInputStream
131+
getClassFileInputStreamFromHttpServer(pathInDirectory)
91132
}
92133
}
93134
val bytes = readAndTransformClass(name, inputStream)
94-
inputStream.close()
95135
Some(defineClass(name, bytes, 0, bytes.length))
96136
} catch {
97-
case e: FileNotFoundException =>
137+
case e: ClassNotFoundException =>
98138
// We did not find the class
99139
logDebug(s"Did not load class $name from REPL class server at $uri", e)
100140
None
101141
case e: Exception =>
102142
// Something bad happened while checking if the class exists
103143
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
104144
None
145+
} finally {
146+
if (inputStream != null) {
147+
try {
148+
inputStream.close()
149+
} catch {
150+
case e: Exception =>
151+
logError("Exception while closing inputStream", e)
152+
}
153+
}
105154
}
106155
}
107156

repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,33 @@ package org.apache.spark.repl
2020
import java.io.File
2121
import java.net.{URL, URLClassLoader}
2222

23+
import scala.concurrent.duration._
24+
import scala.language.implicitConversions
25+
import scala.language.postfixOps
26+
2327
import org.scalatest.BeforeAndAfterAll
2428
import org.scalatest.FunSuite
29+
import org.scalatest.concurrent.Interruptor
30+
import org.scalatest.concurrent.Timeouts._
31+
import org.scalatest.mock.MockitoSugar
32+
import org.mockito.Mockito._
2533

26-
import org.apache.spark.{SparkConf, TestUtils}
34+
import org.apache.spark._
2735
import org.apache.spark.util.Utils
2836

29-
class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
37+
class ExecutorClassLoaderSuite
38+
extends FunSuite
39+
with BeforeAndAfterAll
40+
with MockitoSugar
41+
with Logging {
3042

3143
val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
3244
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
3345
var tempDir1: File = _
3446
var tempDir2: File = _
3547
var url1: String = _
3648
var urls2: Array[URL] = _
49+
var classServer: HttpServer = _
3750

3851
override def beforeAll() {
3952
super.beforeAll()
@@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
4760

4861
override def afterAll() {
4962
super.afterAll()
63+
if (classServer != null) {
64+
classServer.stop()
65+
}
5066
Utils.deleteRecursively(tempDir1)
5167
Utils.deleteRecursively(tempDir2)
68+
SparkEnv.set(null)
5269
}
5370

5471
test("child first") {
@@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
83100
}
84101
}
85102

103+
test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
104+
// This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
105+
// from the driver's class server would leak a HTTP connection, causing the class server's
106+
// thread / connection pool to be exhausted.
107+
val conf = new SparkConf()
108+
val securityManager = new SecurityManager(conf)
109+
classServer = new HttpServer(conf, tempDir1, securityManager)
110+
classServer.start()
111+
// ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
112+
val mockEnv = mock[SparkEnv]
113+
when(mockEnv.securityManager).thenReturn(securityManager)
114+
SparkEnv.set(mockEnv)
115+
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server
116+
val parentLoader = new URLClassLoader(Array.empty, null)
117+
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
118+
classLoader.httpUrlConnectionTimeoutMillis = 500
119+
// Check that this class loader can actually load classes that exist
120+
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
121+
val fakeClassVersion = fakeClass.toString
122+
assert(fakeClassVersion === "1")
123+
// Try to perform a full GC now, since GC during the test might mask resource leaks
124+
System.gc()
125+
// When the original bug occurs, the test thread becomes blocked in a classloading call
126+
// and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
127+
// shut down the HTTP server when the test times out
128+
val interruptor: Interruptor = new Interruptor {
129+
override def apply(thread: Thread): Unit = {
130+
classServer.stop()
131+
classServer = null
132+
thread.interrupt()
133+
}
134+
}
135+
def tryAndFailToLoadABunchOfClasses(): Unit = {
136+
// The number of trials here should be much larger than Jetty's thread / connection limit
137+
// in order to expose thread or connection leaks
138+
for (i <- 1 to 1000) {
139+
if (Thread.currentThread().isInterrupted) {
140+
throw new InterruptedException()
141+
}
142+
// Incorporate the iteration number into the class name in order to avoid any response
143+
// caching that might be added in the future
144+
intercept[ClassNotFoundException] {
145+
classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
146+
}
147+
}
148+
}
149+
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
150+
}
151+
86152
}

0 commit comments

Comments
 (0)