Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.this"),

// [SPARK-28957][SQL] Copy any "spark.hive.foo=bar" spark properties into hadoop conf as "hive.foo=bar"
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations"),

// [SPARK-29348] Add observable metrics.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this")
)

// Exclude rules for 2.4.x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,10 @@ class Analyzer(
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}.copy(child = newChild)

// Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail)
// and we want to retain them inside the aggregate functions.
case m: CollectMetrics => m

// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
Expand Down Expand Up @@ -2932,6 +2936,12 @@ object CleanupAliases extends Rule[LogicalPlan] {
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)

case CollectMetrics(name, metrics, child) =>
val cleanedMetrics = metrics.map {
e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]
}
CollectMetrics(name, cleanedMetrics, child)

// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectConsumer => o
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
Expand Down Expand Up @@ -280,6 +281,41 @@ trait CheckAnalysis extends PredicateHelper {
groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)

case CollectMetrics(name, metrics, _) =>
if (name == null || name.isEmpty) {
operator.failAnalysis(s"observed metrics should be named: $operator")
}
// Check if an expression is a valid metric. A metric must meet the following criteria:
// - Is not a window function;
// - Is not nested aggregate function;
// - Is not a distinct aggregate function;
// - Has only non-deterministic functions that are nested inside an aggregate function;
// - Has only attributes that are nested inside an aggregate function.
def checkMetric(s: Expression, e: Expression, seenAggregate: Boolean = false): Unit = {
e match {
case _: WindowExpression =>
e.failAnalysis(
"window expressions are not allowed in observed metrics, but found: " + s.sql)
case _ if !e.deterministic && !seenAggregate =>
e.failAnalysis(s"non-deterministic expression ${s.sql} can only be used " +
"as an argument to an aggregate function.")
case a: AggregateExpression if seenAggregate =>
e.failAnalysis(
"nested aggregates are not allowed in observed metrics, but found: " + s.sql)
case a: AggregateExpression if a.isDistinct =>
e.failAnalysis(
"distinct aggregates are not allowed in observed metrics, but found: " + s.sql)
case _: Attribute if !seenAggregate =>
e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " +
"aggregate function.")
case _: AggregateExpression =>
e.children.foreach(checkMetric (s, _, seenAggregate = true))
case _ =>
e.children.foreach(checkMetric (s, _, seenAggregate))
}
}
metrics.foreach(m => checkMetric(m, m))

case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
Expand Down Expand Up @@ -534,6 +570,7 @@ trait CheckAnalysis extends PredicateHelper {
case _ => // Analysis successful!
}
}
checkCollectedMetrics(plan)
extendedCheckRules.foreach(_(plan))
plan.foreachUp {
case o if !o.resolved =>
Expand Down Expand Up @@ -627,6 +664,38 @@ trait CheckAnalysis extends PredicateHelper {
checkCorrelationsInSubquery(expr.plan)
}

/**
* Validate that collected metrics names are unique. The same name cannot be used for metrics
* with different results. However multiple instances of metrics with with same result and name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we eliminate the duplicated metrics (same name and result)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do that in a follow-up.

* are allowed (e.g. self-joins).
*/
private def checkCollectedMetrics(plan: LogicalPlan): Unit = {
val metricsMap = mutable.Map.empty[String, LogicalPlan]
def check(plan: LogicalPlan): Unit = plan.foreach { node =>
node match {
case metrics @ CollectMetrics(name, _, _) =>
metricsMap.get(name) match {
case Some(other) =>
// Exact duplicates are allowed. They can be the result
// of a CTE that is used multiple times or a self join.
if (!metrics.sameResult(other)) {
failAnalysis(
s"Multiple definitions of observed metrics named '$name': $plan")
}
case None =>
metricsMap.put(name, metrics)
}
case _ =>
}
node.expressions.foreach(_.foreach {
case subquery: SubqueryExpression =>
check(subquery.plan)
case _ =>
})
}
check(plan)
}

/**
* Validates to make sure the outer references appearing inside the subquery
* are allowed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ object PlanHelper {
case e: WindowExpression
if !plan.isInstanceOf[Window] => e
case e: AggregateExpression
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e
if !(plan.isInstanceOf[Aggregate] ||
plan.isInstanceOf[Window] ||
plan.isInstanceOf[CollectMetrics]) => e
case e: Generator
if !plan.isInstanceOf[Generate] => e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,25 @@ case class Deduplicate(
* This is used to whitelist such commands in the subquery-related checks.
*/
trait SupportsSubquery extends LogicalPlan

/**
* Collect arbitrary (named) metrics from a dataset. As soon as the query reaches a completion
* point (batch query completes or streaming query epoch completes) an event is emitted on the
* driver which can be observed by attaching a listener to the spark session. The metrics are named
* so we can collect metrics at multiple places in a single dataset.
*
* This node behaves like a global aggregate. All the metrics collected must be aggregate functions
* or be literals.
*/
case class CollectMetrics(
name: String,
metrics: Seq[NamedExpression],
child: LogicalPlan)
extends UnaryNode {

override lazy val resolved: Boolean = {
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
}

override def output: Seq[Attribute] = child.output
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,12 @@ object SQLConf {
confGetter.get()()
}
} else {
confGetter.get()()
val conf = existingConf.get()
if (conf != null) {
conf
} else {
confGetter.get()()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ import java.util.Locale

import scala.util.control.NonFatal

import com.fasterxml.jackson.core.{JsonGenerator, JsonParser}
import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonSerializer, SerializerProvider}
import com.fasterxml.jackson.databind.`type`.TypeFactory
import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize}
import org.json4s._
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.{JValueDeserializer, JValueSerializer}
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.Stable
Expand All @@ -40,7 +45,10 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/

@Stable
@JsonSerialize(using = classOf[DataTypeJsonSerializer])
@JsonDeserialize(using = classOf[DataTypeJsonDeserializer])
abstract class DataType extends AbstractDataType {
/**
* Enables matching against DataType for expressions:
Expand Down Expand Up @@ -475,3 +483,30 @@ object DataType {
}
}
}

/**
* Jackson serializer for [[DataType]]. Internally this delegates to json4s based serialization.
*/
class DataTypeJsonSerializer extends JsonSerializer[DataType] {
private val delegate = new JValueSerializer
override def serialize(
value: DataType,
gen: JsonGenerator,
provider: SerializerProvider): Unit = {
delegate.serialize(value.jsonValue, gen, provider)
}
}

/**
* Jackson deserializer for [[DataType]]. Internally this delegates to json4s based deserialization.
*/
class DataTypeJsonDeserializer extends JsonDeserializer[DataType] {
private val delegate = new JValueDeserializer(classOf[Any])

override def deserialize(
jsonParser: JsonParser,
deserializationContext: DeserializationContext): DataType = {
val json = delegate.deserialize(jsonParser, deserializationContext)
DataType.parseDataType(json.asInstanceOf[JValue])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan
import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning,
RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -650,4 +650,87 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assertAnalysisError(parsePlan("INSERT INTO test VALUES (1)"),
Seq("Table not found: test"))
}

test("check CollectMetrics resolved") {
val a = testRelation.output.head
val sum = Sum(a).toAggregateExpression().as("sum")
val random_sum = Sum(Rand(1L)).toAggregateExpression().as("rand_sum")
val literal = Literal(1).as("lit")

// Ok
assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation).resolved)

// Bad name
assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
assertAnalysisError(CollectMetrics("", sum :: Nil, testRelation),
"observed metrics should be named" :: Nil)

// No columns
assert(!CollectMetrics("evt", Nil, testRelation).resolved)

def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = {
assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
}

// Unwrapped attribute
checkAnalysisError(
a :: Nil,
"Attribute", "can only be used as an argument to an aggregate function")

// Unwrapped non-deterministic expression
checkAnalysisError(
Rand(10).as("rnd") :: Nil,
"non-deterministic expression", "can only be used as an argument to an aggregate function")

// Distinct aggregate
checkAnalysisError(
Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
"distinct aggregates are not allowed in observed metrics, but found")

// Nested aggregate
checkAnalysisError(
Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil,
"nested aggregates are not allowed in observed metrics, but found")

// Windowed aggregate
val windowExpr = WindowExpression(
RowNumber(),
WindowSpecDefinition(Nil, a.asc :: Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
checkAnalysisError(
windowExpr.as("rn") :: Nil,
"window expressions are not allowed in observed metrics, but found")
}

test("check CollectMetrics duplicates") {
val a = testRelation.output.head
val sum = Sum(a).toAggregateExpression().as("sum")
val count = Count(Literal(1)).toAggregateExpression().as("cnt")

// Same result - duplicate names are allowed
assertAnalysisSuccess(Union(
CollectMetrics("evt1", count :: Nil, testRelation) ::
CollectMetrics("evt1", count :: Nil, testRelation) :: Nil))

// Same children, structurally different metrics - fail
assertAnalysisError(Union(
CollectMetrics("evt1", count :: Nil, testRelation) ::
CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil),
"Multiple definitions of observed metrics" :: "evt1" :: Nil)

// Different children, same metrics - fail
val b = 'b.string
val tblB = LocalRelation(b)
assertAnalysisError(Union(
CollectMetrics("evt1", count :: Nil, testRelation) ::
CollectMetrics("evt1", count :: Nil, tblB) :: Nil),
"Multiple definitions of observed metrics" :: "evt1" :: Nil)

// Subquery different tree - fail
val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation))
val query = Project(
b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil,
CollectMetrics("evt1", count :: Nil, tblB))
assertAnalysisError(query, "Multiple definitions of observed metrics" :: "evt1" :: Nil)
}
}
Loading