diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 14265f03795f..f6dbeadd96f4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream} import java.net.{InetAddress, InetSocketAddress, SocketException} +import java.net.SocketTimeoutException import java.nio.channels._ import java.util.Arrays import java.util.concurrent.TimeUnit @@ -184,10 +185,18 @@ private[spark] class PythonWorkerFactory( redirectStreamsToStderr(workerProcess.getInputStream, workerProcess.getErrorStream) // Wait for it to connect to our socket, and validate the auth secret. - serverSocketChannel.socket().setSoTimeout(10000) - try { - val socketChannel = serverSocketChannel.accept() + // Wait up to 10 seconds for client to connect. + serverSocketChannel.configureBlocking(false) + val serverSelector = Selector.open() + serverSocketChannel.register(serverSelector, SelectionKey.OP_ACCEPT) + val socketChannel = + if (serverSelector.select(10 * 1000) > 0) { // Wait up to 10 seconds. + serverSocketChannel.accept() + } else { + throw new SocketTimeoutException( + "Timed out while waiting for the Python worker to connect back") + } authHelper.authClient(socketChannel.socket()) val pid = workerProcess.toHandle.pid() if (pid < 0) { diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala new file mode 100644 index 000000000000..34c10bd95ed7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.net.SocketTimeoutException + +// scalastyle:off executioncontextglobal +import scala.concurrent.ExecutionContext.Implicits.global +// scalastyle:on executioncontextglobal +import scala.concurrent.Future +import scala.concurrent.duration._ + +import org.scalatest.matchers.must.Matchers + +import org.apache.spark.SharedSparkContext +import org.apache.spark.SparkException +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ThreadUtils + +// Tests for PythonWorkerFactory. +class PythonWorkerFactorySuite extends SparkFunSuite with Matchers with SharedSparkContext { + + test("createSimpleWorker() fails with a timeout error if worker does not connect back") { + // It verifies that server side times out in accept(), if the worker does not connect back. + // E.g. the worker might fail at the beginning before it tries to connect back. + + val workerFactory = new PythonWorkerFactory( + "python3", "pyspark.testing.non_existing_worker_module", Map.empty + ) + + // Create the worker in a separate thread so that if there is a bug where it does not + // return (accept() used to be blocking), the test doesn't hang for a long time. + val createFuture = Future { + val ex = intercept[SparkException] { + workerFactory.createSimpleWorker(blockingMode = true) // blockingMode doesn't matter. + // NOTE: This takes 10 seconds (which is the accept timeout in PythonWorkerFactory). + // That makes this a bit longish test. + } + assert(ex.getMessage.contains("Python worker failed to connect back")) + assert(ex.getCause.isInstanceOf[SocketTimeoutException]) + } + + // Timeout ensures that the test fails in 5 minutes if createSimplerWorker() doesn't return. + ThreadUtils.awaitReady(createFuture, 5.minutes) + } +}