Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ object FunctionRegistry {
expression[SparkPartitionID]("spark_partition_id"),
expression[InputFileName]("input_file_name"),
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
expression[CurrentDatabase]("current_database"),

// grouping sets
expression[Cube]("cube"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,15 @@ object XxHash64Function extends InterpretedHashFunction {
XXH64.hashUnsafeBytes(base, offset, len, seed)
}
}

/**
* Returns the current database of the SessionCatalog.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns the current database.",
extended = "> SELECT _FUNC_()")
private[sql] case class CurrentDatabase() extends LeafExpression with Unevaluable {
override def dataType: DataType = StringType
override def foldable: Boolean = true
override def nullable: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet

import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases}
import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
Expand All @@ -34,7 +36,9 @@ import org.apache.spark.sql.types._
* Abstract class all optimizers should inherit of, contains the standard batches (extending
* Optimizers can override this.
*/
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
abstract class Optimizer(
conf: CatalystConf,
sessionCatalog: SessionCatalog) extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
Expand All @@ -43,6 +47,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
Batch("Finish Analysis", Once,
EliminateSubqueryAliases,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
DistinctAggregationRewriter) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
Expand Down Expand Up @@ -117,7 +122,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
* To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
* specific rules go to the subclasses
*/
object DefaultOptimizer extends Optimizer
object DefaultOptimizer
extends Optimizer(
EmptyConf,
new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, EmptyConf))

/**
* Pushes operations down into a Sample.
Expand Down Expand Up @@ -1399,6 +1407,16 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}

/** Replaces the expression of CurrentDatabase with the current database name. */
case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan transformAllExpressions {
case CurrentDatabase() =>
Literal.create(sessionCatalog.getCurrentDatabase, StringType)
}
}
}

/**
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.EmptyFunctionRegistry
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -38,7 +40,10 @@ class OptimizerExtendableSuite extends SparkFunSuite {
* This class represents a dummy extended optimizer that takes the batches of the
* Optimizer and adds custom ones.
*/
class ExtendedOptimizer extends Optimizer {
class ExtendedOptimizer
extends Optimizer(
EmptyConf,
new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, EmptyConf)) {

// rules set to DummyRule, would not be executed anyways
val myBatches: Seq[Batch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer

class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer {
class SparkOptimizer(
conf: CatalystConf,
sessionCatalog: SessionCatalog,
experimentalMethods: ExperimentalMethods) extends Optimizer(conf, sessionCatalog) {
override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Logical query plan optimizer.
*/
lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods)
lazy val optimizer: Optimizer = new SparkOptimizer(conf, catalog, experimentalMethods)

/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Returns the current database of metadataHive.
*/
private[hive] case class CurrentDatabase(ctx: HiveContext)
extends LeafExpression with CodegenFallback {
override def dataType: DataType = StringType
override def foldable: Boolean = true
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
UTF8String.fromString(ctx.sessionState.catalog.getCurrentDatabase)
}
}

/**
* An instance of the Spark SQL execution engine that integrates with data stored in Hive.
* Configuration for Hive is read from hive-site.xml on the classpath.
Expand Down Expand Up @@ -133,11 +120,6 @@ class HiveContext private[hive](
@transient
protected[sql] override lazy val sessionState = new HiveSessionState(self)

// The Hive UDF current_database() is foldable, will be evaluated by optimizer,
// but the optimizer can't access the SessionState of metadataHive.
sessionState.functionRegistry.registerFunction(
"current_database", (e: Seq[Expression]) => new CurrentDatabase(self))

/**
* When true, enables an experimental feature where metastore tables that use the parquet SerDe
* are automatically converted to use the Spark SQL parquet table scan, instead of the Hive
Expand Down