Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,29 @@

package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.optimizer.EvalInlineTables
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLExpr, toSQLId}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[ResolvedInlineTable]].
*/
object ResolveInlineTables extends Rule[LogicalPlan]
with CastSupport with AliasHelper with EvalHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
validateInputEvaluable(table)
convert(table)
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsWithPruning(AlwaysProcess.fn, ruleId) {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
validateInputEvaluable(table)
val resolvedTable = findCommonTypesAndCast(table)
earlyEvalIfPossible(resolvedTable)
}
}

/**
Expand Down Expand Up @@ -74,7 +75,10 @@ object ResolveInlineTables extends Rule[LogicalPlan]
table.rows.foreach { row =>
row.foreach { e =>
// Note that nondeterministic expressions are not supported since they are not foldable.
if (!e.resolved || !trimAliases(prepareForEval(e)).foldable) {
// Only exception are CURRENT_LIKE expressions, which are replaced by a literal
// In later stages.
if ((!e.resolved && !e.containsPattern(CURRENT_LIKE))
|| !trimAliases(prepareForEval(e)).foldable) {
e.failAnalysis(
errorClass = "INVALID_INLINE_TABLE.CANNOT_EVALUATE_EXPRESSION_IN_INLINE_TABLE",
messageParameters = Map("expr" -> toSQLExpr(e)))
Expand All @@ -84,14 +88,12 @@ object ResolveInlineTables extends Rule[LogicalPlan]
}

/**
* Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]]
* into a [[LocalRelation]].
*
* This function attempts to coerce inputs into consistent types.
*
* This is package visible for unit testing.
*/
private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
private[analysis] def findCommonTypesAndCast(table: UnresolvedInlineTable):
ResolvedInlineTable = {
// For each column, traverse all the values and find a common data type and nullability.
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
val inputTypes = column.map(_.dataType)
Expand All @@ -105,26 +107,30 @@ object ResolveInlineTables extends Rule[LogicalPlan]
val attributes = DataTypeUtils.toAttributes(StructType(fields))
assert(fields.size == table.names.size)

val newRows: Seq[InternalRow] = table.rows.map { row =>
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType
try {
val castedRows: Seq[Seq[Expression]] = table.rows.map { row =>
Copy link
Contributor

@beliefer beliefer Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we only need the Seq[Expression] here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a table (rows X columns)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that. You means the X columns for each row is different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it now. Thank you!

row.zipWithIndex.map {
case (e, ci) =>
val targetType = fields(ci).dataType
val castedExpr = if (DataTypeUtils.sameType(e.dataType, targetType)) {
e
} else {
cast(e, targetType)
}
prepareForEval(castedExpr).eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(
errorClass = "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION",
messageParameters = Map("sqlExpr" -> toSQLExpr(e)),
cause = ex)
}
})
castedExpr
}
}

LocalRelation(attributes, newRows)
ResolvedInlineTable(castedRows, attributes)
}

/**
* This function attempts to early evaluate rows in inline table.
* If evaluation doesn't rely on non-deterministic expressions (e.g. current_like)
* expressions will be evaluated and inlined as [[LocalRelation]]
* This is package visible for unit testing.
*/
private[analysis] def earlyEvalIfPossible(table: ResolvedInlineTable): LogicalPlan = {
val earlyEvalPossible = table.rows.flatten.forall(!_.containsPattern(CURRENT_LIKE))
if (earlyEvalPossible) EvalInlineTables(table) else table
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ case class UnresolvedInlineTable(
lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved))
}

/**
* An resolved inline table that holds all the expressions that were checked for
* the right shape and common data types.
* This is a preparation step for [[org.apache.spark.sql.catalyst.optimizer.EvalInlineTables]] which
* will produce a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]]
* for this inline table.
*
* @param output list of column attributes
* @param rows expressions for the data rows
*/
case class ResolvedInlineTable(rows: Seq[Seq[Expression]], output: Seq[Attribute])
Copy link
Contributor

@beliefer beliefer Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we simplify rows: Seq[Seq[Expression]] to exprs: Seq[Expression]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbatomic After review this PR again. I'm sorry for the above comment.

extends LeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(INLINE_TABLE_EVAL)
}

/**
* A table-valued function, e.g.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
RewriteAsOfJoin)
RewriteAsOfJoin,
EvalInlineTables
)

override def apply(plan: LogicalPlan): LogicalPlan = {
rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ package org.apache.spark.sql.catalyst.optimizer

import java.time.{Instant, LocalDateTime, ZoneId}

import org.apache.spark.sql.catalyst.CurrentUserContext
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.{CurrentUserContext, InternalRow}
import org.apache.spark.sql.catalyst.analysis.{CastSupport, ResolvedInlineTable}
import org.apache.spark.sql.catalyst.analysis.ResolveInlineTables.prepareForEval
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.trees.TreePatternBits
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -70,6 +75,34 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
}
}

/**
* Computes expressions in inline tables. This rule is supposed to be called at the very end
* of the analysis phase, given that all the expressions need to be fully resolved/replaced
* at this point.
*/
object EvalInlineTables extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(INLINE_TABLE_EVAL)) {
case table: ResolvedInlineTable =>
val newRows: Seq[InternalRow] =
table.rows.map { row => InternalRow.fromSeq(
row.map { e =>
try {
prepareForEval(e).eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(
errorClass = "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION",
messageParameters = Map("sqlExpr" -> toSQLExpr(e)),
cause = ex)
}})
}

LocalRelation(table.output, newRows)
}
}
}

/**
* Computes the current date and time to make sure we return the same result in a single query.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps" ::
"org.apache.spark.sql.catalyst.optimizer.TransposeWindow" ::
"org.apache.spark.sql.catalyst.optimizer.EvalInlineTables" ::
"org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ object TreePattern extends Enumeration {
val IF: Value = Value
val IN: Value = Value
val IN_SUBQUERY: Value = Value
val INLINE_TABLE_EVAL: Value = Value
val INSET: Value = Value
val INTERSECT: Value = Value
val INVOKE: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CurrentTimestamp, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.optimizer.{ComputeCurrentTime, EvalInlineTables}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}

Expand Down Expand Up @@ -83,21 +84,39 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
assert(ResolveInlineTables(table) == table)
}

test("convert") {
test("cast and execute") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted = ResolveInlineTables.convert(table)
val resolved = ResolveInlineTables.findCommonTypesAndCast(table)
val converted = ResolveInlineTables.earlyEvalIfPossible(resolved).asInstanceOf[LocalRelation]

assert(converted.output.map(_.dataType) == Seq(LongType))
assert(converted.data.size == 2)
assert(converted.data(0).getLong(0) == 1L)
assert(converted.data(1).getLong(0) == 2L)
}

test("cast and execute CURRENT_LIKE expressions") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(
Seq(CurrentTimestamp()), Seq(CurrentTimestamp())))
val casted = ResolveInlineTables.findCommonTypesAndCast(table)
val earlyEval = ResolveInlineTables.earlyEvalIfPossible(casted)
// Early eval should keep it in expression form.
assert(earlyEval.isInstanceOf[ResolvedInlineTable])

EvalInlineTables(ComputeCurrentTime(earlyEval)) match {
case LocalRelation(output, data, _) =>
assert(output.map(_.dataType) == Seq(TimestampType))
assert(data.size == 2)
// Make sure that both CURRENT_TIMESTAMP expressions are evaluated to the same value.
assert(data(0).getLong(0) == data(1).getLong(0))
}
}

test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val withTimeZone = ResolveTimeZone.apply(table)
val LocalRelation(output, data, _) = ResolveInlineTables.apply(withTimeZone)
val LocalRelation(output, data, _) = EvalInlineTables(ResolveInlineTables.apply(withTimeZone))
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(output.map(_.dataType) == Seq(TimestampType))
Expand All @@ -107,11 +126,11 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

test("nullability inference in convert") {
val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted1 = ResolveInlineTables.convert(table1)
val converted1 = ResolveInlineTables.findCommonTypesAndCast(table1)
assert(!converted1.schema.fields(0).nullable)

val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
val converted2 = ResolveInlineTables.convert(table2)
val converted2 = ResolveInlineTables.findCommonTypesAndCast(table2)
assert(converted2.schema.fields(0).nullable)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ Project [a#x, b#x]
-- !query
select a from values ("one", current_timestamp) as data(a, b)
-- !query analysis
Project [a#x]
+- SubqueryAlias data
+- LocalRelation [a#x, b#x]
[Analyzer test output redacted due to nondeterminism]


-- !query
Expand Down Expand Up @@ -246,3 +244,15 @@ select * from values (10 + try_divide(5, 0))
-- !query analysis
Project [col1#x]
+- LocalRelation [col1#x]


-- !query
select count(distinct ct) from values now(), now(), now() as data(ct)
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct)
-- !query analysis
[Analyzer test output redacted due to nondeterminism]
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ select * from tt7a left join tt8a using (x), tt8a tt8ax, false, false, Persisted
:- Project [a#x, b#x, c#x, d#x, e#x]
: +- SubqueryAlias v
: +- Project [col1#x AS a#x, col2#x AS b#x, col3#x AS c#x, col4#x AS d#x, col5#x AS e#x]
: +- LocalRelation [col1#x, col2#x, col3#x, col4#x, col5#x]
: +- ResolvedInlineTable [[now(), 2, 3, now(), 5]], [col1#x, col2#x, col3#x, col4#x, col5#x]
+- Project [cast(x#x as timestamp) AS x#x, y#x, z#x, x#x, z#x]
+- Project [x#x, y#x, z#x, x#x, z#x]
+- Join Inner
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,9 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-
select * from values (try_add(5, 0));
select * from values (try_divide(5, 0));
select * from values (10 + try_divide(5, 0));

-- now() should be kept as tempResolved inline expression.
select count(distinct ct) from values now(), now(), now() as data(ct);

-- current_timestamp() should be kept as tempResolved inline expression.
select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add tests mixed current_timestamp and other deterministic function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's testing the correct value using count distinct.

16 changes: 16 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,19 @@ select * from values (10 + try_divide(5, 0))
struct<col1:double>
-- !query output
NULL


-- !query
select count(distinct ct) from values now(), now(), now() as data(ct)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-- !query schema
struct<count(DISTINCT ct):bigint>
-- !query output
1


-- !query
select count(distinct ct) from values current_timestamp(), current_timestamp() as data(ct)
-- !query schema
struct<count(DISTINCT ct):bigint>
-- !query output
1
Original file line number Diff line number Diff line change
Expand Up @@ -1306,4 +1306,18 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}
}

test("Inline table with current time expression") {
withView("v1") {
sql("CREATE VIEW v1 (t1, t2) AS SELECT * FROM VALUES (now(), now())")
val r1 = sql("select t1, t2 from v1").collect()(0)
val ts1 = (r1.getTimestamp(0), r1.getTimestamp(1))
assert(ts1._1 == ts1._2)
Thread.sleep(1)
val r2 = sql("select t1, t2 from v1").collect()(0)
val ts2 = (r2.getTimestamp(0), r2.getTimestamp(1))
assert(ts2._1 == ts2._2)
assert(ts1._1.getTime < ts2._1.getTime)
}
}
}