Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,16 @@ object SQLConf {
.intConf
.createWithDefault(2)

val STREAMING_AGGREGATION_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.aggregation.stateFormatVersion")
.internal()
.doc("State format version used by streaming aggregation operations in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running.")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you intend to change the default to the new version, then you HAVE TO add a test that ensures that existing streaming aggregation checkpoints (generated in Spark 2.3.1 for example) will not fail to recover.

Similar to this test - https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala#L883

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion. Will add the test.


val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"Streaming aggregation doesn't support group aggregate pandas UDF")
}

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

aggregate.AggUtils.planStreamingAggregation(
namedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
planLater(child))

case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ object AggUtils {
groupingExpressions: Seq[NamedExpression],
functionsWithoutDistinct: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
stateFormatVersion: Int,
child: SparkPlan): Seq[SparkPlan] = {

val groupingAttributes = groupingExpressions.map(_.toAttribute)
Expand Down Expand Up @@ -287,7 +288,8 @@ object AggUtils {
child = partialAggregate)
}

val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
partialMerged1)

val partialMerged2: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
Expand All @@ -311,6 +313,7 @@ object AggUtils {
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
stateFormatVersion = stateFormatVersion,
partialMerged2)

val finalAndCompleteAggregate: SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,21 @@ class IncrementalExecution(
val state = new Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = plan transform {
case StateStoreSaveExec(keys, None, None, None,
case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
UnaryExecNode(agg,
StateStoreRestoreExec(_, None, child))) =>
StateStoreRestoreExec(_, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
stateFormatVersion,
agg.withNewChildren(
StateStoreRestoreExec(
keys,
Some(aggStateInfo),
stateFormatVersion,
child) :: Nil))

case StreamingDeduplicateExec(keys, child, None, None) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}

/**
Expand Down Expand Up @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

/**
* Default values of relevant configurations that are used for backward compatibility.
Expand All @@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging {
private val relevantSQLConfDefaultValues = Map[String, String](
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
StreamingAggregationStateManager.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.streaming
import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -81,4 +85,221 @@ package object state {
storeCoordinator)
}
}

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ummm why is it in this package class and not in separate file?? Is there any reason it has to be state package object when not all of stateful require it, only streaming aggregation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misinterpret your suggestion before. I thought you are suggesting move to state package class. Will place it to separate file.

* Base trait for state manager purposed to be used from streaming aggregations.
*/
sealed trait StreamingAggregationStateManager extends Serializable {

/**
* Extract columns consisting key from input row, and return the new row for key columns.
*
* @param row The input row.
* @return The row instance which only contains key columns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a lot of the @param and @return in the docs are a bit superfluous as it just repeats what the main statement already says.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will just remove all the @param and @return if they are repeating.

*/
def getKey(row: InternalRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is the input typed InternalRow where everything else is UnsafeRow? seems inconsistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getKey was basically UnsafeProjection in statefulOperator so didn't necessarily require UnsafeRow. I just followed the usage to make it less restrict, but we know, in reality row will be always UnsafeRow. So OK to fix if it provides consistency.


/**
* Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD.
*
* @return An instance of StructType representing schema for the value of state.
*/
def getStateValueSchema: StructType

/**
* Get the current value of a non-null key from the target state store.
*
* @param store The target StateStore instance.
* @param key The key whose associated value is to be returned.
* @return A non-null row if the key exists in the store, otherwise null.
*/
def get(store: StateStore, key: UnsafeRow): UnsafeRow

/**
* Put a new value for a non-null key to the target state store. Note that key will be
* extracted from the input row, and the key would be same as the result of getKey(inputRow).
*
* @param store The target StateStore instance.
* @param row The input row.
*/
def put(store: StateStore, row: UnsafeRow): Unit

/**
* Commit all the updates that have been made to the target state store, and return the
* new version.
*
* @param store The target StateStore instance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

superfluous. just the main statement has all the information.

* @return The new state version.
*/
def commit(store: StateStore): Long

/**
* Remove a single non-null key from the target state store.
*
* @param store The target StateStore instance.
* @param key The key whose associated value is to be returned.
*/
def remove(store: StateStore, key: UnsafeRow): Unit

/**
* Return an iterator containing all the key-value pairs in target state store.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: some of these can be compressed to a single line doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address.

*/
def iterator(store: StateStore): Iterator[UnsafeRowPair]

/**
* Return an iterator containing all the keys in target state store.
*/
def keys(store: StateStore): Iterator[UnsafeRow]

/**
* Return an iterator containing all the values in target state store.
*/
def values(store: StateStore): Iterator[UnsafeRow]
}

object StreamingAggregationStateManager extends Logging {
val supportedVersions = Seq(1, 2)
val legacyVersion = 1

def createStateManager(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute],
stateFormatVersion: Int): StreamingAggregationStateManager = {
stateFormatVersion match {
case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes)
case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes)
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
}
}
}

abstract class StreamingAggregationStateManagerBaseImpl(
protected val keyExpressions: Seq[Attribute],
protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager {

@transient protected lazy val keyProjector =
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)

def getKey(row: InternalRow): UnsafeRow = keyProjector(row)

override def commit(store: StateStore): Long = store.commit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really does not need to be in this interface as this is not customized and is unlikely to be ever customized across implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is actually based on your review comment: always use state manager and don't directly access state store whenever possible. If your suggestion only applies to operations I can remove commit() from this interface.


override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key)

override def keys(store: StateStore): Iterator[UnsafeRow] = {
// discard and don't convert values to avoid computation
store.getRange(None, None).map(_.key)
}
}

/**
* The implementation of StreamingAggregationStateManager for state version 1.
* In state version 1, the schema of key and value in state are follow:
*
* - key: Same as key expressions.
* - value: Same as input row attributes. The schema of value contains key expressions as well.
*
* This implementation only works when input row attributes contain all the key attributes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any particular reason for saying this? Can there be a situation where this is not true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to put the sentence as a requirement / precondition for possible future usages, but if you think we don't need to put it explicitly I can remove it.

*
* @param keyExpressions The attributes of keys.
* @param inputRowAttributes The attributes of input row.
*/
class StreamingAggregationStateManagerImplV1(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

override def getStateValueSchema: StructType = inputRowAttributes.toStructType

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
store.get(key)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
store.put(getKey(row), row)
}

override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
store.iterator()
}

override def values(store: StateStore): Iterator[UnsafeRow] = {
store.iterator().map(_.value)
}
}

/**
* The implementation of StreamingAggregationStateManager for state version 2.
* In state version 2, the schema of key and value in state are follow:
*
* - key: Same as key expressions.
* - value: The diff between input row attributes and key expressions.
*
* The schema of value is changed to optimize the memory/space usage in state, via removing
* duplicated columns in key-value pair. Hence key columns are excluded from the schema of value.
*
* This implementation only works when input row attributes contain all the key attributes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above.

*
* @param keyExpressions The attributes of keys.
* @param inputRowAttributes The attributes of input row.
*/
class StreamingAggregationStateManagerImplV2(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions)
private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
private val needToProjectToRestoreValue: Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs on what this means (that, if the fields in the joined row are not in the expected order, then use an additional project)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add.

keyValueJoinedExpressions != inputRowAttributes

@transient private lazy val valueProjector =
GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes)

@transient private lazy val joiner =
GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
@transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, inputRowAttributes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is right??

def generate(expressions: InType, inputSchema: Seq[Attribute])

So the 2nd param is the input schema of the input rows of the projection. This projection applied to the joined rows, which have the schema keyValueJoinedExpressions. So I think these two should flip.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why does this not fail any test. is it because needToProjectToRestoreValue is always false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. You're right. Will fix. Btw, needToProjectToRestoreValue is always false, unless sequence of columns for key and value get mixed up.


override def getStateValueSchema: StructType = valueExpressions.toStructType

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
val savedState = store.get(key)
if (savedState == null) {
return savedState
}

val joinedRow = joiner.join(key, savedState)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cant you dedup the code with restoreOriginRow method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed spot. Will leverage restoreOriginRow.

if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
val key = keyProjector(row)
val value = valueProjector(row)
store.put(key, value)
}

override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair)))
}

override def values(store: StateStore): Iterator[UnsafeRow] = {
store.iterator().map(rowPair => restoreOriginRow(rowPair))
}

private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to restoreOriginalRow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rename.

val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}
}

}
Loading