From 11ac4438e12dc01ba252304da8793077280f3067 Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Wed, 7 Dec 2016 23:32:14 +0000 Subject: [PATCH] [SPARK-18772][SQL] NaN/Infinite float parsing in JSON is inconsistent --- .../sql/catalyst/json/JacksonParser.scala | 39 +++++++++---------- .../datasources/json/JsonSuite.scala | 33 ++++++++++++++++ 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index e476cb11a351..893a9d1bf745 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -155,6 +155,17 @@ class JacksonParser( case _ => makeConverter(dataType) } + private object SpecialDouble { + def unapply(value: String): Option[Double] = { + value.toLowerCase match { + case "nan" => Some(Double.NaN) + case "infinity" | "+infinity" | "inf" | "+inf" => Some(Double.PositiveInfinity) + case "-infinity" | "-inf" => Some(Double.NegativeInfinity) + case _ => None + } + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. @@ -193,16 +204,10 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toFloat - } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + parser.getText match { + case SpecialDouble(value) => value.toFloat + case _ => throw new SparkSQLJsonProcessingException( + s"Cannot parse ${parser.getText} as FloatType.") } } @@ -213,16 +218,10 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toDouble - } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + parser.getText match { + case SpecialDouble(value) => value + case _ => throw new SparkSQLJsonProcessingException( + s"Cannot parse ${parser.getText} as DoubleType.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 598e44ec8c19..3a516de15c32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1764,4 +1764,37 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val df2 = spark.read.option("PREfersdecimaL", "true").json(records) assert(df2.schema == schema) } + + test("SPARK-18772: Special floats") { + val records = sparkContext + .parallelize( + """{"a": "NaN"}""" :: + """{"a": "nAn"}""" :: + """{"a": "-iNf"}""" :: + """{"a": "inF"}""" :: + """{"a": "+Inf"}""" :: + """{"a": "-iNfInity"}""" :: + """{"a": "InFiNiTy"}""" :: + """{"a": "+InfiNitY"}""" :: + """{"a": "+Infi"}""" :: + Nil) + + for (dt <- Seq(FloatType, DoubleType)) { + val res = spark.read + .schema(StructType(Seq(StructField("a", dt)))) + .json(records) + .select($"a".cast(DoubleType).as[java.lang.Double]) + .collect() + assert(res.length === 9) + assert(res(0).isNaN) + assert(res(1).isNaN) + assert(res(2).toDouble.isNegInfinity) + assert(res(3).toDouble.isPosInfinity) + assert(res(4).toDouble.isPosInfinity) + assert(res(5).toDouble.isNegInfinity) + assert(res(6).toDouble.isPosInfinity) + assert(res(7).toDouble.isPosInfinity) + assert(res(8) eq null) + } + } }