Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.language.existentials

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -280,15 +280,15 @@ object JavaTypeInference {
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, p),
GetKeyArrayFromMap(path)),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, p),
GetValueArrayFromMap(path)),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1788,78 +1788,3 @@ case class ValidateExternalType(child: Expression, expected: DataType)
ev.copy(code = code, isNull = input.isNull)
}
}

object GetKeyArrayFromMap {

/**
* Construct an instance of GetArrayFromMap case class
* extracting a key array from a Map expression.
*
* @param child a Map expression to extract a key array from
*/
def apply(child: Expression): Expression = {
GetArrayFromMap(
child,
"keyArray",
_.keyArray(),
{ case MapType(kt, _, _) => kt })
}
}

object GetValueArrayFromMap {

/**
* Construct an instance of GetArrayFromMap case class
* extracting a value array from a Map expression.
*
* @param child a Map expression to extract a value array from
*/
def apply(child: Expression): Expression = {
GetArrayFromMap(
child,
"valueArray",
_.valueArray(),
{ case MapType(_, vt, _) => vt })
}
}

/**
* Extracts a key/value array from a Map expression.
*
* @param child a Map expression to extract an array from
* @param functionName name of the function that is invoked to extract an array
* @param arrayGetter function extracting `ArrayData` from `MapData`
* @param elementTypeGetter function extracting array element `DataType` from `MapType`
*/
case class GetArrayFromMap private(
child: Expression,
functionName: String,
arrayGetter: MapData => ArrayData,
elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression {

private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString

lazy val dataType: DataType = {
val mt: MapType = child.dataType.asInstanceOf[MapType]
ArrayType(elementTypeGetter(mt))
}

override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[MapType]) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}")
}
}

override def nullSafeEval(input: Any): Any = {
arrayGetter(input.asInstanceOf[MapData])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()")
}

override def toString: String = s"$child.$functionName"
}