Skip to content

Commit e844636

Browse files
committed
Add tests for StatefulOperatorsHelper as well
1 parent 63dfb5d commit e844636

File tree

3 files changed

+175
-23
lines changed

3 files changed

+175
-23
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.state
19+
20+
import java.util.concurrent.ConcurrentHashMap
21+
22+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
23+
24+
class MemoryStateStore extends StateStore() {
25+
import scala.collection.JavaConverters._
26+
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
27+
28+
override def iterator(): Iterator[UnsafeRowPair] = {
29+
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
30+
}
31+
32+
override def get(key: UnsafeRow): UnsafeRow = map.get(key)
33+
34+
override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
35+
map.put(key.copy(), newValue.copy())
36+
}
37+
38+
override def remove(key: UnsafeRow): Unit = {
39+
map.remove(key)
40+
}
41+
42+
override def commit(): Long = version + 1
43+
44+
override def abort(): Unit = {}
45+
46+
override def id: StateStoreId = null
47+
48+
override def version: Long = 0
49+
50+
override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
51+
52+
override def hasCommitted: Boolean = true
53+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.state
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
22+
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
23+
import org.apache.spark.sql.streaming.StreamTest
24+
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
25+
26+
class StatefulOperatorsHelperSuite extends StreamTest {
27+
import TestMaterial._
28+
29+
test("StateManager v1 - get, put, iter") {
30+
val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 1)
31+
32+
// in V1, input row is stored as value
33+
testGetPutIterOnStateManager(stateManager, OUTPUT_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, TEST_ROW)
34+
}
35+
36+
// ============================ StateManagerImplV2 ============================
37+
test("StateManager v2 - get, put, iter") {
38+
val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 2)
39+
40+
// in V2, row for values itself (excluding keys from input row) is stored as value
41+
// so that stored value doesn't have key part, but state manager V2 will provide same output
42+
// as V1 when getting row for key
43+
testGetPutIterOnStateManager(stateManager, VALUES_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW,
44+
TEST_VALUE_ROW)
45+
}
46+
47+
private def newStateManager(
48+
keysAttributes: Seq[Attribute],
49+
outputAttributes: Seq[Attribute],
50+
version: Int): StreamingAggregationStateManager = {
51+
StreamingAggregationStateManager.createStateManager(keysAttributes, outputAttributes, version)
52+
}
53+
54+
private def testGetPutIterOnStateManager(
55+
stateManager: StreamingAggregationStateManager,
56+
expectedValueExpressions: Seq[Attribute],
57+
inputRow: UnsafeRow,
58+
expectedStateKey: UnsafeRow,
59+
expectedStateValue: UnsafeRow): Unit = {
60+
61+
assert(stateManager.getValueExpressions === expectedValueExpressions)
62+
63+
val memoryStateStore = new MemoryStateStore()
64+
stateManager.put(memoryStateStore, inputRow)
65+
66+
assert(memoryStateStore.iterator().size === 1)
67+
68+
val keyRow = stateManager.extractKey(inputRow)
69+
assert(keyRow === expectedStateKey)
70+
71+
// iterate state store and verify whether expected format of key and value are stored
72+
val pair = memoryStateStore.iterator().next()
73+
assert(pair.key === keyRow)
74+
assert(pair.value === expectedStateValue)
75+
assert(stateManager.restoreOriginRow(pair) === inputRow)
76+
77+
// verify the stored value once again via get
78+
assert(memoryStateStore.get(keyRow) === expectedStateValue)
79+
80+
// state manager should return row which is same as input row regardless of format version
81+
assert(inputRow === stateManager.get(memoryStateStore, keyRow))
82+
}
83+
84+
}
85+
86+
object TestMaterial {
87+
val KEYS: Seq[String] = Seq("key1", "key2")
88+
val VALUES: Seq[String] = Seq("sum(key1)", "sum(key2)")
89+
90+
val OUTPUT_SCHEMA: StructType = StructType(
91+
KEYS.map(createIntegerField) ++ VALUES.map(createIntegerField))
92+
93+
val OUTPUT_ATTRIBUTES: Seq[Attribute] = OUTPUT_SCHEMA.toAttributes
94+
val KEYS_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p =>
95+
KEYS.contains(p.name)
96+
}
97+
val VALUES_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p =>
98+
VALUES.contains(p.name)
99+
}
100+
101+
val TEST_ROW: UnsafeRow = {
102+
val unsafeRowProjection = UnsafeProjection.create(OUTPUT_SCHEMA)
103+
val row = unsafeRowProjection(new SpecificInternalRow(OUTPUT_SCHEMA))
104+
(KEYS ++ VALUES).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) }
105+
row
106+
}
107+
108+
val TEST_KEY_ROW: UnsafeRow = {
109+
val keyProjector = GenerateUnsafeProjection.generate(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES)
110+
keyProjector(TEST_ROW)
111+
}
112+
113+
val TEST_VALUE_ROW: UnsafeRow = {
114+
val valueProjector = GenerateUnsafeProjection.generate(VALUES_ATTRIBUTES, OUTPUT_ATTRIBUTES)
115+
valueProjector(TEST_ROW)
116+
}
117+
118+
private def createIntegerField(name: String): StructField = {
119+
StructField(name, IntegerType, nullable = false)
120+
}
121+
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
1919

2020
import java.io.File
2121
import java.sql.Date
22-
import java.util.concurrent.ConcurrentHashMap
2322

2423
import org.apache.commons.io.FileUtils
2524
import org.scalatest.BeforeAndAfterAll
@@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
3433
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3534
import org.apache.spark.sql.execution.RDDScanExec
3635
import org.apache.spark.sql.execution.streaming._
37-
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
36+
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
3837
import org.apache.spark.sql.internal.SQLConf
3938
import org.apache.spark.sql.streaming.util.StreamManualClock
4039
import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite {
12861285

12871286
var failInTask = true
12881287

1289-
class MemoryStateStore extends StateStore() {
1290-
import scala.collection.JavaConverters._
1291-
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
1292-
1293-
override def iterator(): Iterator[UnsafeRowPair] = {
1294-
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
1295-
}
1296-
1297-
override def get(key: UnsafeRow): UnsafeRow = map.get(key)
1298-
override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
1299-
map.put(key.copy(), newValue.copy())
1300-
}
1301-
override def remove(key: UnsafeRow): Unit = { map.remove(key) }
1302-
override def commit(): Long = version + 1
1303-
override def abort(): Unit = { }
1304-
override def id: StateStoreId = null
1305-
override def version: Long = 0
1306-
override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
1307-
override def hasCommitted: Boolean = true
1308-
}
1309-
13101288
def assertCanGetProcessingTime(predicate: => Boolean): Unit = {
13111289
if (!predicate) throw new TestFailedException("Could not get processing time", 20)
13121290
}

0 commit comments

Comments
 (0)