Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import java.util.Properties
import scala.collection.JavaConverters._

import org.apache.spark.annotation.Stable
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -324,6 +326,20 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
format("json").load(paths: _*)
}

/**
* Loads a `Dataset[String]` storing JSON objects (<a href="http://jsonlines.org/">JSON Lines
* text format or newline-delimited JSON</a>) and returns the result as a `DataFrame`.
*
* Unless the schema is specified using `schema` function, this function goes through the input
* once to determine the input schema.
*
* @param jsonDataset
* input Dataset with one JSON object per record
* @since 3.4.0
*/
def json(jsonDataset: Dataset[String]): DataFrame =
parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON)

/**
* Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
* overloaded `csv()` method for more details.
Expand Down Expand Up @@ -351,6 +367,29 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
@scala.annotation.varargs
def csv(paths: String*): DataFrame = format("csv").load(paths: _*)

/**
* Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
*
* If the schema is not specified using `schema` function and `inferSchema` option is enabled,
* this function goes through the input once to determine the input schema.
*
* If the schema is not specified using `schema` function and `inferSchema` option is disabled,
* it determines the columns as string types and it reads only the first line to determine the
* names and the number of fields.
*
* If the enforceSchema is set to `false`, only the CSV header in the first line is checked to
* conform specified or inferred schema.
*
* @note
* if `header` option is set to `true` when calling this API, all lines same with the header
* will be removed if exists.
* @param csvDataset
* input Dataset with one CSV row per record
* @since 3.4.0
*/
def csv(csvDataset: Dataset[String]): DataFrame =
parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV)

/**
* Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the
* other overloaded `parquet()` method for more details.
Expand Down Expand Up @@ -504,6 +543,19 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}

private def parse(ds: Dataset[String], format: ParseFormat): DataFrame = {
sparkSession.newDataFrame { builder =>
val parseBuilder = builder.getParseBuilder
.setInput(ds.plan.getRoot)
.setFormat(format)
userSpecifiedSchema.foreach(schema =>
parseBuilder.setSchema(DataTypeProtoConverter.toConnectProtoType(schema)))
extraOptions.foreach { case (k, v) =>
parseBuilder.putOptions(k, v)
}
}
}

///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics

import org.apache.spark.SPARK_VERSION
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
import org.apache.spark.sql.functions.{aggregate, array, broadcast, col, count, lit, rand, sequence, shuffle, struct, transform, udf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -644,6 +647,67 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
.collect()
assert(result sameElements expected)
}

test("json from Dataset[String] inferSchema") {
val session = spark
import session.implicits._
val expected = Seq(
new GenericRowWithSchema(
Array(73, "Shandong", "Kong"),
new StructType().add("age", LongType).add("city", StringType).add("name", StringType)))
val ds = Seq("""{"name":"Kong","age":73,"city":'Shandong'}""").toDS()
val result = spark.read.option("allowSingleQuotes", "true").json(ds)
checkSameResult(expected, result)
Copy link
Contributor Author

@LuciferYang LuciferYang Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def json(jsonDataset: Dataset[String]): DataFrame = {
val parsedOptions = new JSONOptions(
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
userSpecifiedSchema.foreach(checkJsonSchema)
val schema = userSpecifiedSchema.map {
case s if !SQLConf.get.getConf(
SQLConf.LEGACY_RESPECT_NULLABILITY_IN_TEXT_DATASET_CONVERSION) => s.asNullable
case other => other
}.getOrElse {
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
}

From the code of the server side, userSpecifiedSchema is an Option[StructType] and default is None, so I think we can use it without specifying theuserSpecifiedSchema for this function? Or is my test case not the correct scenario?

@zhengruifeng

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, you are right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should add the user provided schema in the message? Or always discard it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will inferFromDataset trigger an job? If so, I think we’d better skip it if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think you are right, we should add schema to the message if it exists, thanks ~ I will update it later

}

test("json from Dataset[String] with schema") {
val session = spark
import session.implicits._
val schema = new StructType().add("city", StringType).add("name", StringType)
val expected = Seq(new GenericRowWithSchema(Array("Shandong", "Kong"), schema))
val ds = Seq("""{"name":"Kong","age":73,"city":'Shandong'}""").toDS()
val result = spark.read.schema(schema).option("allowSingleQuotes", "true").json(ds)
checkSameResult(expected, result)
}

test("json from Dataset[String] with invalid schema") {
val message = intercept[ParseException] {
spark.read.schema("123").json(spark.createDataset(Seq.empty[String])(StringEncoder))
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}

test("csv from Dataset[String] inferSchema") {
val session = spark
import session.implicits._
val expected = Seq(
new GenericRowWithSchema(
Array("Meng", 84, "Shandong"),
new StructType().add("name", StringType).add("age", LongType).add("city", StringType)))
val ds = Seq("name,age,city", """"Meng",84,"Shandong"""").toDS()
val result = spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(ds)
checkSameResult(expected, result)
}

test("csv from Dataset[String] with schema") {
val session = spark
import session.implicits._
val schema = new StructType().add("name", StringType).add("age", LongType)
val expected = Seq(new GenericRowWithSchema(Array("Meng", 84), schema))
val ds = Seq(""""Meng",84,"Shandong"""").toDS()
val result = spark.read.schema(schema).csv(ds)
checkSameResult(expected, result)
}

test("csv from Dataset[String] with invalid schema") {
val message = intercept[ParseException] {
spark.read.schema("123").csv(spark.createDataset(Seq.empty[String])(StringEncoder))
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}
}

private[sql] case class MyType(id: Long, a: Double, b: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{functions => fn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.expressions.Window
Expand Down Expand Up @@ -254,6 +255,13 @@ class PlanGenerationTestSuite
session.read.json(testDataPath.resolve("people.json").toString)
}

test("json from dataset") {
session.read
.schema(new StructType().add("c1", StringType).add("c2", IntegerType))
.option("allowSingleQuotes", "true")
.json(session.emptyDataset(StringEncoder))
}

test("toJSON") {
complex.toJSON
}
Expand All @@ -262,6 +270,13 @@ class PlanGenerationTestSuite
session.read.csv(testDataPath.resolve("people.csv").toString)
}

test("csv from dataset") {
session.read
.schema(new StructType().add("c1", StringType).add("c2", IntegerType))
.option("header", "true")
.csv(session.emptyDataset(StringEncoder))
}

test("read parquet") {
session.read.parquet(testDataPath.resolve("users.parquet").toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ object CheckConnectJvmClientCompatibility {

// DataFrame Reader & Writer
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.csv"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.jdbc"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameWriter.jdbc"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ message Relation {
RepartitionByExpression repartition_by_expression = 27;
FrameMap frame_map = 28;
CollectMetrics collect_metrics = 29;
Parse parse = 30;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -798,3 +799,21 @@ message CollectMetrics {
// (Required) The metric sequence.
repeated Expression metrics = 3;
}

message Parse {
// (Required) Input relation to Parse. The input is expected to have single text column.
Relation input = 1;
// (Required) The expected format of the text.
ParseFormat format = 2;

// (Optional) DataType representing the schema. If not set, Spark will infer the schema.
optional DataType schema = 3;

// Options for the csv/json parser. The map key is case insensitive.
map<string, string> options = 4;
enum ParseFormat {
PARSE_FORMAT_UNSPECIFIED = 0;
PARSE_FORMAT_CSV = 1;
PARSE_FORMAT_JSON = 2;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
LogicalRDD [c1#0, c2#0], false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this makes me sad. We are we using RDDs here?

Copy link
Contributor Author

@LuciferYang LuciferYang Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val parsed = jsonDataset.rdd.mapPartitions { iter =>
val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = true)
val parser = new FailureSafeParser[String](
input => rawParser.parse(input, createParser, UTF8String.fromString),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming)

val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
val headerChecker = new CSVHeaderChecker(
actualSchema,
parsedOptions,
source = s"CSV source: $csvDataset")
headerChecker.checkHeaderColumnNames(firstLine)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
val parsed = linesWithoutHeader.mapPartitions { iter =>
val rawParser = new UnivocityParser(actualSchema, parsedOptions)
val parser = new FailureSafeParser[String](
input => rawParser.parse(input),
parsedOptions.parseMode,
schema,
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming)

private[sql] def internalCreateDataFrame(
catalystRows: RDD[InternalRow],
schema: StructType,
isStreaming: Boolean = false): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(
schema.toAttributes,
catalystRows,
isStreaming = isStreaming)(self)
Dataset.ofRows(self, logicalPlan)
}

Copy link
Contributor Author

@LuciferYang LuciferYang Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the server side, the input csvDataset and jsonDataset are still LocalRelation, and the above code path(sparkSession.internalCreateDataFrame) is converted them to LogicalRDD .

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
LogicalRDD [c1#0, c2#0], false
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"common": {
"planId": "1"
},
"parse": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "{\"type\":\"struct\",\"fields\":[{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}"
}
},
"format": "PARSE_FORMAT_CSV",
"schema": {
"struct": {
"fields": [{
"name": "c1",
"dataType": {
"string": {
}
},
"nullable": true
}, {
"name": "c2",
"dataType": {
"integer": {
}
},
"nullable": true
}]
}
},
"options": {
"header": "true"
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"common": {
"planId": "1"
},
"parse": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "{\"type\":\"struct\",\"fields\":[{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}"
}
},
"format": "PARSE_FORMAT_JSON",
"schema": {
"struct": {
"fields": [{
"name": "c1",
"dataType": {
"string": {
}
},
"nullable": true
}, {
"name": "c2",
"dataType": {
"integer": {
}
},
"nullable": true
}]
}
},
"options": {
"allowsinglequotes": "true"
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
Expand Down Expand Up @@ -117,6 +118,7 @@ class SparkConnectPlanner(val session: SparkSession) {
transformFrameMap(rel.getFrameMap)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")

Expand Down Expand Up @@ -733,6 +735,30 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}

private def transformParse(rel: proto.Parse): LogicalPlan = {
def dataFrameReader = {
val localMap = CaseInsensitiveMap[String](rel.getOptionsMap.asScala.toMap)
val reader = session.read
if (rel.hasSchema) {
DataTypeProtoConverter.toCatalystType(rel.getSchema) match {
case s: StructType => reader.schema(s)
case other => throw InvalidPlanInput(s"Invalid schema dataType $other")
}
}
localMap.foreach { case (key, value) => reader.option(key, value) }
reader
}
def ds: Dataset[String] = Dataset(session, transformRelation(rel.getInput))(Encoders.STRING)

rel.getFormat match {
case ParseFormat.PARSE_FORMAT_CSV =>
dataFrameReader.csv(ds).queryExecution.analyzed
case ParseFormat.PARSE_FORMAT_JSON =>
dataFrameReader.json(ds).queryExecution.analyzed
case _ => throw InvalidPlanInput("Does not support " + rel.getFormat.name())
}
}

private def transformFilter(rel: proto.Filter): LogicalPlan = {
assert(rel.hasInput)
val baseRel = transformRelation(rel.getInput)
Expand Down
248 changes: 140 additions & 108 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

Loading