Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,7 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
*/
object ColumnPruning extends Rule[LogicalPlan] {
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
output1.size == output2.size && output1.zip(output2).forall(pair => pair._1 == pair._2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think still we need to check if the exprIds that they refere to are the same.


def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
// Prunes the unused columns from project list of Project/Aggregate/Expand
Expand Down Expand Up @@ -649,17 +648,20 @@ object CollapseProject extends Rule[LogicalPlan] {
}
}

private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
private def collectAliases(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(lower.zipWithIndex.collect {
case (a: Alias, index: Int) =>
a.toAttribute ->
a.copy(name = upper(index).name)(a.exprId, a.qualifier, a.explicitMetadata)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a better way for Alias.copy(...) ?

})
}

private def haveCommonNonDeterministicOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)
val aliases = collectAliases(upper, lower)

// Collapse upper and lower Projects if and only if their overlapped expressions are all
// deterministic.
Expand All @@ -673,7 +675,7 @@ object CollapseProject extends Rule[LogicalPlan] {
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)
val aliases = collectAliases(upper, lower)

// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
Expand Down
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("Insert overwrite table command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Insert overwrite table command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " +
"BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " +
"FROM view1 CLUSTER BY COL3")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " +
"CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("SPARK-25144 'distinct' causes memory leak") {
val ds = List(Foo(Some("bar"))).toDS
val result = ds.flatMap(_.bar).distinct
Expand Down