Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging {
* and return the trailing part after the last dollar sign in the middle
*/
@scala.annotation.tailrec
private def stripDollars(s: String): String = {
def stripDollars(s: String): String = {
val lastDollarIndex = s.lastIndexOf('$')
if (lastDollarIndex < s.length - 1) {
// The last char is not a dollar sign
Expand Down
192 changes: 192 additions & 0 deletions python/pyspark/sql/streaming/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
from typing import Tuple, Optional

from pyspark.sql.types import DateType, Row, StructType

__all__ = ["GroupStateImpl", "GroupStateTimeout"]


class GroupStateTimeout:
NoTimeout: str = "NoTimeout"
ProcessingTimeTimeout: str = "ProcessingTimeTimeout"
EventTimeTimeout: str = "EventTimeTimeout"


class GroupStateImpl:
NO_TIMESTAMP: int = -1

def __init__(
self,
# JVM Constructor
optionalValue: Row,
batchProcessingTimeMs: int,
eventTimeWatermarkMs: int,
timeoutConf: str,
hasTimedOut: bool,
watermarkPresent: bool,
# JVM internal state.
defined: bool,
updated: bool,
removed: bool,
timeoutTimestamp: int,
# Python internal state.
keyAsUnsafe: bytes,
valueSchema: StructType,
) -> None:
self._keyAsUnsafe = keyAsUnsafe
self._value = optionalValue
self._batch_processing_time_ms = batchProcessingTimeMs
self._event_time_watermark_ms = eventTimeWatermarkMs

assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]
self._timeout_conf = timeoutConf

self._has_timed_out = hasTimedOut
self._watermark_present = watermarkPresent

self._defined = defined
self._updated = updated
self._removed = removed
self._timeout_timestamp = timeoutTimestamp
# Python internal state.
self._old_timeout_timestamp = timeoutTimestamp

self._value_schema = valueSchema

@property
def exists(self) -> bool:
return self._defined

@property
def get(self) -> Tuple:
if self.exists:
return tuple(self._value)
else:
raise ValueError("State is either not defined or has already been removed")

@property
def getOption(self) -> Optional[Tuple]:
if self.exists:
return tuple(self._value)
else:
return None

@property
def hasTimedOut(self) -> bool:
return self._has_timed_out

# NOTE: this function is only available to PySpark implementation due to underlying
# implementation, do not port to Scala implementation!
@property
def oldTimeoutTimestamp(self) -> int:
return self._old_timeout_timestamp

def update(self, newValue: Tuple) -> None:
if newValue is None:
raise ValueError("'None' is not a valid state value")

self._value = Row(*newValue)
self._defined = True
self._updated = True
self._removed = False

def remove(self) -> None:
self._defined = False
self._updated = False
self._removed = True

def setTimeoutDuration(self, durationMs: int) -> None:
if isinstance(durationMs, str):
# TODO(SPARK-40437): Support string representation of durationMs.
raise ValueError("durationMs should be int but get :%s" % type(durationMs))

if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout:
raise RuntimeError(
"Cannot set timeout duration without enabling processing time timeout in "
"applyInPandasWithState"
)

if durationMs <= 0:
raise ValueError("Timeout duration must be positive")
self._timeout_timestamp = durationMs + self._batch_processing_time_ms

# TODO(SPARK-40438): Implement additionalDuration parameter.
def setTimeoutTimestamp(self, timestampMs: int) -> None:
if self._timeout_conf != GroupStateTimeout.EventTimeTimeout:
raise RuntimeError(
"Cannot set timeout duration without enabling processing time timeout in "
"applyInPandasWithState"
)

if isinstance(timestampMs, datetime.datetime):
timestampMs = DateType().toInternal(timestampMs)

if timestampMs <= 0:
raise ValueError("Timeout timestamp must be positive")

if (
self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP
and timestampMs < self._event_time_watermark_ms
):
raise ValueError(
"Timeout timestamp (%s) cannot be earlier than the "
"current watermark (%s)" % (timestampMs, self._event_time_watermark_ms)
)

self._timeout_timestamp = timestampMs

def getCurrentWatermarkMs(self) -> int:
if not self._watermark_present:
raise RuntimeError(
"Cannot get event time watermark timestamp without setting watermark before "
"applyInPandasWithState"
)
return self._event_time_watermark_ms

def getCurrentProcessingTimeMs(self) -> int:
return self._batch_processing_time_ms

def __str__(self) -> str:
if self.exists:
return "GroupState(%s)" % (self.get, )
else:
return "GroupState(<undefined>)"

def json(self) -> str:
return json.dumps(
{
# Constructor
"optionalValue": None, # Note that optionalValue will be manually serialized.
"batchProcessingTimeMs": self._batch_processing_time_ms,
"eventTimeWatermarkMs": self._event_time_watermark_ms,
"timeoutConf": self._timeout_conf,
"hasTimedOut": self._has_timed_out,
"watermarkPresent": self._watermark_present,
# JVM internal state.
"defined": self._defined,
"updated": self._updated,
"removed": self._removed,
"timeoutTimestamp": self._timeout_timestamp,
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
@Experimental
@Evolving
public class GroupStateTimeout {
// NOTE: if you're adding new type of timeout, you should also fix the places below:
// - Scala:
// org.apache.spark.sql.execution.streaming.GroupStateImpl.getGroupStateTimeoutFromString
// - Python: pyspark.sql.streaming.state.GroupStateTimeout

/**
* Timeout based on processing time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ package org.apache.spark.sql.execution.streaming
import java.sql.Date
import java.util.concurrent.TimeUnit

import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.api.java.Optional
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe.
Expand All @@ -46,6 +50,9 @@ private[sql] class GroupStateImpl[S] private(
timeoutConf: GroupStateTimeout,
override val hasTimedOut: Boolean,
watermarkPresent: Boolean) extends TestGroupState[S] {
// NOTE: if you're adding new properties here, fix:
// - `json` and `fromJson` methods of this class in Scala
// - pyspark.sql.streaming.state.GroupStateImpl in Python

private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
private var defined: Boolean = optionalValue.isDefined
Expand Down Expand Up @@ -173,6 +180,22 @@ private[sql] class GroupStateImpl[S] private(
throw QueryExecutionErrors.cannotSetTimeoutTimestampError()
}
}

private[sql] def json(): String = compact(render(new JObject(
// Constructor
"optionalValue" -> JNull :: // Note that optionalValue will be manually serialized.
"batchProcessingTimeMs" -> JLong(batchProcessingTimeMs) ::
"eventTimeWatermarkMs" -> JLong(eventTimeWatermarkMs) ::
"timeoutConf" -> JString(Utils.stripDollars(Utils.getSimpleName(timeoutConf.getClass))) ::
"hasTimedOut" -> JBool(hasTimedOut) ::
"watermarkPresent" -> JBool(watermarkPresent) ::

// Internal state
"defined" -> JBool(defined) ::
"updated" -> JBool(updated) ::
"removed" -> JBool(removed) ::
"timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil
)))
}


Expand Down Expand Up @@ -214,4 +237,35 @@ private[sql] object GroupStateImpl {
hasTimedOut = false,
watermarkPresent)
}

def groupStateTimeoutFromString(clazz: String): GroupStateTimeout = clazz match {
case "ProcessingTimeTimeout" => GroupStateTimeout.ProcessingTimeTimeout
case "EventTimeTimeout" => GroupStateTimeout.EventTimeTimeout
case "NoTimeout" => GroupStateTimeout.NoTimeout
case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz)
}

def fromJson[S](value: Option[S], json: JValue): GroupStateImpl[S] = {
implicit val formats = org.json4s.DefaultFormats

val hmap = json.extract[Map[String, Any]]

// Constructor
val newGroupState = new GroupStateImpl[S](
value,
hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(),
hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(),
groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]),
hmap("hasTimedOut").asInstanceOf[Boolean],
hmap("watermarkPresent").asInstanceOf[Boolean])

// Internal state
newGroupState.defined = hmap("defined").asInstanceOf[Boolean]
newGroupState.updated = hmap("updated").asInstanceOf[Boolean]
newGroupState.removed = hmap("removed").asInstanceOf[Boolean]
newGroupState.timeoutTimestamp =
hmap("timeoutTimestamp").asInstanceOf[Number].longValue()

newGroupState
}
}