Skip to content

Commit 6bf72bc

Browse files
committed
address comments
1 parent 3f880c3 commit 6bf72bc

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,5 +200,5 @@ case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends
200200
override def eval(input: Row = null): EvaluatedType =
201201
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
202202

203-
override def toString: String = s"$child.getField($fieldExpr)"
203+
override def toString: String = s"$child[$fieldExpr]"
204204
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GetField.scala

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ object GetField {
2727
/**
2828
* Returns the resolved `GetField`. It will return one kind of concrete `GetField`,
2929
* depend on the type of `child` and `fieldExpr`.
30+
*
31+
* `child` | `fieldExpr` | concrete `GetField`
32+
* -------------------------------------------------------------
33+
* Struct | Literal String | SimpleStructGetField
34+
* Array[Struct] | Literal String | ArrayStructGetField
35+
* Array | Integral type | ArrayOrdinalGetField
36+
* Map | Any type | MapOrdinalGetField
3037
*/
3138
def apply(
3239
child: Expression,
@@ -45,22 +52,27 @@ object GetField {
4552
case (_: MapType, _) =>
4653
MapOrdinalGetField(child, fieldExpr)
4754
case (otherType, _) =>
48-
throw new AnalysisException(
49-
"GetField is not valid on child of type " +
50-
s"$otherType with fieldExpr of type ${fieldExpr.dataType}")
55+
val errorMsg = otherType match {
56+
case StructType(_) | ArrayType(StructType(_), _) =>
57+
s"Field name should be String Literal, but it's $fieldExpr"
58+
case _: ArrayType =>
59+
s"Array index should be integral type, but it's ${fieldExpr.dataType}"
60+
case other =>
61+
s"Can't get field on $child"
62+
}
63+
throw new AnalysisException(errorMsg)
5164
}
5265
}
5366

5467
def unapply(g: GetField): Option[(Expression, Expression)] = {
5568
g match {
56-
case _: StructGetField => Some((g.child, null))
5769
case o: OrdinalGetField => Some((o.child, o.ordinal))
58-
case _ => None
70+
case _ => Some((g.child, null))
5971
}
6072
}
6173

6274
/**
63-
* find the ordinal of StructField, report error if no desired field or over one
75+
* Find the ordinal of StructField, report error if no desired field or over one
6476
* desired fields are found.
6577
*/
6678
private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
@@ -84,51 +96,16 @@ trait GetField extends UnaryExpression {
8496
type EvaluatedType = Any
8597
}
8698

87-
abstract class StructGetField extends GetField {
88-
self: Product =>
89-
90-
def field: StructField
91-
92-
override def foldable: Boolean = child.foldable
93-
override def toString: String = s"$child.${field.name}"
94-
}
95-
96-
abstract class OrdinalGetField extends GetField {
97-
self: Product =>
98-
99-
def ordinal: Expression
100-
101-
/** `Null` is returned for invalid ordinals. */
102-
override def nullable: Boolean = true
103-
override def foldable: Boolean = child.foldable && ordinal.foldable
104-
override def toString: String = s"$child[$ordinal]"
105-
override def children: Seq[Expression] = child :: ordinal :: Nil
106-
107-
override def eval(input: Row): Any = {
108-
val value = child.eval(input)
109-
if (value == null) {
110-
null
111-
} else {
112-
val o = ordinal.eval(input)
113-
if (o == null) {
114-
null
115-
} else {
116-
evalNotNull(value, o)
117-
}
118-
}
119-
}
120-
121-
protected def evalNotNull(value: Any, ordinal: Any): Any
122-
}
123-
12499
/**
125100
* Returns the value of fields in the Struct `child`.
126101
*/
127102
case class SimpleStructGetField(child: Expression, field: StructField, ordinal: Int)
128-
extends StructGetField {
103+
extends GetField {
129104

130105
override def dataType: DataType = field.dataType
131106
override def nullable: Boolean = child.nullable || field.nullable
107+
override def foldable: Boolean = child.foldable
108+
override def toString: String = s"$child.${field.name}"
132109

133110
override def eval(input: Row): Any = {
134111
val baseValue = child.eval(input).asInstanceOf[Row]
@@ -143,10 +120,12 @@ case class ArrayStructGetField(
143120
child: Expression,
144121
field: StructField,
145122
ordinal: Int,
146-
containsNull: Boolean) extends StructGetField {
123+
containsNull: Boolean) extends GetField {
147124

148125
override def dataType: DataType = ArrayType(field.dataType, containsNull)
149126
override def nullable: Boolean = child.nullable
127+
override def foldable: Boolean = child.foldable
128+
override def toString: String = s"$child.${field.name}"
150129

151130
override def eval(input: Row): Any = {
152131
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
@@ -158,6 +137,34 @@ case class ArrayStructGetField(
158137
}
159138
}
160139

140+
abstract class OrdinalGetField extends GetField {
141+
self: Product =>
142+
143+
def ordinal: Expression
144+
145+
/** `Null` is returned for invalid ordinals. */
146+
override def nullable: Boolean = true
147+
override def foldable: Boolean = child.foldable && ordinal.foldable
148+
override def toString: String = s"$child[$ordinal]"
149+
override def children: Seq[Expression] = child :: ordinal :: Nil
150+
151+
override def eval(input: Row): Any = {
152+
val value = child.eval(input)
153+
if (value == null) {
154+
null
155+
} else {
156+
val o = ordinal.eval(input)
157+
if (o == null) {
158+
null
159+
} else {
160+
evalNotNull(value, o)
161+
}
162+
}
163+
}
164+
165+
protected def evalNotNull(value: Any, ordinal: Any): Any
166+
}
167+
161168
/**
162169
* Returns the field at `ordinal` in the Array `child`
163170
*/

0 commit comments

Comments
 (0)