Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ object UnsupportedOperationChecker extends Logging {

case f: FlatMapGroupsWithState =>
if (f.hasInitialState) {
throwError("Batch [flatMap|map]GroupsWithState queries should not" +
" pass an initial state.")(f)
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
" operation on a batch DataFrame/Dataset")(f)
}

case _ =>
Expand Down Expand Up @@ -240,7 +240,9 @@ object UnsupportedOperationChecker extends Logging {

if (m.initialState.isStreaming) {
// initial state has to be a batch relation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-streaming DataFrame/Dataset is not supported as the initial state in [flatMap|map]GroupsWithState operation on a streamiing DataFrame/Dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

throwError("Initial state cannot be a streaming DataFrame/Dataset.")
throwError("Non-streaming DataFrame/Dataset is not supported as the" +
" initial state in [flatMap|map]GroupsWithState operation on a streaming" +
" DataFrame/Dataset")
}
if (m.isMapGroupsWithState) { // check mapGroupsWithState
// allowed only in update query output mode and without aggregation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ case class FlatMapGroupsWithState(
isMapGroupsWithState: Boolean = false,
timeout: GroupStateTimeout,
hasInitialState: Boolean = false,
initialStateGroupAttrs: Seq[Attribute] = Seq.empty,
initialStateDataAttrs: Seq[Attribute] = Seq.empty,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: LogicalPlan,
child: LogicalPlan) extends BinaryNode with ObjectProducer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group. To convert a Dataset ds of type Dataset[(K, S)] to a
* KeyValueGroupedDataset[K, S]
* do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
*
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
* to a `KeyValueGroupedDataset[K, S]`, use
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
*/
Expand Down Expand Up @@ -549,7 +548,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group.
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
* to a `KeyValueGroupedDataset[K, S]`, use
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.LongAccumulator;
import scala.collection.immutable.Range;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.types.DataTypes.*;
Expand Down Expand Up @@ -160,6 +162,72 @@ public void testReduce() {
Assert.assertEquals(6, reduced);
}

@Test
public void testInitialStateFlatMapGroupsWithState() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Dataset<Tuple2<Integer, Long>> initialStateDS = spark.createDataset(
Arrays.asList(new Tuple2<Integer, Long>(2, 2L)),
Encoders.tuple(Encoders.INT(), Encoders.LONG())
);

KeyValueGroupedDataset<Integer, Tuple2<Integer, Long>> kvInitStateDS = initialStateDS.groupByKey(
(MapFunction<Tuple2<Integer, Long>, Integer>) f -> f._1,
Encoders.INT()
);

KeyValueGroupedDataset<Integer, Long> kvInitStateMappedDS = kvInitStateDS.mapValues(
(MapFunction<Tuple2<Integer, Long>, Long>) f -> f._2,
Encoders.LONG()
);

KeyValueGroupedDataset<Integer, String> grouped =
ds.groupByKey((MapFunction<String, Integer>) String::length, Encoders.INT());

Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
(FlatMapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return Collections.singletonList(sb.toString()).iterator();
},
OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);

Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
flatMapped2.collectAsList();
}
);
Dataset<String> mapped2 = grouped.mapGroupsWithState(
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return sb.toString();
},
Encoders.LONG(),
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
mapped2.collectAsList();
}
);
}

@Test
public void testIllegalTestGroupStateCreations() {
// SPARK-35800: test code throws upon illegal TestGroupState create() calls
Expand Down
Loading