diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a090bdf2bebf..27225b4ac74a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,6 +888,137 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } +case class MapSort(base: Expression) + extends UnaryExpression with NullIntolerant with QueryErrorsBase { + + val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType + val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType + + override def child: Expression = base + + override def dataType: DataType = base.dataType + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case m: MapType if RowOrdering.isOrderable(m.keyType) => + TypeCheckResult.TypeCheckSuccess + case _: MapType => + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(base.dataType) + ) + ) + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(MapType), + "inputSql" -> toSQLExpr(base), + "inputType" -> toSQLType(base.dataType)) + ) + } + + override def nullSafeEval(array: Any): Any = { + // put keys and their respective values inside a tuple and sort them + // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map + + val mapData = array.asInstanceOf[MapData] + val numElements = mapData.numElements() + val keys = mapData.keyArray() + val values = mapData.valueArray() + + val ordering = PhysicalDataType.ordering(keyType) + + val sortedMap = Array + .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], + values.get(i, valueType).asInstanceOf[Any])) + .sortBy(_._1)(ordering) + + new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)), + new GenericArrayData(sortedMap.map(_._2))) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b)) + } + + private def sortCodegen(ctx: CodegenContext, ev: ExprCode, + base: String): String = { + + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val genericArrayData = classOf[GenericArrayData].getName + + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val sortArray = ctx.freshName("sortArray") + val i = ctx.freshName("i") + val o1 = ctx.freshName("o1") + val o1entry = ctx.freshName("o1entry") + val o2 = ctx.freshName("o2") + val o2entry = ctx.freshName("o2entry") + val c = ctx.freshName("c") + val newKeys = ctx.freshName("newKeys") + val newValues = ctx.freshName("newValues") + + val boxedKeyType = CodeGenerator.boxedType(keyType) + val boxedValueType = CodeGenerator.boxedType(valueType) + val javaKeyType = CodeGenerator.javaType(keyType) + + val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>" + + val comp = if (CodeGenerator.isPrimitiveType(keyType)) { + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value(); + |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value(); + |int $c = ${ctx.genComp(keyType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};" + } + + s""" + |final int $numElements = $base.numElements(); + |ArrayData $keys = $base.keyArray(); + |ArrayData $values = $base.valueArray(); + | + |Object[] $sortArray = new Object[$numElements]; + | + |for (int $i = 0; $i < $numElements; $i++) { + | $sortArray[$i] = new $simpleEntryType( + | ${CodeGenerator.getValue(keys, keyType, i)}, + | ${CodeGenerator.getValue(values, valueType, i)}); + |} + | + |java.util.Arrays.sort($sortArray, new java.util.Comparator() { + | @Override public int compare(Object $o1entry, Object $o2entry) { + | Object $o1 = (($simpleEntryType) $o1entry).getKey(); + | Object $o2 = (($simpleEntryType) $o2entry).getKey(); + | $comp; + | return $c; + | } + |}); + | + |Object[] $newKeys = new Object[$numElements]; + |Object[] $newValues = new Object[$numElements]; + | + |for (int $i = 0; $i < $numElements; $i++) { + | $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey(); + | $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue(); + |} + | + |${ev.value} = new $arrayBasedMapData( + | new $genericArrayData($newKeys), new $genericArrayData($newValues)); + |""".stripMargin + } + + override protected def withNewChildInternal(newChild: Expression) + : MapSort = copy(base = newChild) +} /** * Common base class for [[SortArray]] and [[ArraySort]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 133e27c5b0a6..d14118eb3f1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -421,6 +421,29 @@ class CollectionExpressionsSuite ) } + test("Sort Map") { + val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), MapType(IntegerType, IntegerType)) + val boolKey = Literal.create(Map(true -> 2, false -> 1), MapType(BooleanType, IntegerType)) + val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3), + MapType(StringType, IntegerType)) + val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3), + MapType(ArrayType(IntegerType), IntegerType)) + val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1, Seq(Seq(3)) -> 3), + MapType(ArrayType(ArrayType(IntegerType)), IntegerType)) + val structKey = Literal.create( + Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), + MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) + + checkEvaluation(MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(MapSort(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(MapSort(nestedArrayKey), + Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) + checkEvaluation(MapSort(structKey), + Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))