Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* @since 3.0.0
*/
@Unstable
public final class CalendarInterval implements Serializable {
public final class CalendarInterval implements Serializable, Comparable<CalendarInterval> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Please check the behavior
#27262
I'm not sure. @yaooqinn

// NOTE: If you're moving or renaming this file, you should also update Unidoc configuration
// specified in 'SparkBuild.scala'.
public final int months;
Expand Down Expand Up @@ -127,4 +127,15 @@ private void appendUnit(StringBuilder sb, long value, String unit) {
* @throws ArithmeticException if a numeric overflow occurs
*/
public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); }

@Override
public int compareTo(CalendarInterval o) {
if (this.months != o.months) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing intervals does not necessarily short circuits via months. We could result in 1 month > 0 months 32 days, which is wrong, obviously.

Besides, 1 month can be 28 ~ 30 days, making the legacy calendar interval type uncomparable

Copy link
Contributor

@cloud-fan cloud-fan Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some comments to explain that this is alphabet ordering. It does not have actual meaning but just makes it possible to find identical interval instances.

We should do the same thing for map type so that we can group by map values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stefankandic did you generate this using IDEA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comments.

@cloud-fan method was generated by intellij but I implemented the logic

return Integer.compare(this.months, o.months);
} else if (this.days != o.days) {
return Integer.compare(this.days, o.days);
} else {
return Long.compare(this.microseconds, o.microseconds);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ public void toStringTest() {
i.toString());
}

@Test
public void compareToTest() {
CalendarInterval i = new CalendarInterval(0, 0, 0);

assertEquals(i.compareTo(new CalendarInterval(0, 0, 0)), 0);
assertEquals(i.compareTo(new CalendarInterval(0, 0, 1)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 1, 0)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 1, -1)), -1);
assertEquals(i.compareTo(new CalendarInterval(1, 0, 0)), -1);
assertEquals(i.compareTo(new CalendarInterval(1, 0, -1)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 0, -1)), 1);
assertEquals(i.compareTo(new CalendarInterval(0, -1, 0)), 1);
assertEquals(i.compareTo(new CalendarInterval(-1, 0, 0)), 1);
assertEquals(i.compareTo(new CalendarInterval(-1, 0, 1)), 1);
}

@Test
public void periodAndDurationTest() {
CalendarInterval interval = new CalendarInterval(120, -40, 123456);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
import org.apache.spark.sql.types.{CalendarIntervalType, DataType, MapType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

object ExprUtils extends QueryErrorsBase {
Expand Down Expand Up @@ -193,8 +193,8 @@ object ExprUtils extends QueryErrorsBase {
messageParameters = Map("sqlExpr" -> expr.sql))
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
// Check if the data type of expr can be used in group by
if (!canBeUsedInGroupBy(expr.dataType)) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
messageParameters = Map(
Expand All @@ -217,4 +217,12 @@ object ExprUtils extends QueryErrorsBase {
a.groupingExpressions.foreach(checkValidGroupingExprs)
a.aggregateExpressions.foreach(checkValidAggregateExpression)
}

/**
* Returns whether the data type can be used in group by
*/
def canBeUsedInGroupBy(dt: DataType): Boolean = dt match {
case CalendarIntervalType => true
case _ => RowOrdering.isOrderable(dt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ class CodegenContext extends Logging {
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case CalendarIntervalType => s"$c1.equals($c2)"
case NullType => "false"
case _ =>
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError(
Expand All @@ -652,6 +653,7 @@ class CodegenContext extends Logging {
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)"
case CalendarIntervalType => s"$c1.compareTo($c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalNumericType}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.types._
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
object TypeUtils extends QueryErrorsBase {

def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
if (RowOrdering.isOrderable(dt)) {
if (ExprUtils.canBeUsedInGroupBy(dt)) {
TypeCheckResult.TypeCheckSuccess
} else {
DataTypeMismatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ abstract class HashMapGenerator(
"""
}
case StringType => hashBytes(s"$input.getBytes()")
case CalendarIntervalType => hashInt(s"$input.hashCode()")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.tags.SlowHiveTest
import org.apache.spark.unsafe.types.CalendarInterval

@SlowHiveTest
class ObjectHashAggregateSuite
Expand Down Expand Up @@ -457,4 +458,31 @@ class ObjectHashAggregateSuite
)
}
}

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
withSQLConf(
SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"
) {
val numRows = 50

assert(
(1 to numRows)
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
.toDF("c0")
.groupBy("c0")
.agg(count("*"))
.count() == 1
)

assert(
(1 to numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")
.groupBy("c0")
.agg(count("*"))
.count() == numRows
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.tags.SlowHiveTest
import org.apache.spark.unsafe.types.CalendarInterval

@SlowHiveTest
class SortAggregateSuite
extends QueryTest
with SQLTestUtils
with TestHiveSingleton
with ExpressionEvalHelper
with AdaptiveSparkPlanHelper {

import testImplicits._

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
// forces the use of sort aggregate by using min/max functions

val numRows = 50
val numRepeat = 25

val df = (0 to numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")

for (_ <- 0 until numRepeat) {
val shuffledDf = df.orderBy(rand())

checkAnswer(
shuffledDf.agg(max("c0")),
Row(new CalendarInterval(numRows, numRows, numRows))
)

checkAnswer(
shuffledDf.agg(min("c0")),
Row(new CalendarInterval(0, 0, 0))
)
}
}
}