Skip to content

Commit 452d24a

Browse files
authored
merging cherry picked commit 21fde57… (#188)
* merging cherry picked commit 21fde57 from apache/spark master for supporting multi line json parsing * provided a single argument constructor so that existing code of snappydata written for spark 2.1 works correctly * fixed scala style failure
1 parent b3531f6 commit 452d24a

File tree

17 files changed

+879
-254
lines changed

17 files changed

+879
-254
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,40 @@ public void writeTo(ByteBuffer buffer) {
152152
buffer.position(pos + numBytes);
153153
}
154154

155+
/**
156+
* Returns a {@link ByteBuffer} wrapping the base object if it is a byte array
157+
* or a copy of the data if the base object is not a byte array.
158+
*
159+
* Unlike getBytes this will not create a copy the array if this is a slice.
160+
*/
161+
public @Nonnull ByteBuffer getByteBuffer() {
162+
if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
163+
final byte[] bytes = (byte[]) base;
164+
165+
// the offset includes an object header... this is only needed for unsafe copies
166+
final long arrayOffset = offset - BYTE_ARRAY_OFFSET;
167+
168+
// verify that the offset and length points somewhere inside the byte array
169+
// and that the offset can safely be truncated to a 32-bit integer
170+
if ((long) bytes.length < arrayOffset + numBytes) {
171+
throw new ArrayIndexOutOfBoundsException();
172+
}
173+
174+
return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes);
175+
} else {
176+
return ByteBuffer.wrap(getBytes());
177+
}
178+
}
179+
180+
public void writeTo(OutputStream out) throws IOException {
181+
final ByteBuffer bb = this.getByteBuffer();
182+
assert(bb.hasArray());
183+
184+
// similar to Utils.writeByteBuffer but without the spark-core dependency
185+
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
186+
}
187+
188+
155189
/**
156190
* Returns the number of bytes for a code point with the first byte as `b`
157191
* @param b The first byte of a code point

core/src/main/scala/org/apache/spark/input/PortableDataStream.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil
2929

3030
import org.apache.spark.internal.config
3131
import org.apache.spark.SparkContext
32+
import org.apache.spark.annotation.Since
3233

3334
/**
3435
* A general format for reading whole files in as streams, byte arrays,
@@ -175,6 +176,7 @@ class PortableDataStream(
175176
* Create a new DataInputStream from the split and context. The user of this method is responsible
176177
* for closing the stream after usage.
177178
*/
179+
@Since("1.2.0")
178180
def open(): DataInputStream = {
179181
val pathp = split.getPath(index)
180182
val fs = pathp.getFileSystem(conf)
@@ -184,6 +186,7 @@ class PortableDataStream(
184186
/**
185187
* Read the file as a byte array
186188
*/
189+
@Since("1.2.0")
187190
def toArray(): Array[Byte] = {
188191
val stream = open()
189192
try {
@@ -193,6 +196,10 @@ class PortableDataStream(
193196
}
194197
}
195198

199+
@Since("1.2.0")
196200
def getPath(): String = path
201+
202+
@Since("2.2.0")
203+
def getConfiguration: Configuration = conf
197204
}
198205

python/pyspark/sql/readwriter.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,14 @@ def load(self, path=None, format=None, schema=None, **options):
158158
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
159159
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
160160
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
161-
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None):
161+
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
162+
timeZone=None, wholeFile=None):
163+
162164
"""
163-
Loads a JSON file (`JSON Lines text format or newline-delimited JSON
164-
<http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
165-
record) and returns the result as a :class`DataFrame`.
165+
Loads a JSON file and returns the results as a :class:`DataFrame`.
166+
167+
Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
168+
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
166169
167170
If the ``schema`` parameter is not specified, this function goes
168171
through the input once to determine the input schema.
@@ -208,7 +211,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
208211
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
209212
formats follow the formats at ``java.text.SimpleDateFormat``.
210213
This applies to timestamp type. If None is set, it uses the
211-
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
214+
215+
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
216+
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
217+
If None is set, it uses the default value, session local timezone.
218+
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
219+
set, it uses the default value, ``false``.
212220
213221
>>> df1 = spark.read.json('python/test_support/sql/people.json')
214222
>>> df1.dtypes
@@ -225,7 +233,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
225233
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
226234
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
227235
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
228-
timestampFormat=timestampFormat)
236+
237+
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
229238
if isinstance(path, basestring):
230239
path = [path]
231240
if type(path) == list:

python/pyspark/sql/streaming.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,14 @@ def load(self, path=None, format=None, schema=None, **options):
428428
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
429429
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
430430
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
431-
mode=None, columnNameOfCorruptRecord=None, dateFormat=None,
432-
timestampFormat=None):
431+
432+
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
433+
timeZone=None, wholeFile=None):
433434
"""
434-
Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
435-
<http://jsonlines.org/>`_) and returns a :class`DataFrame`.
435+
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
436+
437+
Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
438+
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
436439
437440
If the ``schema`` parameter is not specified, this function goes
438441
through the input once to determine the input schema.
@@ -480,7 +483,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
480483
:param timestampFormat: sets the string that indicates a timestamp format. Custom date
481484
formats follow the formats at ``java.text.SimpleDateFormat``.
482485
This applies to timestamp type. If None is set, it uses the
483-
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
486+
487+
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
488+
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
489+
If None is set, it uses the default value, session local timezone.
490+
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
491+
set, it uses the default value, ``false``.
492+
484493
485494
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
486495
>>> json_sdf.isStreaming
@@ -494,7 +503,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
494503
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
495504
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
496505
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
497-
timestampFormat=timestampFormat)
506+
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
507+
498508
if isinstance(path, basestring):
499509
return self._df(self._jreader.json(path))
500510
else:

python/pyspark/sql/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@ def test_udf_with_order_by_and_limit(self):
427427
res.explain(True)
428428
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
429429

430+
def test_wholefile_json(self):
431+
from pyspark.sql.types import StringType
432+
people1 = self.spark.read.json("python/test_support/sql/people.json")
433+
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
434+
wholeFile=True)
435+
self.assertEqual(people1.collect(), people_array.collect())
436+
430437
def test_udf_with_input_file_name(self):
431438
from pyspark.sql.functions import udf, input_file_name
432439
from pyspark.sql.types import StringType
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[
2+
{
3+
"name": "Michael"
4+
},
5+
{
6+
"name": "Andy",
7+
"age": 30
8+
},
9+
{
10+
"name": "Justin",
11+
"age": 19
12+
}
13+
]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,18 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
491491
lazy val parser =
492492
new JacksonParser(
493493
schema,
494-
"invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
495-
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE)))
494+
495+
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE)))
496496

497497
override def dataType: DataType = schema
498498

499499
override def nullSafeEval(json: Any): Any = {
500-
try parser.parse(json.toString).headOption.orNull catch {
500+
try {
501+
parser.parse(
502+
json.asInstanceOf[UTF8String],
503+
CreateJacksonParser.utf8String,
504+
identity[UTF8String]).headOption.orNull
505+
} catch {
501506
case _: SparkSQLJsonProcessingException => null
502507
}
503508
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.json
19+
20+
import java.io.InputStream
21+
22+
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
23+
import org.apache.hadoop.io.Text
24+
25+
import org.apache.spark.unsafe.types.UTF8String
26+
27+
private[sql] object CreateJacksonParser extends Serializable {
28+
def string(jsonFactory: JsonFactory, record: String): JsonParser = {
29+
jsonFactory.createParser(record)
30+
}
31+
32+
def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
33+
val bb = record.getByteBuffer
34+
assert(bb.hasArray)
35+
36+
jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
37+
}
38+
39+
def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
40+
jsonFactory.createParser(record.getBytes, 0, record.getLength)
41+
}
42+
43+
def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
44+
jsonFactory.createParser(record)
45+
}
46+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,28 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
3131
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
3232
*/
3333
private[sql] class JSONOptions(
34-
@transient private val parameters: CaseInsensitiveMap)
34+
35+
@transient private val parameters: CaseInsensitiveMap,
36+
37+
defaultColumnNameOfCorruptRecord: String)
3538
extends Logging with Serializable {
3639

37-
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
40+
def this(
41+
parameters: Map[String, String],
42+
43+
defaultColumnNameOfCorruptRecord: String = "") = {
44+
this(
45+
new CaseInsensitiveMap(parameters),
46+
defaultColumnNameOfCorruptRecord)
47+
}
48+
49+
// provided a constructor so that existing code of snappydata compatible with spark 2.1 continues
50+
// to work
51+
def this(
52+
parameters: Map[String, String]) = {
53+
this(
54+
new CaseInsensitiveMap(parameters), "")
55+
}
3856

3957
val samplingRatio =
4058
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
@@ -56,7 +74,8 @@ private[sql] class JSONOptions(
5674
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
5775
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
5876
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
59-
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
77+
val columnNameOfCorruptRecord =
78+
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
6079

6180
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
6281
val dateFormat: FastDateFormat =
@@ -66,6 +85,8 @@ private[sql] class JSONOptions(
6685
FastDateFormat.getInstance(
6786
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US)
6887

88+
val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
89+
6990
// Parse mode flags
7091
if (!ParseModes.isValidMode(parseMode)) {
7192
logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")

0 commit comments

Comments
 (0)