diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 911c5c7f646f..28fb64f7cd0e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.types; import java.io.Serializable; -import java.util.Locale; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -73,45 +72,53 @@ private static long toLong(String s) { * This method is case-insensitive. */ public static CalendarInterval fromString(String s) { - if (s == null) { - return null; - } - s = s.trim(); - Matcher m = p.matcher(s); - if (!m.matches() || s.compareToIgnoreCase("interval") == 0) { + try { + return fromCaseInsensitiveString(s); + } catch (IllegalArgumentException e) { return null; - } else { - long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); - long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; - microseconds += toLong(m.group(4)) * MICROS_PER_DAY; - microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; - microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; - microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; - microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; - microseconds += toLong(m.group(9)); - return new CalendarInterval((int) months, microseconds); } } /** - * Convert a string to CalendarInterval. Unlike fromString, this method can handle + * Convert a string to CalendarInterval. This method can handle * strings without the `interval` prefix and throws IllegalArgumentException * when the input string is not a valid interval. * * @throws IllegalArgumentException if the string is not a valid internal. */ public static CalendarInterval fromCaseInsensitiveString(String s) { - if (s == null || s.trim().isEmpty()) { - throw new IllegalArgumentException("Interval cannot be null or blank."); + if (s == null) { + throw new IllegalArgumentException("Interval cannot be null"); } - String sInLowerCase = s.trim().toLowerCase(Locale.ROOT); - String interval = - sInLowerCase.startsWith("interval ") ? sInLowerCase : "interval " + sInLowerCase; - CalendarInterval cal = fromString(interval); - if (cal == null) { + String trimmed = s.trim(); + if (trimmed.isEmpty()) { + throw new IllegalArgumentException("Interval cannot be blank"); + } + String prefix = "interval"; + String intervalStr = trimmed; + // Checks the given interval string does not start with the `interval` prefix + if (!intervalStr.regionMatches(true, 0, prefix, 0, prefix.length())) { + // Prepend `interval` if it does not present because + // the regular expression strictly require it. + intervalStr = prefix + " " + trimmed; + } else if (intervalStr.length() == prefix.length()) { + throw new IllegalArgumentException("Interval string must have time units"); + } + + Matcher m = p.matcher(intervalStr); + if (!m.matches()) { throw new IllegalArgumentException("Invalid interval: " + s); } - return cal; + + long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); + long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; + microseconds += toLong(m.group(4)) * MICROS_PER_DAY; + microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; + microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; + microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; + microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; + microseconds += toLong(m.group(9)); + return new CalendarInterval((int) months, microseconds); } public static long toLongWithRange(String fieldName, diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 6ccc65f7d174..587071332ce4 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -19,6 +19,8 @@ import org.junit.Test; +import java.util.Arrays; + import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.CalendarInterval.*; @@ -72,36 +74,26 @@ public void fromStringTest() { testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); testSingleUnit("microsecond", 3, 0, 3); - String input; - - input = "interval -5 years 23 month"; CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); - assertEquals(fromString(input), result); - - input = "interval -5 years 23 month "; - assertEquals(fromString(input), result); - - input = " interval -5 years 23 month "; - assertEquals(fromString(input), result); + Arrays.asList( + "interval -5 years 23 month", + " -5 years 23 month", + "interval -5 years 23 month ", + " -5 years 23 month ", + " interval -5 years 23 month ").forEach(input -> + assertEquals(fromString(input), result) + ); // Error cases - input = "interval 3month 1 hour"; - assertNull(fromString(input)); - - input = "interval 3 moth 1 hour"; - assertNull(fromString(input)); - - input = "interval"; - assertNull(fromString(input)); - - input = "int"; - assertNull(fromString(input)); - - input = ""; - assertNull(fromString(input)); - - input = null; - assertNull(fromString(input)); + Arrays.asList( + "interval 3month 1 hour", + "3month 1 hour", + "interval 3 moth 1 hour", + "3 moth 1 hour", + "interval", + "int", + "", + null).forEach(input -> assertNull(fromString(input))); } @Test @@ -115,7 +107,9 @@ public void fromCaseInsensitiveStringTest() { fromCaseInsensitiveString(input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("cannot be null or blank")); + String msg = e.getMessage(); + if (input == null) assertTrue(msg.contains("cannot be null")); + else assertTrue(msg.contains("cannot be blank")); } } @@ -124,7 +118,12 @@ public void fromCaseInsensitiveStringTest() { fromCaseInsensitiveString(input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("Invalid interval")); + String msg = e.getMessage(); + if (input.trim().equalsIgnoreCase("interval")) { + assertTrue(msg.contains("Interval string must have time units")); + } else { + assertTrue(msg.contains("Invalid interval:")); + } } } } @@ -268,11 +267,13 @@ public void subtractTest() { } private static void testSingleUnit(String unit, int number, int months, long microseconds) { - String input1 = "interval " + number + " " + unit; - String input2 = "interval " + number + " " + unit + "s"; - CalendarInterval result = new CalendarInterval(months, microseconds); - assertEquals(fromString(input1), result); - assertEquals(fromString(input2), result); + Arrays.asList("interval ", "").forEach(prefix -> { + String input1 = prefix + number + " " + unit; + String input2 = prefix + number + " " + unit + "s"; + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(fromString(input1), result); + assertEquals(fromString(input2), result); + }); } @Test diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 96ef3a558b85..fc7a0d3af4e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -672,6 +672,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { "interval 1 years 3 months -3 days") checkEvaluation(Cast(Literal("INTERVAL 1 Second 1 microsecond"), CalendarIntervalType), new CalendarInterval(0, 1000001)) + checkEvaluation(Cast(Literal("1 MONTH 1 Microsecond"), CalendarIntervalType), + new CalendarInterval(1, 1)) } test("cast string to boolean") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 5da2bf059758..c2e80c639f43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -432,8 +432,9 @@ class ExpressionParserSuite extends AnalysisTest { intercept("timestamP '2016-33-11 20:54:00.000'") // Interval. - assertEqual("InterVal 'interval 3 month 1 hour'", - Literal(CalendarInterval.fromString("interval 3 month 1 hour"))) + val intervalLiteral = Literal(CalendarInterval.fromString("interval 3 month 1 hour")) + assertEqual("InterVal 'interval 3 month 1 hour'", intervalLiteral) + assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral) assertEqual("Interval 'interval 3 monthsss 1 hoursss'", Literal(null, CalendarIntervalType)) diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index e1e8d685e878..aef23963da37 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -435,6 +435,6 @@ interval 3 years 1 hours -- !query 45 select interval '3 year 1 hour' -- !query 45 schema -struct +struct -- !query 45 output -NULL +interval 3 years 1 hours