Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ package object dsl {
def desc: SortOrder = SortOrder(expr, Descending)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty)
def as(alias: String): NamedExpression = Alias(expr, alias)()
def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
}

trait ExpressionConversions {
Expand All @@ -166,9 +165,6 @@ package object dsl {
implicit def instantToLiteral(i: Instant): Literal = Literal(i)
implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a)

implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute =
analysis.UnresolvedAttribute(s.name)

/** Converts $"col name" into an [[analysis.UnresolvedAttribute]]. */
implicit class StringToAttributeConversionHelper(val sc: StringContext) {
// Note that if we make ExpressionConversions an object rather than a trait, we can
Expand Down Expand Up @@ -244,7 +240,6 @@ package object dsl {
def windowExpr(windowFunc: Expression, windowSpec: WindowSpecDefinition): WindowExpression =
WindowExpression(windowFunc, windowSpec)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
override def expr: Expression = Literal(s)
Expand Down Expand Up @@ -308,10 +303,10 @@ package object dsl {
AttributeReference(s, arrayType)()

/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
def mapAttr(keyType: DataType, valueType: DataType): AttributeReference =
mapAttr(MapType(keyType, valueType))

def map(mapType: MapType): AttributeReference =
def mapAttr(mapType: MapType): AttributeReference =
AttributeReference(s, mapType, nullable = true)()

/** Creates a new AttributeReference of type struct */
Expand Down Expand Up @@ -414,8 +409,6 @@ package object dsl {
orderSpec: Seq[SortOrder]): LogicalPlan =
Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)

def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)

def except(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan =
Except(logicalPlan, otherPlan, isAll)

Expand Down Expand Up @@ -443,6 +436,7 @@ package object dsl {
InsertIntoStatement(table, partition, Nil, logicalPlan, overwrite, ifPartitionNotExists)

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
def subquery(alias: String): LogicalPlan = as(alias)

def coalesce(num: Integer): LogicalPlan =
Repartition(num, shuffle = false, logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,20 +624,20 @@ class AnalysisErrorSuite extends AnalysisTest {
}

test("Join can work on binary types but can't work on map types") {
val left = LocalRelation(Symbol("a").binary, Symbol("b").map(StringType, StringType))
val right = LocalRelation(Symbol("c").binary, Symbol("d").map(StringType, StringType))
val left = LocalRelation("a".attr.binary, "b".attr.mapAttr(StringType, StringType))
val right = LocalRelation("c".attr.binary, "d".attr.mapAttr(StringType, StringType))

val plan1 = left.join(
right,
joinType = Cross,
condition = Some(Symbol("a") === Symbol("c")))
condition = Some("a".attr === "c".attr))

assertAnalysisSuccess(plan1)

val plan2 = left.join(
right,
joinType = Cross,
condition = Some(Symbol("b") === Symbol("d")))
condition = Some("b".attr === "d".attr))
assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil)
}

Expand Down Expand Up @@ -705,7 +705,7 @@ class AnalysisErrorSuite extends AnalysisTest {
test("Error on filter condition containing aggregate expressions") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val plan = Filter(Symbol("a") === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b))
val plan = Filter("a".attr === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b))
assertAnalysisError(plan,
"Aggregate/Window/Generate expressions are not valid in where clause of the query" :: Nil)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
test("fail for unresolved plan") {
intercept[AnalysisException] {
// `testRelation` does not have column `b`.
testRelation.select('b).analyze
testRelation.select("b".attr).analyze
}
}

Expand Down Expand Up @@ -285,7 +285,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
CreateNamedStruct(Seq(
Literal(att1.name), att1,
Literal("a_plus_1"), (att1 + 1))),
Symbol("col").struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull
"col".attr.struct(prevPlan.output(0).dataType.asInstanceOf[StructType]).notNull
)).as("arr")
)

Expand Down Expand Up @@ -426,15 +426,15 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-12102: Ignore nullability when comparing two sides of case") {
val relation = LocalRelation(Symbol("a").struct(Symbol("x").int),
Symbol("b").struct(Symbol("x").int.withNullability(false)))
val relation = LocalRelation("a".attr.struct("x".attr.int),
"b".attr.struct("x".attr.int.withNullability(false)))
val plan = relation.select(
CaseWhen(Seq((Literal(true), Symbol("a").attr)), Symbol("b")).as("val"))
CaseWhen(Seq((Literal(true), "a".attr)), "b".attr).as("val"))
assertAnalysisSuccess(plan)
}

test("Keep attribute qualifiers after dedup") {
val input = LocalRelation(Symbol("key").int, Symbol("value").string)
val input = LocalRelation("key".attr.int, "value".attr.string)

val query =
Project(Seq($"x.key", $"y.key"),
Expand Down Expand Up @@ -561,13 +561,13 @@ class AnalysisSuite extends AnalysisTest with Matchers {

test("SPARK-20963 Support aliases for join relations in FROM clause") {
def joinRelationWithAliases(outputNames: Seq[String]): LogicalPlan = {
val src1 = LocalRelation(Symbol("id").int, Symbol("v1").string).as("s1")
val src2 = LocalRelation(Symbol("id").int, Symbol("v2").string).as("s2")
val src1 = LocalRelation("id".attr.int, "v1".attr.string).as("s1")
val src2 = LocalRelation("id".attr.int, "v2".attr.string).as("s2")
UnresolvedSubqueryColumnAliases(
outputNames,
SubqueryAlias(
"dst",
src1.join(src2, Inner, Option(Symbol("s1.id") === Symbol("s2.id"))))
src1.join(src2, Inner, Option("s1.id".attr === "s2.id".attr)))
).select(star())
}
assertAnalysisSuccess(joinRelationWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil))
Expand All @@ -591,12 +591,12 @@ class AnalysisSuite extends AnalysisTest with Matchers {

checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20))
checkPartitioning[HashPartitioning](numPartitions = 10,
exprs = Symbol("a").attr, Symbol("b").attr)
exprs = "a".attr, "b".attr)

checkPartitioning[RangePartitioning](numPartitions = 10,
exprs = SortOrder(Literal(10), Ascending))
checkPartitioning[RangePartitioning](numPartitions = 10,
exprs = SortOrder(Symbol("a").attr, Ascending), SortOrder(Symbol("b").attr, Descending))
exprs = SortOrder("a".attr, Ascending), SortOrder("b".attr, Descending))

checkPartitioning[RoundRobinPartitioning](numPartitions = 10, exprs = Seq.empty: _*)

Expand All @@ -608,7 +608,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
intercept[IllegalArgumentException] {
checkPartitioning(numPartitions = 10, exprs =
SortOrder(Symbol("a").attr, Ascending), Symbol("b").attr)
SortOrder("a".attr, Ascending), "b".attr)
}
}

Expand Down Expand Up @@ -779,7 +779,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"Multiple definitions of observed metrics" :: "evt1" :: Nil)

// Different children, same metrics - fail
val b = Symbol("b").string
val b = "b".attr.string
val tblB = LocalRelation(b)
assertAnalysisError(Union(
CollectMetrics("evt1", count :: Nil, testRelation) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical._

class DSLHintSuite extends AnalysisTest {
lazy val a = Symbol("a").int
lazy val b = Symbol("b").string
lazy val c = Symbol("c").string
lazy val a = "a".attr.int
lazy val b = "b".attr.string
lazy val c = "c".attr.string
lazy val r1 = LocalRelation(a, b, c)

test("various hint parameters") {
Expand Down
Loading