Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ exportMethods("%<=>%",
"lower",
"lpad",
"ltrim",
"map_entries",
"map_from_arrays",
"map_keys",
"map_values",
Expand Down
15 changes: 1 addition & 14 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ NULL
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
Expand Down Expand Up @@ -3252,19 +3252,6 @@ setMethod("flatten",
column(jc)
})

#' @details
#' \code{map_entries}: Returns an unordered array of all entries in the given map.
#'
#' @rdname column_collection_functions
#' @aliases map_entries map_entries,Column-method
#' @note map_entries since 2.4.0
setMethod("map_entries",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc)
column(jc)
})

#' @details
#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for
#' keys. The array in the second column is used for values. All elements in the array for key
Expand Down
4 changes: 0 additions & 4 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,6 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") })
#' @name NULL
setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") })
Expand Down
7 changes: 1 addition & 6 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1570,13 +1570,8 @@ test_that("column functions", {
result <- collect(select(df, flatten(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))

# Test map_entries(), map_keys(), map_values() and element_at()
# Test map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_entries(df$map)))[[1]]
expected_entries <- list(listToStruct(list(key = "x", value = 1)),
listToStruct(list(key = "y", value = 2)))
expect_equal(result, list(expected_entries))

result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

Expand Down
20 changes: 0 additions & 20 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,26 +2540,6 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_entries
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
>>> df.select(map_entries("data").alias("entries")).show()
+----------------+
| entries|
+----------------+
|[[1, a], [2, b]]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


@since(2.4)
def map_from_entries(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ object FunctionRegistry {
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
expression[MapConcat]("map_concat"),
expression[Size]("size"),
Expand All @@ -433,13 +432,9 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
expression[ZipWith]("zip_with"),

CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ object TypeCoercion {
BooleanEquality ::
FunctionArgumentConversion ::
ConcatCoercion(conf) ::
MapZipWithCoercion ::
EltCoercion(conf) ::
CaseWhenCoercion ::
IfCoercion ::
Expand Down Expand Up @@ -763,30 +762,6 @@ object TypeCoercion {
}
}

/**
* Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression
* to a common type.
*/
object MapZipWithCoercion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Lambda function isn't resolved when the rule is executed.
case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved &&
MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) =>
findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match {
case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) &&
!Cast.forceNullable(m.rightKeyType, finalKeyType) =>
val newLeft = castIfNotSameType(
left,
MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull))
val newRight = castIfNotSameType(
right,
MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull))
MapZipWith(newLeft, newRight, function)
case _ => m
}
}
}

/**
* Coerces the types of [[Elt]] children to expected ones.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,174 +340,6 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[{"key":1,"value":"a"},{"key":2,"value":"b"}]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

@transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}

override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val arrayData = ctx.freshName("arrayData")
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)

val wordSize = UnsafeRow.WORD_SIZE
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) {
(true, structSize + wordSize)
} else {
(false, -1)
}

val allocation =
s"""
|ArrayData $arrayData = ArrayData.allocateArrayData(
| $elementSize, $numElements, " $prettyName failed.");
""".stripMargin

val code = if (isPrimitive) {
val genCodeForPrimitive = genCodeForPrimitiveElements(
ctx, arrayData, keys, values, ev.value, numElements, structSize)
s"""
|if ($arrayData instanceof UnsafeArrayData) {
| $genCodeForPrimitive
|} else {
| ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}
|}
""".stripMargin
} else {
s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}"
}

s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$allocation
|$code
""".stripMargin
})
}

private def getKey(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.keyType, index)

private def getValue(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.valueType, index)

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String,
structSize: Int): String = {
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val baseObject = ctx.freshName("baseObject")
val unsafeRow = ctx.freshName("unsafeRow")
val structsOffset = ctx.freshName("structsOffset")
val offset = ctx.freshName("offset")
val z = ctx.freshName("z")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val wordSize = UnsafeRow.WORD_SIZE
val structSizeAsLong = s"${structSize}L"

val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z))

val valueAssignmentChecked = CodeGenerator.createArrayAssignment(
unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull)

s"""
|UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData;
|Object $baseObject = $unsafeArrayData.getBaseObject();
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
|UnsafeRow $unsafeRow = new UnsafeRow(2);
|for (int $z = 0; $z < $numElements; $z++) {
| long $offset = $structsOffset + $z * $structSizeAsLong;
| $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize);
| $setKey;
| $valueAssignmentChecked
|}
|$resultArrayData = $arrayData;
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String): String = {
val z = ctx.freshName("z")
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}"
} else {
getValue(values, z)
}

val rowClass = classOf[GenericInternalRow].getName
val genericArrayDataClass = classOf[GenericArrayData].getName
val genericArrayData = ctx.freshName("genericArrayData")
val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})"
s"""
|$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData;
|for (int $z = 0; $z < $numElements; $z++) {
| $genericArrayData.update($z, $rowObject);
|}
|$resultArrayData = $arrayData;
""".stripMargin
}

override def prettyName: String = "map_entries"
}

/**
* Returns the union of all the given maps.
*/
Expand Down
Loading