diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index c7f8f59a7679..e770b9c7053e 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -324,6 +324,12 @@ ], "sqlState" : "0AKD0" }, + "CANNOT_RESOLVE_DATAFRAME_COLUMN" : { + "message" : [ + "Cannot resolve dataframe column . It's probably because of illegal references like `df1.select(df2.col(\"a\"))`." + ], + "sqlState" : "42704" + }, "CANNOT_RESOLVE_STAR_EXPAND" : { "message" : [ "Cannot resolve .* given input columns . Please check that the specified table or struct exists and is accessible in the input columns." @@ -6843,11 +6849,6 @@ "Cannot modify the value of a static config: " ] }, - "_LEGACY_ERROR_TEMP_3051" : { - "message" : [ - "When resolving , fail to find subplan with plan_id= in " - ] - }, "_LEGACY_ERROR_TEMP_3052" : { "message" : [ "Unexpected resolved action: " diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 0740334724e8..288964a084ba 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -894,7 +894,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM // df1("i") is not ambiguous, but it's not valid in the projected df. df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() } - assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + assert(e1.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) checkSameResult( Seq(Row(1, "a")), diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index f58b7f607a0b..db8ecf5b2a30 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -282,6 +282,12 @@ Cannot recognize hive type string: ``, column: ``. The spe Renaming a `` across schemas is not allowed. +### CANNOT_RESOLVE_DATAFRAME_COLUMN + +[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot resolve dataframe column ``. It's probably because of illegal references like `df1.select(df2.col("a"))`. + ### CANNOT_RESOLVE_STAR_EXPAND [SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 045ba8f0060d..a1cd00e79e1a 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -543,10 +543,21 @@ def test_invalid_column(self): with self.assertRaises(AnalysisException): cdf2.withColumn("x", cdf1.a + 1).schema - with self.assertRaisesRegex(AnalysisException, "attribute.*missing"): + # Can find the target plan node, but fail to resolve with it + with self.assertRaisesRegex( + AnalysisException, + "UNRESOLVED_COLUMN.WITH_SUGGESTION", + ): cdf3 = cdf1.select(cdf1.a) cdf3.select(cdf1.b).schema + # Can not find the target plan node by plan id + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + cdf1.select(cdf2.a).schema + def test_collect(self): cdf = self.connect.read.table(self.tbl_name) sdf = self.spark.read.table(self.tbl_name) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index f1d690751ead..c77e7fd89d01 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -26,7 +26,6 @@ import io from contextlib import redirect_stdout -from pyspark import StorageLevel from pyspark.sql import SparkSession, Row, functions from pyspark.sql.functions import col, lit, count, sum, mean, struct from pyspark.sql.types import ( @@ -70,6 +69,14 @@ def test_range(self): self.assertEqual(self.spark.range(-2).count(), 0) self.assertEqual(self.spark.range(3).count(), 3) + def test_self_join(self): + df1 = self.spark.range(10).withColumn("a", lit(0)) + df2 = df1.withColumnRenamed("a", "b") + df = df1.join(df2, df1["a"] == df2["b"]) + self.assertTrue(df.count() == 100) + df = df2.join(df1, df2["b"] == df1["a"]) + self.assertTrue(df.count() == 100) + def test_duplicated_column_names(self): df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select("*").first() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index a90c61565039..3261aa51b9be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -426,7 +426,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { throws: Boolean = false, includeLastResort: Boolean = false): Expression = { resolveExpression( - tryResolveColumnByPlanId(expr, plan), + tryResolveDataFrameColumns(expr, Seq(plan)), resolveColumnByName = nameParts => { plan.resolve(nameParts, conf.resolver) }, @@ -448,7 +448,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { q: LogicalPlan, includeLastResort: Boolean = false): Expression = { resolveExpression( - tryResolveColumnByPlanId(e, q), + tryResolveDataFrameColumns(e, q.children), resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -485,80 +485,107 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // 4. if more than one matching nodes are found, fail due to ambiguous column reference; // 5. resolve the expression with the matching node, if any error occurs here, return the // original expression as it is. - private def tryResolveColumnByPlanId( + private def tryResolveDataFrameColumns( e: Expression, - q: LogicalPlan, - idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match { + q: Seq[LogicalPlan]): Expression = e match { case u: UnresolvedAttribute => - resolveUnresolvedAttributeByPlanId( - u, q, idToPlan: mutable.HashMap[Long, LogicalPlan] - ).getOrElse(u) + resolveDataFrameColumn(u, q).getOrElse(u) case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) => - e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan)) + e.mapChildren(c => tryResolveDataFrameColumns(c, q)) case _ => e } - private def resolveUnresolvedAttributeByPlanId( + private def resolveDataFrameColumn( u: UnresolvedAttribute, - q: LogicalPlan, - idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = { + q: Seq[LogicalPlan]): Option[NamedExpression] = { val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) if (planIdOpt.isEmpty) return None val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val plan = idToPlan.getOrElseUpdate(planId, { - findPlanById(u, planId, q).getOrElse { - // For example: - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3051", - messageParameters = Map( - "u" -> u.toString, - "planId" -> planId.toString, - "q" -> q.toString)) - } - }) + val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty + val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q) + if (!matched) { + // Can not find the target plan node with plan id, e.g. + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw QueryCompilationErrors.cannotResolveColumn(u) + } + resolved + } - val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined - try { - if (!isMetadataAccess) { - plan.resolve(u.nameParts, conf.resolver) - } else if (u.nameParts.size == 1) { - plan.getMetadataAttributeByNameOpt(u.nameParts.head) - } else { - None + private def resolveDataFrameColumnByPlanId( + u: UnresolvedAttribute, + id: Long, + isMetadataAccess: Boolean, + q: Seq[LogicalPlan]): (Option[NamedExpression], Boolean) = { + q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, _)) + .foldLeft((Option.empty[NamedExpression], false)) { + case ((r1, m1), (r2, m2)) => + if (r1.nonEmpty && r2.nonEmpty) { + throw QueryCompilationErrors.ambiguousColumnReferences(u) + } + (if (r1.nonEmpty) r1 else r2, m1 | m2) } - } catch { - case e: AnalysisException => - logDebug(s"Fail to resolve $u with $plan due to $e") - None - } } - private def findPlanById( + private def resolveDataFrameColumnRecursively( u: UnresolvedAttribute, id: Long, - plan: LogicalPlan): Option[LogicalPlan] = { - if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { - Some(plan) - } else if (plan.children.length == 1) { - findPlanById(u, id, plan.children.head) - } else if (plan.children.length > 1) { - val matched = plan.children.flatMap(findPlanById(u, id, _)) - if (matched.length > 1) { - throw new AnalysisException( - errorClass = "AMBIGUOUS_COLUMN_REFERENCE", - messageParameters = Map("name" -> toSQLId(u.nameParts)), - origin = u.origin - ) - } else { - matched.headOption + isMetadataAccess: Boolean, + p: LogicalPlan): (Option[NamedExpression], Boolean) = { + val (resolved, matched) = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + val resolved = try { + if (!isMetadataAccess) { + p.resolve(u.nameParts, conf.resolver) + } else if (u.nameParts.size == 1) { + p.getMetadataAttributeByNameOpt(u.nameParts.head) + } else { + None + } + } catch { + case e: AnalysisException => + logDebug(s"Fail to resolve $u with $p due to $e") + None } + (resolved, true) } else { - None + resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children) + } + + // In self join case like: + // df1 = spark.range(10).withColumn("a", sf.lit(0)) + // df2 = df1.withColumnRenamed("a", "b") + // df1.join(df2, df1["a"] == df2["b"]) + // + // the logical plan would be like: + // + // 'Join Inner, '`==`('a, 'b) [plan_id=5] + // :- Project [id#22L, 0 AS a#25] [plan_id=1] + // : +- Range (0, 10, step=1, splits=Some(12)) + // +- Project [id#28L, a#31 AS b#36] [plan_id=2] + // +- Project [id#28L, 0 AS a#31] [plan_id=1] + // +- Range (0, 10, step=1, splits=Some(12)) + // + // When resolving the column reference df1.a, the target node with plan_id=1 + // can be found in both sides of the Join node. + // To correctly resolve df1.a, the analyzer discards the resolved attribute + // in the right side, by filtering out the result by the output attributes of + // Project plan_id=2. + // + // However, there are analyzer rules (e.g. ResolveReferencesInSort) + // supporting missing column resolution. Then a valid resolved attribute + // maybe filtered out here. In this case, resolveDataFrameColumnByPlanId + // returns None, the dataframe column will remain unresolved, and the analyzer + // will try to resolve it without plan id later. + val filtered = resolved.filter { r => + if (isMetadataAccess) { + r.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput)) + } else { + r.references.subsetOf(p.outputSet) + } } + (filtered, matched) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 387064695770..91d18788fd4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, FunctionIdentifier, InternalRow, QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, FunctionAlreadyExistsException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, UnresolvedRegex} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, FunctionAlreadyExistsException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} @@ -3940,4 +3940,20 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "dsSchema" -> toSQLType(dsSchema), "expectedSchema" -> toSQLType(expectedSchema))) } + + def cannotResolveColumn(u: UnresolvedAttribute): Throwable = { + new AnalysisException( + errorClass = "CANNOT_RESOLVE_DATAFRAME_COLUMN", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } + + def ambiguousColumnReferences(u: UnresolvedAttribute): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_REFERENCE", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } }