Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,25 @@ import org.apache.spark.util.SparkClassUtils
class Dataset[T] private[sql] (
val sparkSession: SparkSession,
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
val encoder: Encoder[T],
@DeveloperApi carryOverObservationsOpt: Option[Map[String, Observation]] = None)
extends Serializable {
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)

private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder)

private var observationsOpt: Option[mutable.Map[String, Observation]] = {
carryOverObservationsOpt match {
case Some(observations) =>
Some(mutable.Map.newBuilder[String, Observation].addAll(observations).result())
case None => None
}
}

private def getObservationsMapOpt: Option[Map[String, Observation]] =
observationsOpt.map(_.toMap)

override def toString: String = {
try {
val builder = new mutable.StringBuilder
Expand Down Expand Up @@ -536,7 +548,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = {
val df = sparkSession.newDataset(StringEncoder) { builder =>
val df = sparkSession.newDataset(StringEncoder, getObservationsMapOpt) { builder =>
builder.getShowStringBuilder
.setInput(plan.getRoot)
.setNumRows(numRows)
Expand Down Expand Up @@ -844,7 +856,7 @@ class Dataset[T] private[sql] (
}

private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
Expand Down Expand Up @@ -898,7 +910,7 @@ class Dataset[T] private[sql] (
EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)),
None)

sparkSession.newDataset(tupleEncoder) { builder =>
sparkSession.newDataset(tupleEncoder, getObservationsMapOpt) { builder =>
val joinBuilder = builder.getJoinBuilder
joinBuilder
.setLeft(plan.getRoot)
Expand Down Expand Up @@ -1028,7 +1040,7 @@ class Dataset[T] private[sql] (
*/
@scala.annotation.varargs
def hint(name: String, parameters: Any*): Dataset[T] =
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
Expand Down Expand Up @@ -1089,10 +1101,12 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSubqueryAliasBuilder
.setInput(plan.getRoot)
.setAlias(alias)
def as(alias: String): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getSubqueryAliasBuilder
.setInput(plan.getRoot)
.setAlias(alias)
}
}

/**
Expand Down Expand Up @@ -1183,7 +1197,7 @@ class Dataset[T] private[sql] (
} else {
c1.expr
}
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(encoder, getObservationsMapOpt) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addExpressions(expr)
Expand All @@ -1205,7 +1219,7 @@ class Dataset[T] private[sql] (
* methods and typed select methods is the encoder used to build the return dataset.
*/
private def selectUntyped(encoder: AgnosticEncoder[_], cols: Seq[Column]): Dataset[_] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(encoder, getObservationsMapOpt) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.expr).asJava)
Expand Down Expand Up @@ -1271,10 +1285,10 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
def filter(condition: Column): Dataset[T] =
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}
}

/**
* Filters rows using the given SQL expression.
Expand Down Expand Up @@ -1394,7 +1408,7 @@ class Dataset[T] private[sql] (
val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr

val result = sparkSession
.newDataset(agnosticEncoder) { builder =>
.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getAggregateBuilder
.setInput(plan.getRoot)
.addAggregateExpressions(reduceExpr)
Expand Down Expand Up @@ -1787,10 +1801,12 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
def limit(n: Int): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
}
}

/**
Expand All @@ -1799,16 +1815,18 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getOffsetBuilder
.setInput(plan.getRoot)
.setOffset(n)
def offset(n: Int): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getOffsetBuilder
.setInput(plan.getRoot)
.setOffset(n)
}
}

private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
f: proto.SetOperation.Builder => Unit): Dataset[T] = {
checkSameSparkSession(right)
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
f(
builder.getSetOpBuilder
.setSetOpType(setOpType)
Expand Down Expand Up @@ -2081,7 +2099,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getSampleBuilder
.setInput(plan.getRoot)
.setWithReplacement(withReplacement)
Expand Down Expand Up @@ -2149,7 +2167,7 @@ class Dataset[T] private[sql] (
normalizedCumWeights
.sliding(2)
.map { case Array(low, high) =>
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getSampleBuilder
.setInput(sortedInput)
.setWithReplacement(false)
Expand Down Expand Up @@ -2476,8 +2494,8 @@ class Dataset[T] private[sql] (

private def buildDropDuplicates(
columns: Option[Seq[String]],
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
withinWaterMark: Boolean): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
Expand All @@ -2486,6 +2504,7 @@ class Dataset[T] private[sql] (
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
}
}

/**
Expand Down Expand Up @@ -2709,7 +2728,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = PrimitiveBooleanEncoder)
sparkSession.newDataset[T](agnosticEncoder) { builder =>
sparkSession.newDataset[T](agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
.setCondition(udf.apply(col("*")).expr)
Expand Down Expand Up @@ -2762,7 +2781,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = outputEncoder)
sparkSession.newDataset(outputEncoder) { builder =>
sparkSession.newDataset(outputEncoder, getObservationsMapOpt) { builder =>
builder.getMapPartitionsBuilder
.setInput(plan.getRoot)
.setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
Expand Down Expand Up @@ -2930,7 +2949,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def tail(n: Int): Array[T] = {
val lastN = sparkSession.newDataset(agnosticEncoder) { builder =>
val lastN = sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getTailBuilder
.setInput(plan.getRoot)
.setLimit(n)
Expand Down Expand Up @@ -3001,7 +3020,7 @@ class Dataset[T] private[sql] (
}

private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getRepartitionBuilder
.setInput(plan.getRoot)
.setNumPartitions(numPartitions)
Expand All @@ -3011,12 +3030,13 @@ class Dataset[T] private[sql] (

private def buildRepartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
partitionExprs: Seq[Column]): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}

/**
Expand Down Expand Up @@ -3329,7 +3349,7 @@ class Dataset[T] private[sql] (
* @since 3.5.0
*/
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getWithWatermarkBuilder
.setInput(plan.getRoot)
.setEventTime(eventTime)
Expand All @@ -3338,7 +3358,25 @@ class Dataset[T] private[sql] (
}

def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
throw new UnsupportedOperationException("observe is not implemented.")
sparkSession.newDataset(agnosticEncoder, getObservationsMapOpt) { builder =>
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllMetrics((expr +: exprs).map(_.expr).asJava)
}
}

def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = {
observationsOpt match {
case Some(obs) =>
if (obs.contains(observation.name)) {
throw new IllegalArgumentException(s"Observation ${observation.name} already exists.")
}
obs += (observation.name -> observation)
case None =>
observationsOpt = Some(mutable.Map(observation.name -> observation))
}
observe(observation.name, expr, exprs: _*)
}

def checkpoint(): Dataset[T] = {
Expand Down Expand Up @@ -3397,7 +3435,11 @@ class Dataset[T] private[sql] (
sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
}

def collectResult(): SparkResult[T] = sparkSession.execute(plan, agnosticEncoder)
def collectResult(): SparkResult[T] =
sparkSession.execute(plan, agnosticEncoder, getObservationsMapOpt)

def collectObservations(): Map[String, Map[String, Any]] =
collectResult().getObservedMetrics

private[sql] def withResult[E](f: SparkResult[T] => E): E = {
val result = collectResult()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.UUID

class Observation(name: String) extends ObservationBase(name) {

/**
* Create an Observation instance without providing a name. This generates a random name.
*/
def this() = this(UUID.randomUUID().toString)
}

/**
* (Scala-specific) Create instances of Observation via Scala `apply`.
* @since 3.3.0
*/
object Observation {

/**
* Observation constructor for creating an anonymous observation.
*/
def apply(): Observation = new Observation()

/**
* Observation constructor for creating a named observation.
*/
def apply(name: String): Observation = new Observation(name)

}
Loading