-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation #21733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
4252f41
941b88d
abec57f
977428c
63dfb5d
e844636
26701a3
60c231e
b4a3807
e0ee04a
8629f59
65801a6
19888ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.streaming | |
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.SQLContext | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} | ||
| import org.apache.spark.sql.internal.SessionState | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
|
|
@@ -81,4 +85,221 @@ package object state { | |
| storeCoordinator) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
||
| * Base trait for state manager purposed to be used from streaming aggregations. | ||
| */ | ||
| sealed trait StreamingAggregationStateManager extends Serializable { | ||
|
|
||
| /** | ||
| * Extract columns consisting key from input row, and return the new row for key columns. | ||
| * | ||
| * @param row The input row. | ||
| * @return The row instance which only contains key columns. | ||
|
||
| */ | ||
| def getKey(row: InternalRow): UnsafeRow | ||
|
||
|
|
||
| /** | ||
| * Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. | ||
| * | ||
| * @return An instance of StructType representing schema for the value of state. | ||
| */ | ||
| def getStateValueSchema: StructType | ||
|
|
||
| /** | ||
| * Get the current value of a non-null key from the target state store. | ||
| * | ||
| * @param store The target StateStore instance. | ||
| * @param key The key whose associated value is to be returned. | ||
| * @return A non-null row if the key exists in the store, otherwise null. | ||
| */ | ||
| def get(store: StateStore, key: UnsafeRow): UnsafeRow | ||
|
|
||
| /** | ||
| * Put a new value for a non-null key to the target state store. Note that key will be | ||
| * extracted from the input row, and the key would be same as the result of getKey(inputRow). | ||
| * | ||
| * @param store The target StateStore instance. | ||
| * @param row The input row. | ||
| */ | ||
| def put(store: StateStore, row: UnsafeRow): Unit | ||
|
|
||
| /** | ||
| * Commit all the updates that have been made to the target state store, and return the | ||
| * new version. | ||
| * | ||
| * @param store The target StateStore instance. | ||
|
||
| * @return The new state version. | ||
| */ | ||
| def commit(store: StateStore): Long | ||
|
|
||
| /** | ||
| * Remove a single non-null key from the target state store. | ||
| * | ||
| * @param store The target StateStore instance. | ||
| * @param key The key whose associated value is to be returned. | ||
| */ | ||
| def remove(store: StateStore, key: UnsafeRow): Unit | ||
|
|
||
| /** | ||
| * Return an iterator containing all the key-value pairs in target state store. | ||
|
||
| */ | ||
| def iterator(store: StateStore): Iterator[UnsafeRowPair] | ||
|
|
||
| /** | ||
| * Return an iterator containing all the keys in target state store. | ||
| */ | ||
| def keys(store: StateStore): Iterator[UnsafeRow] | ||
|
|
||
| /** | ||
| * Return an iterator containing all the values in target state store. | ||
| */ | ||
| def values(store: StateStore): Iterator[UnsafeRow] | ||
| } | ||
|
|
||
| object StreamingAggregationStateManager extends Logging { | ||
| val supportedVersions = Seq(1, 2) | ||
| val legacyVersion = 1 | ||
|
|
||
| def createStateManager( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute], | ||
| stateFormatVersion: Int): StreamingAggregationStateManager = { | ||
| stateFormatVersion match { | ||
| case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) | ||
| case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) | ||
| case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| abstract class StreamingAggregationStateManagerBaseImpl( | ||
| protected val keyExpressions: Seq[Attribute], | ||
| protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { | ||
|
|
||
| @transient protected lazy val keyProjector = | ||
| GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) | ||
|
|
||
| def getKey(row: InternalRow): UnsafeRow = keyProjector(row) | ||
|
|
||
| override def commit(store: StateStore): Long = store.commit() | ||
|
||
|
|
||
| override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) | ||
|
|
||
| override def keys(store: StateStore): Iterator[UnsafeRow] = { | ||
| // discard and don't convert values to avoid computation | ||
| store.getRange(None, None).map(_.key) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * The implementation of StreamingAggregationStateManager for state version 1. | ||
| * In state version 1, the schema of key and value in state are follow: | ||
| * | ||
| * - key: Same as key expressions. | ||
| * - value: Same as input row attributes. The schema of value contains key expressions as well. | ||
| * | ||
| * This implementation only works when input row attributes contain all the key attributes. | ||
|
||
| * | ||
| * @param keyExpressions The attributes of keys. | ||
| * @param inputRowAttributes The attributes of input row. | ||
| */ | ||
| class StreamingAggregationStateManagerImplV1( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { | ||
|
|
||
| override def getStateValueSchema: StructType = inputRowAttributes.toStructType | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| store.get(key) | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| store.put(getKey(row), row) | ||
| } | ||
|
|
||
| override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { | ||
| store.iterator() | ||
| } | ||
|
|
||
| override def values(store: StateStore): Iterator[UnsafeRow] = { | ||
| store.iterator().map(_.value) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * The implementation of StreamingAggregationStateManager for state version 2. | ||
| * In state version 2, the schema of key and value in state are follow: | ||
| * | ||
| * - key: Same as key expressions. | ||
| * - value: The diff between input row attributes and key expressions. | ||
| * | ||
| * The schema of value is changed to optimize the memory/space usage in state, via removing | ||
| * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. | ||
| * | ||
| * This implementation only works when input row attributes contain all the key attributes. | ||
|
||
| * | ||
| * @param keyExpressions The attributes of keys. | ||
| * @param inputRowAttributes The attributes of input row. | ||
| */ | ||
| class StreamingAggregationStateManagerImplV2( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { | ||
|
|
||
| private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) | ||
| private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions | ||
| private val needToProjectToRestoreValue: Boolean = | ||
|
||
| keyValueJoinedExpressions != inputRowAttributes | ||
|
|
||
| @transient private lazy val valueProjector = | ||
| GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) | ||
|
|
||
| @transient private lazy val joiner = | ||
| GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), | ||
| StructType.fromAttributes(valueExpressions)) | ||
| @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( | ||
| keyValueJoinedExpressions, inputRowAttributes) | ||
|
||
|
|
||
| override def getStateValueSchema: StructType = valueExpressions.toStructType | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| val savedState = store.get(key) | ||
| if (savedState == null) { | ||
| return savedState | ||
| } | ||
|
|
||
| val joinedRow = joiner.join(key, savedState) | ||
|
||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| val key = keyProjector(row) | ||
| val value = valueProjector(row) | ||
| store.put(key, value) | ||
| } | ||
|
|
||
| override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { | ||
| store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair))) | ||
| } | ||
|
|
||
| override def values(store: StateStore): Iterator[UnsafeRow] = { | ||
| store.iterator().map(rowPair => restoreOriginRow(rowPair)) | ||
| } | ||
|
|
||
| private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
|
||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you intend to change the default to the new version, then you HAVE TO add a test that ensures that existing streaming aggregation checkpoints (generated in Spark 2.3.1 for example) will not fail to recover.
Similar to this test - https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala#L883
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice suggestion. Will add the test.