Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flink-python/pyflink/datastream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
KeyedCoProcessFunction, AggregateFunction, WindowFunction,
ProcessWindowFunction, BroadcastProcessFunction,
KeyedBroadcastProcessFunction, AsyncFunction,
AsyncRetryPredicate, AsyncRetryStrategy)
AsyncBatchFunction, AsyncRetryPredicate, AsyncRetryStrategy)
from pyflink.datastream.slot_sharing_group import SlotSharingGroup, MemorySize
from pyflink.datastream.state_backend import (StateBackend, CustomStateBackend,
PredefinedOptions, HashMapStateBackend,
Expand Down Expand Up @@ -314,6 +314,7 @@
'BroadcastProcessFunction',
'KeyedBroadcastProcessFunction',
'AsyncFunction',
'AsyncBatchFunction',
'RuntimeContext',
'TimerService',
'CheckpointingMode',
Expand Down
165 changes: 163 additions & 2 deletions flink-python/pyflink/datastream/async_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
from pyflink.common import Time, TypeInformation
from pyflink.datastream import async_retry_strategies
from pyflink.datastream.data_stream import DataStream, _get_one_input_stream_operator
from pyflink.datastream.functions import AsyncFunctionDescriptor, AsyncFunction, AsyncRetryStrategy
from pyflink.datastream.functions import (AsyncFunctionDescriptor, AsyncFunction,
AsyncRetryStrategy, AsyncBatchFunction,
AsyncBatchFunctionDescriptor)
from pyflink.java_gateway import get_gateway
from pyflink.util.java_utils import get_j_env_configuration


class AsyncDataStream(object):
"""
A helper class to apply :class:`~AsyncFunction` to a data stream.
A helper class to apply :class:`~AsyncFunction` or :class:`~AsyncBatchFunction`
to a data stream.

.. versionchanged:: 2.1.0
Added batch async support via :func:`unordered_wait_batch`.
"""

@staticmethod
Expand Down Expand Up @@ -89,6 +95,78 @@ def unordered_wait_with_retry(
j_output_type_info,
j_python_data_stream_function_operator))

@staticmethod
def unordered_wait_batch(
data_stream: DataStream,
async_batch_function: AsyncBatchFunction,
timeout: Time,
batch_size: int,
batch_timeout: Time = None,
capacity: int = 100,
output_type: TypeInformation = None) -> 'DataStream':
"""
Adds an async batch function to the data stream with batch-oriented async execution.
The order of output stream records may be reordered.

This method is designed for AI/ML inference scenarios and other high-latency external
service calls where batching can significantly improve throughput. The Java-side
AsyncBatchWaitOperator collects input elements into batches based on the configured
batch_size and batch_timeout, then invokes the Python async batch function.

Example usage::

class MyAsyncBatchFunction(AsyncBatchFunction):

async def async_invoke_batch(self, inputs: List[Row]) -> List[int]:
# Process batch of inputs together (e.g., ML model inference)
results = await model.predict_batch(inputs)
return results

ds = AsyncDataStream.unordered_wait_batch(
ds,
MyAsyncBatchFunction(),
timeout=Time.seconds(10),
batch_size=32,
batch_timeout=Time.milliseconds(100),
output_type=Types.INT()
)

:param data_stream: The input data stream.
:param async_batch_function: The async batch function to apply.
:param timeout: The overall timeout for asynchronous operations.
:param batch_size: Maximum number of elements to batch together before triggering
the async function. Must be positive.
:param batch_timeout: Maximum time to wait before flushing a partial batch.
If None, only batch_size triggers flushing.
:param capacity: The max number of async operations that can be in-flight.
:param output_type: The output data type.
:return: The transformed DataStream.

.. versionadded:: 2.1.0

.. note:: This is a :class:`PublicEvolving` API and may change in future versions.
"""
AsyncDataStream._validate_batch(
data_stream, async_batch_function, timeout, batch_size, batch_timeout)

from pyflink.fn_execution import flink_fn_execution_pb2
j_python_data_stream_function_operator, j_output_type_info = \
_get_one_input_stream_operator(
data_stream,
AsyncBatchFunctionDescriptor(
async_batch_function,
timeout,
batch_size,
batch_timeout,
capacity,
AsyncBatchFunctionDescriptor.OutputMode.UNORDERED),
flink_fn_execution_pb2.UserDefinedDataStreamFunction.ASYNC_BATCH, # type: ignore
output_type)
return DataStream(data_stream._j_data_stream.transform(
"async batch wait operator",
j_output_type_info,
j_python_data_stream_function_operator))

@staticmethod
def ordered_wait(
data_stream: DataStream,
Expand Down Expand Up @@ -148,6 +226,61 @@ def ordered_wait_with_retry(
j_output_type_info,
j_python_data_stream_function_operator))

@staticmethod
def ordered_wait_batch(
data_stream: DataStream,
async_batch_function: AsyncBatchFunction,
timeout: Time,
batch_size: int,
batch_timeout: Time = None,
capacity: int = 100,
output_type: TypeInformation = None) -> 'DataStream':
"""
Adds an async batch function to the data stream with batch-oriented async execution.
The order of output stream records is guaranteed to be the same as input ones.

This method is designed for AI/ML inference scenarios and other high-latency external
service calls where batching can significantly improve throughput while maintaining
output order. The Java-side AsyncBatchWaitOperator collects input elements into batches
based on the configured batch_size and batch_timeout, then invokes the Python async
batch function, and emits results in the same order as inputs.

:param data_stream: The input data stream.
:param async_batch_function: The async batch function to apply.
:param timeout: The overall timeout for asynchronous operations.
:param batch_size: Maximum number of elements to batch together before triggering
the async function. Must be positive.
:param batch_timeout: Maximum time to wait before flushing a partial batch.
If None, only batch_size triggers flushing.
:param capacity: The max number of async operations that can be in-flight.
:param output_type: The output data type.
:return: The transformed DataStream.

.. versionadded:: 2.1.0

.. note:: This is a :class:`PublicEvolving` API and may change in future versions.
"""
AsyncDataStream._validate_batch(
data_stream, async_batch_function, timeout, batch_size, batch_timeout)

from pyflink.fn_execution import flink_fn_execution_pb2
j_python_data_stream_function_operator, j_output_type_info = \
_get_one_input_stream_operator(
data_stream,
AsyncBatchFunctionDescriptor(
async_batch_function,
timeout,
batch_size,
batch_timeout,
capacity,
AsyncBatchFunctionDescriptor.OutputMode.ORDERED),
flink_fn_execution_pb2.UserDefinedDataStreamFunction.ASYNC_BATCH, # type: ignore
output_type)
return DataStream(data_stream._j_data_stream.transform(
"async batch wait operator (ordered)",
j_output_type_info,
j_python_data_stream_function_operator))

@staticmethod
def _validate(data_stream: DataStream, async_function: AsyncFunction,
timeout: Time, async_retry_strategy: AsyncRetryStrategy) -> None:
Expand All @@ -170,3 +303,31 @@ def _validate(data_stream: DataStream, async_function: AsyncFunction,
j_conf.get(gateway.jvm.org.apache.flink.python.PythonOptions.PYTHON_EXECUTION_MODE))
if python_execution_mode == 'thread':
raise Exception("AsyncFunction is still not supported for 'thread' mode.")

@staticmethod
def _validate_batch(data_stream: DataStream, async_batch_function: AsyncBatchFunction,
timeout: Time, batch_size: int, batch_timeout: Time) -> None:
"""
Validates the parameters for async batch operations.
"""
if not inspect.iscoroutinefunction(async_batch_function.async_invoke_batch):
raise Exception(
"Method 'async_invoke_batch' of class '%s' should be declared as 'async def'."
% type(async_batch_function))

if batch_size <= 0:
raise Exception("Batch size must be positive, got: %d" % batch_size)

if timeout is None or timeout.to_milliseconds() <= 0:
raise Exception("Timeout must be positive for async batch operations.")

if batch_timeout is not None and batch_timeout.to_milliseconds() < 0:
raise Exception("Batch timeout cannot be negative, got: %d ms"
% batch_timeout.to_milliseconds())

gateway = get_gateway()
j_conf = get_j_env_configuration(data_stream._j_data_stream.getExecutionEnvironment())
python_execution_mode = (
j_conf.get(gateway.jvm.org.apache.flink.python.PythonOptions.PYTHON_EXECUTION_MODE))
if python_execution_mode == 'thread':
raise Exception("AsyncBatchFunction is still not supported for 'thread' mode.")
119 changes: 119 additions & 0 deletions flink-python/pyflink/datastream/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
'KeyedBroadcastProcessFunction',
'AsyncFunction',
'AsyncFunctionDescriptor',
'AsyncBatchFunction',
'AsyncBatchFunctionDescriptor',
'AsyncRetryPredicate',
'AsyncRetryStrategy',
]
Expand Down Expand Up @@ -1025,6 +1027,123 @@ def __init__(self, async_function, timeout, capacity, async_retry_strategy, outp
self.output_mode = output_mode


class AsyncBatchFunction(Function, Generic[IN, OUT]):
"""
A function to trigger Async I/O operation with batch processing support.

This function is designed for AI/ML inference scenarios and other high-latency external
service calls where batching can significantly improve throughput.

Unlike :class:`AsyncFunction` which processes one element at a time, this function
receives a batch of input elements and processes them together. This is particularly
beneficial for:

- Machine learning model inference where batching improves GPU utilization
- External service calls that support batch APIs
- Database queries that can be batched for efficiency

For each batch, an async operation is triggered via :func:`async_invoke_batch`.
The batch is formed by the Java-side AsyncBatchWaitOperator based on configured
batch size and timeout parameters.

Example usage::

class MyAsyncBatchFunction(AsyncBatchFunction):

async def async_invoke_batch(self, inputs: List[Row]) -> List[int]:
# Process batch of inputs together
results = []
for value in inputs:
# Simulate async processing
await asyncio.sleep(0.1)
results.append(value[0] + value[1])
return results

# Apply to data stream
ds = AsyncDataStream.unordered_wait_batch(
ds, MyAsyncBatchFunction(),
timeout=Time.seconds(10),
batch_size=32,
batch_timeout=Time.milliseconds(100)
)

.. versionadded:: 2.1.0

.. note:: This is a :class:`PublicEvolving` API and may change in future versions.
"""

@abstractmethod
async def async_invoke_batch(self, inputs: List[IN]) -> List[OUT]:
"""
Trigger async operation for a batch of stream inputs.

The batch is formed by the runtime based on the configured batch_size and batch_timeout.
The implementation should process all inputs in the batch and return a list of results.

Important notes:
- The returned list should have the same length as the input list
- Each result corresponds to the input at the same index
- In case of a user code error, you can raise an exception to make the task fail
and trigger the fail-over process

:param inputs: List of input elements collected into a batch.
:return: List of output elements, one for each input element.
"""
pass

def timeout_batch(self, inputs: List[IN]) -> List[OUT]:
"""
Called when :func:`async_invoke_batch` times out.

By default, it raises a timeout exception. Override this method to provide
custom timeout handling, such as returning default values.

:param inputs: The batch of input elements that timed out.
:return: List of output elements to emit as fallback results.
"""
raise TimeoutError(
"Async batch function call has timed out for inputs: " + str(inputs))


class AsyncBatchFunctionDescriptor(object):
"""
Descriptor for AsyncBatchFunction that holds the function and its configuration.

This descriptor is used internally to pass the batch function and its parameters
to the Python worker for execution.

.. versionadded:: 2.1.0
"""

class OutputMode(Enum):
ORDERED = 0
UNORDERED = 1

def __init__(self,
async_batch_function: AsyncBatchFunction,
timeout,
batch_size: int,
batch_timeout,
capacity: int,
output_mode: 'AsyncBatchFunctionDescriptor.OutputMode'):
"""
Creates a new AsyncBatchFunctionDescriptor.

:param async_batch_function: The AsyncBatchFunction to execute.
:param timeout: The overall timeout for async operations.
:param batch_size: Maximum number of elements to batch together.
:param batch_timeout: Maximum time to wait before flushing a partial batch.
:param capacity: The max number of async operations that can be triggered.
:param output_mode: Whether to emit results in order or unordered.
"""
self.async_batch_function = async_batch_function
self.timeout = timeout
self.batch_size = batch_size
self.batch_timeout = batch_timeout
self.capacity = capacity
self.output_mode = output_mode


class WindowFunction(Function, Generic[IN, OUT, KEY, W]):
"""
Base interface for functions that are evaluated over keyed (grouped) windows.
Expand Down
Loading