From 4e76ddeb9509c3bb7f94343fd434ced1adc54b44 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 18 Feb 2019 17:46:07 +0800 Subject: [PATCH] SupportsDirectWrite --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../sql/sources/v2/SupportsDirectWrite.java | 33 ++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 45 +++++++++++-------- .../execution/datasources/v2/FileTable.scala | 4 +- .../spark/sql/FileBasedDataSourceSuite.scala | 19 ++++++++ 5 files changed, 81 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsDirectWrite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b7b67ed56d2..d285e007dac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1452,7 +1452,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("orc") + .createWithDefault("") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsDirectWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsDirectWrite.java new file mode 100644 index 000000000000..ab93b1a887b8 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsDirectWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports direct write without + * validation with the table schema. + *

+ * If a {@link Table} implements this interface, the + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * with {@link WriteBuilder#buildForBatch()} implemented. + *

+ */ +@Evolving +public interface SupportsDirectWrite extends SupportsWrite {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 450828172b93..e2b0c32a6ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, Logi import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode +import org.apache.spark.sql.sources.v2.writer.{SupportsSaveMode, WriteBuilder} import org.apache.spark.sql.types.StructType /** @@ -264,6 +264,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions + checkFilesExistsOption val dsOptions = new DataSourceOptions(options.asJava) provider.getTable(dsOptions) match { + case table: SupportsDirectWrite => + writeToDataSourceV2(table.newWriteBuilder(dsOptions), table.name) + case table: SupportsBatchWrite => lazy val relation = DataSourceV2Relation.create(table, options) mode match { @@ -279,24 +282,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } case _ => - table.newWriteBuilder(dsOptions) match { - case writeBuilder: SupportsSaveMode => - val write = writeBuilder.mode(mode) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - .buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) - } - } - - case _ => - throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + writeToDataSourceV2(table.newWriteBuilder(dsOptions), table.name) } // Streaming also uses the data source V2 API. So it may be that the data source implements @@ -309,6 +295,27 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } + private def writeToDataSourceV2(writeBuilder: WriteBuilder, name: String): Unit = { + writeBuilder match { + case writeBuilder: SupportsSaveMode => + val write = writeBuilder.mode(mode) + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // It can only return null with `SupportsSaveMode`. We can clean it up after + // removing `SupportsSaveMode`. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } + } + + case _ => + throw new AnalysisException( + s"data source ${name} does not support SaveMode $mode") + } + } + private def saveToV1Source(): Unit = { // Code path for data source v1. runCommand(df.sparkSession, "save") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 21d3e5e29cfb..9e48bb827b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -22,14 +22,14 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType abstract class FileTable( sparkSession: SparkSession, options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) - extends Table with SupportsBatchRead with SupportsBatchWrite { + extends Table with SupportsBatchRead with SupportsBatchWrite with SupportsDirectWrite { lazy val fileIndex: PartitioningAwareFileIndex = { val filePaths = options.paths() val hadoopConf = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index e0c0484593d9..b6176a371f98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -469,6 +469,25 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + test("File data sources V2 supports overwriting with different schema") { + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "") { + Seq("orc", "parquet", "json").foreach { format => + withTempPath { p => + val path = p.getCanonicalPath + spark.range(10).write.format(format).save(path) + val newDF = spark.range(20).map(id => (id.toDouble, id.toString)).toDF("double", "string") + newDF.write.format(format).mode("overwrite").save(path) + + val readDF = spark.read.format(format).load(path) + val expectedSchema = StructType(Seq( + StructField("double", DoubleType, true), StructField("string", StringType, true))) + assert(readDF.schema == expectedSchema) + checkAnswer(readDF, newDF) + } + } + } + } + test("SPARK-25237 compute correct input metrics in FileScanRDD") { withTempPath { p => val path = p.getAbsolutePath