diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index b5f0ecd9a2fd..2db369273b50 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -22,9 +22,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration._ -import scala.util.{Failure, Success} import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils @@ -32,7 +29,6 @@ import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester -import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SPARK_VERSION, SparkException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder @@ -45,7 +41,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.ThreadUtils class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -952,71 +947,6 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM } } } - - test("interrupt all - background queries, foreground interrupt") { - val session = spark - import session.implicits._ - implicit val ec = ExecutionContext.global - val q1 = Future { - spark.range(10).map(n => { Thread.sleep(30000); n }).collect() - } - val q2 = Future { - spark.range(10).map(n => { Thread.sleep(30000); n }).collect() - } - var q1Interrupted = false - var q2Interrupted = false - var error: Option[String] = None - q1.onComplete { - case Success(_) => - error = Some("q1 shouldn't have finished!") - case Failure(t) if t.getMessage.contains("cancelled") => - q1Interrupted = true - case Failure(t) => - error = Some("unexpected failure in q1: " + t.toString) - } - q2.onComplete { - case Success(_) => - error = Some("q2 shouldn't have finished!") - case Failure(t) if t.getMessage.contains("cancelled") => - q2Interrupted = true - case Failure(t) => - error = Some("unexpected failure in q2: " + t.toString) - } - // 20 seconds is < 30 seconds the queries should be running, - // because it should be interrupted sooner - eventually(timeout(20.seconds), interval(1.seconds)) { - // keep interrupting every second, until both queries get interrupted. - spark.interruptAll() - assert(error.isEmpty, s"Error not empty: $error") - assert(q1Interrupted) - assert(q2Interrupted) - } - } - - test("interrupt all - foreground queries, background interrupt") { - val session = spark - import session.implicits._ - implicit val ec = ExecutionContext.global - - @volatile var finished = false - val interruptor = Future { - eventually(timeout(20.seconds), interval(1.seconds)) { - spark.interruptAll() - assert(finished) - } - finished - } - val e1 = intercept[io.grpc.StatusRuntimeException] { - spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() - } - assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1") - val e2 = intercept[io.grpc.StatusRuntimeException] { - spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() - } - assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2") - finished = true - assert(ThreadUtils.awaitResult(interruptor, 10.seconds) == true) - } } private[sql] case class MyType(id: Long, a: Double, b: Double) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala new file mode 100644 index 000000000000..fb295c00edc0 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.util.ThreadUtils + +/** + * NOTE: Do not import classes that only exist in `spark-connect-client-jvm.jar` into the this + * class, whether explicit or implicit, as it will trigger a UDF deserialization error during + * Maven build/test. + */ +class SparkSessionE2ESuite extends RemoteSparkSession { + + test("interrupt all - background queries, foreground interrupt") { + val session = spark + import session.implicits._ + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val q1 = Future { + spark.range(10).map(n => { Thread.sleep(30000); n }).collect() + } + val q2 = Future { + spark.range(10).map(n => { Thread.sleep(30000); n }).collect() + } + var q1Interrupted = false + var q2Interrupted = false + var error: Option[String] = None + q1.onComplete { + case Success(_) => + error = Some("q1 shouldn't have finished!") + case Failure(t) if t.getMessage.contains("cancelled") => + q1Interrupted = true + case Failure(t) => + error = Some("unexpected failure in q1: " + t.toString) + } + q2.onComplete { + case Success(_) => + error = Some("q2 shouldn't have finished!") + case Failure(t) if t.getMessage.contains("cancelled") => + q2Interrupted = true + case Failure(t) => + error = Some("unexpected failure in q2: " + t.toString) + } + // 20 seconds is < 30 seconds the queries should be running, + // because it should be interrupted sooner + eventually(timeout(20.seconds), interval(1.seconds)) { + // keep interrupting every second, until both queries get interrupted. + spark.interruptAll() + assert(error.isEmpty, s"Error not empty: $error") + assert(q1Interrupted) + assert(q2Interrupted) + } + } + + test("interrupt all - foreground queries, background interrupt") { + val session = spark + import session.implicits._ + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + + @volatile var finished = false + val interruptor = Future { + eventually(timeout(20.seconds), interval(1.seconds)) { + spark.interruptAll() + assert(finished) + } + finished + } + val e1 = intercept[io.grpc.StatusRuntimeException] { + spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() + } + assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1") + val e2 = intercept[io.grpc.StatusRuntimeException] { + spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() + } + assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2") + finished = true + assert(ThreadUtils.awaitResult(interruptor, 10.seconds)) + } +}