Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/workflow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def step(*args, **kwargs):
name = kwargs.pop("name", None)
if name is not None:
step_options["name"] = name
metadata = kwargs.pop("metadata", None)
if metadata is not None:
step_options["metadata"] = metadata
if len(kwargs) != 0:
step_options["ray_options"] = kwargs
return make_step_decorator(step_options)
Expand Down
26 changes: 22 additions & 4 deletions python/ray/workflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class WorkflowData:
ray_options: Dict[str, Any]
# name of the step
name: str
# meta data to store
user_metadata: Dict[str, Any]

def to_metadata(self) -> Dict[str, Any]:
f = self.func_body
Expand All @@ -139,6 +141,7 @@ def to_metadata(self) -> Dict[str, Any]:
"workflow_refs": [wr.step_id for wr in self.inputs.workflow_refs],
"catch_exceptions": self.catch_exceptions,
"ray_options": self.ray_options,
"user_metadata": self.user_metadata
}
return metadata

Expand Down Expand Up @@ -261,7 +264,9 @@ def __reduce__(self):
"remote, or stored in Ray objects.")

@PublicAPI(stability="beta")
def run(self, workflow_id: Optional[str] = None) -> Any:
def run(self,
workflow_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Run a workflow.

If the workflow with the given id already exists, it will be resumed.
Expand All @@ -288,11 +293,18 @@ def run(self, workflow_id: Optional[str] = None) -> Any:
Args:
workflow_id: A unique identifier that can be used to resume the
workflow. If not specified, a random id will be generated.
metadata: The metadata to add to the workflow. It has to be able
to serialize to json.

Returns:
The running result.
"""
return ray.get(self.run_async(workflow_id))
return ray.get(self.run_async(workflow_id, metadata))

@PublicAPI(stability="beta")
def run_async(self, workflow_id: Optional[str] = None) -> ObjectRef:
def run_async(self,
workflow_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> ObjectRef:
"""Run a workflow asynchronously.

If the workflow with the given id already exists, it will be resumed.
Expand All @@ -319,8 +331,14 @@ def run_async(self, workflow_id: Optional[str] = None) -> ObjectRef:
Args:
workflow_id: A unique identifier that can be used to resume the
workflow. If not specified, a random id will be generated.
metadata: The metadata to add to the workflow. It has to be able
to serialize to json.

Returns:
The running result as ray.ObjectRef.

"""
# TODO(suquark): avoid cyclic importing
from ray.workflow.execution import run
self._step_id = None
return run(self, workflow_id)
return run(self, workflow_id, metadata)
19 changes: 16 additions & 3 deletions python/ray/workflow/execution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import json
import logging
import time
from typing import Set, List, Tuple, Optional, TYPE_CHECKING
from typing import Set, List, Tuple, Optional, TYPE_CHECKING, Dict
import uuid

import ray

from ray.workflow import workflow_context
from ray.workflow import workflow_storage
from ray.workflow.common import (Workflow, WorkflowStatus, WorkflowMetaData,
Expand All @@ -23,9 +23,21 @@


def run(entry_workflow: Workflow,
workflow_id: Optional[str] = None) -> ray.ObjectRef:
workflow_id: Optional[str] = None,
metadata: Optional[Dict] = None) -> ray.ObjectRef:
"""Run a workflow asynchronously.
"""
if metadata is not None:
if not isinstance(metadata, dict):
raise ValueError("metadata must be a dict.")
for k, v in metadata.items():
try:
json.dumps(v)
except TypeError as e:
raise ValueError("metadata values must be JSON serializable, "
"however '{}' has a value whose {}.".format(
k, e))

store = get_global_storage()
assert ray.is_initialized()
if workflow_id is None:
Expand All @@ -40,6 +52,7 @@ def run(entry_workflow: Workflow,
store.storage_url):
# checkpoint the workflow
ws = workflow_storage.get_workflow_storage(workflow_id)
ws.save_workflow_user_metadata(metadata)

wf_exists = True
try:
Expand Down
8 changes: 8 additions & 0 deletions python/ray/workflow/step_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import asyncio
from dataclasses import dataclass
import logging
Expand Down Expand Up @@ -179,6 +180,9 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage,
# TODO (Alex): Handle the json case better?
wf_storage._put(
wf_storage._key_step_input_metadata(step_id), metadata, True),
wf_storage._put(
wf_storage._key_step_user_metadata(step_id), inputs.user_metadata,
True),
serialization.dump_to_storage(
wf_storage._key_step_function_body(step_id), inputs.func_body,
workflow_id, storage),
Expand Down Expand Up @@ -326,9 +330,13 @@ def _workflow_step_executor(step_type: StepType, func: Callable,
args, kwargs = _resolve_step_inputs(baked_inputs)
store = workflow_storage.get_workflow_storage()
try:
step_prerun_metadata = {"start_time": time.time()}
store.save_step_prerun_metadata(step_id, step_prerun_metadata)
persisted_output, volatile_output = _wrap_run(
func, step_type, step_id, catch_exceptions, max_retries, *args,
**kwargs)
step_postrun_metadata = {"end_time": time.time()}
store.save_step_postrun_metadata(step_id, step_postrun_metadata)
except Exception as e:
commit_step(store, step_id, None, e)
raise e
Expand Down
21 changes: 18 additions & 3 deletions python/ray/workflow/step_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from typing import Callable
import json
from typing import Callable, Dict, Any

from ray._private import signature
from ray.workflow import serialization_context
Expand All @@ -16,18 +17,30 @@ def __init__(self,
max_retries=3,
catch_exceptions=False,
name=None,
metadata=None,
ray_options=None):
if not isinstance(max_retries, int) or max_retries < 1:
raise ValueError("max_retries should be greater or equal to 1.")
if ray_options is not None and not isinstance(ray_options, dict):
raise ValueError("ray_options must be a dict.")
if metadata is not None:
if not isinstance(metadata, dict):
raise ValueError("metadata must be a dict.")
for k, v in metadata.items():
try:
json.dumps(v)
except TypeError as e:
raise ValueError(
"metadata values must be JSON serializable, "
"however '{}' has a value whose {}.".format(k, e))

self._func = func
self._max_retries = max_retries
self._catch_exceptions = catch_exceptions
self._ray_options = ray_options or {}
self._func_signature = signature.extract_signature(func)
self._name = name or ""
self._user_metadata = metadata or {}

# Override signature and docstring
@functools.wraps(func)
Expand All @@ -48,7 +61,7 @@ def prepare_inputs():
catch_exceptions=self._catch_exceptions,
ray_options=self._ray_options,
name=self._name,
)
user_metadata=self._user_metadata)
return Workflow(workflow_data, prepare_inputs)

self.step = _build_workflow
Expand All @@ -64,6 +77,7 @@ def options(self,
max_retries: int = 3,
catch_exceptions: bool = False,
name: str = None,
metadata: Dict[str, Any] = None,
**ray_options) -> "WorkflowStepFunction":
"""This function set how the step function is going to be executed.

Expand All @@ -79,11 +93,12 @@ def options(self,
generate the step_id of the step. The name will be used
directly as the step id if possible, otherwise deduplicated by
appending .N suffixes.
metadata: metadata to add to the step.
**ray_options: All parameters in this fields will be passed
to ray remote function options.

Returns:
The step function itself.
"""
return WorkflowStepFunction(self._func, max_retries, catch_exceptions,
name, ray_options)
name, metadata, ray_options)
128 changes: 128 additions & 0 deletions python/ray/workflow/tests/test_metadata_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import asyncio

from ray import workflow
from ray.tests.conftest import * # noqa
from ray.workflow import workflow_storage
from ray.workflow.storage import get_global_storage

import pytest


def get_metadata(paths, is_json=True):
store = get_global_storage()
key = store.make_key(*paths)
return asyncio.get_event_loop().run_until_complete(store.get(key, is_json))


def test_step_user_metadata(workflow_start_regular):

metadata = {"k1": "v1"}
step_name = "simple_step"
workflow_id = "simple"

@workflow.step(name=step_name, metadata=metadata)
def simple():
return 0

simple.step().run(workflow_id)

checkpointed_metadata = get_metadata(
[workflow_id, "steps", step_name, workflow_storage.STEP_USER_METADATA])
assert metadata == checkpointed_metadata


def test_step_runtime_metadata(workflow_start_regular):

step_name = "simple_step"
workflow_id = "simple"

@workflow.step(name=step_name)
def simple():
return 0

simple.step().run(workflow_id)

prerun_meta = get_metadata([
workflow_id, "steps", step_name, workflow_storage.STEP_PRERUN_METADATA
])
postrun_meta = get_metadata([
workflow_id, "steps", step_name, workflow_storage.STEP_POSTRUN_METADATA
])
assert "start_time" in prerun_meta
assert "end_time" in postrun_meta


def test_workflow_user_metadata(workflow_start_regular):

metadata = {"k1": "v1"}
workflow_id = "simple"

@workflow.step
def simple():
return 0

simple.step().run(workflow_id, metadata=metadata)

checkpointed_metadata = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_USER_METADATA])
assert metadata == checkpointed_metadata


def test_workflow_runtime_metadata(workflow_start_regular):

workflow_id = "simple"

@workflow.step
def simple():
return 0

simple.step().run(workflow_id)

prerun_meta = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_PRERUN_METADATA])
postrun_meta = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_POSTRUN_METADATA])
assert "start_time" in prerun_meta
assert "end_time" in postrun_meta


def test_all_metadata(workflow_start_regular):

user_step_metadata = {"k1": "v1"}
user_run_metadata = {"k2": "v2"}
step_name = "simple_step"
workflow_id = "simple"

@workflow.step
def simple():
return 0

simple.options(
name=step_name, metadata=user_step_metadata).step().run(
workflow_id, metadata=user_run_metadata)

checkpointed_user_step_metadata = get_metadata(
[workflow_id, "steps", step_name, workflow_storage.STEP_USER_METADATA])
checkpointed_user_run_metadata = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_USER_METADATA])
checkpointed_pre_step_meta = get_metadata([
workflow_id, "steps", step_name, workflow_storage.STEP_PRERUN_METADATA
])
checkpointed_post_step_meta = get_metadata([
workflow_id, "steps", step_name, workflow_storage.STEP_POSTRUN_METADATA
])
checkpointed_pre_run_meta = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_PRERUN_METADATA])
checkpointed_post_run_meta = get_metadata(
[workflow_id, workflow_storage.WORKFLOW_POSTRUN_METADATA])
assert user_step_metadata == checkpointed_user_step_metadata
assert user_run_metadata == checkpointed_user_run_metadata
assert "start_time" in checkpointed_pre_step_meta
assert "start_time" in checkpointed_pre_run_meta
assert "end_time" in checkpointed_post_step_meta
assert "end_time" in checkpointed_post_run_meta


if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
1 change: 1 addition & 0 deletions python/ray/workflow/virtual_actor_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def step(method_name, method, *args, **kwargs):
catch_exceptions=False,
ray_options={},
name=None,
user_metadata=None,
)
wf = Workflow(workflow_data)
return wf
Expand Down
5 changes: 5 additions & 0 deletions python/ray/workflow/workflow_access.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING

from dataclasses import dataclass
Expand Down Expand Up @@ -169,6 +170,8 @@ def run_or_resume(self, workflow_id: str, ignore_existing: bool = False
raise RuntimeError(f"The output of workflow[id={workflow_id}] "
"already exists.")
wf_store = workflow_storage.WorkflowStorage(workflow_id, self._store)
workflow_prerun_metadata = {"start_time": time.time()}
wf_store.save_workflow_prerun_metadata(workflow_prerun_metadata)
step_id = wf_store.get_entrypoint_step_id()
try:
current_output = self._workflow_outputs[workflow_id].output
Expand Down Expand Up @@ -229,6 +232,8 @@ def update_step_status(self, workflow_id: str, step_id: str,
wf_store.save_workflow_meta(
common.WorkflowMetaData(common.WorkflowStatus.SUCCESSFUL))
self._step_status.pop(workflow_id)
workflow_postrun_metadata = {"end_time": time.time()}
wf_store.save_workflow_postrun_metadata(workflow_postrun_metadata)

def cancel_workflow(self, workflow_id: str) -> None:
self._step_status.pop(workflow_id)
Expand Down
Loading