Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ package object dsl {
AttributeReference(s, arrayType)()

/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
def mapAttr(keyType: DataType, valueType: DataType): AttributeReference =
mapAttr(MapType(keyType, valueType))

def map(mapType: MapType): AttributeReference =
def mapAttr(mapType: MapType): AttributeReference =
AttributeReference(s, mapType, nullable = true)()

/** Creates a new AttributeReference of type struct */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ class AnalysisErrorSuite extends AnalysisTest {
}

test("Join can work on binary types but can't work on map types") {
val left = LocalRelation(Symbol("a").binary, Symbol("b").map(StringType, StringType))
val right = LocalRelation(Symbol("c").binary, Symbol("d").map(StringType, StringType))
val left = LocalRelation(Symbol("a").binary, Symbol("b").mapAttr(StringType, StringType))
val right = LocalRelation(Symbol("c").binary, Symbol("d").mapAttr(StringType, StringType))

val plan1 = left.join(
right,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
Symbol("booleanField").boolean,
Symbol("decimalField").decimal(8, 0),
Symbol("arrayField").array(StringType),
Symbol("mapField").map(StringType, LongType))
Symbol("mapField").mapAttr(StringType, LongType))

def assertError(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,38 +54,38 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[StringLongClass]

// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
val attrs1 = Seq("a".attr.string, "b".attr.int)
testFromRow(encoder, attrs1, InternalRow(str, 1))

// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
val attrs2 = Seq("a".attr.int, "b".attr.long)
testFromRow(encoder, attrs2, InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
val attrs = Seq("a".attr.int, "b".attr.struct("a".attr.int, "b".attr.long))
testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
val attrs = Seq("a".attr.struct("a".attr.string, "b".attr.byte), "b".attr.int)
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val attrs = Seq("arr".attr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
testFromRow(encoder, attrs, InternalRow(array))
}

test("the real type is not compatible with encoder schema: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(StringType))
val attrs = Seq("arr".attr.array(StringType))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
s"""
|Cannot up cast array element from string to bigint.
Expand All @@ -99,7 +99,8 @@ class EncoderResolutionSuite extends PlanTest {

test("real type doesn't match encoder schema but they are compatible: array") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val attrs =
Seq("arr".attr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
testFromRow(encoder, attrs, InternalRow(array))
}
Expand All @@ -116,14 +117,14 @@ class EncoderResolutionSuite extends PlanTest {

test("the real type is not compatible with encoder schema: non-array field") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.int)
val attrs = Seq("arr".attr.int)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"need an array field but got int")
}

test("the real type is not compatible with encoder schema: array element type") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("c", "int")))
val attrs = Seq("arr".attr.array(new StructType().add("c", "int")))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"No such struct field a in c")
}
Expand All @@ -147,7 +148,7 @@ class EncoderResolutionSuite extends PlanTest {

test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil
val attrs = "a".attr.array(IntegerType) :: Nil

// It should pass analysis
val fromRow = encoder.resolveAndBind(attrs).createDeserializer()
Expand All @@ -166,14 +167,14 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[(String, Long)]

{
val attrs = Seq('a.string, 'b.long, 'c.int)
val attrs = Seq("a".attr.string, "b".attr.long, "c".attr.int)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
"but failed as the number of fields does not line up.")
}

{
val attrs = Seq('a.string)
val attrs = Seq("a".attr.string)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<a:string> to Tuple2, " +
"but failed as the number of fields does not line up.")
Expand All @@ -184,14 +185,15 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[(String, (Long, String))]

{
val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
val attrs =
Seq("a".attr.string, "b".attr.struct("x".attr.long, "y".attr.string, "z".attr.int))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
"but failed as the number of fields does not line up.")
}

{
val attrs = Seq('a.string, 'b.struct('x.long))
val attrs = Seq("a".attr.string, "b".attr.struct("x".attr.long))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<x:bigint> to Tuple2, " +
"but failed as the number of fields does not line up.")
Expand All @@ -200,14 +202,16 @@ class EncoderResolutionSuite extends PlanTest {

test("nested case class can have different number of fields from the real schema") {
val encoder = ExpressionEncoder[(String, StringIntClass)]
val attrs = Seq('a.string, 'b.struct('a.string, 'b.int, 'c.int))
val attrs =
Seq("a".attr.string, "b".attr.struct("a".attr.string, "b".attr.int, "c".attr.int))
encoder.resolveAndBind(attrs)
}

test("SPARK-28497: complex type is not compatible with string encoder schema") {
val encoder = ExpressionEncoder[String]

Seq('a.struct('x.long), 'a.array(StringType), 'a.map(StringType, StringType)).foreach { attr =>
Seq("a".attr.struct("x".attr.long), "a".attr.array(StringType),
"a".attr.mapAttr(StringType, StringType)).foreach { attr =>
val attrs = Seq(attr)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
s"""
Expand All @@ -221,7 +225,7 @@ class EncoderResolutionSuite extends PlanTest {

test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long))
ExpressionEncoder[StringIntClass].resolveAndBind(Seq("a".attr.string, "b".attr.long))
}.message
assert(msg1 ==
s"""
Expand All @@ -234,7 +238,8 @@ class EncoderResolutionSuite extends PlanTest {

val msg2 = intercept[AnalysisException] {
val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType)))
ExpressionEncoder[ComplexClass].resolveAndBind(
Seq("a".attr.long, "b".attr.struct(structType)))
}.message
assert(msg2 ==
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,16 +386,16 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
val c1 = "a".attr.int.at(0)
val c3 = "c".attr.int.at(2)
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null))
}

test("CreateNamedStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
val c1 = "a".attr.int.at(0)
val c3 = "c".attr.int.at(2)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
create_row(1, UTF8String.fromString("y")), row)
Expand All @@ -410,11 +410,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
ExtractValue(u.child, u.extraction, _ == _)
}

checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")),
checkEvaluation(
quickResolve("c".attr.mapAttr(MapType(StringType, StringType)).at(0).getItem("a")),
"b", create_row(Map("a" -> "b")))
checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)),
checkEvaluation(quickResolve("c".attr.array(StringType).at(0).getItem(1)),
"b", create_row(Seq("a", "b")))
checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")),
checkEvaluation(quickResolve("c".attr.struct("a".attr.int).at(0).getField("a")),
1, create_row(create_row(1)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// with dummy input, resolve the plan by the analyzer, and replace the dummy input
// with a literal for tests.
val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer)
val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType)))
val dummyInputPlan = LocalRelation("value".attr.mapAttr(MapType(IntegerType, StringType)))
val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan)

val analyzedPlan = SimpleAnalyzer.execute(plan)
Expand Down
Loading