Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator(
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (requesters.size == numTasks) {
requesters.foreach(_.reply(messages))
requesters.foreach(_.reply(messages.clone()))
// Finished current barrier() call successfully, clean up ContextBarrierState and
// increase the barrier epoch.
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
// double check we kill task success
assert(System.currentTimeMillis() - startTime < 5000)
}

test("SPARK-40932, messages of allGather should not been overridden " +
"by the following barrier APIs") {

sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]"))
sc.setLogLevel("INFO")
val rdd = sc.makeRDD(1 to 10, 2)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
// Pass partitionId message in
val message: String = context.partitionId().toString
val messages: Array[String] = context.allGather(message)
context.barrier()
Iterator.single(messages.toList)
}
val messages = rdd2.collect()
// All the task partitionIds are shared across all tasks
assert(messages.length === 2)
assert(messages.forall(_ == List("0", "1")))
}

}
1 change: 1 addition & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ matplotlib<3.3.0

# PySpark test dependencies
unittest-xml-reporting
openpyxl

# PySpark test dependencies (optional)
coverage
Expand Down
15 changes: 10 additions & 5 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
import unittest
import shutil
import tempfile
from pyspark.testing.sqlutils import have_pandas

import pandas
if have_pandas:
import pandas

from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, LongType, StringType
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit

if have_pandas:
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit
from pyspark.sql.dataframe import DataFrame
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase
Expand All @@ -36,7 +40,8 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
"""Parent test fixture class for all Spark Connect related
test cases."""

connect: RemoteSparkSession
if have_pandas:
connect: RemoteSparkSession
tbl_name: str
df_text: "DataFrame"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
# limitations under the License.
#

from typing import cast
import unittest
from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
import pyspark.sql.connect.functions as fun
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message

if have_pandas:
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
import pyspark.sql.connect.functions as fun


@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
def test_simple_column_expressions(self):
df = c.DataFrame.withPlan(p.Read("table"))
Expand Down
13 changes: 9 additions & 4 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import cast
import unittest

from pyspark.testing.connectutils import PlanOnlyTestFixture
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
from pyspark.sql.types import StringType
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message

if have_pandas:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
from pyspark.sql.types import StringType


@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
Expand Down
15 changes: 11 additions & 4 deletions python/pyspark/sql/tests/connect/test_connect_select_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import cast
import unittest

from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read
import pyspark.sql.connect.proto as proto
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message

if have_pandas:
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read
import pyspark.sql.connect.proto as proto


@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
class SparkConnectToProtoSuite(PlanOnlyTestFixture):
def test_select_with_columns_and_strings(self):
df = DataFrame.withPlan(Read("table"))
Expand Down
15 changes: 10 additions & 5 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
from typing import Any, Dict
import functools
import unittest
from pyspark.testing.sqlutils import have_pandas

from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.plan import Read
from pyspark.testing.utils import search_jar
if have_pandas:
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.plan import Read
from pyspark.testing.utils import search_jar

connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
else:
connect_jar = None


connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
if connect_jar is None:
connect_requirement_message = (
"Skipping all Spark Connect Python tests as the optional Spark Connect project was "
Expand All @@ -38,7 +43,7 @@
os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args])
connect_requirement_message = None # type: ignore

should_test_connect = connect_requirement_message is None
should_test_connect = connect_requirement_message is None and have_pandas


class MockRemoteSession:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class AnalysisErrorSuite extends AnalysisTest {
messageParameters: Map[String, String],
caseSensitive: Boolean = true): Unit = {
test(name) {
assertAnalysisErrorClass(plan, errorClass, messageParameters,
caseSensitive = true, line = -1, pos = -1)
assertAnalysisErrorClass(plan, errorClass, messageParameters, caseSensitive = caseSensitive)
}
}

Expand Down Expand Up @@ -899,9 +898,8 @@ class AnalysisErrorSuite extends AnalysisTest {
"inputSql" -> inputSql,
"inputType" -> inputType,
"requiredType" -> "(\"INT\" or \"BIGINT\")"),
caseSensitive = false,
line = -1,
pos = -1)
caseSensitive = false
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class AnalysisExceptionPositionSuite extends AnalysisTest {
parsePlan("SHOW COLUMNS FROM unknown IN db"),
"TABLE_OR_VIEW_NOT_FOUND",
Map("relationName" -> "`db`.`unknown`"),
line = 1,
pos = 18)
Array(ExpectedContext("unknown", 18, 24))
)
verifyTableOrViewPosition("ALTER TABLE unknown RENAME TO t", "unknown")
verifyTableOrViewPosition("ALTER VIEW unknown RENAME TO v", "unknown")
}
Expand Down Expand Up @@ -92,13 +92,13 @@ class AnalysisExceptionPositionSuite extends AnalysisTest {
}

private def verifyPosition(sql: String, table: String): Unit = {
val expectedPos = sql.indexOf(table)
assert(expectedPos != -1)
val startPos = sql.indexOf(table)
assert(startPos != -1)
assertAnalysisErrorClass(
parsePlan(sql),
"TABLE_OR_VIEW_NOT_FOUND",
Map("relationName" -> s"`$table`"),
line = 1,
pos = expectedPos)
Array(ExpectedContext(table, startPos, startPos + table.length - 1))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Project(Seq(UnresolvedAttribute("tBl.a")),
SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))),
"UNRESOLVED_COLUMN.WITH_SUGGESTION",
Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`"),
caseSensitive = true,
line = -1,
pos = -1)
Map("objectName" -> "`tBl`.`a`", "proposal" -> "`TbL`.`a`")
)

checkAnalysisWithoutViewWrapper(
Project(Seq(UnresolvedAttribute("TbL.a")),
Expand Down Expand Up @@ -716,9 +714,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assertAnalysisErrorClass(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"),
"UNRESOLVED_COLUMN.WITH_SUGGESTION",
Map("objectName" -> "`y`", "proposal" -> "`t`.`x`"),
caseSensitive = true,
line = -1,
pos = -1)
Array(ExpectedContext("y", 46, 46))
)
}

test("CTE with non-matching column alias") {
Expand All @@ -729,7 +726,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {

test("SPARK-28251: Insert into non-existing table error message is user friendly") {
assertAnalysisErrorClass(parsePlan("INSERT INTO test VALUES (1)"),
"TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`"))
"TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`test`"),
Array(ExpectedContext("test", 12, 15)))
}

test("check CollectMetrics resolved") {
Expand Down Expand Up @@ -1157,9 +1155,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
|""".stripMargin),
"UNRESOLVED_COLUMN.WITH_SUGGESTION",
Map("objectName" -> "`c`.`y`", "proposal" -> "`x`"),
caseSensitive = true,
line = -1,
pos = -1)
Array(ExpectedContext("c.y", 123, 125))
)
}

test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") {
Expand All @@ -1178,7 +1175,9 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"inputSql" -> "\"c\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
caseSensitive = false)
queryContext = Array(ExpectedContext("mean(t.c)", 65, 73)),
caseSensitive = false
)

assertAnalysisErrorClass(
inputPlan = parsePlan(
Expand All @@ -1195,6 +1194,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"inputSql" -> "\"c\"",
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
queryContext = Array(ExpectedContext("mean(c)", 91, 97)),
caseSensitive = false)

assertAnalysisErrorClass(
Expand All @@ -1213,9 +1213,9 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"inputType" -> "\"BOOLEAN\"",
"requiredType" ->
"(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"),
caseSensitive = false,
line = -1,
pos = -1)
queryContext = Array(ExpectedContext("abs(t.c)", 65, 72)),
caseSensitive = false
)

assertAnalysisErrorClass(
inputPlan = parsePlan(
Expand All @@ -1233,9 +1233,9 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"inputType" -> "\"BOOLEAN\"",
"requiredType" ->
"(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR TO MONTH\")"),
caseSensitive = false,
line = -1,
pos = -1)
queryContext = Array(ExpectedContext("abs(c)", 91, 96)),
caseSensitive = false
)
}

test("SPARK-39354: should be [TABLE_OR_VIEW_NOT_FOUND]") {
Expand All @@ -1246,7 +1246,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
|FROM t1
|JOIN t2 ON t1.user_id = t2.user_id
|WHERE t1.dt >= DATE_SUB('2020-12-27', 90)""".stripMargin),
"TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`"))
"TABLE_OR_VIEW_NOT_FOUND", Map("relationName" -> "`t2`"),
Array(ExpectedContext("t2", 84, 85)))
}

test("SPARK-39144: nested subquery expressions deduplicate relations should be done bottom up") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.StructType

trait AnalysisTest extends PlanTest {

import org.apache.spark.QueryContext

protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil

protected def createTempView(
Expand Down Expand Up @@ -174,40 +176,19 @@ trait AnalysisTest extends PlanTest {
inputPlan: LogicalPlan,
expectedErrorClass: String,
expectedMessageParameters: Map[String, String],
caseSensitive: Boolean = true,
line: Int = -1,
pos: Int = -1): Unit = {
queryContext: Array[QueryContext] = Array.empty,
caseSensitive: Boolean = true): Unit = {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
val analyzer = getAnalyzer
val e = intercept[AnalysisException] {
analyzer.checkAnalysis(analyzer.execute(inputPlan))
}

if (e.getErrorClass != expectedErrorClass ||
e.messageParameters != expectedMessageParameters ||
(line >= 0 && e.line.getOrElse(-1) != line) ||
(pos >= 0) && e.startPosition.getOrElse(-1) != pos) {
var failMsg = ""
if (e.getErrorClass != expectedErrorClass) {
failMsg +=
s"""Error class should be: ${expectedErrorClass}
|Actual error class: ${e.getErrorClass}
""".stripMargin
}
if (e.messageParameters != expectedMessageParameters) {
failMsg +=
s"""Message parameters should be: ${expectedMessageParameters.mkString("\n ")}
|Actual message parameters: ${e.messageParameters.mkString("\n ")}
""".stripMargin
}
if (e.line.getOrElse(-1) != line || e.startPosition.getOrElse(-1) != pos) {
failMsg +=
s"""Line/position should be: $line, $pos
|Actual line/position: ${e.line.getOrElse(-1)}, ${e.startPosition.getOrElse(-1)}
""".stripMargin
}
fail(failMsg)
}
checkError(
exception = e,
errorClass = expectedErrorClass,
parameters = expectedMessageParameters,
queryContext = queryContext
)
}
}

Expand Down
Loading