diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 08be456f090e2..6e071eb48fe47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -139,7 +139,8 @@ class RelationResolution(override val catalogManager: CatalogManager) ident, table, u.clearWritePrivileges.options, - u.isStreaming + u.isStreaming, + finalTimeTravelSpec ) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) u.getTagValue(LogicalPlan.PLAN_ID_TAG) @@ -162,7 +163,8 @@ class RelationResolution(override val catalogManager: CatalogManager) ident: Identifier, table: Option[Table], options: CaseInsensitiveStringMap, - isStreaming: Boolean): Option[LogicalPlan] = { + isStreaming: Boolean, + timeTravelSpec: Option[TimeTravelSpec]): Option[LogicalPlan] = { table.map { // To utilize this code path to execute V1 commands, e.g. INSERT, // either it must be session catalog, or tracksPartitionsInCatalog @@ -189,6 +191,7 @@ class RelationResolution(override val catalogManager: CatalogManager) case table => if (isStreaming) { + assert(timeTravelSpec.isEmpty, "time travel is not allowed in streaming") val v1Fallback = table match { case withFallback: V2TableWithV1Fallback => Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) @@ -210,7 +213,7 @@ class RelationResolution(override val catalogManager: CatalogManager) } else { SubqueryAlias( catalog.name +: ident.asMultipartIdentifier, - DataSourceV2Relation.create(table, Some(catalog), Some(ident), options) + DataSourceV2Relation.create(table, Some(catalog), Some(ident), options, timeTravelSpec) ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala index fecec238145e1..977b1624271e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala @@ -27,8 +27,13 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap sealed trait TimeTravelSpec -case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec -case class AsOfVersion(version: String) extends TimeTravelSpec +case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec { + override def toString: String = s"TIMESTAMP AS OF $timestamp" +} + +case class AsOfVersion(version: String) extends TimeTravelSpec { + override def toString: String = s"VERSION AS OF '$version'" +} object TimeTravelSpec { def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9f58a78568d0f..2d03ec6399351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -748,6 +748,13 @@ class SessionCatalog( getRawLocalOrGlobalTempView(toNameParts(name)).map(getTempViewPlan) } + /** + * Generate a [[View]] operator from the local or global temporary view stored. + */ + def getLocalOrGlobalTempView(name: Seq[String]): Option[View] = { + getRawLocalOrGlobalTempView(name).map(getTempViewPlan) + } + /** * Return the raw logical plan of a temporary local or global view for the given name. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 180a14df865b2..7e7990c317aa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation, TimeTravelSpec} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -45,7 +45,8 @@ abstract class DataSourceV2RelationBase( output: Seq[AttributeReference], catalog: Option[CatalogPlugin], identifier: Option[Identifier], - options: CaseInsensitiveStringMap) + options: CaseInsensitiveStringMap, + timeTravelSpec: Option[TimeTravelSpec] = None) extends LeafNode with MultiInstanceRelation with NamedRelation { import DataSourceV2Implicits._ @@ -65,7 +66,12 @@ abstract class DataSourceV2RelationBase( override def skipSchemaResolution: Boolean = table.supports(TableCapability.ACCEPT_ANY_SCHEMA) override def simpleString(maxFields: Int): String = { - s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" + val outputString = truncatedString(output, "[", ", ", "]", maxFields) + val nameWithTimeTravelSpec = timeTravelSpec match { + case Some(spec) => s"$name $spec" + case _ => name + } + s"RelationV2$outputString $nameWithTimeTravelSpec" } override def computeStats(): Statistics = { @@ -96,8 +102,9 @@ case class DataSourceV2Relation( override val output: Seq[AttributeReference], catalog: Option[CatalogPlugin], identifier: Option[Identifier], - options: CaseInsensitiveStringMap) - extends DataSourceV2RelationBase(table, output, catalog, identifier, options) + options: CaseInsensitiveStringMap, + timeTravelSpec: Option[TimeTravelSpec] = None) + extends DataSourceV2RelationBase(table, output, catalog, identifier, options, timeTravelSpec) with ExposesMetadataColumns { import DataSourceV2Implicits._ @@ -117,7 +124,7 @@ case class DataSourceV2Relation( def withMetadataColumns(): DataSourceV2Relation = { val newMetadata = metadataOutput.filterNot(outputSet.contains) if (newMetadata.nonEmpty) { - DataSourceV2Relation(table, output ++ newMetadata, catalog, identifier, options) + copy(output = output ++ newMetadata) } else { this } @@ -151,7 +158,12 @@ case class DataSourceV2ScanRelation( override def name: String = relation.name override def simpleString(maxFields: Int): String = { - s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" + val outputString = truncatedString(output, "[", ", ", "]", maxFields) + val nameWithTimeTravelSpec = relation.timeTravelSpec match { + case Some(spec) => s"$name $spec" + case _ => name + } + s"RelationV2$outputString $nameWithTimeTravelSpec" } override def computeStats(): Statistics = { @@ -235,17 +247,29 @@ object ExtractV2Table { def unapply(relation: DataSourceV2Relation): Option[Table] = Some(relation.table) } +object ExtractV2CatalogAndIdentifier { + def unapply(relation: DataSourceV2Relation): Option[(CatalogPlugin, Identifier)] = { + relation match { + case DataSourceV2Relation(_, _, Some(catalog), Some(identifier), _, _) => + Some((catalog, identifier)) + case _ => + None + } + } +} + object DataSourceV2Relation { def create( table: Table, catalog: Option[CatalogPlugin], identifier: Option[Identifier], - options: CaseInsensitiveStringMap): DataSourceV2Relation = { + options: CaseInsensitiveStringMap, + timeTravelSpec: Option[TimeTravelSpec] = None): DataSourceV2Relation = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ // The v2 source may return schema containing char/varchar type. We replace char/varchar // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.columns.asSchema) - DataSourceV2Relation(table, toAttributes(schema), catalog, identifier, options) + DataSourceV2Relation(table, toAttributes(schema), catalog, identifier, options, timeTravelSpec) } def create( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 5faf71551586c..d66ba5a23cc84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -725,6 +725,10 @@ abstract class InMemoryBaseTable( } } } + + def copy(): Table = { + throw new UnsupportedOperationException(s"copy is not supported for ${getClass.getName}") + } } object InMemoryBaseTable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 90e13ad5b1754..46169a9db4914 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -125,6 +125,41 @@ class InMemoryTable( new InMemoryWriterBuilderWithOverWrite(info) } + override def copy(): Table = { + val copiedTable = new InMemoryTable( + name, + columns(), + partitioning, + properties, + constraints, + distribution, + ordering, + numPartitions, + advisoryPartitionSize, + isDistributionStrictlyRequired, + numRowsPerSplit) + + dataMap.synchronized { + dataMap.foreach { case (key, splits) => + val copiedSplits = splits.map { bufferedRows => + val copiedBufferedRows = new BufferedRows(bufferedRows.key, bufferedRows.schema) + copiedBufferedRows.rows ++= bufferedRows.rows.map(_.copy()) + copiedBufferedRows + } + copiedTable.dataMap.put(key, copiedSplits) + } + } + + copiedTable.commits ++= commits.map(_.copy()) + + copiedTable.setCurrentVersion(currentVersion()) + if (validatedVersion() != null) { + copiedTable.setValidatedVersion(validatedVersion()) + } + + copiedTable + } + class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo) extends InMemoryWriterBuilder(info) with SupportsOverwrite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 5ecb74fd938a1..1da0882ec211a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -66,6 +66,19 @@ class BasicInMemoryTableCatalog extends TableCatalog { } } + def pinTable(ident: Identifier, version: String): Unit = { + Option(tables.get(ident)) match { + case Some(table: InMemoryBaseTable) => + val versionIdent = Identifier.of(ident.namespace, ident.name + version) + val versionTable = table.copy() + tables.put(versionIdent, versionTable) + case Some(table) => + throw new UnsupportedOperationException(s"Can't pin ${table.getClass.getName}") + case _ => + throw new NoSuchTableException(ident.asMultipartIdentifier) + } + } + override def loadTable(ident: Identifier, version: String): Table = { val versionIdent = Identifier.of(ident.namespace, ident.name + version) Option(tables.get(versionIdent)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 45f494f65c300..71d12bf09b079 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, import org.apache.spark.sql.connector.catalog.CatalogV2Util.v2ColumnsToStructType import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{ShowNamespacesCommand, ShowTablesCommand} +import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.connector.V1Function @@ -810,20 +811,13 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - // We first try to parse `tableName` to see if it is 2 part name. If so, then in HMS we check - // if it is a temp view and uncache the temp view from HMS, otherwise we uncache it from the - // cache manager. - // if `tableName` is not 2 part name, then we directly uncache it from the cache manager. - try { - val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - sessionCatalog.getLocalOrGlobalTempView(tableIdent).map(uncacheView).getOrElse { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), - cascade = true) - } - } catch { - case e: org.apache.spark.sql.catalyst.parser.ParseException => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), - cascade = true) + // parse the table name and check if it's a temp view (must have 1-2 name parts) + // temp views are uncached using uncacheView which respects view text semantics (SPARK-33142) + // use CommandUtils for all tables (including with 3+ part names) + val nameParts = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) + sessionCatalog.getLocalOrGlobalTempView(nameParts).map(uncacheView).getOrElse { + val relation = resolveRelation(tableName) + CommandUtils.uncacheTableOrView(sparkSession, relation, cascade = true) } } @@ -868,7 +862,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { * @since 2.0.0 */ override def refreshTable(tableName: String): Unit = { - val relation = sparkSession.table(tableName).queryExecution.analyzed + val relation = resolveRelation(tableName) relation.refresh() @@ -891,7 +885,11 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { // Note this is a no-op for the relation itself if it's not cached, but will clear all // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + CommandUtils.recacheTableOrView(sparkSession, relation) + } + + private def resolveRelation(tableName: String): LogicalPlan = { + sparkSession.table(tableName).queryExecution.analyzed } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index e49ae04e6af6e..671fcb765648d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPla import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.classic.{Dataset, SparkSession} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.CommandUtils @@ -83,6 +84,11 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { cachedData.isEmpty } + // Test-only + private[sql] def numCachedEntries: Int = { + cachedData.size + } + // Test-only def cacheQuery(query: Dataset[_]): Unit = { cacheQuery(query, tableName = None, storageLevel = MEMORY_AND_DISK) @@ -215,12 +221,23 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { uncacheByCondition(spark, _.sameResult(plan), cascade, blocking) } - def uncacheTableOrView(spark: SparkSession, name: Seq[String], cascade: Boolean): Unit = { + def uncacheTableOrView( + spark: SparkSession, + name: Seq[String], + cascade: Boolean, + blocking: Boolean = false): Unit = { uncacheByCondition( - spark, isMatchedTableOrView(_, name, spark.sessionState.conf), cascade, blocking = false) + spark, + isMatchedTableOrView(_, name, spark.sessionState.conf, includeTimeTravel = true), + cascade, + blocking) } - private def isMatchedTableOrView(plan: LogicalPlan, name: Seq[String], conf: SQLConf): Boolean = { + private def isMatchedTableOrView( + plan: LogicalPlan, + name: Seq[String], + conf: SQLConf, + includeTimeTravel: Boolean): Boolean = { def isSameName(nameInCache: Seq[String]): Boolean = { nameInCache.length == name.length && nameInCache.zip(name).forall(conf.resolver.tupled) } @@ -229,9 +246,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { case LogicalRelationWithTable(_, Some(catalogTable)) => isSameName(catalogTable.identifier.nameParts) - case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _) => - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper - isSameName(v2Ident.toQualifiedNameParts(catalog)) + case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _, timeTravelSpec) => + val nameInCache = v2Ident.toQualifiedNameParts(catalog) + isSameName(nameInCache) && (includeTimeTravel || timeTravelSpec.isEmpty) case v: View => isSameName(v.desc.identifier.nameParts) @@ -304,6 +321,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { recacheByCondition(spark, _.plan.exists(_.sameResult(normalized))) } + /** + * Re-caches all cache entries that reference the given table name. + */ + def recacheTableOrView( + spark: SparkSession, + name: Seq[String], + includeTimeTravel: Boolean = true): Unit = { + def shouldInvalidate(entry: CachedData): Boolean = { + entry.plan.exists(isMatchedTableOrView(_, name, spark.sessionState.conf, includeTimeTravel)) + } + recacheByCondition(spark, shouldInvalidate) + } + /** * Re-caches all the cache entries that satisfies the given `condition`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 9a86357ca0b76..e1ff1ae730942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{COUNT, DATABASE_NAME, ERROR, TABLE_NAME, TIME} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -35,9 +36,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{QueryExecution, RemoveShuffleFiles} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} +import org.apache.spark.sql.execution.datasources.v2.ExtractV2CatalogAndIdentifier import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.types._ @@ -466,16 +469,55 @@ object CommandUtils extends Logging { } def uncacheTableOrView(sparkSession: SparkSession, ident: ResolvedIdentifier): Unit = { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper - uncacheTableOrView(sparkSession, ident.identifier.toQualifiedNameParts(ident.catalog)) + val nameParts = ident.identifier.toQualifiedNameParts(ident.catalog) + uncacheTableOrView(sparkSession, nameParts, cascade = true) } def uncacheTableOrView(sparkSession: SparkSession, ident: TableIdentifier): Unit = { - uncacheTableOrView(sparkSession, ident.nameParts) + uncacheTableOrView(sparkSession, ident.nameParts, cascade = true) } - private def uncacheTableOrView(sparkSession: SparkSession, name: Seq[String]): Unit = { - sparkSession.sharedState.cacheManager.uncacheTableOrView(sparkSession, name, cascade = true) + // uncaches plans that reference the provided table/view by plan + // if the passed relation is a DSv2 relation without time travel, + // this method invalidates all cache entries for the given table by name (including time travel) + def uncacheTableOrView( + sparkSession: SparkSession, + relation: LogicalPlan, + cascade: Boolean): Unit = { + EliminateSubqueryAliases(relation) match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => + val nameParts = ident.toQualifiedNameParts(catalog) + uncacheTableOrView(sparkSession, nameParts, cascade) + case _ => + uncacheQuery(sparkSession, relation, cascade) + } + } + + private def uncacheTableOrView( + sparkSession: SparkSession, + name: Seq[String], + cascade: Boolean): Unit = { + sparkSession.sharedState.cacheManager.uncacheTableOrView(sparkSession, name, cascade) + } + + private def uncacheQuery( + sparkSession: SparkSession, + plan: LogicalPlan, + cascade: Boolean): Unit = { + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, plan, cascade) + } + + // recaches all plans that reference the provided table/view by plan + // if the passed relation is a DSv2 relation without time travel, + // this method recaches all cache entries for the given table by name (including time travel) + def recacheTableOrView(sparkSession: SparkSession, relation: LogicalPlan): Unit = { + EliminateSubqueryAliases(relation) match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => + val nameParts = ident.toQualifiedNameParts(catalog) + sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) + case _ => + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + } } def calculateRowCountsPerPartition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala index a28b40dc7cbf5..730f26ea7f7a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.command.{CreateViewCommand, DropTempViewCommand} +import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -127,7 +128,7 @@ case class UncacheTableExec( relation: LogicalPlan, cascade: Boolean) extends LeafV2CommandExec { override def run(): Seq[InternalRow] = { - session.sharedState.cacheManager.uncacheQuery(session, relation, cascade) + CommandUtils.uncacheTableOrView(session, relation, cascade) Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 98ea63862ac24..12b8c75adaa74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, ResolveTableConstraints, V2ExpressionBuilder} import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} +import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -59,23 +59,32 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat import DataSourceV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + private def cacheManager = session.sharedState.cacheManager + private def hadoopConf = session.sessionState.newHadoopConf() - private def refreshCache(r: DataSourceV2Relation)(): Unit = { - session.sharedState.cacheManager.recacheByPlan(session, r) + // recaches all cache entries without time travel for the given table + // after a write operation that moves the state of the table forward (e.g. append, overwrite) + // this method accounts for V2 tables loaded via TableProvider (no catalog/identifier) + private def refreshCache(r: DataSourceV2Relation)(): Unit = r match { + case ExtractV2CatalogAndIdentifier(catalog, ident) => + val nameParts = ident.toQualifiedNameParts(catalog) + cacheManager.recacheTableOrView(session, nameParts, includeTimeTravel = false) + case _ => + cacheManager.recacheByPlan(session, r) } - private def recacheTable(r: ResolvedTable)(): Unit = { - val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier)) - session.sharedState.cacheManager.recacheByPlan(session, v2Relation) + private def recacheTable(r: ResolvedTable, includeTimeTravel: Boolean)(): Unit = { + val nameParts = r.identifier.toQualifiedNameParts(r.catalog) + cacheManager.recacheTableOrView(session, nameParts, includeTimeTravel) } // Invalidates the cache associated with the given table. If the invalidated cache matches the // given table, the cache's storage level is returned. private def invalidateTableCache(r: ResolvedTable)(): Option[StorageLevel] = { val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier)) - val cache = session.sharedState.cacheManager.lookupCachedData(session, v2Relation) - session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) + val cache = cacheManager.lookupCachedData(session, v2Relation) + invalidateCache(r.catalog, r.identifier) if (cache.isDefined) { val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel Some(cacheLevel) @@ -84,9 +93,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } - private def invalidateCache(catalog: TableCatalog, table: Table, ident: Identifier): Unit = { - val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) - session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) + private def invalidateCache(catalog: TableCatalog, ident: Identifier): Unit = { + val nameParts = ident.toQualifiedNameParts(catalog) + cacheManager.uncacheTableOrView(session, nameParts, cascade = true) } private def makeQualifiedDBObjectPath(location: String): String = { @@ -216,7 +225,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } case RefreshTable(r: ResolvedTable) => - RefreshTableExec(r.catalog, r.identifier, recacheTable(r)) :: Nil + RefreshTableExec(r.catalog, r.identifier, recacheTable(r, includeTimeTravel = true)) :: Nil case c @ ReplaceTable( ResolvedIdentifier(catalog, ident), columns, parts, tableSpec: TableSpec, orCreate) => @@ -449,7 +458,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat table, parts.asResolvedPartitionSpecs, ignoreIfExists, - recacheTable(r)) :: Nil + recacheTable(r, includeTimeTravel = false)) :: Nil case DropPartitions( r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _), @@ -461,7 +470,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat parts.asResolvedPartitionSpecs, ignoreIfNotExists, purge, - recacheTable(r)) :: Nil + recacheTable(r, includeTimeTravel = false)) :: Nil case RenamePartitions( r @ ResolvedTable(_, _, table: SupportsPartitionManagement, _), from, to) => @@ -469,7 +478,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat table, Seq(from).asResolvedPartitionSpecs.head, Seq(to).asResolvedPartitionSpecs.head, - recacheTable(r)) :: Nil + recacheTable(r, includeTimeTravel = false)) :: Nil case RecoverPartitions(_: ResolvedTable) => throw QueryCompilationErrors.alterTableRecoverPartitionsNotSupportedForV2TablesError() @@ -489,13 +498,13 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case TruncateTable(r: ResolvedTable) => TruncateTableExec( r.table.asTruncatable, - recacheTable(r)) :: Nil + recacheTable(r, includeTimeTravel = false)) :: Nil case TruncatePartition(r: ResolvedTable, part) => TruncatePartitionExec( r.table.asPartitionable, Seq(part).asResolvedPartitionSpecs.head, - recacheTable(r)) :: Nil + recacheTable(r, includeTimeTravel = false)) :: Nil case ShowColumns(resolvedTable: ResolvedTable, ns, output) => ns match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index a819f836a1dd6..7ce95ced0d242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.TableSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, TableCatalog, TableInfo} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.metric.SQLMetric @@ -36,14 +36,13 @@ case class ReplaceTableExec( partitioning: Seq[Transform], tableSpec: TableSpec, orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec { + invalidateCache: (TableCatalog, Identifier) => Unit) extends LeafV2CommandExec { val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { - val table = catalog.loadTable(ident) - invalidateCache(catalog, table, ident) + invalidateCache(catalog, ident) catalog.dropTable(ident) } else if (!orCreate) { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) @@ -68,7 +67,7 @@ case class AtomicReplaceTableExec( partitioning: Seq[Transform], tableSpec: TableSpec, orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec { + invalidateCache: (TableCatalog, Identifier) => Unit) extends LeafV2CommandExec { val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) @@ -77,8 +76,7 @@ case class AtomicReplaceTableExec( override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { - val table = catalog.loadTable(identifier) - invalidateCache(catalog, table, identifier) + invalidateCache(catalog, identifier) } val staged = if (orCreate) { val tableInfo = new TableInfo.Builder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2a3a3441accc8..7a2795b729f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -153,7 +153,7 @@ case class ReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) + invalidateCache: (TableCatalog, Identifier) => Unit) extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -168,8 +168,7 @@ case class ReplaceTableAsSelectExec( // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { - val table = catalog.loadTable(ident) - invalidateCache(catalog, table, ident) + invalidateCache(catalog, ident) catalog.dropTable(ident) } else if (!orCreate) { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) @@ -205,7 +204,7 @@ case class AtomicReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) + invalidateCache: (TableCatalog, Identifier) => Unit) extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -216,8 +215,7 @@ case class AtomicReplaceTableAsSelectExec( override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { - val table = catalog.loadTable(ident) - invalidateCache(catalog, table, ident) + invalidateCache(catalog, ident) } val staged = if (orCreate) { val tableInfo = new TableInfo.Builder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 178e1aead43a5..191587888ab81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -31,14 +31,21 @@ import org.apache.spark.executor.DataReadMethod._ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.AsOfVersion import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.util.DateTimeConstants +import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.InMemoryCatalog import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan, SparkPlanInfo} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.command.CommandUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ @@ -1894,4 +1901,607 @@ class CachedTableSuite extends QueryTest with SQLTestUtils assertNotCached(sql("SELECT * FROM v")) } } + + test("cache DSv2 table with time travel") { + val t = "testcat.tbl" + val version = "v1" + withTable(t, "cached_tt") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin current version + pinTable("testcat", Identifier.of(Array(), "tbl"), version) + + // cache pinned version + sql(s"CACHE TABLE cached_tt AS SELECT * FROM $t VERSION AS OF '$version'") + assertCached(sql("SELECT * FROM cached_tt")) + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + checkAnswer(sql("SELECT * FROM cached_tt"), Seq(Row(1, "a"), Row(2, "b"))) + + // add more data to base table + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // verify lookupCachedData finds time travel cache + val timeTravelDF = sql(s"SELECT * FROM $t VERSION AS OF '$version'") + assert(cacheManager.lookupCachedData(timeTravelDF).isDefined) + + // verify base table is not cached + assertNotCached(sql(s"SELECT * FROM $t")) + assert(!spark.catalog.isCached(t)) + + // verify lookupCachedData does NOT match base table with time travel cache + val baseDF = sql(s"SELECT * FROM $t") + assert(cacheManager.lookupCachedData(baseDF).isEmpty) + } + } + + test("uncache DSv2 table by name to invalidate base and time travel plans") { + val t = "testcat.tbl" + val version = "v1" + withTable(t, "cached_tt") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin current version + pinTable("testcat", Identifier.of(Array(), "tbl"), version) + + // insert more data to base table + sql(s"INSERT INTO $t VALUES (3, 'c'), (2, 'b')") + + // cache base table + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + assert(spark.catalog.isCached(t)) + + // cache pinned version + sql(s"CACHE TABLE cached_tt AS SELECT * FROM $t VERSION AS OF '$version'") + assertCached(sql("SELECT * FROM cached_tt")) + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + + // verify lookupCachedData finds separate entries for base and time travel plans + val baseDF = sql(s"SELECT * FROM $t") + val timeTravelDF = sql(s"SELECT * FROM $t VERSION AS OF '$version'") + assert(cacheManager.lookupCachedData(baseDF).isDefined) + assert(cacheManager.lookupCachedData(timeTravelDF).isDefined) + assert(cacheManager.lookupCachedData(baseDF) != cacheManager.lookupCachedData(timeTravelDF)) + + // uncaching base table by name should affect ALL time-traveled cache entries + spark.catalog.uncacheTable(t) + assertNotCached(sql(s"SELECT * FROM $t")) + assertNotCached(sql("SELECT * FROM cached_tt")) + + // verify lookupCachedData returns None after uncaching + assert(cacheManager.lookupCachedData(baseDF).isEmpty) + assert(cacheManager.lookupCachedData(timeTravelDF).isEmpty) + } + } + + test("uncache DSv2 table with time travel by plan") { + val t = "testcat.tbl" + val version = "v1" + withTable(t) { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin current version + pinTable("testcat", Identifier.of(Array(), "tbl"), version) + + // insert more data to base table + sql(s"INSERT INTO $t VALUES (3, 'c'), (2, 'b')") + + // cache pinned version + val timeTravelDF = sql(s"SELECT * FROM $t VERSION AS OF '$version'") + timeTravelDF.cache() + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + + // verify base table is not affected + assertNotCached(sql(s"SELECT * FROM $t")) + + // verify lookupCachedData finds the cache before uncaching + assert(cacheManager.lookupCachedData(timeTravelDF).isDefined) + + // uncache pinned version by plan + cacheManager.uncacheQuery(timeTravelDF, cascade = false) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + + // verify lookupCachedData returns None after uncaching + assert(cacheManager.lookupCachedData(timeTravelDF).isEmpty) + } + } + + test("uncache DSv2 table by plan should not affect time travel") { + val t = "testcat.tbl" + val version = "v1" + withTable(t) { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin current version + pinTable("testcat", Identifier.of(Array(), "tbl"), version) + + // insert more data to base table + sql(s"INSERT INTO $t VALUES (3, 'c'), (2, 'b')") + + // cache base table + val baseDF = sql(s"SELECT * FROM $t") + baseDF.cache() + assertCached(sql(s"SELECT * FROM $t")) + + // cache pinned version + val timeTravelDF = sql(s"SELECT * FROM $t VERSION AS OF '$version'") + timeTravelDF.cache() + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + + // uncache base table by plan + baseDF.unpersist(blocking = true) + + // verify only base table plan is affected + assertNotCached(baseDF) + assertNotCached(sql(s"SELECT * FROM $t")) + assertCached(timeTravelDF) + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version'")) + } + } + + test("look up DSv2 table by relation with multiple time travel versions") { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + // cache base + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + checkAnswer( + sql(s"SELECT * FROM $t"), + Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"), Row(4, "d"), Row(5, "e"), Row(6, "f"))) + + // cache v1 + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + checkAnswer( + sql(s"SELECT * FROM $t VERSION AS OF '$version1'"), + Seq(Row(1, "a"), Row(2, "b"))) + + // cache v2 + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + checkAnswer( + sql(s"SELECT * FROM $t VERSION AS OF '$version2'"), + Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"), Row(4, "d"))) + + // verify lookupCachedData finds distinct entries for each version + val baseDF = sql(s"SELECT * FROM $t") + assert(cacheManager.lookupCachedData(baseDF).isDefined) + val v1DF = sql(s"SELECT * FROM $t VERSION AS OF '$version1'") + assert(cacheManager.lookupCachedData(v1DF).isDefined) + val v2DF = sql(s"SELECT * FROM $t VERSION AS OF '$version2'") + assert(cacheManager.lookupCachedData(v2DF).isDefined) + + // look up cache using DataSourceV2Relation directly + val cat = catalog("testcat").asTableCatalog + val baseTable = cat.loadTable(ident) + val baseRelation = DataSourceV2Relation.create(baseTable, Some(cat), Some(ident)) + assert(cacheManager.lookupCachedData(spark, baseRelation).isDefined) + val v1Table = cat.loadTable(ident, version1) + val v1Relation = baseRelation.copy( + table = v1Table, + timeTravelSpec = Some(AsOfVersion(version1))) + assert(cacheManager.lookupCachedData(spark, v1Relation).isDefined) + val v2Table = cat.loadTable(ident, version2) + val v2Relation = baseRelation.copy( + table = v2Table, + timeTravelSpec = Some(AsOfVersion(version2))) + assert(cacheManager.lookupCachedData(spark, v2Relation).isDefined) + + // uncache using DataSourceV2Relation directly + CommandUtils.uncacheTableOrView(spark, baseRelation, cascade = true) + assert(cacheManager.lookupCachedData(spark, baseRelation).isEmpty) + assert(cacheManager.lookupCachedData(spark, v1Relation).isEmpty) + assert(cacheManager.lookupCachedData(spark, v2Relation).isEmpty) + + // verify queries don't use cache + assertNotCached(sql(s"SELECT * FROM $t")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + } + } + + test("uncache DSv2 table using SQL") { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + // cache base and both pinned versions + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + // uncache all plans using SQL + sql(s"UNCACHE TABLE $t") + + // verify queries don't use cache + assertNotCached(sql(s"SELECT * FROM $t")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + } + } + + test("uncache DSv2 table using uncacheTableOrView") { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + // cache base and both pinned versions + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + // uncache all plans using uncacheTableOrView + cacheManager.uncacheTableOrView(spark, Seq("testcat", "tbl"), cascade = true) + + // verify queries don't use cache + assertNotCached(sql(s"SELECT * FROM $t")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + assertNotCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + } + } + + test("REFRESH TABLE refreshes time travel plans correctly") { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + // cache base and both versions + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + // must have 3 cache entries + assert(cacheManager.numCachedEntries == 3) + checkCacheLoading(sql(s"SELECT * FROM $t"), isLoaded = true) + checkCacheLoading(sql(s"SELECT * FROM cached_tt1"), isLoaded = true) + checkCacheLoading(sql(s"SELECT * FROM cached_tt2"), isLoaded = true) + + // refresh table by name to invalidate all plans + sql(s"REFRESH TABLE $t") + + // all entries must be refreshed + checkCacheLoading(sql(s"SELECT * FROM $t"), isLoaded = false) + checkCacheLoading(sql(s"SELECT * FROM cached_tt1"), isLoaded = false) + checkCacheLoading(sql(s"SELECT * FROM cached_tt2"), isLoaded = false) + } + } + + test("recacheByTableName with time travel plans") { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // cache base and both versions + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + // must have 3 cache entries + assert(cacheManager.numCachedEntries == 3) + checkCacheLoading(sql(s"SELECT * FROM $t"), isLoaded = true) + checkCacheLoading(sql(s"SELECT * FROM cached_tt1"), isLoaded = true) + checkCacheLoading(sql(s"SELECT * FROM cached_tt2"), isLoaded = true) + + // refresh base, keep pinned versions cached + cacheManager.recacheTableOrView(spark, Seq("testcat", "tbl"), includeTimeTravel = false) + + // time travel entries must NOT be refreshed + checkCacheLoading(sql(s"SELECT * FROM $t"), isLoaded = false) + checkCacheLoading(sql(s"SELECT * FROM cached_tt1"), isLoaded = true) + checkCacheLoading(sql(s"SELECT * FROM cached_tt2"), isLoaded = true) + + // refresh all + cacheManager.recacheTableOrView(spark, Seq("testcat", "tbl")) + + // all plans must be refreshed + checkCacheLoading(sql(s"SELECT * FROM $t"), isLoaded = false) + checkCacheLoading(sql(s"SELECT * FROM cached_tt1"), isLoaded = false) + checkCacheLoading(sql(s"SELECT * FROM cached_tt2"), isLoaded = false) + } + } + + private def checkCacheLoading(ds: Dataset[_], isLoaded: Boolean): Unit = { + cacheManager.lookupCachedData(ds) match { + case Some(entry) => + assert(entry.cachedRepresentation.cacheBuilder.isCachedColumnBuffersLoaded == isLoaded) + case _ => + fail("dataset is not cached") + } + } + + test("RENAME TABLE manages cache with time travel plans correctly") { + val t = "testcat.tbl" + val tRenamed = "testcat.tbl_renamed" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, tRenamed, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (4, 'e'), (5, 'f')") + + // cache base and both versions + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + // must have 3 cache entries + assert(cacheManager.numCachedEntries == 3) + + // rename base table + sql(s"ALTER TABLE $t RENAME TO tbl_renamed") + + // assert cache was cleared and renamed table (current version) was cached again + assert(cacheManager.numCachedEntries == 1) + assertCached(sql(s"SELECT * FROM $tRenamed")) + } + } + + test("DROP TABLE invalidates time travel cache entries") { + checkTimeTravelPlanInvalidation { t => + sql(s"DROP TABLE $t") + } + } + + test("REPLACE TABLE invalidates time travel cache entries") { + checkTimeTravelPlanInvalidation { t => + sql(s"REPLACE TABLE $t (a INT COMMENT 'test', b STRING NOT NULL)") + } + } + + private def checkTimeTravelPlanInvalidation(action: String => Unit): Unit = { + val t = "testcat.tbl" + val ident = Identifier.of(Array(), "tbl") + val version1 = "v1" + val version2 = "v2" + withTable(t, "cached_tt1", "cached_tt2") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // pin v1 + pinTable("testcat", ident, version1) + + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // pin v2 + pinTable("testcat", ident, version2) + + sql(s"INSERT INTO $t VALUES (5, 'e'), (6, 'f')") + + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql(s"CACHE TABLE cached_tt1 AS SELECT * FROM $t VERSION AS OF '$version1'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version1'")) + sql(s"CACHE TABLE cached_tt2 AS SELECT * FROM $t VERSION AS OF '$version2'") + assertCached(sql(s"SELECT * FROM $t VERSION AS OF '$version2'")) + + assert(cacheManager.numCachedEntries == 3) + + action(t) + + assert(cacheManager.isEmpty) + } + } + + test("recache views with logical plans on top of DSv2 tables on changes to table") { + val t = "testcat.tbl" + withTable(t, "v") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // create and cache view + spark.table(t).select("id").createOrReplaceTempView("v") + sql("SELECT * FROM v").cache() + + // verify view is cached + assertCached(sql("SELECT * FROM v")) + checkAnswer(sql("SELECT * FROM v"), Seq(Row(1), Row(2))) + + // insert data into base table + sql(s"INSERT INTO $t VALUES (3, 'c'), (4, 'd')") + + // verify cache was refreshed and will pick up new data + checkCacheLoading(sql(s"SELECT * FROM v"), isLoaded = false) + + // verify view is recached correctly + assertCached(sql("SELECT * FROM v")) + checkAnswer( + sql("SELECT * FROM v"), + Seq(Row(1), Row(2), Row(3), Row(4))) + } + } + + test("uncache DSv2 table using SQL triggers uncaching of views with logical plans") { + val t = "testcat.tbl" + withTable(t, "v") { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + // cache table + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + checkAnswer(sql(s"SELECT * FROM $t"), Seq(Row(1, "a"), Row(2, "b"))) + + // create and cache view + spark.table(t).select("id").createOrReplaceTempView("v") + sql("SELECT * FROM v").cache() + assertCached(sql("SELECT * FROM v")) + checkAnswer(sql("SELECT * FROM v"), Seq(Row(1), Row(2))) + + // uncache table must invalidate view cache (cascading) + sql(s"UNCACHE TABLE $t") + + // verify view is not cached anymore + assertNotCached(sql("SELECT * FROM v")) + } + } + + test("uncache DSv2 table must qualify table names") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id int, data string) USING foo") + sql(s"INSERT INTO $t VALUES (1, 'a'), (2, 'b')") + + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql("USE testcat") + sql("USE NAMESPACE ns1.ns2") + sql("UNCACHE TABLE tbl") + assertNotCached(sql(s"SELECT * FROM $t")) + + sql(s"CACHE TABLE $t") + assertCached(sql(s"SELECT * FROM $t")) + sql("USE testcat.ns1.ns2") + spark.catalog.uncacheTable("tbl") + assertNotCached(sql(s"SELECT * FROM $t")) + } + } + + test("uncache persistent table via catalog API") { + withTable("tbl1") { + sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") + sql("INSERT INTO tbl1 VALUES ('Alice', 30), ('Bob', 25)") + + // cache the table + spark.catalog.cacheTable("tbl1") + assert(spark.catalog.isCached("tbl1")) + assertCached(spark.table("tbl1")) + + // uncache the table using catalog API + spark.catalog.uncacheTable("tbl1") + + // verify it's actually uncached + assert(!spark.catalog.isCached("tbl1")) + assertNotCached(spark.table("tbl1")) + } + } + + test("uncache non-existent table") { + checkError( + exception = intercept[AnalysisException] { spark.catalog.uncacheTable("non_existent") }, + condition = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`non_existent`")) + + checkError( + exception = intercept[AnalysisException] { sql("UNCACHE TABLE non_existent") }, + condition = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`non_existent`"), + context = ExpectedContext("non_existent", 14, 25)) + } + + private def cacheManager = spark.sharedState.cacheManager + + private def pinTable( + catalogName: String, + ident: Identifier, + version: String): Unit = { + catalog(catalogName) match { + case inMemory: BasicInMemoryTableCatalog => inMemory.pinTable(ident, version) + case _ => fail(s"can't pin $ident in $catalogName") + } + } + + private def catalog(name: String): CatalogPlugin = { + spark.sessionState.catalogManager.catalog(name) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 3adefa6b4d535..21538ec8e44ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -24,6 +24,7 @@ import java.util.Locale import scala.concurrent.duration.MICROSECONDS import scala.jdk.CollectionConverters._ +import scala.util.matching.Regex import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql._ @@ -84,12 +85,43 @@ abstract class DataSourceV2SQLSuite val t = "testcat.tbl" withTable(t) { spark.sql(s"CREATE TABLE $t (id int, data string)") - val explain = spark.sql(s"EXPLAIN EXTENDED SELECT * FROM $t").head().getString(0) - val relationPattern = raw".*RelationV2\[[^\]]*]\s+$t\s*$$".r - val relations = explain.split("\n").filter(_.contains("RelationV2")) - assert(relations.nonEmpty && relations.forall(line => relationPattern.matches(line.trim))) + checkExplain( + query = s"SELECT * FROM $t", + relationPattern = raw".*RelationV2\[[^]]*]\s$t$$".r) } } + + test("EXPLAIN with time travel (version)") { + val t = "testcat.tbl" + val version = "snapshot1" + val tWithVersion = t + version + withTable(tWithVersion) { + spark.sql(s"CREATE TABLE $tWithVersion (id int, data string)") + val tableWithVersionPattern = raw"$t\sVERSION\sAS\sOF\s'$version'" + val relationPattern = raw".*RelationV2\[[^]]*]\s$tableWithVersionPattern$$".r + checkExplain(s"SELECT * FROM $t VERSION AS OF '$version'", relationPattern) + } + } + + test("EXPLAIN with time travel (timestamp)") { + val t = "testcat.tbl" + val ts = DateTimeUtils.stringToTimestampAnsi( + UTF8String.fromString("2019-01-29 00:37:58"), + DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) + val tWithTs = t + ts + withTable(tWithTs) { + spark.sql(s"CREATE TABLE $tWithTs (id int, data string)") + val tableWithTsPattern = raw"$t\s+TIMESTAMP\s+AS\s+OF\s+$ts" + val relationPattern = raw".*RelationV2\[[^]]*]\s$tableWithTsPattern$$".r + checkExplain(s"SELECT * FROM $t TIMESTAMP AS OF '2019-01-29 00:37:58'", relationPattern) + } + } + + private def checkExplain(query: String, relationPattern: Regex): Unit = { + val explain = spark.sql(s"EXPLAIN EXTENDED $query").head().getString(0) + val relations = explain.split("\n").filter(_.contains("RelationV2")) + assert(relations.nonEmpty && relations.forall(line => relationPattern.matches(line.trim))) + } } class DataSourceV2SQLSuiteV1Filter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index ecc293a5acc2a..9ea8b9130ba8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -27,7 +27,7 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedFieldPosition, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedFieldPosition, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, AsOfVersion, EmptyFunctionRegistry, NoSuchTableException, RelationResolution, ResolvedFieldName, ResolvedFieldPosition, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, TimeTravelSpec, UnresolvedAttribute, UnresolvedFieldPosition, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke @@ -163,6 +163,15 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[String]())).thenAnswer((invocation: InvocationOnMock) => { + val ident = invocation.getArguments()(0).asInstanceOf[Identifier] + val version = invocation.getArguments()(1).asInstanceOf[String] + (ident.name, version) match { + case ("tab", "v1") => table + case ("tab", _) => throw new RuntimeException("Unknown version: " + version) + case _ => throw new NoSuchTableException(Seq(ident.name)) + } + }) when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) .thenCallRealMethod() when(newCatalog.name()).thenReturn("testcat") @@ -3182,6 +3191,98 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { assert(cmdAnalyzed.children.isEmpty) } + test("relation resolution - cache key behavior with time travel") { + AnalysisContext.withNewAnalysisContext { + val ctx = AnalysisContext.get + assert(ctx.relationCache.isEmpty) + + // create two unresolved relations without time travel + val unresolved1 = UnresolvedRelation(Seq("testcat", "tab")) + val unresolved2 = UnresolvedRelation(Seq("testcat", "tab")) + + // resolve both relations + val resolved1 = resolve(unresolved1) + val resolved2 = resolve(unresolved2) + + // relations without time travel should have None for timeTravelSpec + assert(resolved1.timeTravelSpec.isEmpty) + assert(resolved2.timeTravelSpec.isEmpty) + + // after first resolution, cache should have 1 entry (without time travel) + assert(ctx.relationCache.size == 1) + assert(ctx.relationCache.keys.head._2.isEmpty) + + // create unresolved relation with time travel spec + val timeTravelSpec = AsOfVersion("v1") + val unresolved3 = UnresolvedRelation(Seq("testcat", "tab")) + + // resolve with time travel + val resolved3 = resolve(unresolved3, Some(timeTravelSpec)) + + // relation with time travel should preserve the timeTravelSpec + assert(resolved3.timeTravelSpec.isDefined) + assert(resolved3.timeTravelSpec.get == timeTravelSpec) + + // after time travel resolution, cache should have 2 entries (with and without time travel) + assert(ctx.relationCache.size == 2) + } + } + + test("relation resolution - plan ID cloning on cache hit with time travel") { + AnalysisContext.withNewAnalysisContext { + val ctx = AnalysisContext.get + assert(ctx.relationCache.isEmpty) + + val timeTravelSpec = AsOfVersion("v1") + + // create first unresolved relation with a plan ID + val unresolved1 = UnresolvedRelation(Seq("testcat", "tab")) + val planId1 = 12345L + unresolved1.setTagValue(LogicalPlan.PLAN_ID_TAG, planId1) + + // resolve first relation (this should populate the cache) + val resolved1 = resolve(unresolved1, Some(timeTravelSpec), planId = Some(planId1)) + + // cache should have 1 entry now + assert(ctx.relationCache.size == 1) + + // create second unresolved relation with a different plan ID + val unresolved2 = UnresolvedRelation(Seq("testcat", "tab")) + val planId2 = 67890L + unresolved2.setTagValue(LogicalPlan.PLAN_ID_TAG, planId2) + + // resolve second relation (this should hit the cache) + val resolved2 = resolve(unresolved2, Some(timeTravelSpec), planId = Some(planId2)) + + // cache should still have 1 entry (cache hit) + assert(ctx.relationCache.size == 1) + + // verify the plans are different instances (cloned) + assert(resolved1 ne resolved2) + + // verify the underlying table, catalog, identifier, and time travel spec are equal + assert(resolved1.table == resolved2.table) + assert(resolved1.catalog == resolved2.catalog) + assert(resolved1.identifier == resolved2.identifier) + assert(resolved1.timeTravelSpec == resolved2.timeTravelSpec) + } + } + + private def resolve( + unresolvedRelation: UnresolvedRelation, + timeTravelSpec: Option[TimeTravelSpec] = None, + planId: Option[Long] = None): DataSourceV2Relation = { + val rule = new RelationResolution(catalogManagerWithDefault) + rule.resolveRelation(unresolvedRelation, timeTravelSpec) match { + case Some(p @ AsDataSourceV2Relation(relation)) => + assert(unresolvedRelation.getTagValue(LogicalPlan.PLAN_ID_TAG) == planId) + assert(p.getTagValue(LogicalPlan.PLAN_ID_TAG) == planId) + relation + case _ => + fail(s"failed to resolve $unresolvedRelation as v2 table") + } + } + // TODO: add tests for more commands. }