Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
val castedLit = lit.dataType match {
case CalendarIntervalType =>
val calendarInterval = lit.value.asInstanceOf[CalendarInterval]
if (calendarInterval.months > 0) {
if (calendarInterval.months != 0) {
invalid = true
logWarning(
s"Failed to extract state value watermark from condition $exprToCollectFrom " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ object TimeWindow {
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
val cal = CalendarInterval.fromCaseInsensitiveString(interval)
if (cal.months > 0) {
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import java.util.concurrent.TimeUnit

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -29,8 +30,7 @@ object EventTimeWatermark {

def getDelayMs(delay: CalendarInterval): Long = {
// We define month as `31 days` to simplify calculation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be removed, as no 31 can be seen here.

val millisPerMonth = TimeUnit.MICROSECONDS.toMillis(CalendarInterval.MICROS_PER_DAY) * 31
delay.milliseconds + delay.months * millisPerMonth
IntervalUtils.getDuration(delay, 31, TimeUnit.MILLISECONDS)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.util

import java.util.concurrent.TimeUnit

import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -88,4 +90,32 @@ object IntervalUtils {
result += MICROS_PER_MONTH * (interval.months % MONTHS_PER_YEAR)
Decimal(result, 18, 6)
}

/**
* Gets interval duration
*
* @param cal - the interval to get duration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems "interval" is a better name.

* @param daysPerMonth - the number of days per one month
* @param targetUnit - time units of the result
* @return duration in the specified time units
*/
def getDuration(
cal: CalendarInterval,
daysPerMonth: Int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be an optional final argument and default to 31. For now, when it's always 31, maybe that's simpler.

targetUnit: TimeUnit): Long = {
val monthsDuration = Math.multiplyExact(daysPerMonth * DateTimeUtils.MICROS_PER_DAY, cal.months)
val result = Math.addExact(cal.microseconds, monthsDuration)
targetUnit.convert(result, TimeUnit.MICROSECONDS)
}

/**
* Checks the interval is negative
*
* @param cal - the checked interval
* @param daysPerMonth - the number of days per one month
* @return true if duration of the given interval is less than 0 otherwise false
*/
def isNegative(cal: CalendarInterval, daysPerMonth: Int): Boolean = {
getDuration(cal, daysPerMonth, TimeUnit.MICROSECONDS) < 0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import java.util.concurrent.TimeUnit

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.unsafe.types.CalendarInterval

class IntervalUtilsSuite extends SparkFunSuite with Matchers {
test("interval duration") {
def duration(s: String, daysPerMonth: Int, unit: TimeUnit): Long = {
getDuration(CalendarInterval.fromString(s), daysPerMonth, unit)
}

assert(duration("0 seconds", 31, TimeUnit.MILLISECONDS) === 0)
assert(duration("1 month", 31, TimeUnit.DAYS) === 31)
assert(duration("1 microsecond", 30, TimeUnit.MICROSECONDS) === 1)
assert(duration("1 month -30 days", 31, TimeUnit.DAYS) === 1)

try {
duration(Integer.MAX_VALUE + " month", 31, TimeUnit.SECONDS)
fail("Expected to throw an exception for the invalid input")
} catch {
case e: ArithmeticException =>
assert(e.getMessage.contains("overflow"))
}
}

test("negative interval") {
def isNegative(s: String, daysPerMonth: Int): Boolean = {
IntervalUtils.isNegative(CalendarInterval.fromString(s), daysPerMonth)
}

assert(isNegative("-1 months", 28))
assert(isNegative("-1 microsecond", 30))
assert(isNegative("-1 month 30 days", 31))
assert(isNegative("2 months -61 days", 30))
assert(isNegative("-1 year -2 seconds", 30))
assert(!isNegative("0 months", 28))
assert(!isNegative("1 year -360 days", 31))
assert(!isNegative("-1 year 380 days", 31))

}
}
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -731,7 +732,7 @@ class Dataset[T] private[sql](
s"Unable to parse time delay '$delayThreshold'",
cause = Some(e))
}
require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0,
require(!IntervalUtils.isNegative(parsedDelay, 31),
s"delay threshold ($delayThreshold) should not be negative.")
EliminateEventTimeWatermark(
EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.Date
import java.util.concurrent.TimeUnit

import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -160,12 +161,12 @@ private[sql] class GroupStateImpl[S] private(

private def parseDuration(duration: String): Long = {
val cal = CalendarInterval.fromCaseInsensitiveString(duration)
if (cal.milliseconds < 0 || cal.months < 0) {
throw new IllegalArgumentException(s"Provided duration ($duration) is not positive")
val daysPerMonth = 31
if (IntervalUtils.isNegative(cal, daysPerMonth)) {
throw new IllegalArgumentException(s"Provided duration ($duration) is negative")
}

val millisPerMonth = TimeUnit.MICROSECONDS.toMillis(CalendarInterval.MICROS_PER_DAY) * 31
cal.milliseconds + cal.months * millisPerMonth
IntervalUtils.getDuration(cal, daysPerMonth, TimeUnit.MILLISECONDS)
}

private def checkTimeoutTimestampAllowed(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private object Triggers {

def convert(interval: String): Long = {
val cal = CalendarInterval.fromCaseInsensitiveString(interval)
if (cal.months > 0) {
if (cal.months != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like another way of converting interval to duration: make sure the months field is 0. Shall we also take it into account in the new getDuration method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change getDuration() to:

  def getDuration(
      interval: CalendarInterval,
      targetUnit: TimeUnit,
      daysPerMonth: Option[Int] = Some(31)): Long = {
    val monthsDuration = daysPerMonth
      .map { days =>
        Math.multiplyExact(days * DateTimeUtils.MICROS_PER_DAY, interval.months)
      }.getOrElse {
        if (interval.months == 0) {
          0L
        } else {
          throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
        }
      }
    val result = Math.addExact(interval.microseconds, monthsDuration)
    targetUnit.convert(result, TimeUnit.MICROSECONDS)
  }

and call getDuration(cal, TimeUnit.MILLISECONDS, None)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I am not sure that this check should be inside of getDuration()

throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
TimeUnit.MICROSECONDS.toMillis(cal.microseconds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false)
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
state.setTimeoutDuration("-1 month 31 days 1 second")
assert(state.getTimeoutTimestamp === 2000)
state.setTimeoutDuration(500)
assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
Expand Down Expand Up @@ -225,8 +227,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
testIllegalTimeout {
state.setTimeoutDuration("-1 month")
}

testIllegalTimeout {
state.setTimeoutDuration("1 month -1 day")
state.setTimeoutDuration("1 month -31 day")
}

state = GroupStateImpl.createForStreaming(
Expand All @@ -241,7 +244,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
state.setTimeoutTimestamp(10000, "-1 month")
}
testIllegalTimeout {
state.setTimeoutTimestamp(10000, "1 month -1 day")
state.setTimeoutTimestamp(10000, "1 month -32 day")
}
testIllegalTimeout {
state.setTimeoutTimestamp(new Date(-10000))
Expand All @@ -253,7 +256,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
state.setTimeoutTimestamp(new Date(-10000), "-1 month")
}
testIllegalTimeout {
state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day")
state.setTimeoutTimestamp(new Date(-10000), "1 month -32 day")
}
}

Expand Down