diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 20f4080c9859..a4ca1a0a72ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1482,7 +1482,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("orc") + .createWithDefault("csv,orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 7cdfddc5e7aa..b68618755258 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,4 +1,4 @@ -org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +org.apache.spark.sql.execution.datasources.v2.csv.CSVDataSourceV2 org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat org.apache.spark.sql.execution.datasources.noop.NoopDataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8b70e336c14b..08d6dc62e354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.csv.CSVDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -238,7 +238,7 @@ case class AlterTableAddColumnsCommand( // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not // come in here. - case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat | _: OrcDataSourceV2 => + case _: JsonFileFormat | _: CSVDataSourceV2 | _: ParquetFileFormat | _: OrcDataSourceV2 => case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index d08a54cc9b1f..c8de53a17aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -118,7 +118,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + "referenced columns only include the internal corrupt record column\n" + - s"(named _corrupt_record by default). For example:\n" + + "(named _corrupt_record by default). For example:\n" + "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" + "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" + "Instead, you can cache or save the parsed results and then send the same query.\n" + @@ -163,31 +163,3 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } -private[csv] class CsvOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext, - params: CSVOptions) extends OutputWriter with Logging { - - private var univocityGenerator: Option[UnivocityGenerator] = None - - if (params.headerFlag) { - val gen = getGen() - gen.writeHeaders() - } - - private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { - val charset = Charset.forName(params.charset) - val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) - val newGen = new UnivocityGenerator(dataSchema, os, params) - univocityGenerator = Some(newGen) - newGen - } - - override def write(row: InternalRow): Unit = { - val gen = getGen() - gen.write(row) - } - - override def close(): Unit = univocityGenerator.foreach(_.close()) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala new file mode 100644 index 000000000000..3ff36bfde3cc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityGenerator} +import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} +import org.apache.spark.sql.types.StructType + +class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVOptions) extends OutputWriter with Logging { + + private var univocityGenerator: Option[UnivocityGenerator] = None + + if (params.headerFlag) { + val gen = getGen() + gen.writeHeaders() + } + + private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { + val charset = Charset.forName(params.charset) + val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) + val newGen = new UnivocityGenerator(dataSchema, os, params) + univocityGenerator = Some(newGen) + newGen + } + + override def write(row: InternalRow): Unit = { + val gen = getGen() + gen.write(row) + } + + override def close(): Unit = univocityGenerator.foreach(_.close()) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index e9c7a1bb749d..ebe7fee312e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -47,4 +47,8 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { Option(map.get("path")).toSeq } } + + protected def getTableName(paths: Seq[String]): String = { + shortName() + ":" + paths.mkString(";") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala new file mode 100644 index 000000000000..f9dfcf448a3e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.sources.v2.reader.PartitionReader + +class PartitionReaderFromIterator[InternalRow]( + iter: Iterator[InternalRow]) extends PartitionReader[InternalRow] { + private var currentValue: InternalRow = _ + + override def next(): Boolean = { + if (iter.hasNext) { + currentValue = iter.next() + true + } else { + false + } + } + + override def get(): InternalRow = currentValue + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala new file mode 100644 index 000000000000..8d9cc68417ef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +abstract class TextBasedFileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScan(sparkSession, fileIndex, readSchema, options) { + private var codecFactory: CompressionCodecFactory = _ + + override def isSplitable(path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala new file mode 100644 index 000000000000..4ecd9cdc32ac --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CSVDataSourceV2 extends FileDataSourceV2 { + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[CSVFileFormat] + + override def shortName(): String = "csv" + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + CSVTable(tableName, sparkSession, options, paths, None) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + CSVTable(tableName, sparkSession, options, paths, Some(schema)) + } +} + +object CSVDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala new file mode 100644 index 000000000000..e2d50282e9cb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.PartitionReader +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * A factory used to create CSV readers. + * + * @param sqlConf SQL configuration. + * @param broadcastedConf Broadcasted serializable Hadoop Configuration. + * @param dataSchema Schema of CSV files. + * @param partitionSchema Schema of partitions. + * @param readSchema Required schema in the batch scan. + * @param parsedOptions Options for parsing CSV files. + */ +case class CSVPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + partitionSchema: StructType, + readSchema: StructType, + parsedOptions: CSVOptions) extends FilePartitionReaderFactory { + private val columnPruning = sqlConf.csvColumnPruning + private val readDataSchema = + getReadDataSchema(readSchema, partitionSchema, sqlConf.caseSensitiveAnalysis) + + override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { + val conf = broadcastedConf.value.value + + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(readDataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + val schema = if (columnPruning) readDataSchema else dataSchema + val isStartOfFile = file.start == 0 + val headerChecker = new CSVHeaderChecker( + schema, parsedOptions, source = s"CSV file: ${file.filePath}", isStartOfFile) + val iter = CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + headerChecker, + readDataSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + partitionSchema, file.partitionValues) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala new file mode 100644 index 000000000000..35c6a668f22a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +case class CSVScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readSchema: StructType, + options: CaseInsensitiveStringMap) + extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) { + + private lazy val parsedOptions: CSVOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + override def isSplitable(path: Path): Boolean = { + CSVDataSource(parsedOptions).isSplitable && super.isSplitable(path) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // Check a field requirement for corrupt records here to throw an exception in a driver side + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) + + if (readSchema.length == 1 && + readSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { + throw new AnalysisException( + "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + + "referenced columns only include the internal corrupt record column\n" + + "(named _corrupt_record by default). For example:\n" + + "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" + + "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" + + "Instead, you can cache or save the parsed results and then send the same query.\n" + + "For example, val df = spark.read.schema(schema).csv(file).cache() and then\n" + + "df.filter($\"_corrupt_record\".isNotNull).count()." + ) + } + + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions) + } + + override def supportsDataType(dataType: DataType): Boolean = { + CSVDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "CSV" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala new file mode 100644 index 000000000000..dbb3c03ca981 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.v2.reader.Scan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class CSVScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { + + override def build(): Scan = { + CSVScan(sparkSession, fileIndex, dataSchema, readSchema, options) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala new file mode 100644 index 000000000000..bf4b8ba868f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class CSVTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def newScanBuilder(options: CaseInsensitiveStringMap): CSVScanBuilder = + CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + val parsedOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = + new CSVWriteBuilder(options, paths) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala new file mode 100644 index 000000000000..bb26d2f92d74 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.util.CompressionCodecs +import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.execution.datasources.csv.CsvOutputWriter +import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) + extends FileWriteBuilder(options, paths) { + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val conf = job.getConfiguration + val csvOptions = new CSVOptions( + options, + columnPruning = sqlConf.csvColumnPruning, + sqlConf.sessionLocalTimeZone) + csvOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, csvOptions) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + CodecStreams.getCompressionExtension(context) + } + } + } + + override def supportsDataType(dataType: DataType): Boolean = { + CSVDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "CSV" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 900c94e937ff..36e7e12e41ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -29,10 +29,6 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - private def getTableName(paths: Seq[String]): String = { - shortName() + ":" + paths.mkString(";") - } - override def getTable(options: CaseInsensitiveStringMap): Table = { val paths = getPaths(options) val tableName = getTableName(paths) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 58522f7b1376..1d30cbfbaf1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,27 +329,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath + // TODO: test file source V2 after write path is fixed. Seq(true).foreach { useV1 => val useV1List = if (useV1) { - "orc" + "csv,orc" } else { "" } - def errorMessage(format: String, isWrite: Boolean): String = { - if (isWrite && (useV1 || format != "orc")) { - "cannot save interval data type into external storage." - } else { - s"$format data source does not support calendarinterval data type." - } + def validateErrorMessage(msg: String): Unit = { + val msg1 = "cannot save interval data type into external storage." + val msg2 = "data source does not support calendarinterval data type." + assert(msg.toLowerCase(Locale.ROOT).contains(msg1) || + msg.toLowerCase(Locale.ROOT).contains(msg2)) } withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { // write path Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { + val msg = intercept[AnalysisException] { sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + validateErrorMessage(msg) } // read path @@ -359,14 +359,14 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + validateErrorMessage(msg) msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + validateErrorMessage(msg) } } } @@ -374,9 +374,10 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + // TODO: test file source V2 after write path is fixed. Seq(true).foreach { useV1 => val useV1List = if (useV1) { - "orc" + "csv,orc" } else { "" } @@ -470,22 +471,25 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-25237 compute correct input metrics in FileScanRDD") { - withTempPath { p => - val path = p.getAbsolutePath - spark.range(1000).repartition(1).write.csv(path) - val bytesReads = new mutable.ArrayBuffer[Long]() - val bytesReadListener = new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + // TODO: Test CSV V2 as well after it implements [[SupportsReportStatistics]]. + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "csv") { + withTempPath { p => + val path = p.getAbsolutePath + spark.range(1000).repartition(1).write.csv(path) + val bytesReads = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + spark.read.csv(path).limit(1).collect() + sparkContext.listenerBus.waitUntilEmpty(1000L) + assert(bytesReads.sum === 7860) + } finally { + sparkContext.removeSparkListener(bytesReadListener) } - } - sparkContext.addSparkListener(bytesReadListener) - try { - spark.read.csv(path).limit(1).collect() - sparkContext.listenerBus.waitUntilEmpty(1000L) - assert(bytesReads.sum === 7860) - } finally { - sparkContext.removeSparkListener(bytesReadListener) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d9e5d7af1967..e369596a716b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1343,15 +1343,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .collect() }.getMessage assert(msg.contains("only include the internal corrupt record column")) - intercept[org.apache.spark.sql.catalyst.errors.TreeNodeException[_]] { - spark - .read - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema) - .csv(testFile(valueMalformedFile)) - .filter($"_corrupt_record".isNotNull) - .count() - } + // workaround val df = spark .read diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 9f969473da61..2569085bec08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -428,7 +428,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) }.getMessage - assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + assert(message.toLowerCase(Locale.ROOT).contains("unable to infer schema for csv")) testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema)