Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,17 @@ class Analyzer(
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand =>
write.table match {
case UnresolvedRelation(ident, _, false) =>
lookupTempView(ident).map(EliminateSubqueryAliases(_)).map {
case r: DataSourceV2Relation => write.withNewTable(r)
case _ => throw new AnalysisException("Cannot write into temp view " +
s"${ident.quoted} as it's not a data source v2 relation.")
}.getOrElse(write)
case _ => write
}
case u @ UnresolvedTable(ident) =>
lookupTempView(ident).foreach { _ =>
u.failAnalysis(s"${ident.quoted} is a temp view not table.")
Expand Down Expand Up @@ -942,6 +953,18 @@ class Analyzer(
.map(v2Relation => i.copy(table = v2Relation))
.getOrElse(i)

// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand =>
write.table match {
case u: UnresolvedRelation if !u.isStreaming =>
lookupV2Relation(u.multipartIdentifier, u.options, false).map {
case r: DataSourceV2Relation => write.withNewTable(r)
case other => throw new IllegalStateException(
"[BUG] unexpected plan returned by `lookupV2Relation`: " + other)
}.getOrElse(write)
case _ => write
}

case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) =>
CatalogV2Util.loadRelation(u.catalog, u.tableName)
.map(rel => alter.copy(table = rel))
Expand Down Expand Up @@ -1019,6 +1042,24 @@ class Analyzer(
case other => i.copy(table = other)
}

// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand =>
write.table match {
case u: UnresolvedRelation if !u.isStreaming =>
lookupRelation(u.multipartIdentifier, u.options, false)
.map(EliminateSubqueryAliases(_))
.map {
case v: View => write.failAnalysis(
s"Writing into a view is not allowed. View: ${v.desc.identifier}.")
case u: UnresolvedCatalogRelation => write.failAnalysis(
"Cannot write into v1 table: " + u.tableMeta.identifier)
case r: DataSourceV2Relation => write.withNewTable(r)
case other => throw new IllegalStateException(
"[BUG] unexpected plan returned by `lookupRelation`: " + other)
}.getOrElse(write)
case _ => write
}

case u: UnresolvedRelation =>
lookupRelation(u.multipartIdentifier, u.options, u.isStreaming)
.map(resolveViews).getOrElse(u)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ trait CheckAnalysis extends PredicateHelper {
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) =>
failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")

// TODO (SPARK-27484): handle streaming write commands when we have them.
case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] =>
val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier
write.table.failAnalysis(s"Table or view not found: ${tblName.quoted}")

case u: UnresolvedV2Relation if isView(u.originalNameParts) =>
u.failAnalysis(
s"Invalid command: '${u.originalNameParts.quoted}' is a view not a table.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ trait V2WriteCommand extends Command {
}

def withNewQuery(newQuery: LogicalPlan): V2WriteCommand
def withNewTable(newTable: NamedRelation): V2WriteCommand
}

/**
Expand All @@ -64,6 +65,7 @@ case class AppendData(
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand {
override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery)
override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable)
}

object AppendData {
Expand Down Expand Up @@ -97,6 +99,9 @@ case class OverwriteByExpression(
override def withNewQuery(newQuery: LogicalPlan): OverwriteByExpression = {
copy(query = newQuery)
}
override def withNewTable(newTable: NamedRelation): OverwriteByExpression = {
copy(table = newTable)
}
}

object OverwriteByExpression {
Expand Down Expand Up @@ -128,6 +133,9 @@ case class OverwritePartitionsDynamic(
override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = {
copy(query = newQuery)
}
override def withNewTable(newTable: NamedRelation): OverwritePartitionsDynamic = {
copy(table = newTable)
}
}

object OverwritePartitionsDynamic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelectStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement}
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.IntegerType

/**
Expand All @@ -38,21 +37,12 @@ import org.apache.spark.sql.types.IntegerType
final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
extends CreateTableWriter[T] {

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier

private val df: DataFrame = ds.toDF()

private val sparkSession = ds.sparkSession

private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

private val (catalog, identifier) = {
val CatalogAndIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

private val logicalPlan = df.queryExecution.logical

private var provider: Option[String] = None
Expand Down Expand Up @@ -153,15 +143,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def append(): Unit = {
val append = loadTable(catalog, identifier) match {
case Some(t) =>
AppendData.byName(
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}

val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap)
runCommand("append")(append)
}

Expand All @@ -177,15 +159,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def overwrite(condition: Column): Unit = {
val overwrite = loadTable(catalog, identifier) match {
case Some(t) =>
OverwriteByExpression.byName(
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, condition.expr, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}

val overwrite = OverwriteByExpression.byName(
UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap)
runCommand("overwrite")(overwrite)
}

Expand All @@ -204,15 +179,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def overwritePartitions(): Unit = {
val dynamicOverwrite = loadTable(catalog, identifier) match {
case Some(t) =>
OverwritePartitionsDynamic.byName(
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}

val dynamicOverwrite = OverwritePartitionsDynamic.byName(
UnresolvedRelation(tableName), logicalPlan, options.toMap)
runCommand("overwritePartitions")(dynamicOverwrite)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.sources.FakeSourceOne
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType}
import org.apache.spark.sql.util.QueryExecutionListener
Expand Down Expand Up @@ -57,6 +58,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
}

after {
spark.sessionState.catalog.reset()
spark.sessionState.catalogManager.reset()
spark.sessionState.conf.clear()
}
Expand Down Expand Up @@ -118,6 +120,18 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
}

test("Append: write to a temp view of v2 relation") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
spark.table("testcat.table_name").createOrReplaceTempView("temp_view")
spark.table("source").writeTo("temp_view").append()
checkAnswer(
spark.table("testcat.table_name"),
Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
checkAnswer(
spark.table("temp_view"),
Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
}

test("Append: by name not position") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")

Expand All @@ -136,11 +150,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
}

test("Append: fail if table does not exist") {
val exc = intercept[NoSuchTableException] {
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("testcat.table_name").append()
}

assert(exc.getMessage.contains("table_name"))
assert(exc.getMessage.contains("Table or view not found: testcat.table_name"))
}

test("Append: fail if it writes to a temp view that is not v2 relation") {
spark.range(10).createOrReplaceTempView("temp_view")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("temp_view").append()
}
assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " +
"data source v2 relation"))
}

test("Append: fail if it writes to a view") {
spark.sql("CREATE VIEW v AS SELECT 1")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("v").append()
}
assert(exc.getMessage.contains("Writing into a view is not allowed"))
}

test("Append: fail if it writes to a v1 table") {
sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("table_name").append()
}
assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`"))
}

test("Overwrite: overwrite by expression: true") {
Expand Down Expand Up @@ -181,6 +220,20 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
}

test("Overwrite: write to a temp view of v2 relation") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
spark.table("source").writeTo("testcat.table_name").append()
spark.table("testcat.table_name").createOrReplaceTempView("temp_view")

spark.table("source2").writeTo("testcat.table_name").overwrite(lit(true))
checkAnswer(
spark.table("testcat.table_name"),
Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
checkAnswer(
spark.table("temp_view"),
Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
}

test("Overwrite: by name not position") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")

Expand All @@ -200,11 +253,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
}

test("Overwrite: fail if table does not exist") {
val exc = intercept[NoSuchTableException] {
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("testcat.table_name").overwrite(lit(true))
}

assert(exc.getMessage.contains("table_name"))
assert(exc.getMessage.contains("Table or view not found: testcat.table_name"))
}

test("Overwrite: fail if it writes to a temp view that is not v2 relation") {
spark.range(10).createOrReplaceTempView("temp_view")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("temp_view").overwrite(lit(true))
}
assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " +
"data source v2 relation"))
}

test("Overwrite: fail if it writes to a view") {
spark.sql("CREATE VIEW v AS SELECT 1")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("v").overwrite(lit(true))
}
assert(exc.getMessage.contains("Writing into a view is not allowed"))
}

test("Overwrite: fail if it writes to a v1 table") {
sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("table_name").overwrite(lit(true))
}
assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`"))
}

test("OverwritePartitions: overwrite conflicting partitions") {
Expand Down Expand Up @@ -245,6 +323,20 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
}

test("OverwritePartitions: write to a temp view of v2 relation") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
spark.table("source").writeTo("testcat.table_name").append()
spark.table("testcat.table_name").createOrReplaceTempView("temp_view")

spark.table("source2").writeTo("testcat.table_name").overwritePartitions()
checkAnswer(
spark.table("testcat.table_name"),
Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
checkAnswer(
spark.table("temp_view"),
Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
}

test("OverwritePartitions: by name not position") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")

Expand All @@ -264,11 +356,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
}

test("OverwritePartitions: fail if table does not exist") {
val exc = intercept[NoSuchTableException] {
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("testcat.table_name").overwritePartitions()
}

assert(exc.getMessage.contains("table_name"))
assert(exc.getMessage.contains("Table or view not found: testcat.table_name"))
}

test("OverwritePartitions: fail if it writes to a temp view that is not v2 relation") {
spark.range(10).createOrReplaceTempView("temp_view")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("temp_view").overwritePartitions()
}
assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " +
"data source v2 relation"))
}

test("OverwritePartitions: fail if it writes to a view") {
spark.sql("CREATE VIEW v AS SELECT 1")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("v").overwritePartitions()
}
assert(exc.getMessage.contains("Writing into a view is not allowed"))
}

test("OverwritePartitions: fail if it writes to a v1 table") {
sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}")
val exc = intercept[AnalysisException] {
spark.table("source").writeTo("table_name").overwritePartitions()
}
assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`"))
}

test("Create: basic behavior") {
Expand Down