diff --git a/python/ray/workflow/api.py b/python/ray/workflow/api.py index 32c7670aba0d..14989712c3b6 100644 --- a/python/ray/workflow/api.py +++ b/python/ray/workflow/api.py @@ -99,6 +99,9 @@ def step(*args, **kwargs): name = kwargs.pop("name", None) if name is not None: step_options["name"] = name + metadata = kwargs.pop("metadata", None) + if metadata is not None: + step_options["metadata"] = metadata if len(kwargs) != 0: step_options["ray_options"] = kwargs return make_step_decorator(step_options) diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 169b318ed5d7..eb648e2b1cc6 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -128,6 +128,8 @@ class WorkflowData: ray_options: Dict[str, Any] # name of the step name: str + # meta data to store + user_metadata: Dict[str, Any] def to_metadata(self) -> Dict[str, Any]: f = self.func_body @@ -139,6 +141,7 @@ def to_metadata(self) -> Dict[str, Any]: "workflow_refs": [wr.step_id for wr in self.inputs.workflow_refs], "catch_exceptions": self.catch_exceptions, "ray_options": self.ray_options, + "user_metadata": self.user_metadata } return metadata @@ -261,7 +264,9 @@ def __reduce__(self): "remote, or stored in Ray objects.") @PublicAPI(stability="beta") - def run(self, workflow_id: Optional[str] = None) -> Any: + def run(self, + workflow_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None) -> Any: """Run a workflow. If the workflow with the given id already exists, it will be resumed. @@ -288,11 +293,18 @@ def run(self, workflow_id: Optional[str] = None) -> Any: Args: workflow_id: A unique identifier that can be used to resume the workflow. If not specified, a random id will be generated. + metadata: The metadata to add to the workflow. It has to be able + to serialize to json. + + Returns: + The running result. """ - return ray.get(self.run_async(workflow_id)) + return ray.get(self.run_async(workflow_id, metadata)) @PublicAPI(stability="beta") - def run_async(self, workflow_id: Optional[str] = None) -> ObjectRef: + def run_async(self, + workflow_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None) -> ObjectRef: """Run a workflow asynchronously. If the workflow with the given id already exists, it will be resumed. @@ -319,8 +331,14 @@ def run_async(self, workflow_id: Optional[str] = None) -> ObjectRef: Args: workflow_id: A unique identifier that can be used to resume the workflow. If not specified, a random id will be generated. + metadata: The metadata to add to the workflow. It has to be able + to serialize to json. + + Returns: + The running result as ray.ObjectRef. + """ # TODO(suquark): avoid cyclic importing from ray.workflow.execution import run self._step_id = None - return run(self, workflow_id) + return run(self, workflow_id, metadata) diff --git a/python/ray/workflow/execution.py b/python/ray/workflow/execution.py index b660c65fa8a9..58168cb049b0 100644 --- a/python/ray/workflow/execution.py +++ b/python/ray/workflow/execution.py @@ -1,11 +1,11 @@ import asyncio +import json import logging import time -from typing import Set, List, Tuple, Optional, TYPE_CHECKING +from typing import Set, List, Tuple, Optional, TYPE_CHECKING, Dict import uuid import ray - from ray.workflow import workflow_context from ray.workflow import workflow_storage from ray.workflow.common import (Workflow, WorkflowStatus, WorkflowMetaData, @@ -23,9 +23,21 @@ def run(entry_workflow: Workflow, - workflow_id: Optional[str] = None) -> ray.ObjectRef: + workflow_id: Optional[str] = None, + metadata: Optional[Dict] = None) -> ray.ObjectRef: """Run a workflow asynchronously. """ + if metadata is not None: + if not isinstance(metadata, dict): + raise ValueError("metadata must be a dict.") + for k, v in metadata.items(): + try: + json.dumps(v) + except TypeError as e: + raise ValueError("metadata values must be JSON serializable, " + "however '{}' has a value whose {}.".format( + k, e)) + store = get_global_storage() assert ray.is_initialized() if workflow_id is None: @@ -40,6 +52,7 @@ def run(entry_workflow: Workflow, store.storage_url): # checkpoint the workflow ws = workflow_storage.get_workflow_storage(workflow_id) + ws.save_workflow_user_metadata(metadata) wf_exists = True try: diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index b5416b5a4021..42b22a3080df 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -1,3 +1,4 @@ +import time import asyncio from dataclasses import dataclass import logging @@ -179,6 +180,9 @@ async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, # TODO (Alex): Handle the json case better? wf_storage._put( wf_storage._key_step_input_metadata(step_id), metadata, True), + wf_storage._put( + wf_storage._key_step_user_metadata(step_id), inputs.user_metadata, + True), serialization.dump_to_storage( wf_storage._key_step_function_body(step_id), inputs.func_body, workflow_id, storage), @@ -326,9 +330,13 @@ def _workflow_step_executor(step_type: StepType, func: Callable, args, kwargs = _resolve_step_inputs(baked_inputs) store = workflow_storage.get_workflow_storage() try: + step_prerun_metadata = {"start_time": time.time()} + store.save_step_prerun_metadata(step_id, step_prerun_metadata) persisted_output, volatile_output = _wrap_run( func, step_type, step_id, catch_exceptions, max_retries, *args, **kwargs) + step_postrun_metadata = {"end_time": time.time()} + store.save_step_postrun_metadata(step_id, step_postrun_metadata) except Exception as e: commit_step(store, step_id, None, e) raise e diff --git a/python/ray/workflow/step_function.py b/python/ray/workflow/step_function.py index e2a4467f189b..671b804aa6f3 100644 --- a/python/ray/workflow/step_function.py +++ b/python/ray/workflow/step_function.py @@ -1,5 +1,6 @@ import functools -from typing import Callable +import json +from typing import Callable, Dict, Any from ray._private import signature from ray.workflow import serialization_context @@ -16,11 +17,22 @@ def __init__(self, max_retries=3, catch_exceptions=False, name=None, + metadata=None, ray_options=None): if not isinstance(max_retries, int) or max_retries < 1: raise ValueError("max_retries should be greater or equal to 1.") if ray_options is not None and not isinstance(ray_options, dict): raise ValueError("ray_options must be a dict.") + if metadata is not None: + if not isinstance(metadata, dict): + raise ValueError("metadata must be a dict.") + for k, v in metadata.items(): + try: + json.dumps(v) + except TypeError as e: + raise ValueError( + "metadata values must be JSON serializable, " + "however '{}' has a value whose {}.".format(k, e)) self._func = func self._max_retries = max_retries @@ -28,6 +40,7 @@ def __init__(self, self._ray_options = ray_options or {} self._func_signature = signature.extract_signature(func) self._name = name or "" + self._user_metadata = metadata or {} # Override signature and docstring @functools.wraps(func) @@ -48,7 +61,7 @@ def prepare_inputs(): catch_exceptions=self._catch_exceptions, ray_options=self._ray_options, name=self._name, - ) + user_metadata=self._user_metadata) return Workflow(workflow_data, prepare_inputs) self.step = _build_workflow @@ -64,6 +77,7 @@ def options(self, max_retries: int = 3, catch_exceptions: bool = False, name: str = None, + metadata: Dict[str, Any] = None, **ray_options) -> "WorkflowStepFunction": """This function set how the step function is going to be executed. @@ -79,6 +93,7 @@ def options(self, generate the step_id of the step. The name will be used directly as the step id if possible, otherwise deduplicated by appending .N suffixes. + metadata: metadata to add to the step. **ray_options: All parameters in this fields will be passed to ray remote function options. @@ -86,4 +101,4 @@ def options(self, The step function itself. """ return WorkflowStepFunction(self._func, max_retries, catch_exceptions, - name, ray_options) + name, metadata, ray_options) diff --git a/python/ray/workflow/tests/test_metadata_put.py b/python/ray/workflow/tests/test_metadata_put.py new file mode 100644 index 000000000000..6d7653858186 --- /dev/null +++ b/python/ray/workflow/tests/test_metadata_put.py @@ -0,0 +1,128 @@ +import asyncio + +from ray import workflow +from ray.tests.conftest import * # noqa +from ray.workflow import workflow_storage +from ray.workflow.storage import get_global_storage + +import pytest + + +def get_metadata(paths, is_json=True): + store = get_global_storage() + key = store.make_key(*paths) + return asyncio.get_event_loop().run_until_complete(store.get(key, is_json)) + + +def test_step_user_metadata(workflow_start_regular): + + metadata = {"k1": "v1"} + step_name = "simple_step" + workflow_id = "simple" + + @workflow.step(name=step_name, metadata=metadata) + def simple(): + return 0 + + simple.step().run(workflow_id) + + checkpointed_metadata = get_metadata( + [workflow_id, "steps", step_name, workflow_storage.STEP_USER_METADATA]) + assert metadata == checkpointed_metadata + + +def test_step_runtime_metadata(workflow_start_regular): + + step_name = "simple_step" + workflow_id = "simple" + + @workflow.step(name=step_name) + def simple(): + return 0 + + simple.step().run(workflow_id) + + prerun_meta = get_metadata([ + workflow_id, "steps", step_name, workflow_storage.STEP_PRERUN_METADATA + ]) + postrun_meta = get_metadata([ + workflow_id, "steps", step_name, workflow_storage.STEP_POSTRUN_METADATA + ]) + assert "start_time" in prerun_meta + assert "end_time" in postrun_meta + + +def test_workflow_user_metadata(workflow_start_regular): + + metadata = {"k1": "v1"} + workflow_id = "simple" + + @workflow.step + def simple(): + return 0 + + simple.step().run(workflow_id, metadata=metadata) + + checkpointed_metadata = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_USER_METADATA]) + assert metadata == checkpointed_metadata + + +def test_workflow_runtime_metadata(workflow_start_regular): + + workflow_id = "simple" + + @workflow.step + def simple(): + return 0 + + simple.step().run(workflow_id) + + prerun_meta = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_PRERUN_METADATA]) + postrun_meta = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_POSTRUN_METADATA]) + assert "start_time" in prerun_meta + assert "end_time" in postrun_meta + + +def test_all_metadata(workflow_start_regular): + + user_step_metadata = {"k1": "v1"} + user_run_metadata = {"k2": "v2"} + step_name = "simple_step" + workflow_id = "simple" + + @workflow.step + def simple(): + return 0 + + simple.options( + name=step_name, metadata=user_step_metadata).step().run( + workflow_id, metadata=user_run_metadata) + + checkpointed_user_step_metadata = get_metadata( + [workflow_id, "steps", step_name, workflow_storage.STEP_USER_METADATA]) + checkpointed_user_run_metadata = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_USER_METADATA]) + checkpointed_pre_step_meta = get_metadata([ + workflow_id, "steps", step_name, workflow_storage.STEP_PRERUN_METADATA + ]) + checkpointed_post_step_meta = get_metadata([ + workflow_id, "steps", step_name, workflow_storage.STEP_POSTRUN_METADATA + ]) + checkpointed_pre_run_meta = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_PRERUN_METADATA]) + checkpointed_post_run_meta = get_metadata( + [workflow_id, workflow_storage.WORKFLOW_POSTRUN_METADATA]) + assert user_step_metadata == checkpointed_user_step_metadata + assert user_run_metadata == checkpointed_user_run_metadata + assert "start_time" in checkpointed_pre_step_meta + assert "start_time" in checkpointed_pre_run_meta + assert "end_time" in checkpointed_post_step_meta + assert "end_time" in checkpointed_post_run_meta + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/workflow/virtual_actor_class.py b/python/ray/workflow/virtual_actor_class.py index 3f0f4368ecdc..c7c03d2bc301 100644 --- a/python/ray/workflow/virtual_actor_class.py +++ b/python/ray/workflow/virtual_actor_class.py @@ -217,6 +217,7 @@ def step(method_name, method, *args, **kwargs): catch_exceptions=False, ray_options={}, name=None, + user_metadata=None, ) wf = Workflow(workflow_data) return wf diff --git a/python/ray/workflow/workflow_access.py b/python/ray/workflow/workflow_access.py index c1b5d78d253a..c6591041c786 100644 --- a/python/ray/workflow/workflow_access.py +++ b/python/ray/workflow/workflow_access.py @@ -1,4 +1,5 @@ import logging +import time from typing import Any, Dict, List, Tuple, Optional, TYPE_CHECKING from dataclasses import dataclass @@ -169,6 +170,8 @@ def run_or_resume(self, workflow_id: str, ignore_existing: bool = False raise RuntimeError(f"The output of workflow[id={workflow_id}] " "already exists.") wf_store = workflow_storage.WorkflowStorage(workflow_id, self._store) + workflow_prerun_metadata = {"start_time": time.time()} + wf_store.save_workflow_prerun_metadata(workflow_prerun_metadata) step_id = wf_store.get_entrypoint_step_id() try: current_output = self._workflow_outputs[workflow_id].output @@ -229,6 +232,8 @@ def update_step_status(self, workflow_id: str, step_id: str, wf_store.save_workflow_meta( common.WorkflowMetaData(common.WorkflowStatus.SUCCESSFUL)) self._step_status.pop(workflow_id) + workflow_postrun_metadata = {"end_time": time.time()} + wf_store.save_workflow_postrun_metadata(workflow_postrun_metadata) def cancel_workflow(self, workflow_id: str) -> None: self._step_status.pop(workflow_id) diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index 5a188cca1a3f..a702d72b9ba5 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -29,6 +29,9 @@ OBJECTS_DIR = "objects" STEPS_DIR = "steps" STEP_INPUTS_METADATA = "inputs.json" +STEP_USER_METADATA = "user_step_metadata.json" +STEP_PRERUN_METADATA = "pre_step_metadata.json" +STEP_POSTRUN_METADATA = "post_step_metadata.json" STEP_OUTPUTS_METADATA = "outputs.json" STEP_ARGS = "args.pkl" STEP_OUTPUT = "output.pkl" @@ -36,6 +39,9 @@ STEP_FUNC_BODY = "func_body.pkl" CLASS_BODY = "class_body.pkl" WORKFLOW_META = "workflow_meta.json" +WORKFLOW_USER_METADATA = "user_run_metadata.json" +WORKFLOW_PRERUN_METADATA = "pre_run_metadata.json" +WORKFLOW_POSTRUN_METADATA = "post_run_metadata.json" WORKFLOW_PROGRESS = "progress.json" # Without this counter, we're going to scan all steps to get the number of # steps with a given name. This can be very expensive if there are too @@ -372,6 +378,76 @@ def save_actor_class_body(self, cls: type) -> None: """ asyncio_run(self._put(self._key_class_body(), cls)) + def save_step_prerun_metadata(self, step_id: StepID, + metadata: Dict[str, Any]): + """Save pre-run metadata of the current step. + + Args: + step_id: ID of the workflow step. + metadata: pre-run metadata of the current step. + + Raises: + DataSaveError: if we fail to save the pre-run metadata. + """ + + asyncio_run( + self._put(self._key_step_prerun_metadata(step_id), metadata, True)) + + def save_step_postrun_metadata(self, step_id: StepID, + metadata: Dict[str, Any]): + """Save post-run metadata of the current step. + + Args: + step_id: ID of the workflow step. + metadata: post-run metadata of the current step. + + Raises: + DataSaveError: if we fail to save the post-run metadata. + """ + + asyncio_run( + self._put( + self._key_step_postrun_metadata(step_id), metadata, True)) + + def save_workflow_user_metadata(self, metadata: Dict[str, Any]): + """Save user metadata of the current workflow. + + Args: + metadata: user metadata of the current workflow. + + Raises: + DataSaveError: if we fail to save the user metadata. + """ + + asyncio_run( + self._put(self._key_workflow_user_metadata(), metadata, True)) + + def save_workflow_prerun_metadata(self, metadata: Dict[str, Any]): + """Save pre-run metadata of the current workflow. + + Args: + metadata: pre-run metadata of the current workflow. + + Raises: + DataSaveError: if we fail to save the pre-run metadata. + """ + + asyncio_run( + self._put(self._key_workflow_prerun_metadata(), metadata, True)) + + def save_workflow_postrun_metadata(self, metadata: Dict[str, Any]): + """Save post-run metadata of the current workflow. + + Args: + metadata: post-run metadata of the current workflow. + + Raises: + DataSaveError: if we fail to save the post-run metadata. + """ + + asyncio_run( + self._put(self._key_workflow_postrun_metadata(), metadata, True)) + def save_workflow_meta(self, metadata: WorkflowMetaData) -> None: """Save the metadata of the current workflow. @@ -517,6 +593,15 @@ def _key_workflow_progress(self): def _key_step_input_metadata(self, step_id): return [self._workflow_id, STEPS_DIR, step_id, STEP_INPUTS_METADATA] + def _key_step_user_metadata(self, step_id): + return [self._workflow_id, STEPS_DIR, step_id, STEP_USER_METADATA] + + def _key_step_prerun_metadata(self, step_id): + return [self._workflow_id, STEPS_DIR, step_id, STEP_PRERUN_METADATA] + + def _key_step_postrun_metadata(self, step_id): + return [self._workflow_id, STEPS_DIR, step_id, STEP_POSTRUN_METADATA] + def _key_step_output(self, step_id): return [self._workflow_id, STEPS_DIR, step_id, STEP_OUTPUT] @@ -544,6 +629,15 @@ def _key_class_body(self): def _key_workflow_metadata(self): return [self._workflow_id, WORKFLOW_META] + def _key_workflow_user_metadata(self): + return [self._workflow_id, WORKFLOW_USER_METADATA] + + def _key_workflow_prerun_metadata(self): + return [self._workflow_id, WORKFLOW_PRERUN_METADATA] + + def _key_workflow_postrun_metadata(self): + return [self._workflow_id, WORKFLOW_POSTRUN_METADATA] + def _key_num_steps_with_name(self, name): return [self._workflow_id, DUPLICATE_NAME_COUNTER, name]