diff --git a/.style.yapf b/.style.yapf index 8572636c2268..255598756db1 100644 --- a/.style.yapf +++ b/.style.yapf @@ -1,4 +1,5 @@ [style] +based_on_style=chromium # Align closing bracket with visual indentation. align_closing_bracket_with_visual_indent=True @@ -14,9 +15,19 @@ allow_multiline_dictionary_keys=False # Allow lambdas to be formatted on more than one line. allow_multiline_lambdas=False +# Allow splits before the dictionary value. +allow_split_before_dict_value=True + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition=2 + # Insert a blank line before a class-level docstring. blank_line_before_class_docstring=False +# Insert a blank line before a module docstring. +blank_line_before_module_docstring=False + # Insert a blank line before a 'def' or 'class' immediately nested # within another 'def' or 'class'. For example: # @@ -45,7 +56,23 @@ blank_line_before_nested_class_or_def=False coalesce_brackets=True # The column limit. -column_limit=79 +column_limit=80 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation +# alignment. +# - LESS: Slightly left if cannot vertically align continuation lines with +# indent characters. +# - VALIGN-RIGHT: Vertically align continuation lines with indent +# characters. Slightly right (one more indent character) if cannot +# vertically align continuation lines with indent characters. +# +# For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is +# enabled. +continuation_align_style=SPACE # Indent width used for line continuations. continuation_indent_width=4 @@ -66,7 +93,7 @@ continuation_indent_width=4 # start_ts=now()-timedelta(days=3), # end_ts=now(), # ) # <--- this bracket is dedented and on a separate line -dedent_closing_brackets=False +dedent_closing_brackets=True # Place each dictionary entry onto its own line. each_dict_entry_on_separate_line=True @@ -79,7 +106,7 @@ i18n_comment= # The i18n function call names. The presence of this function stops # reformattting on that line, because the string it has cannot be moved # away from the i18n comment. -i18n_function_call= +i18n_function_call=[''] # Indent the dictionary value if it cannot fit on the same line as the # dictionary key. For example: @@ -90,13 +117,13 @@ i18n_function_call= # 'key2': value1 + # value2, # } -indent_dictionary_value=True +indent_dictionary_value=False # The number of columns to use for indentation. indent_width=4 # Join short lines into one line. E.g., single line 'if' statements. -join_multiple_lines=True +join_multiple_lines=False # Do not include spaces around selected binary operators. For example: # @@ -106,7 +133,7 @@ join_multiple_lines=True # # 1 + 2*3 - 4/5 # -no_spaces_around_selected_binary_operators=set([]) +no_spaces_around_selected_binary_operators={'set()'} # Use spaces around default or named assigns. spaces_around_default_or_named_assign=False @@ -123,12 +150,16 @@ space_between_ending_comma_and_closing_bracket=True # Split before arguments if the argument list is terminated by a # comma. -split_arguments_when_comma_terminated=False +split_arguments_when_comma_terminated=True # Set to True to prefer splitting before '&', '|' or '^' rather than # after. split_before_bitwise_operator=True +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket=True + # Split before a dictionary or set generator (comp_for). For example, note # the split before the 'for': # @@ -138,6 +169,10 @@ split_before_bitwise_operator=True # } split_before_dict_set_generator=True +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren=True + # If an argument / parameter list is going to be split, then split before # the first argument. split_before_first_argument=False @@ -149,6 +184,22 @@ split_before_logical_operator=True # Split named assignments onto individual lines. split_before_named_assigns=True +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension=True + # The penalty for splitting right after the opening bracket. split_penalty_after_opening_bracket=30 @@ -162,8 +213,12 @@ split_penalty_before_if_expr=0 # operators. split_penalty_bitwise_operator=300 +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension=80 + # The penalty for characters over the column limit. -split_penalty_excess_character=4500 +split_penalty_excess_character=1000 # The penalty incurred by adding a line split to the unwrapped line. The # more line splits added the higher the penalty. @@ -187,3 +242,5 @@ split_penalty_logical_operator=300 # Use the Tab character for indentation. use_tabs=False + + diff --git a/.travis/yapf.sh b/.travis/yapf.sh index b8af8656c040..94fbb5fad5f7 100755 --- a/.travis/yapf.sh +++ b/.travis/yapf.sh @@ -1,27 +1,30 @@ #!/usr/bin/env bash # Cause the script to exit if a single command fails -set -e - -ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) - -pushd $ROOT_DIR/../test - find . -name '*.py' -type f -exec yapf --style=pep8 -i -r {} \; -popd - -pushd $ROOT_DIR/../python - find . -name '*.py' -type f -not -path './ray/dataframe/*' -not -path './ray/rllib/*' -not -path './ray/cloudpickle/*' -exec yapf --style=pep8 -i -r {} \; -popd - -CHANGED_FILES=(`git diff --name-only`) -if [ "$CHANGED_FILES" ]; then - echo 'Reformatted staged files. Please review and stage the changes.' - echo - echo 'Files updated:' - for file in ${CHANGED_FILES[@]}; do - echo " $file" - done - exit 1 -else - exit 0 +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" + +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" + +yapf \ + --style "$ROOT/.style.yapf" \ + --in-place --recursive --parallel \ + --exclude 'python/ray/cloudpickle/' \ + -- \ + 'test/' 'python/' + +CHANGED_FILES=($(git diff --name-only)) + +if [[ "${#CHANGED_FILES[@]}" -gt 0 ]]; then + echo 'Reformatted staged files. Please review and stage the changes.' + echo 'Files updated:' + + for file in "${CHANGED_FILES[@]}"; do + echo "$file" + done + + exit 1 fi diff --git a/python/ray/__init__.py b/python/ray/__init__.py index eee900c1d7dc..4f678c7ecdab 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -6,14 +6,17 @@ import sys if "pyarrow" in sys.modules: - raise ImportError("Ray must be imported before pyarrow because Ray " - "requires a specific version of pyarrow (which is " - "packaged along with Ray).") + raise ImportError( + "Ray must be imported before pyarrow because Ray " + "requires a specific version of pyarrow (which is " + "packaged along with Ray)." + ) # Add the directory containing pyarrow to the Python path so that we find the # pyarrow version packaged with ray and not a pre-existing pyarrow. pyarrow_path = os.path.join( - os.path.abspath(os.path.dirname(__file__)), "pyarrow_files") + os.path.abspath(os.path.dirname(__file__)), "pyarrow_files" +) sys.path.insert(0, pyarrow_path) # See https://github.com/ray-project/ray/issues/131. @@ -27,15 +30,21 @@ try: import pyarrow # noqa: F401 except ImportError as e: - if ((hasattr(e, "msg") and isinstance(e.msg, str) - and ("libstdc++" in e.msg or "CXX" in e.msg))): + if (( + hasattr(e, "msg") and isinstance(e.msg, str) + and ("libstdc++" in e.msg or "CXX" in e.msg) + )): # This code path should be taken with Python 3. e.msg += helpful_message - elif (hasattr(e, "message") and isinstance(e.message, str) - and ("libstdc++" in e.message or "CXX" in e.message)): + elif ( + hasattr(e, "message") and isinstance(e.message, str) + and ("libstdc++" in e.message or "CXX" in e.message) + ): # This code path should be taken with Python 2. - condition = (hasattr(e, "args") and isinstance(e.args, tuple) - and len(e.args) == 1 and isinstance(e.args[0], str)) + condition = ( + hasattr(e, "args") and isinstance(e.args, tuple) + and len(e.args) == 1 and isinstance(e.args[0], str) + ) if condition: e.args = (e.args[0] + helpful_message, ) else: @@ -47,12 +56,13 @@ raise from ray.local_scheduler import _config # noqa: E402 -from ray.worker import (error_info, init, connect, disconnect, get, put, wait, - remote, log_event, log_span, flush_log, get_gpu_ids, - get_webui_url, - register_custom_serializer) # noqa: E402 -from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, - SILENT_MODE) # noqa: E402 +from ray.worker import ( + error_info, init, connect, disconnect, get, put, wait, remote, log_event, + log_span, flush_log, get_gpu_ids, get_webui_url, register_custom_serializer +) # noqa: E402 +from ray.worker import ( + SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE +) # noqa: E402 from ray.worker import global_state # noqa: E402 # We import ray.actor because some code is run in actor.py which initializes # some functions in the worker. @@ -66,9 +76,9 @@ __all__ = [ "error_info", "init", "connect", "disconnect", "get", "put", "wait", "remote", "log_event", "log_span", "flush_log", "actor", "method", - "get_gpu_ids", "get_webui_url", "register_custom_serializer", - "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state", - "_config", "__version__" + "get_gpu_ids", "get_webui_url", "register_custom_serializer", "SCRIPT_MODE", + "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state", "_config", + "__version__" ] import ctypes # noqa: E402 diff --git a/python/ray/actor.py b/python/ray/actor.py index 505dd07c6a53..2543bf5d5950 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -12,8 +12,9 @@ import ray.local_scheduler import ray.signature as signature import ray.worker -from ray.utils import (FunctionProperties, _random_string, is_cython, - push_error_to_driver) +from ray.utils import ( + FunctionProperties, _random_string, is_cython, push_error_to_driver +) def compute_actor_handle_id(actor_handle_id, num_forks): @@ -68,8 +69,9 @@ def compute_actor_method_function_id(class_name, attr): return ray.local_scheduler.ObjectID(function_id) -def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, - frontier): +def set_actor_checkpoint( + worker, actor_id, checkpoint_index, checkpoint, frontier +): """Set the most recent checkpoint associated with a given actor ID. Args: @@ -85,7 +87,8 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, "checkpoint_index": checkpoint_index, "checkpoint": checkpoint, "frontier": frontier, - }) + } + ) def get_actor_checkpoint(worker, actor_id): @@ -104,7 +107,8 @@ def get_actor_checkpoint(worker, actor_id): """ actor_key = b"Actor:" + actor_id checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( - actor_key, ["checkpoint_index", "checkpoint", "frontier"]) + actor_key, ["checkpoint_index", "checkpoint", "frontier"] + ) if checkpoint_index is not None: checkpoint_index = int(checkpoint_index) return checkpoint_index, checkpoint, frontier @@ -131,7 +135,8 @@ def save_and_log_checkpoint(worker, actor): data={ "actor_class": actor.__class__.__name__, "function_name": actor.__ray_checkpoint__.__name__ - }) + } + ) def restore_and_log_checkpoint(worker, actor): @@ -155,7 +160,8 @@ def restore_and_log_checkpoint(worker, actor): data={ "actor_class": actor.__class__.__name__, "function_name": actor.__ray_checkpoint_restore__.__name__ - }) + } + ) return checkpoint_resumed @@ -197,15 +203,18 @@ def actor_method_executor(dummy_return_id, actor, *args): return # Determine whether we should checkpoint the actor. - checkpointing_on = (actor_imported - and worker.actor_checkpoint_interval > 0) + checkpointing_on = ( + actor_imported and worker.actor_checkpoint_interval > 0 + ) # We should checkpoint the actor if user checkpointing is on, we've # executed checkpoint_interval tasks since the last checkpoint, and the # method we're about to execute is not a checkpoint. save_checkpoint = ( - checkpointing_on and - (worker.actor_task_counter % worker.actor_checkpoint_interval == 0 - and method_name != "__ray_checkpoint__")) + checkpointing_on and ( + worker.actor_task_counter % worker.actor_checkpoint_interval == 0 + and method_name != "__ray_checkpoint__" + ) + ) # Execute the assigned method and save a checkpoint if necessary. try: @@ -238,21 +247,24 @@ def fetch_and_register_actor(actor_class_key, resources, worker): worker: The worker to use. """ actor_id_str = worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, actor_method_names, - actor_method_num_return_vals) = worker.redis_client.hmget( - actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", - "checkpoint_interval", "actor_method_names", - "actor_method_num_return_vals" - ]) + ( + driver_id, class_id, class_name, module, pickled_class, + checkpoint_interval, actor_method_names, actor_method_num_return_vals + ) = worker.redis_client.hmget( + actor_class_key, [ + "driver_id", "class_id", "class_name", "module", "class", + "checkpoint_interval", "actor_method_names", + "actor_method_num_return_vals" + ] + ) actor_name = class_name.decode("ascii") module = module.decode("ascii") checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(actor_method_names.decode("ascii")) actor_method_num_return_vals = json.loads( - actor_method_num_return_vals.decode("ascii")) + actor_method_num_return_vals.decode("ascii") + ) # Create a temporary actor with some temporary methods so that if the actor # fails to be unpickled, the temporary actor can be used (just to produce @@ -264,23 +276,29 @@ class TemporaryActor(object): worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): - raise Exception("The actor with name {} failed to be imported, and so " - "cannot execute this method".format(actor_name)) + raise Exception( + "The actor with name {} failed to be imported, and so " + "cannot execute this method".format(actor_name) + ) # Register the actor method signatures. - register_actor_signatures(worker, driver_id, class_id, class_name, - actor_method_names, actor_method_num_return_vals) + register_actor_signatures( + worker, driver_id, class_id, class_name, actor_method_names, + actor_method_num_return_vals + ) # Register the actor method executors. for actor_method_name in actor_method_names: - function_id = compute_actor_method_function_id(class_name, - actor_method_name).id() + function_id = compute_actor_method_function_id( + class_name, actor_method_name + ).id() temporary_executor = make_actor_method_executor( worker, actor_method_name, temporary_actor_method, - actor_imported=False) - worker.functions[driver_id][function_id] = (actor_method_name, - temporary_executor) + actor_imported=False + ) + worker.functions[driver_id][function_id + ] = (actor_method_name, temporary_executor) worker.num_task_executions[driver_id][function_id] = 0 try: @@ -296,7 +314,8 @@ def temporary_actor_method(*xs): "register_actor_signatures", traceback_str, driver_id, - data={"actor_id": actor_id_str}) + data={"actor_id": actor_id_str} + ) # TODO(rkn): In the future, it might make sense to have the worker exit # here. However, currently that would lead to hanging if someone calls # ray.get on a method invoked on the actor. @@ -306,30 +325,35 @@ def temporary_actor_method(*xs): worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) + return ( + inspect.isfunction(x) or inspect.ismethod(x) or is_cython(x) + ) actor_methods = inspect.getmembers(unpickled_class, predicate=pred) for actor_method_name, actor_method in actor_methods: function_id = compute_actor_method_function_id( - class_name, actor_method_name).id() + class_name, actor_method_name + ).id() executor = make_actor_method_executor( - worker, actor_method_name, actor_method, actor_imported=True) - worker.functions[driver_id][function_id] = (actor_method_name, - executor) + worker, actor_method_name, actor_method, actor_imported=True + ) + worker.functions[driver_id][function_id + ] = (actor_method_name, executor) # We do not set worker.function_properties[driver_id][function_id] # because we currently do need the actor worker to submit new tasks # for the actor. -def register_actor_signatures(worker, - driver_id, - class_id, - class_name, - actor_method_names, - actor_method_num_return_vals, - actor_creation_resources=None, - actor_method_cpus=None): +def register_actor_signatures( + worker, + driver_id, + class_id, + class_name, + actor_method_names, + actor_method_num_return_vals, + actor_creation_resources=None, + actor_method_cpus=None +): """Register an actor's method signatures in the worker. Args: @@ -346,12 +370,14 @@ def register_actor_signatures(worker, """ assert len(actor_method_names) == len(actor_method_num_return_vals) for actor_method_name, num_return_vals in zip( - actor_method_names, actor_method_num_return_vals): + actor_method_names, actor_method_num_return_vals + ): # TODO(rkn): When we create a second actor, we are probably overwriting # the values from the first actor here. This may or may not be a # problem. - function_id = compute_actor_method_function_id(class_name, - actor_method_name).id() + function_id = compute_actor_method_function_id( + class_name, actor_method_name + ).id() worker.function_properties[driver_id][function_id] = ( # The extra return value is an actor dummy object. # In the cases where actor_method_cpus is None, that value should @@ -359,7 +385,9 @@ def register_actor_signatures(worker, FunctionProperties( num_return_vals=num_return_vals + 1, resources={"CPU": actor_method_cpus}, - max_calls=0)) + max_calls=0 + ) + ) if actor_creation_resources is not None: # Also register the actor creation task. @@ -369,7 +397,9 @@ def register_actor_signatures(worker, FunctionProperties( num_return_vals=0 + 1, resources=actor_creation_resources, - max_calls=0)) + max_calls=0 + ) + ) def publish_actor_class_to_key(key, actor_class_info, worker): @@ -391,9 +421,10 @@ def publish_actor_class_to_key(key, actor_class_info, worker): worker.redis_client.rpush("Exports", key) -def export_actor_class(class_id, Class, actor_method_names, - actor_method_num_return_vals, checkpoint_interval, - worker): +def export_actor_class( + class_id, Class, actor_method_names, actor_method_num_return_vals, + checkpoint_interval, worker +): key = b"ActorClass:" + class_id actor_class_info = { "class_name": Class.__name__, @@ -411,7 +442,9 @@ def export_actor_class(class_id, Class, actor_method_names, # called. assert worker.cached_remote_functions_and_actors is not None worker.cached_remote_functions_and_actors.append( - ("actor", (key, actor_class_info))) + ("actor", + (key, actor_class_info)) + ) # This caching code path is currently not used because we only export # actor class definitions lazily when we instantiate the actor for the # first time. @@ -423,9 +456,11 @@ def export_actor_class(class_id, Class, actor_method_names, # https://github.com/ray-project/ray/issues/1146. -def export_actor(actor_id, class_id, class_name, actor_method_names, - actor_method_num_return_vals, actor_creation_resources, - actor_method_cpus, worker): +def export_actor( + actor_id, class_id, class_name, actor_method_names, + actor_method_num_return_vals, actor_creation_resources, actor_method_cpus, + worker +): """Export an actor to redis. Args: @@ -441,8 +476,10 @@ def export_actor(actor_id, class_id, class_name, actor_method_names, """ ray.worker.check_main_thread() if worker.mode is None: - raise Exception("Actors cannot be created before Ray has been " - "started. You can start Ray with 'ray.init()'.") + raise Exception( + "Actors cannot be created before Ray has been " + "started. You can start Ray with 'ray.init()'." + ) driver_id = worker.task_driver_id.id() register_actor_signatures( @@ -453,7 +490,8 @@ def export_actor(actor_id, class_id, class_name, actor_method_names, actor_method_names, actor_method_num_return_vals, actor_creation_resources=actor_creation_resources, - actor_method_cpus=actor_method_cpus) + actor_method_cpus=actor_method_cpus + ) args = [class_id] function_id = compute_actor_creation_function_id(class_id) @@ -481,17 +519,21 @@ def __init__(self, actor, method_name): self._method_name = method_name def __call__(self, *args, **kwargs): - raise Exception("Actor methods cannot be called directly. Instead " - "of running 'object.{}()', try " - "'object.{}.remote()'.".format(self._method_name, - self._method_name)) + raise Exception( + "Actor methods cannot be called directly. Instead " + "of running 'object.{}()', try " + "'object.{}.remote()'.".format( + self._method_name, self._method_name + ) + ) def remote(self, *args, **kwargs): return self._actor._actor_method_call( self._method_name, args=args, kwargs=kwargs, - dependency=self._actor._ray_actor_cursor) + dependency=self._actor._ray_actor_cursor + ) class ActorHandleWrapper(object): @@ -501,12 +543,12 @@ class ActorHandleWrapper(object): can tell that an argument is an ActorHandle. """ - def __init__(self, actor_id, class_id, actor_handle_id, actor_cursor, - actor_counter, actor_method_names, - actor_method_num_return_vals, method_signatures, - checkpoint_interval, class_name, - actor_creation_dummy_object_id, actor_creation_resources, - actor_method_cpus): + def __init__( + self, actor_id, class_id, actor_handle_id, actor_cursor, actor_counter, + actor_method_names, actor_method_num_return_vals, method_signatures, + checkpoint_interval, class_name, actor_creation_dummy_object_id, + actor_creation_resources, actor_method_cpus + ): # TODO(rkn): Some of these fields are probably not necessary. We should # strip out the unnecessary fields to keep actor handles lightweight. self.actor_id = actor_id @@ -538,8 +580,9 @@ def wrap_actor_handle(actor_handle): wrapper = ActorHandleWrapper( actor_handle._ray_actor_id, actor_handle._ray_class_id, - compute_actor_handle_id(actor_handle._ray_actor_handle_id, - actor_handle._ray_actor_forks), + compute_actor_handle_id( + actor_handle._ray_actor_handle_id, actor_handle._ray_actor_forks + ), actor_handle._ray_actor_cursor, 0, # Reset the actor counter. actor_handle._ray_actor_method_names, @@ -549,7 +592,8 @@ def wrap_actor_handle(actor_handle): actor_handle._ray_class_name, actor_handle._ray_actor_creation_dummy_object_id, actor_handle._ray_actor_creation_resources, - actor_handle._ray_actor_method_cpus) + actor_handle._ray_actor_method_cpus + ) actor_handle._ray_actor_forks += 1 return wrapper @@ -568,17 +612,18 @@ def unwrap_actor_handle(worker, wrapper): register_actor_signatures( worker, driver_id, wrapper.class_id, wrapper.class_name, wrapper.actor_method_names, wrapper.actor_method_num_return_vals, - wrapper.actor_creation_resources, wrapper.actor_method_cpus) + wrapper.actor_creation_resources, wrapper.actor_method_cpus + ) actor_handle_class = make_actor_handle_class(wrapper.class_name) actor_object = actor_handle_class.__new__(actor_handle_class) actor_object._manual_init( wrapper.actor_id, wrapper.class_id, wrapper.actor_handle_id, - wrapper.actor_cursor, wrapper.actor_counter, - wrapper.actor_method_names, wrapper.actor_method_num_return_vals, - wrapper.method_signatures, wrapper.checkpoint_interval, - wrapper.actor_creation_dummy_object_id, - wrapper.actor_creation_resources, wrapper.actor_method_cpus) + wrapper.actor_cursor, wrapper.actor_counter, wrapper.actor_method_names, + wrapper.actor_method_num_return_vals, wrapper.method_signatures, + wrapper.checkpoint_interval, wrapper.actor_creation_dummy_object_id, + wrapper.actor_creation_resources, wrapper.actor_method_cpus + ) return actor_object @@ -594,20 +639,26 @@ class ActorHandleParent(object): def make_actor_handle_class(class_name): class ActorHandle(ActorHandleParent): def __init__(self, *args, **kwargs): - raise Exception("Actor classes cannot be instantiated directly. " - "Instead of running '{}()', try '{}.remote()'." - .format(class_name, class_name)) + raise Exception( + "Actor classes cannot be instantiated directly. " + "Instead of running '{}()', try '{}.remote()'." + .format(class_name, class_name) + ) @classmethod def remote(cls, *args, **kwargs): - raise NotImplementedError("The classmethod remote() can only be " - "called on the original Class.") - - def _manual_init(self, actor_id, class_id, actor_handle_id, - actor_cursor, actor_counter, actor_method_names, - actor_method_num_return_vals, method_signatures, - checkpoint_interval, actor_creation_dummy_object_id, - actor_creation_resources, actor_method_cpus): + raise NotImplementedError( + "The classmethod remote() can only be " + "called on the original Class." + ) + + def _manual_init( + self, actor_id, class_id, actor_handle_id, actor_cursor, + actor_counter, actor_method_names, actor_method_num_return_vals, + method_signatures, checkpoint_interval, + actor_creation_dummy_object_id, actor_creation_resources, + actor_method_cpus + ): self._ray_actor_id = actor_id self._ray_class_id = class_id self._ray_actor_handle_id = actor_handle_id @@ -615,21 +666,21 @@ def _manual_init(self, actor_id, class_id, actor_handle_id, self._ray_actor_counter = actor_counter self._ray_actor_method_names = actor_method_names self._ray_actor_method_num_return_vals = ( - actor_method_num_return_vals) + actor_method_num_return_vals + ) self._ray_method_signatures = method_signatures self._ray_checkpoint_interval = checkpoint_interval self._ray_class_name = class_name self._ray_actor_forks = 0 self._ray_actor_creation_dummy_object_id = ( - actor_creation_dummy_object_id) + actor_creation_dummy_object_id + ) self._ray_actor_creation_resources = actor_creation_resources self._ray_actor_method_cpus = actor_method_cpus - def _actor_method_call(self, - method_name, - args=None, - kwargs=None, - dependency=None): + def _actor_method_call( + self, method_name, args=None, kwargs=None, dependency=None + ): """Method execution stub for an actor handle. This is the function that executes when @@ -666,7 +717,8 @@ def _actor_method_call(self, if ray.worker.global_worker.mode == ray.PYTHON_MODE: return getattr( ray.worker.global_worker.actors[self._ray_actor_id], - method_name)(*copy.deepcopy(args)) + method_name + )(*copy.deepcopy(args)) # Add the execution dependency. if dependency is None: @@ -677,7 +729,8 @@ def _actor_method_call(self, is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") function_id = compute_actor_method_function_id( - self._ray_class_name, method_name) + self._ray_class_name, method_name + ) object_ids = ray.worker.global_worker.submit_task( function_id, args, @@ -686,8 +739,10 @@ def _actor_method_call(self, actor_counter=self._ray_actor_counter, is_actor_checkpoint_method=is_actor_checkpoint_method, actor_creation_dummy_object_id=( - self._ray_actor_creation_dummy_object_id), - execution_dependencies=execution_dependencies) + self._ray_actor_creation_dummy_object_id + ), + execution_dependencies=execution_dependencies + ) # Update the actor counter and cursor to reflect the most recent # invocation. self._ray_actor_counter += 1 @@ -708,7 +763,8 @@ def __getattribute__(self, attr): try: # Check whether this is an actor method. actor_method_names = object.__getattribute__( - self, "_ray_actor_method_names") + self, "_ray_actor_method_names" + ) if attr in actor_method_names: # We create the ActorMethod on the fly here so that the # ActorHandle doesn't need a reference to the ActorMethod. @@ -736,18 +792,23 @@ def __del__(self): # TODO(swang): Also clean up forked actor handles. # Kill the worker if this is the original actor handle, created # with Class.remote(). - if (ray.worker.global_worker.connected and - self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID): + if ( + ray.worker.global_worker.connected + and self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID + ): # TODO(rkn): Should we be passing in the actor cursor as a # dependency here? self._actor_method_call( - "__ray_terminate__", args=[self._ray_actor_id.id()]) + "__ray_terminate__", args=[self._ray_actor_id.id()] + ) return ActorHandle -def actor_handle_from_class(Class, class_id, actor_creation_resources, - checkpoint_interval, actor_method_cpus): +def actor_handle_from_class( + Class, class_id, actor_creation_resources, checkpoint_interval, + actor_method_cpus +): class_name = Class.__name__.encode("ascii") actor_handle_class = make_actor_handle_class(class_name) exported = [] @@ -756,14 +817,17 @@ class ActorHandle(actor_handle_class): @classmethod def remote(cls, *args, **kwargs): if ray.worker.global_worker.mode is None: - raise Exception("Actors cannot be created before ray.init() " - "has been called.") + raise Exception( + "Actors cannot be created before ray.init() " + "has been called." + ) actor_id = ray.local_scheduler.ObjectID(_random_string()) # The ID for this instance of ActorHandle. These should be unique # across instances with the same _ray_actor_id. actor_handle_id = ray.local_scheduler.ObjectID( - ray.worker.NIL_ACTOR_ID) + ray.worker.NIL_ACTOR_ID + ) # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then @@ -774,8 +838,10 @@ def remote(cls, *args, **kwargs): # Get the actor methods of the given class. def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) + return ( + inspect.isfunction(x) or inspect.ismethod(x) + or is_cython(x) + ) actor_methods = inspect.getmembers(Class, predicate=pred) # Extract the signatures of each of the methods. This will be used @@ -790,7 +856,8 @@ def pred(x): # it. signature.check_signature_supported(v, warn=True) method_signatures[k] = signature.extract_signature( - v, ignore_first=True) + v, ignore_first=True + ) actor_method_names = [ method_name for method_name, _ in actor_methods @@ -799,7 +866,8 @@ def pred(x): for _, method in actor_methods: if hasattr(method, "__ray_num_return_vals__"): actor_method_num_return_vals.append( - method.__ray_num_return_vals__) + method.__ray_num_return_vals__ + ) else: actor_method_num_return_vals.append(1) # Do not export the actor class or the actor if run in PYTHON_MODE @@ -807,19 +875,22 @@ def pred(x): # global_worker's dictionary if ray.worker.global_worker.mode == ray.PYTHON_MODE: ray.worker.global_worker.actors[actor_id] = ( - Class.__new__(Class)) + Class.__new__(Class) + ) else: # Export the actor. if not exported: - export_actor_class(class_id, Class, actor_method_names, - actor_method_num_return_vals, - checkpoint_interval, - ray.worker.global_worker) + export_actor_class( + class_id, Class, actor_method_names, + actor_method_num_return_vals, checkpoint_interval, + ray.worker.global_worker + ) exported.append(0) actor_cursor = export_actor( actor_id, class_id, class_name, actor_method_names, actor_method_num_return_vals, actor_creation_resources, - actor_method_cpus, ray.worker.global_worker) + actor_method_cpus, ray.worker.global_worker + ) # Increment the actor counter to account for the creation task. actor_counter += 1 @@ -827,10 +898,10 @@ def pred(x): actor_object = cls.__new__(cls) actor_object._manual_init( actor_id, class_id, actor_handle_id, actor_cursor, - actor_counter, actor_method_names, - actor_method_num_return_vals, method_signatures, - checkpoint_interval, actor_cursor, actor_creation_resources, - actor_method_cpus) + actor_counter, actor_method_names, actor_method_num_return_vals, + method_signatures, checkpoint_interval, actor_cursor, + actor_creation_resources, actor_method_cpus + ) # Call __init__ as a remote function. if "__init__" in actor_object._ray_actor_method_names: @@ -838,7 +909,8 @@ def pred(x): "__init__", args=args, kwargs=kwargs, - dependency=actor_cursor) + dependency=actor_cursor + ) else: print("WARNING: this object has no __init__ method.") @@ -858,8 +930,9 @@ def __ray_terminate__(self, actor_id): # Record that this actor has been removed so that if this node # dies later, the actor won't be recreated. Alternatively, we could # remove the actor key from Redis here. - ray.worker.global_worker.redis_client.hset(b"Actor:" + actor_id, - "removed", True) + ray.worker.global_worker.redis_client.hset( + b"Actor:" + actor_id, "removed", True + ) # Disconnect the worker from the local scheduler. The point of this # is so that when the worker kills itself below, the local # scheduler won't push an error message to the driver. @@ -905,11 +978,13 @@ def __ray_checkpoint__(self): # on checkpoint resumption. actor_id = ray.local_scheduler.ObjectID(worker.actor_id) frontier = worker.local_scheduler_client.get_actor_frontier( - actor_id) + actor_id + ) # Save the checkpoint in Redis. TODO(rkn): Checkpoints # should not be stored in Redis. Fix this. - set_actor_checkpoint(worker, worker.actor_id, checkpoint_index, - checkpoint, frontier) + set_actor_checkpoint( + worker, worker.actor_id, checkpoint_index, checkpoint, frontier + ) def __ray_checkpoint_restore__(self): """Restore a checkpoint. @@ -924,14 +999,16 @@ def __ray_checkpoint_restore__(self): worker = ray.worker.global_worker # Get the most recent checkpoint stored, if any. checkpoint_index, checkpoint, frontier = get_actor_checkpoint( - worker, worker.actor_id) + worker, worker.actor_id + ) # Try to resume from the checkpoint. checkpoint_resumed = False if checkpoint_index is not None: # Load the actor state from the checkpoint. worker.actors[worker.actor_id] = ( - worker.actor_class.__ray_restore_from_checkpoint__( - checkpoint)) + worker.actor_class. + __ray_restore_from_checkpoint__(checkpoint) + ) # Set the number of tasks executed so far. worker.actor_task_counter = checkpoint_index # Set the actor frontier in the local scheduler. @@ -945,8 +1022,9 @@ def __ray_checkpoint_restore__(self): class_id = _random_string() - return actor_handle_from_class(Class, class_id, resources, - checkpoint_interval, actor_method_cpus) + return actor_handle_from_class( + Class, class_id, resources, checkpoint_interval, actor_method_cpus + ) ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index d5a5336f41b8..ad1e37be6a1b 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -59,7 +59,8 @@ "module": (str, OPTIONAL), # module, if using external node provider }, - REQUIRED), + REQUIRED + ), # How Ray will authenticate with newly launched nodes. "auth": ( @@ -67,7 +68,8 @@ "ssh_user": (str, REQUIRED), # e.g. ubuntu "ssh_private_key": (str, OPTIONAL), }, - REQUIRED), + REQUIRED + ), # Docker configuration. If this is specified, all setup and start commands # will be executed in the container. @@ -76,7 +78,8 @@ "image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3 "container_name": (str, OPTIONAL), # e.g., ray_docker }, - OPTIONAL), + OPTIONAL + ), # Provider-specific config for the head node, e.g. instance type. "head_node": (dict, OPTIONAL), @@ -144,8 +147,11 @@ def prune(mapping): for unwanted_key in unwanted: del mapping[unwanted_key] if unwanted: - print("Removed {} stale ip mappings: {} not in {}".format( - len(unwanted), unwanted, active_ips)) + print( + "Removed {} stale ip mappings: {} not in {}".format( + len(unwanted), unwanted, active_ips + ) + ) prune(self.last_used_time_by_ip) prune(self.static_resources_by_ip) @@ -155,8 +161,11 @@ def approx_workers_used(self): return self._info()["NumNodesUsed"] def debug_string(self): - return " - {}".format("\n - ".join( - ["{}: {}".format(k, v) for k, v in sorted(self._info().items())])) + return " - {}".format( + "\n - ".join([ + "{}: {}".format(k, v) for k, v in sorted(self._info().items()) + ]) + ) def _info(self): nodes_used = 0.0 @@ -185,8 +194,8 @@ def _info(self): ", ".join([ "{}/{} {}".format( round(resources_used[rid], 2), - round(resources_total[rid], 2), rid) - for rid in sorted(resources_used) + round(resources_total[rid], 2), rid + ) for rid in sorted(resources_used) ]), "NumNodesConnected": len(self.static_resources_by_ip), @@ -196,7 +205,8 @@ def _info(self): "Min={} Mean={} Max={}".format( int(np.min(idle_times)) if idle_times else -1, int(np.mean(idle_times)) if idle_times else -1, - int(np.max(idle_times)) if idle_times else -1), + int(np.max(idle_times)) if idle_times else -1 + ), } @@ -218,20 +228,23 @@ class StandardAutoscaler(object): until the target cluster size is met). """ - def __init__(self, - config_path, - load_metrics, - max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES, - max_failures=AUTOSCALER_MAX_NUM_FAILURES, - process_runner=subprocess, - verbose_updates=True, - node_updater_cls=NodeUpdaterProcess, - update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S): + def __init__( + self, + config_path, + load_metrics, + max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES, + max_failures=AUTOSCALER_MAX_NUM_FAILURES, + process_runner=subprocess, + verbose_updates=True, + node_updater_cls=NodeUpdaterProcess, + update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S + ): self.config_path = config_path self.reload_config(errors_fatal=True) self.load_metrics = load_metrics - self.provider = get_node_provider(self.config["provider"], - self.config["cluster_name"]) + self.provider = get_node_provider( + self.config["provider"], self.config["cluster_name"] + ) self.max_failures = max_failures self.max_concurrent_launches = max_concurrent_launches @@ -257,8 +270,10 @@ def update(self): self.reload_config(errors_fatal=False) self._update() except Exception as e: - print("StandardAutoscaler: Error during autoscaling: {}", - traceback.format_exc()) + print( + "StandardAutoscaler: Error during autoscaling: {}", + traceback.format_exc() + ) self.num_failures += 1 if self.num_failures > self.max_failures: print("*** StandardAutoscaler: Too many errors, abort. ***") @@ -273,8 +288,9 @@ def _update(self): self.last_update_time = time.time() nodes = self.workers() print(self.debug_string(nodes)) - self.load_metrics.prune_active_ips( - [self.provider.internal_ip(node_id) for node_id in nodes]) + self.load_metrics.prune_active_ips([ + self.provider.internal_ip(node_id) for node_id in nodes + ]) # Terminate any idle or out of date nodes last_used = self.load_metrics.last_used_time_by_ip @@ -285,13 +301,17 @@ def _update(self): if node_ip in last_used and last_used[node_ip] < horizon and \ len(nodes) - num_terminated > self.config["min_workers"]: num_terminated += 1 - print("StandardAutoscaler: Terminating idle node: " - "{}".format(node_id)) + print( + "StandardAutoscaler: Terminating idle node: " + "{}".format(node_id) + ) self.provider.terminate_node(node_id) elif not self.launch_config_ok(node_id): num_terminated += 1 - print("StandardAutoscaler: Terminating outdated node: " - "{}".format(node_id)) + print( + "StandardAutoscaler: Terminating outdated node: " + "{}".format(node_id) + ) self.provider.terminate_node(node_id) if num_terminated > 0: nodes = self.workers() @@ -301,8 +321,10 @@ def _update(self): num_terminated = 0 while len(nodes) > self.config["max_workers"]: num_terminated += 1 - print("StandardAutoscaler: Terminating unneeded node: " - "{}".format(nodes[-1])) + print( + "StandardAutoscaler: Terminating unneeded node: " + "{}".format(nodes[-1]) + ) self.provider.terminate_node(nodes[-1]) nodes = nodes[:-1] if num_terminated > 0: @@ -313,7 +335,8 @@ def _update(self): target_num = self.target_num_workers() if len(nodes) < target_num: self.launch_new_node( - min(self.max_concurrent_launches, target_num - len(nodes))) + min(self.max_concurrent_launches, target_num - len(nodes)) + ) print(self.debug_string()) # Process any completed updates @@ -347,13 +370,16 @@ def reload_config(self, errors_fatal=False): with open(self.config_path) as f: new_config = yaml.load(f.read()) validate_config(new_config) - new_launch_hash = hash_launch_conf(new_config["worker_nodes"], - new_config["auth"]) - new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [ - new_config["setup_commands"], - new_config["worker_setup_commands"], - new_config["worker_start_ray_commands"] - ]) + new_launch_hash = hash_launch_conf( + new_config["worker_nodes"], new_config["auth"] + ) + new_runtime_hash = hash_runtime_conf( + new_config["file_mounts"], [ + new_config["setup_commands"], + new_config["worker_setup_commands"], + new_config["worker_start_ray_commands"] + ] + ) self.config = new_config self.launch_hash = new_launch_hash self.runtime_hash = new_runtime_hash @@ -361,19 +387,23 @@ def reload_config(self, errors_fatal=False): if errors_fatal: raise e else: - print("StandardAutoscaler: Error parsing config: {}", - traceback.format_exc()) + print( + "StandardAutoscaler: Error parsing config: {}", + traceback.format_exc() + ) def target_num_workers(self): target_frac = self.config["target_utilization_fraction"] cur_used = self.load_metrics.approx_workers_used() ideal_num_workers = int(np.ceil(cur_used / float(target_frac))) - return min(self.config["max_workers"], - max(self.config["min_workers"], ideal_num_workers)) + return min( + self.config["max_workers"], + max(self.config["min_workers"], ideal_num_workers) + ) def launch_config_ok(self, node_id): - launch_conf = self.provider.node_tags(node_id).get( - TAG_RAY_LAUNCH_CONFIG) + launch_conf = self.provider.node_tags(node_id + ).get(TAG_RAY_LAUNCH_CONFIG) if self.launch_hash != launch_conf: return False return True @@ -383,7 +413,9 @@ def files_up_to_date(self, node_id): if applied != self.runtime_hash: print( "StandardAutoscaler: {} has runtime state {}, want {}".format( - node_id, applied, self.runtime_hash)) + node_id, applied, self.runtime_hash + ) + ) return False return True @@ -391,7 +423,8 @@ def recover_if_needed(self, node_id): if not self.can_update(node_id): return last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip.get( - self.provider.internal_ip(node_id), 0) + self.provider.internal_ip(node_id), 0 + ) if time.time() - last_heartbeat_time < AUTOSCALER_HEARTBEAT_TIMEOUT_S: return print("StandardAutoscaler: Restarting Ray on {}".format(node_id)) @@ -403,7 +436,8 @@ def recover_if_needed(self, node_id): with_head_node_ip(self.config["worker_start_ray_commands"]), self.runtime_hash, redirect_output=not self.verbose_updates, - process_runner=self.process_runner) + process_runner=self.process_runner + ) updater.start() self.updaters[node_id] = updater @@ -414,12 +448,16 @@ def update_if_needed(self, node_id): return if self.config.get("no_restart", False) and \ self.num_successful_updates.get(node_id, 0) > 0: - init_commands = (self.config["setup_commands"] + - self.config["worker_setup_commands"]) + init_commands = ( + self.config["setup_commands"] + + self.config["worker_setup_commands"] + ) else: - init_commands = (self.config["setup_commands"] + - self.config["worker_setup_commands"] + - self.config["worker_start_ray_commands"]) + init_commands = ( + self.config["setup_commands"] + + self.config["worker_setup_commands"] + + self.config["worker_start_ray_commands"] + ) updater = self.node_updater_cls( node_id, self.config["provider"], @@ -429,7 +467,8 @@ def update_if_needed(self, node_id): with_head_node_ip(init_commands), self.runtime_hash, redirect_output=not self.verbose_updates, - process_runner=self.process_runner) + process_runner=self.process_runner + ) updater.start() self.updaters[node_id] = updater @@ -453,7 +492,8 @@ def launch_new_node(self, count): TAG_RAY_NODE_TYPE: "Worker", TAG_RAY_NODE_STATUS: "Uninitialized", TAG_RAY_LAUNCH_CONFIG: self.launch_hash, - }, count) + }, count + ) # TODO(ekl) be less conservative in this check assert len(self.workers()) > num_before, \ "Num nodes failed to increase after creating a new node" @@ -471,10 +511,12 @@ def debug_string(self, nodes=None): suffix += " ({} updating)".format(len(self.updaters)) if self.num_failed_updates: suffix += " ({} failed to update)".format( - len(self.num_failed_updates)) + len(self.num_failed_updates) + ) return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format( datetime.now(), len(nodes), self.target_num_workers(), suffix, - self.load_metrics.debug_string()) + self.load_metrics.debug_string() + ) def typename(v): @@ -497,7 +539,9 @@ def check_required(config, schema): type_str = typename(v) raise ValueError( "Missing required config key `{}` of type {}".format( - k, type_str)) + k, type_str + ) + ) if not isinstance(v, type): check_required(config[k], v) @@ -508,8 +552,11 @@ def check_extraneous(config, schema): raise ValueError("Config {} is not a dictionary".format(config)) for k in config: if k not in schema: - raise ValueError("Unexpected config key `{}` not in {}".format( - k, list(schema.keys()))) + raise ValueError( + "Unexpected config key `{}` not in {}".format( + k, list(schema.keys()) + ) + ) v, kreq = schema[k] if v is None: continue @@ -518,7 +565,9 @@ def check_extraneous(config, schema): raise ValueError( "Config key `{}` has wrong type {}, expected {}".format( k, - type(config[k]).__name__, v.__name__)) + type(config[k]).__name__, v.__name__ + ) + ) else: check_extraneous(config[k], v) @@ -549,8 +598,7 @@ def with_head_node_ip(cmds): def hash_launch_conf(node_conf, auth): hasher = hashlib.sha1() - hasher.update( - json.dumps([node_conf, auth], sort_keys=True).encode("utf-8")) + hasher.update(json.dumps([node_conf, auth], sort_keys=True).encode("utf-8")) return hasher.hexdigest() diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 3e3facdcf86f..4da00342d57b 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -25,10 +25,14 @@ def key_pair(i, region): """Returns the ith default (aws_key_pair_name, key_pair_path).""" if i == 0: - return ("{}_{}".format(RAY, region), - os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region))) - return ("{}_{}_{}".format(RAY, i, region), - os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region))) + return ( + "{}_{}".format(RAY, region), + os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)) + ) + return ( + "{}_{}_{}".format(RAY, i, region), + os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)) + ) # Suppress excessive connection dropped logs from boto @@ -60,11 +64,14 @@ def _configure_iam_role(config): profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config) if profile is None: - print("Creating new instance profile {}".format( - DEFAULT_RAY_INSTANCE_PROFILE)) + print( + "Creating new instance profile {}". + format(DEFAULT_RAY_INSTANCE_PROFILE) + ) client = _client("iam", config) client.create_instance_profile( - InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE) + InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE + ) profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config) time.sleep(15) # wait for propagation @@ -87,13 +94,16 @@ def _configure_iam_role(config): "Action": "sts:AssumeRole", }, ], - })) + }) + ) role = _get_role(DEFAULT_RAY_IAM_ROLE, config) assert role is not None, "Failed to create role" role.attach_policy( - PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess") + PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess" + ) role.attach_policy( - PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess") + PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess" + ) profile.add_role(RoleName=role.name) time.sleep(15) # wait for propagation @@ -150,25 +160,28 @@ def _configure_subnet(config): if s.state == "available" and s.map_public_ip_on_launch ], reverse=True, # sort from Z-A - key=lambda subnet: subnet.availability_zone) + key=lambda subnet: subnet.availability_zone + ) if not subnets: raise Exception( "No usable subnets found, try manually creating an instance in " "your specified region to populate the list of subnets " "and trying this again. Note that the subnet must map public IPs " - "on instance launch.") + "on instance launch." + ) if "availability_zone" in config["provider"]: default_subnet = next(( s for s in subnets - if s.availability_zone == config["provider"]["availability_zone"]), - None) + if s.availability_zone == config["provider"]["availability_zone"] + ), None) if not default_subnet: raise Exception( "No usable subnets matching availability zone {} " "found. Choose a different availability zone or try " "manually creating an instance in your specified region " "to populate the list of subnets and trying this again." - .format(config["provider"]["availability_zone"])) + .format(config["provider"]["availability_zone"]) + ) else: default_subnet = subnets[0] @@ -176,15 +189,21 @@ def _configure_subnet(config): assert default_subnet.map_public_ip_on_launch, \ "The chosen subnet must map nodes with public IPs on launch" config["head_node"]["SubnetId"] = default_subnet.id - print("SubnetId not specified for head node, using {} in {}".format( - default_subnet.id, default_subnet.availability_zone)) + print( + "SubnetId not specified for head node, using {} in {}".format( + default_subnet.id, default_subnet.availability_zone + ) + ) if "SubnetId" not in config["worker_nodes"]: assert default_subnet.map_public_ip_on_launch, \ "The chosen subnet must map nodes with public IPs on launch" config["worker_nodes"]["SubnetId"] = default_subnet.id - print("SubnetId not specified for workers, using {} in {}".format( - default_subnet.id, default_subnet.availability_zone)) + print( + "SubnetId not specified for workers, using {} in {}".format( + default_subnet.id, default_subnet.availability_zone + ) + ) return config @@ -204,7 +223,8 @@ def _configure_security_group(config): client.create_security_group( Description="Auto-created security group for Ray workers", GroupName=group_name, - VpcId=subnet.vpc_id) + VpcId=subnet.vpc_id + ) security_group = _get_security_group(config, subnet.vpc_id, group_name) assert security_group, "Failed to create security group" @@ -224,16 +244,23 @@ def _configure_security_group(config): "IpRanges": [{ "CidrIp": "0.0.0.0/0" }] - }]) + }] + ) if "SecurityGroupIds" not in config["head_node"]: - print("SecurityGroupIds not specified for head node, using {}".format( - security_group.group_name)) + print( + "SecurityGroupIds not specified for head node, using {}".format( + security_group.group_name + ) + ) config["head_node"]["SecurityGroupIds"] = [security_group.id] if "SecurityGroupIds" not in config["worker_nodes"]: - print("SecurityGroupIds not specified for workers, using {}".format( - security_group.group_name)) + print( + "SecurityGroupIds not specified for workers, using {}".format( + security_group.group_name + ) + ) config["worker_nodes"]["SecurityGroupIds"] = [security_group.id] return config @@ -242,10 +269,13 @@ def _configure_security_group(config): def _get_subnet_or_die(config, subnet_id): ec2 = _resource("ec2", config) subnet = list( - ec2.subnets.filter(Filters=[{ - "Name": "subnet-id", - "Values": [subnet_id] - }])) + ec2.subnets.filter( + Filters=[{ + "Name": "subnet-id", + "Values": [subnet_id] + }] + ) + ) assert len(subnet) == 1, "Subnet not found" subnet = subnet[0] return subnet @@ -254,10 +284,13 @@ def _get_subnet_or_die(config, subnet_id): def _get_security_group(config, vpc_id, group_name): ec2 = _resource("ec2", config) existing_groups = list( - ec2.security_groups.filter(Filters=[{ - "Name": "vpc-id", - "Values": [vpc_id] - }])) + ec2.security_groups.filter( + Filters=[{ + "Name": "vpc-id", + "Values": [vpc_id] + }] + ) + ) for sg in existing_groups: if sg.group_name == group_name: return sg @@ -285,10 +318,12 @@ def _get_instance_profile(profile_name, config): def _get_key(key_name, config): ec2 = _resource("ec2", config) - for key in ec2.key_pairs.filter(Filters=[{ + for key in ec2.key_pairs.filter( + Filters=[{ "Name": "key-name", "Values": [key_name] - }]): + }] + ): if key.name == key_name: return key @@ -301,4 +336,5 @@ def _client(name, config): def _resource(name, config): boto_config = Config(retries=dict(max_attempts=BOTO_MAX_RETRIES)) return boto3.resource( - name, config["provider"]["region"], config=boto_config) + name, config["provider"]["region"], config=boto_config + ) diff --git a/python/ray/autoscaler/aws/node_provider.py b/python/ray/autoscaler/aws/node_provider.py index a31d3e51b854..3a73f4f7ee80 100644 --- a/python/ray/autoscaler/aws/node_provider.py +++ b/python/ray/autoscaler/aws/node_provider.py @@ -15,7 +15,8 @@ def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) config = Config(retries=dict(max_attempts=BOTO_MAX_RETRIES)) self.ec2 = boto3.resource( - "ec2", region_name=provider_config["region"], config=config) + "ec2", region_name=provider_config["region"], config=config + ) # Cache of node objects from the last nodes() call. This avoids # excessive DescribeInstances requests. diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 5ccc8eaac60b..3df27f6b2a57 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -23,8 +23,9 @@ from ray.autoscaler.updater import NodeUpdaterProcess -def create_or_update_cluster(config_file, override_min_workers, - override_max_workers, no_restart, yes): +def create_or_update_cluster( + config_file, override_min_workers, override_max_workers, no_restart, yes +): """Create or updates an autoscaling Ray cluster from a config json.""" config = yaml.load(open(config_file).read()) @@ -38,8 +39,9 @@ def create_or_update_cluster(config_file, override_min_workers, importer = NODE_PROVIDERS.get(config["provider"]["type"]) if not importer: - raise NotImplementedError("Unsupported provider {}".format( - config["provider"])) + raise NotImplementedError( + "Unsupported provider {}".format(config["provider"]) + ) bootstrap_config, _ = importer() config = bootstrap_config(config) @@ -91,7 +93,8 @@ def get_or_create_head_node(config, no_restart, yes): launch_hash = hash_launch_conf(config["head_node"], config["auth"]) if head_node is None or provider.node_tags(head_node).get( - TAG_RAY_LAUNCH_CONFIG) != launch_hash: + TAG_RAY_LAUNCH_CONFIG + ) != launch_hash: if head_node is not None: confirm("Head node config out-of-date. It will be terminated", yes) print("Terminating outdated head node {}".format(head_node)) @@ -124,7 +127,8 @@ def get_or_create_head_node(config, no_restart, yes): # Now inject the rewritten config and SSH key into the head node remote_config_file = tempfile.NamedTemporaryFile( - "w", prefix="ray-bootstrap-") + "w", prefix="ray-bootstrap-" + ) remote_config_file.write(json.dumps(remote_config)) remote_config_file.flush() config["file_mounts"].update({ @@ -136,11 +140,13 @@ def get_or_create_head_node(config, no_restart, yes): if no_restart: init_commands = ( - config["setup_commands"] + config["head_setup_commands"]) + config["setup_commands"] + config["head_setup_commands"] + ) else: init_commands = ( config["setup_commands"] + config["head_setup_commands"] + - config["head_start_ray_commands"]) + config["head_start_ray_commands"] + ) updater = NodeUpdaterProcess( head_node, @@ -150,7 +156,8 @@ def get_or_create_head_node(config, no_restart, yes): config["file_mounts"], init_commands, runtime_hash, - redirect_output=False) + redirect_output=False + ) updater.start() updater.join() @@ -158,27 +165,39 @@ def get_or_create_head_node(config, no_restart, yes): provider.nodes(head_node_tags) if updater.exitcode != 0: - print("Error: updating {} failed".format( - provider.external_ip(head_node))) + print( + "Error: updating {} failed".format(provider.external_ip(head_node)) + ) sys.exit(1) - print("Head node up-to-date, IP address is: {}".format( - provider.external_ip(head_node))) + print( + "Head node up-to-date, IP address is: {}".format( + provider.external_ip(head_node) + ) + ) monitor_str = "tail -f /tmp/raylogs/monitor-*" for s in init_commands: - if ("ray start" in s and "docker exec" in s - and "--autoscaling-config" in s): + if ( + "ray start" in s and "docker exec" in s + and "--autoscaling-config" in s + ): monitor_str = "docker exec {} /bin/sh -c {}".format( - config["docker"]["container_name"], quote(monitor_str)) - print("To monitor auto-scaling activity, you can run:\n\n" - " ssh -i {} {}@{} {}\n".format(config["auth"]["ssh_private_key"], - config["auth"]["ssh_user"], - provider.external_ip(head_node), - quote(monitor_str))) - print("To login to the cluster, run:\n\n" - " ssh -i {} {}@{}\n".format(config["auth"]["ssh_private_key"], - config["auth"]["ssh_user"], - provider.external_ip(head_node))) + config["docker"]["container_name"], quote(monitor_str) + ) + print( + "To monitor auto-scaling activity, you can run:\n\n" + " ssh -i {} {}@{} {}\n".format( + config["auth"]["ssh_private_key"], config["auth"]["ssh_user"], + provider.external_ip(head_node), quote(monitor_str) + ) + ) + print( + "To login to the cluster, run:\n\n" + " ssh -i {} {}@{}\n".format( + config["auth"]["ssh_private_key"], config["auth"]["ssh_user"], + provider.external_ip(head_node) + ) + ) def get_head_node_ip(config_file): @@ -194,8 +213,11 @@ def get_head_node_ip(config_file): head_node = nodes[0] return provider.external_ip(head_node) else: - print("Head node of cluster ({}) not found!".format( - config["cluster_name"])) + print( + "Head node of cluster ({}) not found!".format( + config["cluster_name"] + ) + ) sys.exit(1) diff --git a/python/ray/autoscaler/docker.py b/python/ray/autoscaler/docker.py index 8d8e2e8a2b10..82380958a742 100644 --- a/python/ray/autoscaler/docker.py +++ b/python/ray/autoscaler/docker.py @@ -23,21 +23,27 @@ def dockerize_if_needed(config): docker_mounts = {dst: dst for dst in config["file_mounts"]} config["setup_commands"] = ( docker_install_cmds() + docker_start_cmds( - config["auth"]["ssh_user"], docker_image, docker_mounts, cname) + - with_docker_exec(config["setup_commands"], container_name=cname)) + config["auth"]["ssh_user"], docker_image, docker_mounts, cname + ) + with_docker_exec(config["setup_commands"], container_name=cname) + ) config["head_setup_commands"] = with_docker_exec( - config["head_setup_commands"], container_name=cname) + config["head_setup_commands"], container_name=cname + ) config["head_start_ray_commands"] = ( docker_autoscaler_setup(cname) + with_docker_exec( - config["head_start_ray_commands"], container_name=cname)) + config["head_start_ray_commands"], container_name=cname + ) + ) config["worker_setup_commands"] = with_docker_exec( - config["worker_setup_commands"], container_name=cname) + config["worker_setup_commands"], container_name=cname + ) config["worker_start_ray_commands"] = with_docker_exec( config["worker_start_ray_commands"], container_name=cname, - env_vars=["RAY_HEAD_IP"]) + env_vars=["RAY_HEAD_IP"] + ) return config @@ -45,11 +51,13 @@ def dockerize_if_needed(config): def with_docker_exec(cmds, container_name, env_vars=None): env_str = "" if env_vars: - env_str = " ".join( - ["-e {env}=${env}".format(env=env) for env in env_vars]) + env_str = " ".join([ + "-e {env}=${env}".format(env=env) for env in env_vars + ]) return [ - "docker exec {} {} /bin/sh -c {} ".format(env_str, container_name, - quote(cmd)) for cmd in cmds + "docker exec {} {} /bin/sh -c {} ".format( + env_str, container_name, quote(cmd) + ) for cmd in cmds ] @@ -61,10 +69,12 @@ def docker_install_cmds(): def aptwait_cmd(): - return ("while sudo fuser" - " /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock" - " >/dev/null 2>&1; " - "do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done") + return ( + "while sudo fuser" + " /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock" + " >/dev/null 2>&1; " + "do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done" + ) def docker_start_cmds(user, image, mount, cname): @@ -81,13 +91,15 @@ def docker_start_cmds(user, image, mount, cname): "-p {port}:{port}".format(port=port) for port in ["6379", "8076", "4321"] ]) - mount_flags = " ".join( - ["-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items()]) + mount_flags = " ".join([ + "-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items() + ]) # for click, used in ray cli env_vars = {"LC_ALL": "C.UTF-8", "LANG": "C.UTF-8"} - env_flags = " ".join( - ["-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items()]) + env_flags = " ".join([ + "-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items() + ]) # docker run command docker_run = [ @@ -108,10 +120,13 @@ def docker_autoscaler_setup(cname): for path in ["~/ray_bootstrap_config.yaml", "~/ray_bootstrap_key.pem"]: # needed because docker doesn't allow relative paths base_path = os.path.basename(path) - cmds.append("docker cp {path} {cname}:{dpath}".format( - path=path, dpath=base_path, cname=cname)) + cmds.append( + "docker cp {path} {cname}:{dpath}".format( + path=path, dpath=base_path, cname=cname + ) + ) cmds.extend( - with_docker_exec( - ["cp {} {}".format("/" + base_path, path)], - container_name=cname)) + with_docker_exec(["cp {} {}".format("/" + base_path, path)], + container_name=cname) + ) return cmds diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 9aa4bb994345..cac0db191a7f 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -56,7 +56,8 @@ def load_class(path): class_data = path.split(".") if len(class_data) < 2: raise ValueError( - "You need to pass a valid path like mymodule.provider_class") + "You need to pass a valid path like mymodule.provider_class" + ) module_path = ".".join(class_data[:-1]) class_str = class_data[-1] module = importlib.import_module(module_path) @@ -71,8 +72,9 @@ def get_node_provider(provider_config, cluster_name): importer = NODE_PROVIDERS.get(provider_config["type"]) if importer is None: - raise NotImplementedError("Unsupported node provider: {}".format( - provider_config["type"])) + raise NotImplementedError( + "Unsupported node provider: {}".format(provider_config["type"]) + ) _, provider_cls = importer() return provider_cls(provider_config, cluster_name) @@ -82,8 +84,9 @@ def get_default_config(provider_config): return {} load_config = DEFAULT_CONFIGS.get(provider_config["type"]) if load_config is None: - raise NotImplementedError("Unsupported node provider: {}".format( - provider_config["type"])) + raise NotImplementedError( + "Unsupported node provider: {}".format(provider_config["type"]) + ) path_to_default = load_config() with open(path_to_default) as f: defaults = yaml.load(f) diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index af13f0fecf7b..0bc009a29c5a 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -26,16 +26,18 @@ def pretty_cmd(cmd_str): class NodeUpdater(object): """A process for syncing files and running init commands on a node.""" - def __init__(self, - node_id, - provider_config, - auth_config, - cluster_name, - file_mounts, - setup_cmds, - runtime_hash, - redirect_output=True, - process_runner=subprocess): + def __init__( + self, + node_id, + provider_config, + auth_config, + cluster_name, + file_mounts, + setup_cmds, + runtime_hash, + redirect_output=True, + process_runner=subprocess + ): self.daemon = True self.process_runner = process_runner self.provider = get_node_provider(provider_config, cluster_name) @@ -48,7 +50,8 @@ def __init__(self, self.runtime_hash = runtime_hash if redirect_output: self.logfile = tempfile.NamedTemporaryFile( - mode="w", prefix="node-updater-", delete=False) + mode="w", prefix="node-updater-", delete=False + ) self.output_name = self.logfile.name self.stdout = self.logfile self.stderr = self.logfile @@ -59,39 +62,52 @@ def __init__(self, self.stderr = sys.stderr def run(self): - print("NodeUpdater: Updating {} to {}, logging to {}".format( - self.node_id, self.runtime_hash, self.output_name)) + print( + "NodeUpdater: Updating {} to {}, logging to {}".format( + self.node_id, self.runtime_hash, self.output_name + ) + ) try: self.do_update() except Exception as e: error_str = str(e) if hasattr(e, "cmd"): error_str = "(Exit Status {}) {}".format( - e.returncode, pretty_cmd(" ".join(e.cmd))) + e.returncode, pretty_cmd(" ".join(e.cmd)) + ) print( "NodeUpdater: Error updating {}" "See {} for remote logs.".format(error_str, self.output_name), - file=self.stdout) - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "UpdateFailed"}) + file=self.stdout + ) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"} + ) if self.logfile is not None: - print("----- BEGIN REMOTE LOGS -----\n" + open( - self.logfile.name).read() + "\n----- END REMOTE LOGS -----" - ) + print( + "----- BEGIN REMOTE LOGS -----\n" + + open(self.logfile.name).read() + + "\n----- END REMOTE LOGS -----" + ) raise e self.provider.set_node_tags( - self.node_id, { + self.node_id, + { TAG_RAY_NODE_STATUS: "Up-to-date", TAG_RAY_RUNTIME_CONFIG: self.runtime_hash - }) + } + ) print( "NodeUpdater: Applied config {} to node {}".format( - self.runtime_hash, self.node_id), - file=self.stdout) + self.runtime_hash, self.node_id + ), + file=self.stdout + ) def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "WaitingForSSH"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"} + ) deadline = time.time() + NODE_START_WAIT_S # Wait for external IP @@ -99,7 +115,8 @@ def do_update(self): not self.provider.is_terminated(self.node_id): print( "NodeUpdater: Waiting for IP of {}...".format(self.node_id), - file=self.stdout) + file=self.stdout + ) self.ssh_ip = self.provider.external_ip(self.node_id) if self.ssh_ip is not None: break @@ -113,36 +130,44 @@ def do_update(self): try: print( "NodeUpdater: Waiting for SSH to {}...".format( - self.node_id), - file=self.stdout) + self.node_id + ), + file=self.stdout + ) if not self.provider.is_running(self.node_id): raise Exception("Node not running yet...") self.ssh_cmd( "uptime", connect_timeout=5, - redirect=open("/dev/null", "w")) + redirect=open("/dev/null", "w") + ) ssh_ok = True except Exception as e: retry_str = str(e) if hasattr(e, "cmd"): retry_str = "(Exit Status {}): {}".format( - e.returncode, pretty_cmd(" ".join(e.cmd))) + e.returncode, pretty_cmd(" ".join(e.cmd)) + ) print( "NodeUpdater: SSH not up, retrying: {}".format(retry_str), - file=self.stdout) + file=self.stdout + ) time.sleep(5) else: break assert ssh_ok, "Unable to SSH to node" # Rsync file mounts - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "SyncingFiles"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"} + ) for remote_path, local_path in self.file_mounts.items(): print( "NodeUpdater: Syncing {} to {}...".format( - local_path, remote_path), - file=self.stdout) + local_path, remote_path + ), + file=self.stdout + ) assert os.path.exists(local_path) if os.path.isdir(local_path): if not local_path.endswith("/"): @@ -150,19 +175,19 @@ def do_update(self): if not remote_path.endswith("/"): remote_path += "/" self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path))) - self.process_runner.check_call( - [ - "rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) + - "-o ConnectTimeout=120s -o StrictHostKeyChecking=no", - "--delete", "-avz", "{}".format(local_path), - "{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path) - ], - stdout=self.stdout, - stderr=self.stderr) + self.process_runner.check_call([ + "rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) + + "-o ConnectTimeout=120s -o StrictHostKeyChecking=no", + "--delete", "-avz", "{}".format(local_path), + "{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path) + ], + stdout=self.stdout, + stderr=self.stderr) # Run init commands - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "SettingUp"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"} + ) for cmd in self.setup_cmds: self.ssh_cmd(cmd, verbose=True) @@ -170,19 +195,19 @@ def ssh_cmd(self, cmd, connect_timeout=120, redirect=None, verbose=False): if verbose: print( "NodeUpdater: running {} on {}...".format( - pretty_cmd(cmd), self.ssh_ip), - file=self.stdout) + pretty_cmd(cmd), self.ssh_ip + ), + file=self.stdout + ) force_interactive = "set -i && source ~/.bashrc && " - self.process_runner.check_call( - [ - "ssh", "-o", "ConnectTimeout={}s".format(connect_timeout), - "-o", "StrictHostKeyChecking=no", - "-i", self.ssh_private_key, "{}@{}".format( - self.ssh_user, self.ssh_ip), "bash --login -c {}".format( - pipes.quote(force_interactive + cmd)) - ], - stdout=redirect or self.stdout, - stderr=redirect or self.stderr) + self.process_runner.check_call([ + "ssh", "-o", "ConnectTimeout={}s".format(connect_timeout), "-o", + "StrictHostKeyChecking=no", "-i", self.ssh_private_key, + "{}@{}".format(self.ssh_user, self.ssh_ip), + "bash --login -c {}".format(pipes.quote(force_interactive + cmd)) + ], + stdout=redirect or self.stdout, + stderr=redirect or self.stderr) class NodeUpdaterProcess(NodeUpdater, Process): diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index 3fb0425e0b27..b20a44070cdb 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -65,224 +65,295 @@ def testInvalidObjectTableAdd(self): with self.assertRaises(redis.ResponseError): self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - "one", "hash2", "manager_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id2", "one", "hash2", + "manager_id1" + ) with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, - "hash2", "manager_id1", - "extra argument") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id2", 1, "hash2", "manager_id1", + "extra argument" + ) # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an # object ID that is already present with a different hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1"}) with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2" + ) # Check that the second manager was added, even though the hash was # mismatched. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Check that it is fine if we add the same object ID multiple times # with the most recent hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, - "hash2", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) def testObjectTableAddAndLookup(self): # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(response, None) # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Add a manager that already exists again and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Check that we properly handle NULL characters. In the past, NULL # characters were handled improperly causing a "hash mismatch" error if # two object IDs that agreed up to the NULL character were inserted # with different hashes. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, - "hash2", "manager_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, "hash1", "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, "hash2", "manager_id1" + ) # Check that NULL characters in the hash are handled properly. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash1", "manager_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash1", "manager_id1" + ) with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash2", "manager_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id3", 1, "\x00hash2", + "manager_id1" + ) def testObjectTableAddAndRemove(self): # Try removing a manager from an object ID that has not been added yet. with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1" + ) # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(response, None) # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Remove a manager that doesn't exist, and make sure we still have the # same set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) # Remove a manager that does exist. Make sure it gets removed the first # time and does nothing the second time. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id2"}) - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id1" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), {b"manager_id2"}) # Remove the last manager, and make sure we have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), set()) # Remove a manager from an empty set, and make sure we now have an # empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3" + ) + response = self.redis.execute_command( + "RAY.OBJECT_TABLE_LOOKUP", "object_id1" + ) self.assertEqual(set(response), set()) def testObjectTableSubscribeToNotifications(self): # Define a helper method for checking the contents of object # notifications. - def check_object_notification(notification_message, object_id, - object_size, manager_ids): - notification_object = (SubscribeToNotificationsReply. - GetRootAsSubscribeToNotificationsReply( - notification_message, 0)) + def check_object_notification( + notification_message, object_id, object_size, manager_ids + ): + notification_object = ( + SubscribeToNotificationsReply. + GetRootAsSubscribeToNotificationsReply(notification_message, 0) + ) self.assertEqual(notification_object.ObjectId(), object_id) self.assertEqual(notification_object.ObjectSize(), object_size) - self.assertEqual(notification_object.ManagerIdsLength(), - len(manager_ids)) + self.assertEqual( + notification_object.ManagerIdsLength(), len(manager_ids) + ) for i in range(len(manager_ids)): self.assertEqual( - notification_object.ManagerIds(i), manager_ids[i]) + notification_object.ManagerIds(i), manager_ids[i] + ) data_size = 0xf1f0 p = self.redis.pubsub() # Subscribe to an object ID. p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX)) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", - data_size, "hash1", "manager_id2") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", + "manager_id2" + ) # Receive the acknowledgement message. self.assertEqual(get_next_message(p)["data"], 1) # Request a notification and receive the data. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", + "object_id1" + ) # Verify that the notification is correct. check_object_notification( get_next_message(p)["data"], b"object_id1", data_size, - [b"manager_id2"]) + [b"manager_id2"] + ) # Request a notification for an object that isn't there. Then add the # object and receive the data. Only the first call to # RAY.OBJECT_TABLE_ADD should trigger notifications. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id2", "object_id3") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id3") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", + "object_id2", "object_id3" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", + "manager_id1" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", + "manager_id2" + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", + "manager_id3" + ) # Verify that the notification is correct. check_object_notification( get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1"]) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - data_size, "hash1", "manager_id3") + [b"manager_id1"] + ) + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", + "manager_id3" + ) # Verify that the notification is correct. check_object_notification( get_next_message(p)["data"], b"object_id2", data_size, - [b"manager_id3"]) + [b"manager_id3"] + ) # Request notifications for object_id3 again. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id3") + self.redis.execute_command( + "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", + "object_id3" + ) # Verify that the notification is correct. check_object_notification( get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1", b"manager_id2", b"manager_id3"]) + [b"manager_id1", b"manager_id2", b"manager_id3"] + ) def testResultTableAddAndLookup(self): def check_result_table_entry(message, task_id, is_put): result_table_reply = ResultTableReply.GetRootAsResultTableReply( - message, 0) + message, 0 + ) self.assertEqual(result_table_reply.TaskId(), task_id) self.assertEqual(result_table_reply.IsPut(), is_put) # Try looking up something in the result table before anything is # added. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") + response = self.redis.execute_command( + "RAY.RESULT_TABLE_LOOKUP", "object_id1" + ) self.assertIsNone(response) # Adding the object to the object table should have no effect. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1" + ) + response = self.redis.execute_command( + "RAY.RESULT_TABLE_LOOKUP", "object_id1" + ) self.assertIsNone(response) # Add the result to the result table. The lookup now returns the task # ID. task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", - task_id, 0) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") + self.redis.execute_command( + "RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0 + ) + response = self.redis.execute_command( + "RAY.RESULT_TABLE_LOOKUP", "object_id1" + ) check_result_table_entry(response, task_id, False) # Doing it again should still work. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") + response = self.redis.execute_command( + "RAY.RESULT_TABLE_LOOKUP", "object_id1" + ) check_result_table_entry(response, task_id, False) # Try another result table lookup. This should succeed. task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", - task_id, 1) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id2") + self.redis.execute_command( + "RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1 + ) + response = self.redis.execute_command( + "RAY.RESULT_TABLE_LOOKUP", "object_id2" + ) check_result_table_entry(response, task_id, True) def testInvalidTaskTableAdd(self): @@ -293,16 +364,20 @@ def testInvalidTaskTableAdd(self): with self.assertRaises(redis.ResponseError): self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, - "node_id") + self.redis.execute_command( + "RAY.TASK_TABLE_ADD", "task_id", 3, "node_id" + ) with self.assertRaises(redis.ResponseError): # Non-integer scheduling states should not be added. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - "invalid_state", "node_id", "task_spec") + self.redis.execute_command( + "RAY.TASK_TABLE_ADD", "task_id", "invalid_state", "node_id", + "task_spec" + ) with self.assertRaises(redis.ResponseError): # Should not be able to update a non-existent task. - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, - "node_id", b"") + self.redis.execute_command( + "RAY.TASK_TABLE_UPDATE", "task_id", 10, "node_id", b"" + ) def testTaskTableAddAndLookup(self): TASK_STATUS_WAITING = 1 @@ -315,71 +390,84 @@ def testTaskTableAddAndLookup(self): p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) def check_task_reply(message, task_args, updated=False): - (task_status, local_scheduler_id, execution_dependencies_string, - spillback_count, task_spec) = task_args + ( + task_status, local_scheduler_id, execution_dependencies_string, + spillback_count, task_spec + ) = task_args task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) self.assertEqual(task_reply_object.State(), task_status) - self.assertEqual(task_reply_object.LocalSchedulerId(), - local_scheduler_id) - self.assertEqual(task_reply_object.SpillbackCount(), - spillback_count) + self.assertEqual( + task_reply_object.LocalSchedulerId(), local_scheduler_id + ) + self.assertEqual( + task_reply_object.SpillbackCount(), spillback_count + ) self.assertEqual(task_reply_object.TaskSpec(), task_spec) self.assertEqual(task_reply_object.Updated(), updated) # Check that task table adds, updates, and lookups work correctly. task_args = [TASK_STATUS_WAITING, b"node_id", b"", 0, b"task_spec"] - response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - *task_args) + response = self.redis.execute_command( + "RAY.TASK_TABLE_ADD", "task_id", *task_args + ) response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") check_task_reply(response, task_args) task_args[0] = TASK_STATUS_SCHEDULED - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", - *task_args[:4]) + self.redis.execute_command( + "RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:4] + ) response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") check_task_reply(response, task_args) # If the current value, test value, and set value are all the same, the # update happens, and the response is still the same task. task_args = [task_args[0]] + task_args - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) + response = self.redis.execute_command( + "RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3] + ) check_task_reply(response, task_args[1:], updated=True) # Check that the task entry is still the same. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") + get_response = self.redis.execute_command( + "RAY.TASK_TABLE_GET", "task_id" + ) check_task_reply(get_response, task_args[1:]) # If the current value is the same as the test value, and the set value # is different, the update happens, and the response is the entire # task. task_args[1] = TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) + response = self.redis.execute_command( + "RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3] + ) check_task_reply(response, task_args[1:], updated=True) # Check that the update happened. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") + get_response = self.redis.execute_command( + "RAY.TASK_TABLE_GET", "task_id" + ) check_task_reply(get_response, task_args[1:]) # If the current value is no longer the same as the test value, the # response is the same task as before the test-and-set. new_task_args = task_args[:] new_task_args[1] = TASK_STATUS_WAITING - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) + response = self.redis.execute_command( + "RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *new_task_args[:3] + ) check_task_reply(response, task_args[1:], updated=False) # Check that the update did not happen. - get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") + get_response2 = self.redis.execute_command( + "RAY.TASK_TABLE_GET", "task_id" + ) self.assertEqual(get_response2, get_response) # If the test value is a bitmask that matches the current value, the # update happens. task_args = new_task_args task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) + response = self.redis.execute_command( + "RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3] + ) check_task_reply(response, task_args[1:], updated=True) # If the test value is a bitmask that does not match the current value, @@ -388,12 +476,14 @@ def check_task_reply(message, task_args, updated=False): new_task_args = task_args[:] new_task_args[0] = TASK_STATUS_SCHEDULED old_response = response - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) + response = self.redis.execute_command( + "RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *new_task_args[:3] + ) check_task_reply(response, task_args[1:], updated=False) # Check that the update did not happen. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") + get_response = self.redis.execute_command( + "RAY.TASK_TABLE_GET", "task_id" + ) self.assertNotEqual(get_response, old_response) check_task_reply(get_response, task_args[1:]) @@ -410,8 +500,9 @@ def check_task_subscription(self, p, scheduling_state, local_scheduler_id): self.assertEqual(notification_object.TaskId(), task_args[0]) self.assertEqual(notification_object.State(), task_args[1]) self.assertEqual(notification_object.LocalSchedulerId(), task_args[2]) - self.assertEqual(notification_object.ExecutionDependencies(), - task_args[3]) + self.assertEqual( + notification_object.ExecutionDependencies(), task_args[3] + ) self.assertEqual(notification_object.TaskSpec(), task_args[-1]) def testTaskTableSubscribe(self): @@ -428,23 +519,35 @@ def testTaskTableSubscribe(self): # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) - p.psubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) + p.psubscribe( + "{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state + ) + ) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) + p.punsubscribe( + "{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state + ) + ) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) - p.psubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + p.psubscribe( + "{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id + ) + ) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + p.punsubscribe( + "{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id + ) + ) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py index 5892d289fa73..ecce9cfecce9 100644 --- a/python/ray/common/test/test.py +++ b/python/ray/common/test/test.py @@ -47,8 +47,10 @@ def random_task_id(): TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS] DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS] -SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS + - TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS) +SIMPLE_OBJECTS = ( + BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS + TUPLE_SIMPLE_OBJECTS + + DICT_SIMPLE_OBJECTS +) # Create some complex objects that cannot be serialized by value in tasks. @@ -71,8 +73,10 @@ def __init__(self): TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS] DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS] -COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS + - TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS) +COMPLEX_OBJECTS = ( + BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS + TUPLE_COMPLEX_OBJECTS + + DICT_COMPLEX_OBJECTS +) class TestSerialization(unittest.TestCase): @@ -168,8 +172,9 @@ def test_create_and_serialize_task(self): object_ids + 100 * ["a"] + object_ids] for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(driver_id, function_id, args, - num_return_vals, parent_id, 0) + task = local_scheduler.Task( + driver_id, function_id, args, num_return_vals, parent_id, 0 + ) self.check_task(task, function_id, num_return_vals, args) data = local_scheduler.task_to_string(task) task2 = local_scheduler.task_from_string(data) diff --git a/python/ray/dataframe/__init__.py b/python/ray/dataframe/__init__.py index 7eea37f99f2a..ea5faf724f0b 100644 --- a/python/ray/dataframe/__init__.py +++ b/python/ray/dataframe/__init__.py @@ -11,8 +11,10 @@ pd_minor = int(pd_version.split(".")[1]) if pd_major == 0 and pd_minor < 22: - raise Exception("In order to use Pandas on Ray, please upgrade your Pandas" - " version to >= 0.22.") + raise Exception( + "In order to use Pandas on Ray, please upgrade your Pandas" + " version to >= 0.22." + ) DEFAULT_NPARTITIONS = 8 @@ -30,10 +32,21 @@ def get_npartitions(): # because they depend on npartitions. from .dataframe import DataFrame # noqa: 402 from .series import Series # noqa: 402 -from .io import (read_csv, read_parquet, read_json, read_html, # noqa: 402 - read_clipboard, read_excel, read_hdf, read_feather, # noqa: 402 - read_msgpack, read_stata, read_sas, read_pickle, # noqa: 402 - read_sql) # noqa: 402 +from .io import ( + read_csv, + read_parquet, + read_json, + read_html, # noqa: 402 + read_clipboard, + read_excel, + read_hdf, + read_feather, # noqa: 402 + read_msgpack, + read_stata, + read_sas, + read_pickle, # noqa: 402 + read_sql +) # noqa: 402 from .concat import concat # noqa: 402 __all__ = [ diff --git a/python/ray/dataframe/concat.py b/python/ray/dataframe/concat.py index 952e326edc1f..1f8c9064c024 100644 --- a/python/ray/dataframe/concat.py +++ b/python/ray/dataframe/concat.py @@ -8,9 +8,18 @@ from .utils import _reindex_helper -def concat(objs, axis=0, join='outer', join_axes=None, ignore_index=False, - keys=None, levels=None, names=None, verify_integrity=False, - copy=True): +def concat( + objs, + axis=0, + join='outer', + join_axes=None, + ignore_index=False, + keys=None, + levels=None, + names=None, + verify_integrity=False, + copy=True +): if keys is not None: objs = [objs[k] for k in keys] @@ -26,39 +35,46 @@ def concat(objs, axis=0, join='outer', join_axes=None, ignore_index=False, raise ValueError("All objects passed were None") try: - type_check = next(obj for obj in objs - if not isinstance(obj, (pandas.Series, - pandas.DataFrame, - DataFrame))) + type_check = next( + obj for obj in objs if + not isinstance(obj, + (pandas.Series, pandas.DataFrame, DataFrame)) + ) except StopIteration: type_check = None if type_check is not None: - raise ValueError("cannot concatenate object of type \"{0}\"; only " - "pandas.Series, pandas.DataFrame, " - "and ray.dataframe.DataFrame objs are " - "valid", type(type_check)) - - all_series = all([isinstance(obj, pandas.Series) - for obj in objs]) + raise ValueError( + "cannot concatenate object of type \"{0}\"; only " + "pandas.Series, pandas.DataFrame, " + "and ray.dataframe.DataFrame objs are " + "valid", type(type_check) + ) + + all_series = all([isinstance(obj, pandas.Series) for obj in objs]) if all_series: - return pandas.concat(objs, axis, join, join_axes, - ignore_index, keys, levels, names, - verify_integrity, copy) + return pandas.concat( + objs, axis, join, join_axes, ignore_index, keys, levels, names, + verify_integrity, copy + ) if isinstance(objs, dict): raise NotImplementedError( - "Obj as dicts not implemented. To contribute to " - "Pandas on Ray, please visit github.com/ray-project/ray.") + "Obj as dicts not implemented. To contribute to " + "Pandas on Ray, please visit github.com/ray-project/ray." + ) axis = pandas.DataFrame()._get_axis_number(axis) if join not in ['inner', 'outer']: - raise ValueError("Only can inner (intersect) or outer (union) join the" - " other axis") + raise ValueError( + "Only can inner (intersect) or outer (union) join the" + " other axis" + ) # We need this in a list because we use it later. - all_index, all_columns = list(zip(*[(obj.index, obj.columns) - for obj in objs])) + all_index, all_columns = list( + zip(*[(obj.index, obj.columns) for obj in objs]) + ) def series_to_df(series, columns): df = pandas.DataFrame(series) @@ -69,8 +85,10 @@ def series_to_df(series, columns): # true regardless of the existence of another column named 0 in the # concat. if axis == 0: - objs = [series_to_df(obj, [0]) - if isinstance(obj, pandas.Series) else obj for obj in objs] + objs = [ + series_to_df(obj, [0]) if isinstance(obj, pandas.Series) else obj + for obj in objs + ] else: # Pandas starts the count at 0 so this will increment the names as # long as there's a new nameless Series being added. @@ -80,9 +98,11 @@ def name_incrementer(i): return val i = [0] - objs = [series_to_df(obj, obj.name if obj.name is not None - else name_incrementer(i)) - if isinstance(obj, pandas.Series) else obj for obj in objs] + objs = [ + series_to_df( + obj, obj.name if obj.name is not None else name_incrementer(i) + ) if isinstance(obj, pandas.Series) else obj for obj in objs + ] # Using concat on the columns and index is fast because they're empty, # and it forces the error checking. It also puts the columns in the @@ -103,31 +123,39 @@ def name_incrementer(i): # Put all of the DataFrames into Ray format # TODO just partition the DataFrames instead of building a new Ray DF. - objs = [DataFrame(obj) if isinstance(obj, (pandas.DataFrame, - pandas.Series)) else obj - for obj in objs] + objs = [ + DataFrame(obj) if isinstance(obj, + (pandas.DataFrame, pandas.Series)) else obj + for obj in objs + ] # Here we reuse all_columns/index so we don't have to materialize objects # from remote memory built in the previous line. In the future, we won't be # building new DataFrames, rather just partitioning the DataFrames. if axis == 0: - new_blocks = np.array([_reindex_helper._submit( - args=tuple([all_columns[i], final_columns, axis, - len(objs[0]._block_partitions)] + part.tolist()), - num_return_vals=len(objs[0]._block_partitions)) - for i in range(len(objs)) - for part in objs[i]._block_partitions]) + new_blocks = np.array([ + _reindex_helper._submit( + args=tuple([ + all_columns[i], final_columns, axis, + len(objs[0]._block_partitions) + ] + part.tolist()), + num_return_vals=len(objs[0]._block_partitions) + ) for i in range(len(objs)) for part in objs[i]._block_partitions + ]) else: # Transposing the columns is necessary because the remote task treats # everything like rows and returns in row-major format. Luckily, this # operation is cheap in numpy. - new_blocks = np.array([_reindex_helper._submit( - args=tuple([all_index[i], final_index, axis, - len(objs[0]._block_partitions.T)] + part.tolist()), - num_return_vals=len(objs[0]._block_partitions.T)) - for i in range(len(objs)) - for part in objs[i]._block_partitions.T]).T - - return DataFrame(block_partitions=new_blocks, - columns=final_columns, - index=final_index) + new_blocks = np.array([ + _reindex_helper._submit( + args=tuple([ + all_index[i], final_index, axis, + len(objs[0]._block_partitions.T) + ] + part.tolist()), + num_return_vals=len(objs[0]._block_partitions.T) + ) for i in range(len(objs)) for part in objs[i]._block_partitions.T + ]).T + + return DataFrame( + block_partitions=new_blocks, columns=final_columns, index=final_index + ) diff --git a/python/ray/dataframe/dataframe.py b/python/ray/dataframe/dataframe.py index b96c4c836453..264bbe3039eb 100644 --- a/python/ray/dataframe/dataframe.py +++ b/python/ray/dataframe/dataframe.py @@ -12,10 +12,8 @@ from pandas.compat import lzip, string_types, cPickle as pkl import pandas.core.common as com from pandas.core.dtypes.common import ( - is_bool_dtype, - is_list_like, - is_numeric_dtype, - is_timedelta64_dtype) + is_bool_dtype, is_list_like, is_numeric_dtype, is_timedelta64_dtype +) from pandas.core.indexing import check_bool_indexer import warnings @@ -28,26 +26,29 @@ from .groupby import DataFrameGroupBy from .utils import ( - _deploy_func, - _map_partitions, - _partition_pandas_dataframe, - to_pandas, - _blocks_to_col, - _blocks_to_row, - _create_block_partitions, - _inherit_docstrings, - _reindex_helper, - _co_op_helper) + _deploy_func, _map_partitions, _partition_pandas_dataframe, to_pandas, + _blocks_to_col, _blocks_to_row, _create_block_partitions, + _inherit_docstrings, _reindex_helper, _co_op_helper +) from . import get_npartitions from .index_metadata import _IndexMetadata @_inherit_docstrings(pd.DataFrame) class DataFrame(object): - - def __init__(self, data=None, index=None, columns=None, dtype=None, - copy=False, col_partitions=None, row_partitions=None, - block_partitions=None, row_metadata=None, col_metadata=None): + def __init__( + self, + data=None, + index=None, + columns=None, + dtype=None, + copy=False, + col_partitions=None, + row_partitions=None, + block_partitions=None, + row_metadata=None, + col_metadata=None + ): """Distributed DataFrame object backed by Pandas dataframes. Args: @@ -74,12 +75,14 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, self._row_metadata = self._col_metadata = None # Check type of data and use appropriate constructor - if data is not None or (col_partitions is None and - row_partitions is None and - block_partitions is None): + if data is not None or ( + col_partitions is None and row_partitions is None + and block_partitions is None + ): - pd_df = pd.DataFrame(data=data, index=index, columns=columns, - dtype=dtype, copy=copy) + pd_df = pd.DataFrame( + data=data, index=index, columns=columns, dtype=dtype, copy=copy + ) # TODO convert _partition_pandas_dataframe to block partitioning. row_partitions = \ @@ -126,23 +129,25 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, # problematic for building blocks from the partitions, so we # add whatever dimension we're missing from the input. if self._block_partitions.ndim < 2: - self._block_partitions = np.expand_dims(self._block_partitions, - axis=axis ^ 1) + self._block_partitions = np.expand_dims( + self._block_partitions, axis=axis ^ 1 + ) assert self._block_partitions.ndim == 2, "Block Partitions must be 2D." # Create the row and column index objects for using our partitioning. # If the objects haven't been inherited, then generate them if self._row_metadata is None: - self._row_metadata = _IndexMetadata(self._block_partitions[:, 0], - index=index, axis=0) + self._row_metadata = _IndexMetadata( + self._block_partitions[:, 0], index=index, axis=0 + ) if self._col_metadata is None: - self._col_metadata = _IndexMetadata(self._block_partitions[0, :], - index=columns, axis=1) + self._col_metadata = _IndexMetadata( + self._block_partitions[0, :], index=columns, axis=1 + ) def _get_row_partitions(self): - return [_blocks_to_row.remote(*part) - for part in self._block_partitions] + return [_blocks_to_row.remote(*part) for part in self._block_partitions] def _set_row_partitions(self, new_row_partitions): self._block_partitions = \ @@ -152,8 +157,10 @@ def _set_row_partitions(self, new_row_partitions): _row_partitions = property(_get_row_partitions, _set_row_partitions) def _get_col_partitions(self): - return [_blocks_to_col.remote(*self._block_partitions[:, i]) - for i in range(self._block_partitions.shape[1])] + return [ + _blocks_to_col.remote(*self._block_partitions[:, i]) + for i in range(self._block_partitions.shape[1]) + ] def _set_col_partitions(self, new_col_partitions): self._block_partitions = \ @@ -175,8 +182,7 @@ def head(df, n, get_local_head=False): if get_local_head: return df.head(n) - new_dfs = _map_partitions(lambda df: df.head(n), - df) + new_dfs = _map_partitions(lambda df: df.head(n), df) index = self.index[:n] pd_head = pd.concat(ray.get(new_dfs), axis=1, copy=False) @@ -189,8 +195,7 @@ def tail(df, n, get_local_tail=False): if get_local_tail: return df.tail(n) - new_dfs = _map_partitions(lambda df: df.tail(n), - df) + new_dfs = _map_partitions(lambda df: df.tail(n), df) index = self.index[-n:] pd_tail = pd.concat(ray.get(new_dfs), axis=1, copy=False) @@ -203,7 +208,7 @@ def front(df, n): cum_col_lengths = self._col_metadata._lengths.cumsum() index = np.argmax(cum_col_lengths >= 10) - pd_front = pd.concat(ray.get(x[:index+1]), axis=1, copy=False) + pd_front = pd.concat(ray.get(x[:index + 1]), axis=1, copy=False) pd_front = pd_front.iloc[:, :n] pd_front.index = self.index pd_front.columns = self.columns[:n] @@ -212,10 +217,11 @@ def front(df, n): def back(df, n): """Get last n columns without creating a new Dataframe""" - cum_col_lengths = np.flip(self._col_metadata._lengths, - axis=0).cumsum() + cum_col_lengths = np.flip( + self._col_metadata._lengths, axis=0 + ).cumsum() index = np.argmax(cum_col_lengths >= 10) - pd_back = pd.concat(ray.get(x[-(index+1):]), axis=1, copy=False) + pd_back = pd.concat(ray.get(x[-(index + 1):]), axis=1, copy=False) pd_back = pd_back.iloc[:, -n:] pd_back.index = self.index pd_back.columns = self.columns[-n:] @@ -230,8 +236,7 @@ def back(df, n): front = front(x, 10) back = back(x, 10) - col_dots = pd.Series(["..." - for _ in range(len(self.index))]) + col_dots = pd.Series(["..." for _ in range(len(self.index))]) col_dots.index = self.index col_dots.name = "..." x = pd.concat([front, col_dots, back], axis=1) @@ -244,8 +249,7 @@ def back(df, n): tail = tail(x, 30, get_local_head) # Make the dots in between the head and tail - row_dots = pd.Series(["..." - for _ in range(len(head.columns))]) + row_dots = pd.Series(["..." for _ in range(len(head.columns))]) row_dots.index = head.columns row_dots.name = "..." @@ -328,9 +332,12 @@ def _arithmetic_helper(self, remote_func, axis, level=None): axis = pd.DataFrame()._get_axis_number(axis) if axis is not None \ else 0 - oid_series = ray.get(_map_partitions(remote_func, - self._col_partitions if axis == 0 - else self._row_partitions)) + oid_series = ray.get( + _map_partitions( + remote_func, self._col_partitions + if axis == 0 else self._row_partitions + ) + ) if axis == 0: # We use the index to get the internal index. @@ -344,7 +351,8 @@ def _arithmetic_helper(self, remote_func, axis, level=None): this_partition[this_partition.isin(df.index)].index result_series = pd.concat([obj[0] for obj in oid_series], - axis=0, copy=False) + axis=0, + copy=False) else: result_series = pd.concat(oid_series, axis=0, copy=False) result_series.index = self.index @@ -361,8 +369,10 @@ def _validate_eval_query(self, expr, **kwargs): raise ValueError("expr cannot be an empty string") if isinstance(expr, str) and '@' in expr: - raise NotImplementedError("Local variables not yet supported in " - "eval.") + raise NotImplementedError( + "Local variables not yet supported in " + "eval." + ) if isinstance(expr, str) and 'not' in expr: if 'parser' in kwargs and kwargs['parser'] == 'python': @@ -386,8 +396,9 @@ def ndim(self): """ # The number of dimensions is common across all partitions. # The first partition will be enough. - return ray.get(_deploy_func.remote(lambda df: df.ndim, - self._row_partitions[0])) + return ray.get( + _deploy_func.remote(lambda df: df.ndim, self._row_partitions[0]) + ) @property def ftypes(self): @@ -398,8 +409,9 @@ def ftypes(self): """ # The ftypes are common across all partitions. # The first partition will be enough. - result = ray.get(_deploy_func.remote(lambda df: df.ftypes, - self._row_partitions[0])) + result = ray.get( + _deploy_func.remote(lambda df: df.ftypes, self._row_partitions[0]) + ) result.index = self.columns return result @@ -412,8 +424,9 @@ def dtypes(self): """ # The dtypes are common across all partitions. # The first partition will be enough. - result = ray.get(_deploy_func.remote(lambda df: df.dtypes, - self._row_partitions[0])) + result = ray.get( + _deploy_func.remote(lambda df: df.dtypes, self._row_partitions[0]) + ) result.index = self.columns return result @@ -425,8 +438,9 @@ def empty(self): True if the DataFrame is empty. False otherwise. """ - all_empty = ray.get(_map_partitions( - lambda df: df.empty, self._row_partitions)) + all_empty = ray.get( + _map_partitions(lambda df: df.empty, self._row_partitions) + ) return False not in all_empty @property @@ -436,8 +450,11 @@ def values(self): Returns: The numpy representation of this DataFrame. """ - return np.concatenate(ray.get(_map_partitions( - lambda df: df.values, self._row_partitions))) + return np.concatenate( + ray.get( + _map_partitions(lambda df: df.values, self._row_partitions) + ) + ) @property def axes(self): @@ -457,9 +474,16 @@ def shape(self): """ return len(self.index), len(self.columns) - def _update_inplace(self, row_partitions=None, col_partitions=None, - block_partitions=None, columns=None, index=None, - col_metadata=None, row_metadata=None): + def _update_inplace( + self, + row_partitions=None, + col_partitions=None, + block_partitions=None, + columns=None, + index=None, + col_metadata=None, + row_metadata=None + ): """Updates the current DataFrame inplace. Behavior should be similar to the constructor, given the corresponding @@ -501,13 +525,15 @@ def _update_inplace(self, row_partitions=None, col_partitions=None, assert columns is not None, \ "Columns must be passed without col_metadata" self._col_metadata = _IndexMetadata( - self._block_partitions[0, :], index=columns, axis=1) + self._block_partitions[0, :], index=columns, axis=1 + ) if row_metadata is not None: self._row_metadata = row_metadata else: # Index can be None for default index, so we don't check self._row_metadata = _IndexMetadata( - self._block_partitions[:, 0], index=index, axis=0) + self._block_partitions[:, 0], index=index, axis=0 + ) def add_prefix(self, prefix): """Add a prefix to each of the column names. @@ -516,9 +542,11 @@ def add_prefix(self, prefix): A new DataFrame containing the new column names. """ new_cols = self.columns.map(lambda x: str(prefix) + str(x)) - return DataFrame(block_partitions=self._block_partitions, - columns=new_cols, - index=self.index) + return DataFrame( + block_partitions=self._block_partitions, + columns=new_cols, + index=self.index + ) def add_suffix(self, suffix): """Add a suffix to each of the column names. @@ -527,9 +555,11 @@ def add_suffix(self, suffix): A new DataFrame containing the new column names. """ new_cols = self.columns.map(lambda x: str(x) + str(suffix)) - return DataFrame(block_partitions=self._block_partitions, - columns=new_cols, - index=self.index) + return DataFrame( + block_partitions=self._block_partitions, + columns=new_cols, + index=self.index + ) def applymap(self, func): """Apply a function to a DataFrame elementwise. @@ -539,15 +569,19 @@ def applymap(self, func): """ if not callable(func): raise ValueError( - "\'{0}\' object is not callable".format(type(func))) + "\'{0}\' object is not callable".format(type(func)) + ) new_block_partitions = np.array([ _map_partitions(lambda df: df.applymap(func), block) - for block in self._block_partitions]) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def copy(self, deep=True): """Creates a shallow copy of the DataFrame. @@ -555,12 +589,23 @@ def copy(self, deep=True): Returns: A new DataFrame pointing to the same partitions as this one. """ - return DataFrame(block_partitions=self._block_partitions, - columns=self.columns, - index=self.index) - - def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, - group_keys=True, squeeze=False, **kwargs): + return DataFrame( + block_partitions=self._block_partitions, + columns=self.columns, + index=self.index + ) + + def groupby( + self, + by=None, + axis=0, + level=None, + as_index=True, + sort=True, + group_keys=True, + squeeze=False, + **kwargs + ): """Apply a groupby to this DataFrame. See _groupby() remote task. Args: by: The value to groupby. @@ -584,12 +629,14 @@ def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, if all([obj in self for obj in by]) and mismatch: raise NotImplementedError( - "Groupby with lists of columns not yet supported.") + "Groupby with lists of columns not yet supported." + ) elif mismatch: raise KeyError(next(x for x in by if x not in self)) - return DataFrameGroupBy(self, by, axis, level, as_index, sort, - group_keys, squeeze, **kwargs) + return DataFrameGroupBy( + self, by, axis, level, as_index, sort, group_keys, squeeze, **kwargs + ) def sum(self, axis=None, skipna=True, level=None, numeric_only=None): """Perform a sum across the DataFrame. @@ -601,9 +648,14 @@ def sum(self, axis=None, skipna=True, level=None, numeric_only=None): Returns: The sum of the DataFrame. """ + def remote_func(df): - return df.sum(axis=axis, skipna=skipna, level=level, - numeric_only=numeric_only) + return df.sum( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only + ) return self._arithmetic_helper(remote_func, axis, level) @@ -618,13 +670,16 @@ def abs(self): # TODO Give a more accurate error to Pandas raise TypeError("bad operand type for abs():", "str") - new_block_partitions = np.array([_map_partitions(lambda df: df.abs(), - block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.abs(), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def isin(self, values): """Fill a DataFrame with booleans for cells contained in values. @@ -638,13 +693,16 @@ def isin(self, values): True: cell is contained in values. False: otherwise """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.isin(values), block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.isin(values), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def isna(self): """Fill a DataFrame with booleans for cells containing NA. @@ -655,14 +713,18 @@ def isna(self): True: cell contains NA. False: otherwise. """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.isna(), block) for block in self._block_partitions]) - - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index, - row_metadata=self._row_metadata, - col_metadata=self._col_metadata) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.isna(), block) + for block in self._block_partitions + ]) + + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index, + row_metadata=self._row_metadata, + col_metadata=self._col_metadata + ) def isnull(self): """Fill a DataFrame with booleans for cells containing a null value. @@ -673,13 +735,16 @@ def isnull(self): True: cell contains null. False: otherwise. """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.isnull(), block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.isnull(), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def keys(self): """Get the info axis for the DataFrame. @@ -696,12 +761,16 @@ def transpose(self, *args, **kwargs): Returns: A new DataFrame transposed from this DataFrame. """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.T, block) for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.T, block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions.T, - columns=self.index, - index=self.columns) + return DataFrame( + block_partitions=new_block_partitions.T, + columns=self.index, + index=self.columns + ) T = property(transpose) @@ -738,8 +807,9 @@ def add(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the applied addition. """ - return self._operator_helper(pd.DataFrame.add, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.add, other, axis, level, fill_value + ) def agg(self, func, axis=0, *args, **kwargs): return self.aggregate(func, axis, *args, **kwargs) @@ -774,16 +844,17 @@ def _aggregate(self, arg, *args, **kwargs): elif isinstance(arg, dict): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) elif is_list_like(arg): from .concat import concat - x = [self._aggregate(func, *args, **kwargs) - for func in arg] + x = [self._aggregate(func, *args, **kwargs) for func in arg] - new_dfs = [x[i] if not isinstance(x[i], pd.Series) - else pd.DataFrame(x[i], columns=[arg[i]]).T - for i in range(len(x))] + new_dfs = [ + x[i] if not isinstance(x[i], pd.Series) else + pd.DataFrame(x[i], columns=[arg[i]]).T for i in range(len(x)) + ] return concat(new_dfs) elif callable(arg): @@ -802,9 +873,9 @@ def _string_function(self, func, *args, **kwargs): return f(*args, **kwargs) assert len(args) == 0 - assert len([kwarg - for kwarg in kwargs - if kwarg not in ['axis', '_level']]) == 0 + assert len([ + kwarg for kwarg in kwargs if kwarg not in ['axis', '_level'] + ]) == 0 return f f = getattr(np, func, None) @@ -849,8 +920,10 @@ def agg_helper(df, arg, *args, **kwargs): if is_transform: if is_scalar(new_df) or len(new_df) != len(df): - raise ValueError("transforms cannot produce " - "aggregated results") + raise ValueError( + "transforms cannot produce " + "aggregated results" + ) return is_series, new_df, index, columns @@ -886,51 +959,74 @@ def agg_helper(df, arg, *args, **kwargs): columns = ray.get(columns) columns = columns[0].append(columns[1:]) - return DataFrame(col_partitions=new_parts, - columns=columns, - index=self.index if new_index is None - else new_index) + return DataFrame( + col_partitions=new_parts, + columns=columns, + index=self.index if new_index is None else new_index + ) else: new_index = ray.get(index[0]) columns = ray.get(columns) columns = columns[0].append(columns[1:]) - return DataFrame(row_partitions=new_parts, - columns=columns, - index=self.index if new_index is None - else new_index) - - def align(self, other, join='outer', axis=None, level=None, copy=True, - fill_value=None, method=None, limit=None, fill_axis=0, - broadcast_axis=None): + return DataFrame( + row_partitions=new_parts, + columns=columns, + index=self.index if new_index is None else new_index + ) + + def align( + self, + other, + join='outer', + axis=None, + level=None, + copy=True, + fill_value=None, + method=None, + limit=None, + fill_axis=0, + broadcast_axis=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def all(self, axis=None, bool_only=None, skipna=None, level=None, - **kwargs): + def all(self, axis=None, bool_only=None, skipna=None, level=None, **kwargs): """Return whether all elements are True over requested axis Note: If axis=None or axis=0, this call applies df.all(axis=1) to the transpose of df. """ + def remote_func(df): - return df.all(axis=axis, bool_only=bool_only, skipna=skipna, - level=level, **kwargs) + return df.all( + axis=axis, + bool_only=bool_only, + skipna=skipna, + level=level, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) - def any(self, axis=None, bool_only=None, skipna=None, level=None, - **kwargs): + def any(self, axis=None, bool_only=None, skipna=None, level=None, **kwargs): """Return whether any elements are True over requested axis Note: If axis=None or axis=0, this call applies on the column partitions, otherwise operates on row partitions """ + def remote_func(df): - return df.any(axis=axis, bool_only=bool_only, skipna=skipna, - level=level, **kwargs) + return df.any( + axis=axis, + bool_only=bool_only, + skipna=skipna, + level=level, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) @@ -949,8 +1045,10 @@ def append(self, other, ignore_index=False, verify_integrity=False): if isinstance(other, dict): other = pd.Series(other) if other.name is None and not ignore_index: - raise TypeError('Can only append a Series if ignore_index=True' - ' or if the Series has a name') + raise TypeError( + 'Can only append a Series if ignore_index=True' + ' or if the Series has a name' + ) if other.name is None: index = None @@ -960,11 +1058,14 @@ def append(self, other, ignore_index=False, verify_integrity=False): index = pd.Index([other.name], name=self.index.name) combined_columns = self.columns.tolist() + self.columns.union( - other.index).difference(self.columns).tolist() + other.index + ).difference(self.columns).tolist() other = other.reindex(combined_columns, copy=False) - other = pd.DataFrame(other.values.reshape((1, len(other))), - index=index, - columns=combined_columns) + other = pd.DataFrame( + other.values.reshape((1, len(other))), + index=index, + columns=combined_columns + ) other = other._convert(datetime=True, timedelta=True) elif isinstance(other, list) and not isinstance(other[0], DataFrame): other = pd.DataFrame(other) @@ -977,11 +1078,22 @@ def append(self, other, ignore_index=False, verify_integrity=False): else: to_concat = [self, other] - return concat(to_concat, ignore_index=ignore_index, - verify_integrity=verify_integrity) - - def apply(self, func, axis=0, broadcast=False, raw=False, reduce=None, - args=(), **kwds): + return concat( + to_concat, + ignore_index=ignore_index, + verify_integrity=verify_integrity + ) + + def apply( + self, + func, + axis=0, + broadcast=False, + raw=False, + reduce=None, + args=(), + **kwds + ): """Apply a function along input axis of DataFrame. Args: @@ -996,11 +1108,12 @@ def apply(self, func, axis=0, broadcast=False, raw=False, reduce=None, """ axis = pd.DataFrame()._get_axis_number(axis) - if is_list_like(func) and not all([isinstance(obj, str) - for obj in func]): + if is_list_like(func + ) and not all([isinstance(obj, str) for obj in func]): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) if axis == 0 and is_list_like(func): return self.aggregate(func, axis, *args, **kwds) @@ -1013,58 +1126,71 @@ def apply(self, func, axis=0, broadcast=False, raw=False, reduce=None, else: raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def as_blocks(self, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def as_matrix(self, columns=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def asfreq(self, freq, method=None, how=None, normalize=False, - fill_value=None): + def asfreq( + self, freq, method=None, how=None, normalize=False, fill_value=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def asof(self, where, subset=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def assign(self, **kwargs): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def astype(self, dtype, copy=True, errors='raise', **kwargs): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def at_time(self, time, asof=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def between_time(self, start_time, end_time, include_start=True, - include_end=True): + def between_time( + self, start_time, end_time, include_start=True, include_end=True + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def bfill(self, axis=None, inplace=False, limit=None, downcast=None): """Synonym for DataFrame.fillna(method='bfill') """ - new_df = self.fillna(method='bfill', - axis=axis, - limit=limit, - downcast=downcast, - inplace=inplace) + new_df = self.fillna( + method='bfill', + axis=axis, + limit=limit, + downcast=downcast, + inplace=inplace + ) if not inplace: return new_df @@ -1076,73 +1202,102 @@ def bool(self): element is not boolean """ shape = self.shape - if shape != (1,) and shape != (1, 1): - raise ValueError("""The PandasObject does not have exactly + if shape != (1, ) and shape != (1, 1): + raise ValueError( + """The PandasObject does not have exactly 1 element. Return the bool of a single element PandasObject. The truth value is ambiguous. Use a.empty, a.item(), a.any() - or a.all().""") + or a.all().""" + ) else: return to_pandas(self).bool() - def boxplot(self, column=None, by=None, ax=None, fontsize=None, rot=0, - grid=True, figsize=None, layout=None, return_type=None, - **kwds): + def boxplot( + self, + column=None, + by=None, + ax=None, + fontsize=None, + rot=0, + grid=True, + figsize=None, + layout=None, + return_type=None, + **kwds + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def clip(self, lower=None, upper=None, axis=None, inplace=False, *args, - **kwargs): + def clip( + self, lower=None, upper=None, axis=None, inplace=False, *args, **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def clip_lower(self, threshold, axis=None, inplace=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def clip_upper(self, threshold, axis=None, inplace=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def combine(self, other, func, fill_value=None, overwrite=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def combine_first(self, other): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def compound(self, axis=None, skipna=None, level=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def consolidate(self, inplace=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def convert_objects(self, convert_dates=True, convert_numeric=False, - convert_timedeltas=True, copy=True): + "github.com/ray-project/ray." + ) + + def convert_objects( + self, + convert_dates=True, + convert_numeric=False, + convert_timedeltas=True, + copy=True + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def corr(self, method='pearson', min_periods=1): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def corrwith(self, other, axis=0, drop=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def count(self, axis=0, level=None, numeric_only=False): """Get the count of non-null objects in the DataFrame. @@ -1156,6 +1311,7 @@ def count(self, axis=0, level=None, numeric_only=False): Returns: The count, in a Series (or DataFrame if level is specified). """ + def remote_func(df): return df.count(axis=axis, level=level, numeric_only=numeric_only) @@ -1164,7 +1320,8 @@ def remote_func(df): def cov(self, min_periods=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def _cumulative_helper(self, func, axis): axis = pd.DataFrame()._get_axis_number(axis) if axis is not None \ @@ -1172,14 +1329,14 @@ def _cumulative_helper(self, func, axis): if axis == 0: new_cols = _map_partitions(func, self._col_partitions) - return DataFrame(col_partitions=new_cols, - columns=self.columns, - index=self.index) + return DataFrame( + col_partitions=new_cols, columns=self.columns, index=self.index + ) else: new_rows = _map_partitions(func, self._row_partitions) - return DataFrame(row_partitions=new_rows, - columns=self.columns, - index=self.index) + return DataFrame( + row_partitions=new_rows, columns=self.columns, index=self.index + ) def cummax(self, axis=None, skipna=True, *args, **kwargs): """Perform a cumulative maximum across the DataFrame. @@ -1191,6 +1348,7 @@ def cummax(self, axis=None, skipna=True, *args, **kwargs): Returns: The cumulative maximum of the DataFrame. """ + def remote_func(df): return df.cummax(axis=axis, skipna=skipna, *args, **kwargs) @@ -1206,6 +1364,7 @@ def cummin(self, axis=None, skipna=True, *args, **kwargs): Returns: The cumulative minimum of the DataFrame. """ + def remote_func(df): return df.cummin(axis=axis, skipna=skipna, *args, **kwargs) @@ -1221,6 +1380,7 @@ def cumprod(self, axis=None, skipna=True, *args, **kwargs): Returns: The cumulative product of the DataFrame. """ + def remote_func(df): return df.cumprod(axis=axis, skipna=skipna, *args, **kwargs) @@ -1236,6 +1396,7 @@ def cumsum(self, axis=None, skipna=True, *args, **kwargs): Returns: The cumulative sum of the DataFrame. """ + def remote_func(df): return df.cumsum(axis=axis, skipna=skipna, *args, **kwargs) @@ -1254,13 +1415,13 @@ def describe(self, percentiles=None, include=None, exclude=None): Returns: Series/DataFrame of summary statistics """ + def describe_helper(df): """This to ensure nothing goes on with non-numeric columns""" try: return df.select_dtypes(exclude='object').describe( - percentiles=percentiles, - include=include, - exclude=exclude) + percentiles=percentiles, include=include, exclude=exclude + ) # This exception is thrown when there are only non-numeric columns # in this partition except ValueError: @@ -1282,7 +1443,8 @@ def describe_helper(df): def diff(self, periods=1, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def div(self, other, axis='columns', level=None, fill_value=None): """Divides this DataFrame against another DataFrame/Series/scalar. @@ -1296,8 +1458,9 @@ def div(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Divide applied. """ - return self._operator_helper(pd.DataFrame.add, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.add, other, axis, level, fill_value + ) def divide(self, other, axis='columns', level=None, fill_value=None): """Synonym for div. @@ -1316,10 +1479,19 @@ def divide(self, other, axis='columns', level=None, fill_value=None): def dot(self, other): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def drop(self, labels=None, axis=0, index=None, columns=None, level=None, - inplace=False, errors='raise'): + "github.com/ray-project/ray." + ) + + def drop( + self, + labels=None, + axis=0, + index=None, + columns=None, + level=None, + inplace=False, + errors='raise' + ): """Return new object with labels in requested axis removed. Args: labels: Index or column labels to drop. @@ -1346,17 +1518,21 @@ def drop(self, labels=None, axis=0, index=None, columns=None, level=None, inplace = validate_bool_kwarg(inplace, "inplace") if labels is not None: if index is not None or columns is not None: - raise ValueError("Cannot specify both 'labels' and " - "'index'/'columns'") + raise ValueError( + "Cannot specify both 'labels' and " + "'index'/'columns'" + ) axis = pd.DataFrame()._get_axis_name(axis) axes = {axis: labels} elif index is not None or columns is not None: - axes, _ = pd.DataFrame()._construct_axes_from_arguments((index, - columns), - {}) + axes, _ = pd.DataFrame()._construct_axes_from_arguments( + (index, columns), {} + ) else: - raise ValueError("Need to specify at least one of 'labels', " - "'index' or 'columns'") + raise ValueError( + "Need to specify at least one of 'labels', " + "'index' or 'columns'" + ) obj = self.copy() def drop_helper(obj, axis, label): @@ -1432,15 +1608,17 @@ def drop_helper(obj, axis, label): for label in labels: if errors != 'ignore' and label and \ label not in getattr(self, axis): - raise ValueError("The label [{}] is not in the [{}]", - label, axis) + raise ValueError( + "The label [{}] is not in the [{}]", label, axis + ) else: obj = drop_helper(obj, axis, label) else: if errors != 'ignore' and labels and \ labels not in getattr(self, axis): - raise ValueError("The label [{}] is not in the [{}]", - labels, axis) + raise ValueError( + "The label [{}] is not in the [{}]", labels, axis + ) else: obj = drop_helper(obj, axis, labels) @@ -1454,12 +1632,14 @@ def drop_helper(obj, axis, label): def drop_duplicates(self, subset=None, keep='first', inplace=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def duplicated(self, subset=None, keep='first'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def eq(self, other, axis='columns', level=None): """Checks element-wise that this is equal to other. @@ -1481,6 +1661,7 @@ def equals(self, other): Returns: Boolean: True if equal, otherwise False """ + # TODO(kunalgosar): Implement Copartition and use to implement equals def helper(df, index, other_series): return df.iloc[index['index_within_partition']] \ @@ -1499,10 +1680,9 @@ def helper(df, index, other_series): other_series = other_df.iloc[idx['index_within_partition']] curr_index = self._row_metadata._coord_df.iloc[i] curr_df = self._row_partitions[int(curr_index['partition'])] - results.append(_deploy_func.remote(helper, - curr_df, - curr_index, - other_series)) + results.append( + _deploy_func.remote(helper, curr_df, curr_index, other_series) + ) for r in results: if not ray.get(r): @@ -1569,8 +1749,9 @@ def eval_helper(df): inplace = validate_bool_kwarg(inplace, "inplace") new_rows = _map_partitions(eval_helper, self._row_partitions) - result_type = ray.get(_deploy_func.remote(lambda df: type(df), - new_rows[0])) + result_type = ray.get( + _deploy_func.remote(lambda df: type(df), new_rows[0]) + ) if result_type is pd.Series: new_series = pd.concat(ray.get(new_rows), axis=0) new_series.index = self.index @@ -1585,30 +1766,52 @@ def eval_helper(df): else: return DataFrame(columns=columns, row_partitions=new_rows) - def ewm(self, com=None, span=None, halflife=None, alpha=None, - min_periods=0, freq=None, adjust=True, ignore_na=False, axis=0): + def ewm( + self, + com=None, + span=None, + halflife=None, + alpha=None, + min_periods=0, + freq=None, + adjust=True, + ignore_na=False, + axis=0 + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def expanding(self, min_periods=1, freq=None, center=False, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def ffill(self, axis=None, inplace=False, limit=None, downcast=None): """Synonym for DataFrame.fillna(method='ffill') """ - new_df = self.fillna(method='ffill', - axis=axis, - limit=limit, - downcast=downcast, - inplace=inplace) + new_df = self.fillna( + method='ffill', + axis=axis, + limit=limit, + downcast=downcast, + inplace=inplace + ) if not inplace: return new_df - def fillna(self, value=None, method=None, axis=None, inplace=False, - limit=None, downcast=None, **kwargs): + def fillna( + self, + value=None, + method=None, + axis=None, + inplace=False, + limit=None, + downcast=None, + **kwargs + ): """Fill NA/NaN values using the specified method. Args: @@ -1641,8 +1844,10 @@ def fillna(self, value=None, method=None, axis=None, inplace=False, """ # TODO implement value passed as DataFrame if isinstance(value, pd.DataFrame): - raise NotImplementedError("Passing a DataFrame as the value for " - "fillna is not yet supported.") + raise NotImplementedError( + "Passing a DataFrame as the value for " + "fillna is not yet supported." + ) inplace = validate_bool_kwarg(inplace, 'inplace') @@ -1651,14 +1856,17 @@ def fillna(self, value=None, method=None, axis=None, inplace=False, else 0 if isinstance(value, (list, tuple)): - raise TypeError('"value" parameter must be a scalar or dict, but ' - 'you passed a "{0}"'.format(type(value).__name__)) + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + 'you passed a "{0}"'.format(type(value).__name__) + ) if value is None and method is None: raise ValueError('must specify a fill method or value') if value is not None and method is not None: raise ValueError('cannot specify both a fill method and value') - if method is not None and method not in ['backfill', 'bfill', 'pad', - 'ffill']: + if method is not None and method not in [ + 'backfill', 'bfill', 'pad', 'ffill' + ]: expecting = 'pad (ffill) or backfill (bfill)' msg = 'Invalid fill method. Expecting {expecting}. Got {method}'\ .format(expecting=expecting, method=method) @@ -1696,9 +1904,10 @@ def fillna(self, value=None, method=None, axis=None, inplace=False, # Not every partition was changed, so we put everything back that # was not changed and update those that were. - new_parts = [parts[i] if coords_obj.index[i] not in new_vals - else new_vals[coords_obj.index[i]] - for i in range(len(parts))] + new_parts = [ + parts[i] if coords_obj.index[i] not in new_vals else + new_vals[coords_obj.index[i]] for i in range(len(parts)) + ] else: new_parts = _map_partitions(lambda df: df.fillna( value=value, @@ -1710,25 +1919,31 @@ def fillna(self, value=None, method=None, axis=None, inplace=False, **kwargs), parts) if axis == 0: - new_obj._update_inplace(col_partitions=new_parts, - columns=self.columns, - index=self.index) + new_obj._update_inplace( + col_partitions=new_parts, + columns=self.columns, + index=self.index + ) else: - new_obj._update_inplace(row_partitions=new_parts, - columns=self.columns, - index=self.index) + new_obj._update_inplace( + row_partitions=new_parts, + columns=self.columns, + index=self.index + ) if not inplace: return new_obj def filter(self, items=None, like=None, regex=None, axis=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def first(self, offset): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def first_valid_index(self): """Return index for first non-NA/null value. @@ -1750,35 +1965,55 @@ def floordiv(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Divide applied. """ - return self._operator_helper(pd.DataFrame.floordiv, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.floordiv, other, axis, level, fill_value + ) @classmethod - def from_csv(self, path, header=0, sep=', ', index_col=0, - parse_dates=True, encoding=None, tupleize_cols=None, - infer_datetime_format=False): + def from_csv( + self, + path, + header=0, + sep=', ', + index_col=0, + parse_dates=True, + encoding=None, + tupleize_cols=None, + infer_datetime_format=False + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @classmethod def from_dict(self, data, orient='columns', dtype=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @classmethod def from_items(self, items, columns=None, orient='columns'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @classmethod - def from_records(self, data, index=None, exclude=None, columns=None, - coerce_float=False, nrows=None): + def from_records( + self, + data, + index=None, + exclude=None, + columns=None, + coerce_float=False, + nrows=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def ge(self, other, axis='columns', level=None): """Checks element-wise that this is greater than or equal to other. @@ -1816,8 +2051,11 @@ def get_dtype_counts(self): Returns: The counts of dtypes in this object. """ - return ray.get(_deploy_func.remote(lambda df: df.get_dtype_counts(), - self._row_partitions[0])) + return ray.get( + _deploy_func.remote( + lambda df: df.get_dtype_counts(), self._row_partitions[0] + ) + ) def get_ftype_counts(self): """Get the counts of ftypes in this object. @@ -1825,18 +2063,23 @@ def get_ftype_counts(self): Returns: The counts of ftypes in this object. """ - return ray.get(_deploy_func.remote(lambda df: df.get_ftype_counts(), - self._row_partitions[0])) + return ray.get( + _deploy_func.remote( + lambda df: df.get_ftype_counts(), self._row_partitions[0] + ) + ) def get_value(self, index, col, takeable=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def get_values(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def gt(self, other, axis='columns', level=None): """Checks element-wise that this is greater than other. @@ -1863,21 +2106,36 @@ def head(self, n=5): if n >= len(self._row_metadata): return self.copy() - new_dfs = _map_partitions(lambda df: df.head(n), - self._col_partitions) + new_dfs = _map_partitions(lambda df: df.head(n), self._col_partitions) index = self._row_metadata.index[:n] - return DataFrame(col_partitions=new_dfs, - columns=self.columns, - index=index) - - def hist(self, data, column=None, by=None, grid=True, xlabelsize=None, - xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False, - sharey=False, figsize=None, layout=None, bins=10, **kwds): + return DataFrame( + col_partitions=new_dfs, columns=self.columns, index=index + ) + + def hist( + self, + data, + column=None, + by=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + sharex=False, + sharey=False, + figsize=None, + layout=None, + bins=10, + **kwds + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def idxmax(self, axis=0, skipna=True): """Get the index of the first occurrence of the max value of the axis. @@ -1892,7 +2150,8 @@ def idxmax(self, axis=0, skipna=True): """ if not all([d != np.dtype('O') for d in self.dtypes]): raise TypeError( - "reduction operation 'argmax' not allowed for this dtype") + "reduction operation 'argmax' not allowed for this dtype" + ) def remote_func(df): return df.idxmax(axis=axis, skipna=skipna) @@ -1914,7 +2173,8 @@ def idxmin(self, axis=0, skipna=True): """ if not all([d != np.dtype('O') for d in self.dtypes]): raise TypeError( - "reduction operation 'argmax' not allowed for this dtype") + "reduction operation 'argmax' not allowed for this dtype" + ) def remote_func(df): return df.idxmin(axis=axis, skipna=skipna) @@ -1926,23 +2186,32 @@ def remote_func(df): def infer_objects(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def info(self, verbose=None, buf=None, max_cols=None, memory_usage=None, - null_counts=None): - + "github.com/ray-project/ray." + ) + + def info( + self, + verbose=None, + buf=None, + max_cols=None, + memory_usage=None, + null_counts=None + ): def info_helper(df): output_buffer = io.StringIO() - df.info(verbose=verbose, - buf=output_buffer, - max_cols=max_cols, - memory_usage=memory_usage, - null_counts=null_counts) + df.info( + verbose=verbose, + buf=output_buffer, + max_cols=max_cols, + memory_usage=memory_usage, + null_counts=null_counts + ) return output_buffer.getvalue() # Combine the per-partition info and split into lines - result = ''.join(ray.get(_map_partitions(info_helper, - self._col_partitions))) + result = ''.join( + ray.get(_map_partitions(info_helper, self._col_partitions)) + ) lines = result.split('\n') # Class denoted in info() output @@ -1953,7 +2222,8 @@ def info_helper(df): # A column header is needed in the inf() output col_header = 'Data columns (total {0} columns):\n'.format( - len(self.columns)) + len(self.columns) + ) # Parse the per-partition values to get the per-column details # Find all the lines in the output that start with integers @@ -1961,9 +2231,10 @@ def info_helper(df): col_lines = [prog.match(line) for line in lines] cols = [c.group(0) for c in col_lines if c is not None] # replace the partition columns names with real column names - columns = ["{0}\t{1}\n".format(self.columns[i], - cols[i].split(" ", 1)[1]) - for i in range(len(cols))] + columns = [ + "{0}\t{1}\n".format(self.columns[i], cols[i].split(" ", 1)[1]) + for i in range(len(cols)) + ] col_string = ''.join(columns) + '\n' # A summary of the dtypes in the dataframe @@ -1976,22 +2247,26 @@ def info_helper(df): # Parse lines for memory usage number prog = re.compile('^memory+.+') mems = [prog.match(line) for line in lines] - mem_vals = [float(re.search(r'\d+', m.group(0)).group()) - for m in mems if m is not None] + mem_vals = [ + float(re.search(r'\d+', m.group(0)).group()) + for m in mems + if m is not None + ] memory_string = "" if len(mem_vals) != 0: # Sum memory usage from each partition if memory_usage != 'deep': - memory_string = 'memory usage: {0}+ bytes'.format( - sum(mem_vals)) + memory_string = 'memory usage: {0}+ bytes'.format(sum(mem_vals)) else: memory_string = 'memory usage: {0} bytes'.format(sum(mem_vals)) # Combine all the components of the info() output - result = ''.join([class_string, index_string, col_header, - col_string, dtypes_string, memory_string]) + result = ''.join([ + class_string, index_string, col_header, col_string, dtypes_string, + memory_string + ]) # Write to specified output buffer if buf: @@ -2012,15 +2287,15 @@ def insert(self, loc, column, value, allow_duplicates=False): value = np.full(len(self.index), value) if len(value) != len(self.index): - raise ValueError( - "Length of values does not match length of index") + raise ValueError("Length of values does not match length of index") if not allow_duplicates and column in self.columns: - raise ValueError( - "cannot insert {0}, already exists".format(column)) + raise ValueError("cannot insert {0}, already exists".format(column)) if loc > len(self.columns): raise IndexError( "index {0} is out of bounds for axis 0 with size {1}".format( - loc, len(self.columns))) + loc, len(self.columns) + ) + ) if loc < 0: raise ValueError("unbounded slice") @@ -2033,21 +2308,31 @@ def insert_col_part(df): df.insert(index_within_partition, column, value, allow_duplicates) return df - new_obj = _deploy_func.remote(insert_col_part, - self._col_partitions[partition]) - new_cols = [self._col_partitions[i] - if i != partition - else new_obj - for i in range(len(self._col_partitions))] + new_obj = _deploy_func.remote( + insert_col_part, self._col_partitions[partition] + ) + new_cols = [ + self._col_partitions[i] if i != partition else new_obj + for i in range(len(self._col_partitions)) + ] new_col_names = self.columns.insert(loc, column) self._update_inplace(col_partitions=new_cols, columns=new_col_names) - def interpolate(self, method='linear', axis=0, limit=None, inplace=False, - limit_direction='forward', downcast=None, **kwargs): + def interpolate( + self, + method='linear', + axis=0, + limit=None, + inplace=False, + limit_direction='forward', + downcast=None, + **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def iterrows(self): """Iterate over DataFrame rows as (index, Series) pairs. @@ -2060,15 +2345,17 @@ def iterrows(self): Returns: A generator that iterates over the rows of the frame. """ + def update_iterrow(series, i): """Helper function to correct the columns + name of the Series.""" series.index = self.columns series.name = list(self.index)[i] return series - iters = ray.get([_deploy_func.remote( - lambda df: list(df.iterrows()), part) - for part in self._row_partitions]) + iters = ray.get([ + _deploy_func.remote(lambda df: list(df.iterrows()), part) + for part in self._row_partitions + ]) iters = itertools.chain.from_iterable(iters) series = map(lambda s: update_iterrow(s[1][1], s[0]), enumerate(iters)) @@ -2085,9 +2372,10 @@ def items(self): Returns: A generator that iterates over the columns of the frame. """ - iters = ray.get([_deploy_func.remote( - lambda df: list(df.items()), part) - for part in self._row_partitions]) + iters = ray.get([ + _deploy_func.remote(lambda df: list(df.items()), part) + for part in self._row_partitions + ]) def concat_iters(iterables): for partitions in enumerate(zip(*iterables)): @@ -2127,8 +2415,9 @@ def itertuples(self, index=True, name='Pandas'): """ iters = ray.get([ _deploy_func.remote( - lambda df: list(df.itertuples(index=index, name=name)), - part) for part in self._row_partitions]) + lambda df: list(df.itertuples(index=index, name=name)), part + ) for part in self._row_partitions + ]) iters = itertools.chain.from_iterable(iters) def _replace_index(row_tuple, idx): @@ -2137,15 +2426,16 @@ def _replace_index(row_tuple, idx): try: row_tuple = row_tuple._replace(Index=idx) except AttributeError: # Tuple not namedtuple - row_tuple = (idx,) + row_tuple[1:] + row_tuple = (idx, ) + row_tuple[1:] return row_tuple if index: iters = itertools.starmap(_replace_index, zip(iters, self.index)) return iters - def join(self, other, on=None, how='left', lsuffix='', rsuffix='', - sort=False): + def join( + self, other, on=None, how='left', lsuffix='', rsuffix='', sort=False + ): """Join two or more DataFrames, or a DataFrame with a collection. Args: @@ -2181,25 +2471,28 @@ def join(self, other, on=None, how='left', lsuffix='', rsuffix='', .join(pd.DataFrame(columns=other.columns), lsuffix=lsuffix, rsuffix=rsuffix).columns - new_partition_num = max(len(self._block_partitions.T), - len(other._block_partitions.T)) + new_partition_num = max( + len(self._block_partitions.T), len(other._block_partitions.T) + ) # Join is a concat once we have shuffled the data internally. # We shuffle the data by computing the correct order. # Another important thing to note: We set the current self index # to the index variable which may be 'on'. new_self = np.array([ - _reindex_helper._submit(args=tuple([index, new_index, 1, - new_partition_num] + - block.tolist()), - num_return_vals=new_partition_num) - for block in self._block_partitions.T]) + _reindex_helper._submit( + args=tuple([index, new_index, 1, new_partition_num] + + block.tolist()), + num_return_vals=new_partition_num + ) for block in self._block_partitions.T + ]) new_other = np.array([ - _reindex_helper._submit(args=tuple([other.index, new_index, 1, - new_partition_num] + - block.tolist()), - num_return_vals=new_partition_num) - for block in other._block_partitions.T]) + _reindex_helper._submit( + args=tuple([other.index, new_index, 1, new_partition_num] + + block.tolist()), + num_return_vals=new_partition_num + ) for block in other._block_partitions.T + ]) # Append the blocks together (i.e. concat) new_block_parts = np.concatenate((new_self, new_other)).T @@ -2209,67 +2502,86 @@ def join(self, other, on=None, how='left', lsuffix='', rsuffix='', new_index = None # TODO join the two metadata tables for performance. - return DataFrame(block_partitions=new_block_parts, - index=new_index, - columns=new_column_labels) + return DataFrame( + block_partitions=new_block_parts, + index=new_index, + columns=new_column_labels + ) else: # This constraint carried over from Pandas. if on is not None: - raise ValueError("Joining multiple DataFrames only supported" - " for joining on index") + raise ValueError( + "Joining multiple DataFrames only supported" + " for joining on index" + ) # Joining the empty DataFrames with either index or columns is # fast. It gives us proper error checking for the edge cases that # would otherwise require a lot more logic. - new_index = pd.DataFrame(index=self.index).join( - [pd.DataFrame(index=obj.index) for obj in other], - how=how, sort=sort).index - - new_column_labels = pd.DataFrame(columns=self.columns).join( - [pd.DataFrame(columns=obj.columns) for obj in other], - lsuffix=lsuffix, rsuffix=rsuffix).columns - - new_partition_num = max([len(self._block_partitions.T)] + - [len(obj._block_partitions.T) - for obj in other]) + new_index = pd.DataFrame(index=self.index).join([ + pd.DataFrame(index=obj.index) for obj in other + ], + how=how, + sort=sort).index + + new_column_labels = pd.DataFrame( + columns=self.columns + ).join([pd.DataFrame(columns=obj.columns) for obj in other], + lsuffix=lsuffix, + rsuffix=rsuffix).columns + + new_partition_num = max( + [len(self._block_partitions.T)] + + [len(obj._block_partitions.T) for obj in other] + ) new_self = np.array([ - _reindex_helper._submit(args=tuple([self.index, new_index, 1, - new_partition_num] + - block.tolist()), - num_return_vals=new_partition_num) - for block in self._block_partitions.T]) - - new_others = np.array([_reindex_helper._submit( - args=tuple([obj.index, new_index, 1, new_partition_num] + - block.tolist()), - num_return_vals=new_partition_num - ) for obj in other for block in obj._block_partitions.T]) + _reindex_helper._submit( + args=tuple([self.index, new_index, 1, new_partition_num] + + block.tolist()), + num_return_vals=new_partition_num + ) for block in self._block_partitions.T + ]) + + new_others = np.array([ + _reindex_helper._submit( + args=tuple([obj.index, new_index, 1, new_partition_num] + + block.tolist()), + num_return_vals=new_partition_num + ) for obj in other for block in obj._block_partitions.T + ]) # Append the columns together (i.e. concat) new_block_parts = np.concatenate((new_self, new_others)).T # TODO join the two metadata tables for performance. - return DataFrame(block_partitions=new_block_parts, - index=new_index, - columns=new_column_labels) - - def kurt(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + return DataFrame( + block_partitions=new_block_parts, + index=new_index, + columns=new_column_labels + ) + + def kurt( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def kurtosis(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def kurtosis( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def last(self, offset): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def last_valid_index(self): """Return index for last non-NA/null value. @@ -2295,7 +2607,8 @@ def le(self, other, axis='columns', level=None): def lookup(self, row_labels, col_labels): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def lt(self, other, axis='columns', level=None): """Checks element-wise that this is less than other. @@ -2313,16 +2626,28 @@ def lt(self, other, axis='columns', level=None): def mad(self, axis=None, skipna=None, level=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None, - errors='raise', try_cast=False, raise_on_error=None): + "github.com/ray-project/ray." + ) + + def mask( + self, + cond, + other=np.nan, + inplace=False, + axis=None, + level=None, + errors='raise', + try_cast=False, + raise_on_error=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def max(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def max( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): """Perform max across the DataFrame. Args: @@ -2332,14 +2657,21 @@ def max(self, axis=None, skipna=None, level=None, numeric_only=None, Returns: The max of the DataFrame. """ + def remote_func(df): - return df.max(axis=axis, skipna=skipna, level=level, - numeric_only=numeric_only, **kwargs) + return df.max( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) - def mean(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def mean( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): """Computes mean across the DataFrame. Args: @@ -2349,14 +2681,21 @@ def mean(self, axis=None, skipna=None, level=None, numeric_only=None, Returns: The mean of the DataFrame. (Pandas series) """ + def remote_func(df): - return df.mean(axis=axis, skipna=skipna, level=level, - numeric_only=numeric_only, **kwargs) + return df.mean( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) - def median(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def median( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): """Computes median across the DataFrame. Args: @@ -2366,20 +2705,32 @@ def median(self, axis=None, skipna=None, level=None, numeric_only=None, Returns: The median of the DataFrame. (Pandas series) """ + def remote_func(df): - return df.median(axis=axis, skipna=skipna, level=level, - numeric_only=numeric_only, **kwargs) + return df.median( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) - def melt(self, id_vars=None, value_vars=None, var_name=None, - value_name='value', col_level=None): + def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name='value', + col_level=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def memory_usage(self, index=True, deep=False): - def remote_func(df): return df.memory_usage(index=False, deep=deep) @@ -2392,16 +2743,29 @@ def remote_func(df): return result - def merge(self, right, how='inner', on=None, left_on=None, right_on=None, - left_index=False, right_index=False, sort=False, - suffixes=('_x', '_y'), copy=True, indicator=False, - validate=None): + def merge( + self, + right, + how='inner', + on=None, + left_on=None, + right_on=None, + left_index=False, + right_index=False, + sort=False, + suffixes=('_x', '_y'), + copy=True, + indicator=False, + validate=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def min(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def min( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): """Perform min across the DataFrame. Args: @@ -2411,9 +2775,15 @@ def min(self, axis=None, skipna=None, level=None, numeric_only=None, Returns: The min of the DataFrame. """ + def remote_func(df): - return df.min(axis=axis, skipna=skipna, level=level, - numeric_only=numeric_only, **kwargs) + return df.min( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) @@ -2429,13 +2799,15 @@ def mod(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Mod applied. """ - return self._operator_helper(pd.DataFrame.mod, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.mod, other, axis, level, fill_value + ) def mode(self, axis=0, numeric_only=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def mul(self, other, axis='columns', level=None, fill_value=None): """Multiplies this DataFrame against another DataFrame/Series/scalar. @@ -2449,8 +2821,9 @@ def mul(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Multiply applied. """ - return self._operator_helper(pd.DataFrame.mul, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.mul, other, axis, level, fill_value + ) def multiply(self, other, axis='columns', level=None, fill_value=None): """Synonym for mul. @@ -2482,7 +2855,8 @@ def ne(self, other, axis='columns', level=None): def nlargest(self, n, columns, keep='first'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def notna(self): """Perform notna across the DataFrame. @@ -2494,12 +2868,16 @@ def notna(self): Boolean DataFrame where value is False if corresponding value is NaN, True otherwise """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.notna(), block) for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.notna(), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def notnull(self): """Perform notnull across the DataFrame. @@ -2511,57 +2889,102 @@ def notnull(self): Boolean DataFrame where value is False if corresponding value is NaN, True otherwise """ - new_block_partitions = np.array([_map_partitions( - lambda df: df.notnull(), block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.notnull(), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def nsmallest(self, n, columns, keep='first'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def nunique(self, axis=0, dropna=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None, - **kwargs): + def pct_change( + self, periods=1, fill_method='pad', limit=None, freq=None, **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def pipe(self, func, *args, **kwargs): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def pivot(self, index=None, columns=None, values=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def pivot_table(self, values=None, index=None, columns=None, - aggfunc='mean', fill_value=None, margins=False, - dropna=True, margins_name='All'): + "github.com/ray-project/ray." + ) + + def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc='mean', + fill_value=None, + margins=False, + dropna=True, + margins_name='All' + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def plot(self, x=None, y=None, kind='line', ax=None, subplots=False, - sharex=None, sharey=False, layout=None, figsize=None, - use_index=True, title=None, grid=None, legend=True, style=None, - logx=False, logy=False, loglog=False, xticks=None, yticks=None, - xlim=None, ylim=None, rot=None, fontsize=None, colormap=None, - table=False, yerr=None, xerr=None, secondary_y=False, - sort_columns=False, **kwds): + "github.com/ray-project/ray." + ) + + def plot( + self, + x=None, + y=None, + kind='line', + ax=None, + subplots=False, + sharex=None, + sharey=False, + layout=None, + figsize=None, + use_index=True, + title=None, + grid=None, + legend=True, + style=None, + logx=False, + logy=False, + loglog=False, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + rot=None, + fontsize=None, + colormap=None, + table=False, + yerr=None, + xerr=None, + secondary_y=False, + sort_columns=False, + **kwds + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def pop(self, item): """Pops an item from this DataFrame and returns it. @@ -2589,23 +3012,41 @@ def pow(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Pow applied. """ - return self._operator_helper(pd.DataFrame.pow, other, axis, level, - fill_value) - - def prod(self, axis=None, skipna=None, level=None, numeric_only=None, - min_count=0, **kwargs): + return self._operator_helper( + pd.DataFrame.pow, other, axis, level, fill_value + ) + + def prod( + self, + axis=None, + skipna=None, + level=None, + numeric_only=None, + min_count=0, + **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def product(self, axis=None, skipna=None, level=None, numeric_only=None, - min_count=0, **kwargs): + "github.com/ray-project/ray." + ) + + def product( + self, + axis=None, + skipna=None, + level=None, + numeric_only=None, + min_count=0, + **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def quantile(self, q=0.5, axis=0, numeric_only=True, - interpolation='linear'): + def quantile( + self, q=0.5, axis=0, numeric_only=True, interpolation='linear' + ): """Return values at the given quantile over requested axis, a la numpy.percentile. @@ -2629,8 +3070,12 @@ def quantile(self, q=0.5, axis=0, numeric_only=True, def quantile_helper(df, q, axis, numeric_only, interpolation): try: - return df.quantile(q=q, axis=axis, numeric_only=numeric_only, - interpolation=interpolation) + return df.quantile( + q=q, + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation + ) except ValueError: return pd.Series() @@ -2639,10 +3084,15 @@ def quantile_helper(df, q, axis, numeric_only, interpolation): # TODO Revisit for performance quantiles = [] for q_i in q: + def remote_func(df): - return quantile_helper(df, q=q_i, axis=axis, - numeric_only=numeric_only, - interpolation=interpolation) + return quantile_helper( + df, + q=q_i, + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation + ) result = self._arithmetic_helper(remote_func, axis) result.name = q_i @@ -2650,10 +3100,15 @@ def remote_func(df): return pd.concat(quantiles, axis=1).T else: + def remote_func(df): - return quantile_helper(df, q=q, axis=axis, - numeric_only=numeric_only, - interpolation=interpolation) + return quantile_helper( + df, + q=q, + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation + ) result = self._arithmetic_helper(remote_func, axis) result.name = q @@ -2676,8 +3131,7 @@ def query_helper(df): df.columns = pd.RangeIndex(0, len(df.columns)) return df - new_rows = _map_partitions(query_helper, - self._row_partitions) + new_rows = _map_partitions(query_helper, self._row_partitions) if inplace: self._update_inplace(row_partitions=new_rows) @@ -2687,38 +3141,77 @@ def query_helper(df): def radd(self, other, axis='columns', level=None, fill_value=None): return self.add(other, axis, level, fill_value) - def rank(self, axis=0, method='average', numeric_only=None, - na_option='keep', ascending=True, pct=False): + def rank( + self, + axis=0, + method='average', + numeric_only=None, + na_option='keep', + ascending=True, + pct=False + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def rdiv(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rdiv(other, axis, level, fill_value), - other, axis, level) - - def reindex(self, labels=None, index=None, columns=None, axis=None, - method=None, copy=True, level=None, fill_value=np.nan, - limit=None, tolerance=None): + lambda df: df.rdiv(other, axis, level, fill_value), other, axis, + level + ) + + def reindex( + self, + labels=None, + index=None, + columns=None, + axis=None, + method=None, + copy=True, + level=None, + fill_value=np.nan, + limit=None, + tolerance=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def reindex_axis(self, labels, axis=0, method=None, level=None, copy=True, - limit=None, fill_value=np.nan): + "github.com/ray-project/ray." + ) + + def reindex_axis( + self, + labels, + axis=0, + method=None, + level=None, + copy=True, + limit=None, + fill_value=np.nan + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def reindex_like(self, other, method=None, copy=True, limit=None, - tolerance=None): + def reindex_like( + self, other, method=None, copy=True, limit=None, tolerance=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def rename(self, mapper=None, index=None, columns=None, axis=None, - copy=True, inplace=False, level=None): + "github.com/ray-project/ray." + ) + + def rename( + self, + mapper=None, + index=None, + columns=None, + axis=None, + copy=True, + inplace=False, + level=None + ): """Alters axes labels. Args: @@ -2738,8 +3231,11 @@ def rename(self, mapper=None, index=None, columns=None, axis=None, # kwargs. It doesn't ignore None values passed in, so we have to filter # them ourselves. args = locals() - kwargs = {k: v for k, v in args.items() - if v is not None and k != "self"} + kwargs = { + k: v + for k, v in args.items() + if v is not None and k != "self" + } # inplace should always be true because this is just a copy, and we # will use the results after. kwargs['inplace'] = True @@ -2792,23 +3288,48 @@ def _set_axis_name(self, name, axis=0, inplace=False): def reorder_levels(self, order, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def replace(self, to_replace=None, value=None, inplace=False, limit=None, - regex=False, method='pad', axis=None): + "github.com/ray-project/ray." + ) + + def replace( + self, + to_replace=None, + value=None, + inplace=False, + limit=None, + regex=False, + method='pad', + axis=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def resample(self, rule, how=None, axis=0, fill_method=None, closed=None, - label=None, convention='start', kind=None, loffset=None, - limit=None, base=0, on=None, level=None): + "github.com/ray-project/ray." + ) + + def resample( + self, + rule, + how=None, + axis=0, + fill_method=None, + closed=None, + label=None, + convention='start', + kind=None, + loffset=None, + limit=None, + base=0, + on=None, + level=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def reset_index(self, level=None, drop=False, inplace=False, col_level=0, - col_fill=''): + def reset_index( + self, level=None, drop=False, inplace=False, col_level=0, col_fill='' + ): """Reset this index to default and create column from current index. Args: @@ -2856,7 +3377,8 @@ def _maybe_casted_values(index, labels=None): values = values.take(labels) if mask.any(): values, changed = maybe_upcast_putmask( - values, mask, np.nan) + values, mask, np.nan + ) return values # We're building a new default index dataframe for use later. @@ -2871,8 +3393,10 @@ def _maybe_casted_values(index, labels=None): if not drop: if isinstance(self.index, pd.MultiIndex): - names = [n if n is not None else ('level_%d' % i) - for (i, n) in enumerate(self.index.names)] + names = [ + n if n is not None else ('level_%d' % i) + for (i, n) in enumerate(self.index.names) + ] to_insert = lzip(self.index.levels, self.index.labels) else: default = 'index' @@ -2881,9 +3405,9 @@ def _maybe_casted_values(index, labels=None): default = 'level_{}'.format(i) i += 1 - names = ([default] if self.index.name is None - else [self.index.name]) - to_insert = ((self.index, None),) + names = ([default] + if self.index.name is None else [self.index.name]) + to_insert = ((self.index, None), ) multi_col = isinstance(self.columns, pd.MultiIndex) for i, (lev, lab) in reversed(list(enumerate(to_insert))): @@ -2891,13 +3415,16 @@ def _maybe_casted_values(index, labels=None): continue name = names[i] if multi_col: - col_name = (list(name) if isinstance(name, tuple) - else [name]) + col_name = ( + list(name) if isinstance(name, tuple) else [name] + ) if col_fill is None: if len(col_name) not in (1, self.columns.nlevels): - raise ValueError("col_fill=None is incompatible " - "with incomplete column name " - "{}".format(name)) + raise ValueError( + "col_fill=None is incompatible " + "with incomplete column name " + "{}".format(name) + ) col_fill = col_name[0] lev_num = self.columns._get_level_number(col_level) @@ -2916,49 +3443,75 @@ def _maybe_casted_values(index, labels=None): def rfloordiv(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rfloordiv(other, axis, level, fill_value), - other, axis, level) + lambda df: df.rfloordiv(other, axis, level, fill_value), other, + axis, level + ) def rmod(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rmod(other, axis, level, fill_value), - other, axis, level) + lambda df: df.rmod(other, axis, level, fill_value), other, axis, + level + ) def rmul(self, other, axis='columns', level=None, fill_value=None): return self.mul(other, axis, level, fill_value) - def rolling(self, window, min_periods=None, freq=None, center=False, - win_type=None, on=None, axis=0, closed=None): + def rolling( + self, + window, + min_periods=None, + freq=None, + center=False, + win_type=None, + on=None, + axis=0, + closed=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def round(self, decimals=0, *args, **kwargs): - new_block_partitions = np.array([_map_partitions( - lambda df: df.round(decimals=decimals, *args, **kwargs), block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions( + lambda df: df.round(decimals=decimals, *args, **kwargs), block + ) for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def rpow(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rpow(other, axis, level, fill_value), - other, axis, level) + lambda df: df.rpow(other, axis, level, fill_value), other, axis, + level + ) def rsub(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rsub(other, axis, level, fill_value), - other, axis, level) + lambda df: df.rsub(other, axis, level, fill_value), other, axis, + level + ) def rtruediv(self, other, axis='columns', level=None, fill_value=None): return self._single_df_op_helper( - lambda df: df.rtruediv(other, axis, level, fill_value), - other, axis, level) - - def sample(self, n=None, frac=None, replace=False, weights=None, - random_state=None, axis=None): + lambda df: df.rtruediv(other, axis, level, fill_value), other, axis, + level + ) + + def sample( + self, + n=None, + frac=None, + replace=False, + weights=None, + random_state=None, + axis=None + ): """Returns a random sample of items from an axis of object. Args: @@ -3006,25 +3559,33 @@ def sample(self, n=None, frac=None, replace=False, weights=None, try: weights = self[weights] except KeyError: - raise KeyError("String passed to weights not a " - "valid column") + raise KeyError( + "String passed to weights not a " + "valid column" + ) else: - raise ValueError("Strings can only be passed to " - "weights when sampling from rows on " - "a DataFrame") + raise ValueError( + "Strings can only be passed to " + "weights when sampling from rows on " + "a DataFrame" + ) weights = pd.Series(weights, dtype='float64') if len(weights) != axis_length: - raise ValueError("Weights and axis to be sampled must be of " - "same length") + raise ValueError( + "Weights and axis to be sampled must be of " + "same length" + ) if (weights == np.inf).any() or (weights == -np.inf).any(): raise ValueError("weight vector may not include `inf` values") if (weights < 0).any(): - raise ValueError("weight vector many not include negative " - "values") + raise ValueError( + "weight vector many not include negative " + "values" + ) # weights cannot be NaN when sampling, so we must set all nan # values to 0 @@ -3054,18 +3615,24 @@ def sample(self, n=None, frac=None, replace=False, weights=None, elif n is not None and frac is not None: # Pandas specification does not allow both n and frac to be passed # in - raise ValueError('Please enter a value for `frac` OR `n`, not ' - 'both') + raise ValueError( + 'Please enter a value for `frac` OR `n`, not ' + 'both' + ) if n < 0: - raise ValueError("A negative number of rows requested. Please " - "provide positive value.") + raise ValueError( + "A negative number of rows requested. Please " + "provide positive value." + ) if n == 0: # An Empty DataFrame is returned if the number of samples is 0. # The Empty Dataframe should have either columns or index specified # depending on which axis is passed in. - return DataFrame(columns=[] if axis == 1 else self.columns, - index=self.index if axis == 1 else []) + return DataFrame( + columns=[] if axis == 1 else self.columns, + index=self.index if axis == 1 else [] + ) if axis == 1: axis_labels = self.columns @@ -3085,56 +3652,73 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_num_gen = random_state else: # random_state must be an int or a numpy RandomState object - raise ValueError("Please enter an `int` OR a " - "np.random.RandomState for random_state") + raise ValueError( + "Please enter an `int` OR a " + "np.random.RandomState for random_state" + ) # choose random numbers and then get corresponding labels from # chosen axis sample_indices = random_num_gen.randint( - low=0, - high=len(partition_metadata), - size=n) + low=0, high=len(partition_metadata), size=n + ) samples = axis_labels[sample_indices] else: # randomly select labels from chosen axis - samples = np.random.choice(a=axis_labels, size=n, - replace=replace, p=weights) + samples = np.random.choice( + a=axis_labels, size=n, replace=replace, p=weights + ) # create an array of (partition, index_within_partition) tuples for # each sample - part_ind_tuples = [partition_metadata[sample] - for sample in samples] + part_ind_tuples = [partition_metadata[sample] for sample in samples] if axis == 1: # tup[0] refers to the partition number and tup[1] is the index # within that partition - new_cols = [_deploy_func.remote(lambda df: df.iloc[:, [tup[1]]], - partitions[tup[0]]) for tup in part_ind_tuples] - return DataFrame(col_partitions=new_cols, - columns=samples, - index=self.index) + new_cols = [ + _deploy_func.remote( + lambda df: df.iloc[:, [tup[1]]], partitions[tup[0]] + ) for tup in part_ind_tuples + ] + return DataFrame( + col_partitions=new_cols, columns=samples, index=self.index + ) else: - new_rows = [_deploy_func.remote(lambda df: df.loc[[tup[1]]], - partitions[tup[0]]) for tup in part_ind_tuples] - return DataFrame(row_partitions=new_rows, - columns=self.columns, - index=samples) + new_rows = [ + _deploy_func.remote( + lambda df: df.loc[[tup[1]]], partitions[tup[0]] + ) for tup in part_ind_tuples + ] + return DataFrame( + row_partitions=new_rows, columns=self.columns, index=samples + ) def select(self, crit, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def select_dtypes(self, include=None, exclude=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def sem(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + "github.com/ray-project/ray." + ) + + def sem( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def set_axis(self, labels, axis=0, inplace=None): """Assign desired index to given axis. @@ -3153,7 +3737,9 @@ def set_axis(self, labels, axis=0, inplace=None): '"axis" as named parameter. The old form, with "axis" as ' 'first parameter and \"labels\" as second, is still supported ' 'but will be deprecated in a future version of pandas.', - FutureWarning, stacklevel=2) + FutureWarning, + stacklevel=2 + ) labels, axis = axis, labels if inplace is None: @@ -3161,7 +3747,9 @@ def set_axis(self, labels, axis=0, inplace=None): 'set_axis currently defaults to operating inplace.\nThis ' 'will change in a future version of pandas, use ' 'inplace=True to avoid this warning.', - FutureWarning, stacklevel=2) + FutureWarning, + stacklevel=2 + ) inplace = True if inplace: setattr(self, pd.DataFrame()._get_axis_name(axis), labels) @@ -3170,8 +3758,14 @@ def set_axis(self, labels, axis=0, inplace=None): obj.set_axis(labels, axis=axis, inplace=True) return obj - def set_index(self, keys, drop=True, append=False, inplace=False, - verify_integrity=False): + def set_index( + self, + keys, + drop=True, + append=False, + inplace=False, + verify_integrity=False + ): """Set the DataFrame index using one or more existing columns. Args: @@ -3251,55 +3845,93 @@ def set_index(self, keys, drop=True, append=False, inplace=False, def set_value(self, index, col, value, takeable=False): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def shift(self, periods=1, freq=None, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def skew(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def skew( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def slice_shift(self, periods=1, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def sort_index(self, axis=0, level=None, ascending=True, inplace=False, - kind='quicksort', na_position='last', sort_remaining=True, - by=None): + "github.com/ray-project/ray." + ) + + def sort_index( + self, + axis=0, + level=None, + ascending=True, + inplace=False, + kind='quicksort', + na_position='last', + sort_remaining=True, + by=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def sort_values(self, by, axis=0, ascending=True, inplace=False, - kind='quicksort', na_position='last'): + "github.com/ray-project/ray." + ) + + def sort_values( + self, + by, + axis=0, + ascending=True, + inplace=False, + kind='quicksort', + na_position='last' + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def sortlevel(self, level=0, axis=0, ascending=True, inplace=False, - sort_remaining=True): + "github.com/ray-project/ray." + ) + + def sortlevel( + self, + level=0, + axis=0, + ascending=True, + inplace=False, + sort_remaining=True + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def squeeze(self, axis=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def stack(self, level=-1, dropna=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def std(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + "github.com/ray-project/ray." + ) + + def std( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): """Computes standard deviation across the DataFrame. Args: @@ -3310,9 +3942,16 @@ def std(self, axis=None, skipna=None, level=None, ddof=1, Returns: The std of the DataFrame (Pandas Series) """ + def remote_func(df): - return df.std(axis=axis, skipna=skipna, level=level, ddof=ddof, - numeric_only=numeric_only, **kwargs) + return df.std( + axis=axis, + skipna=skipna, + level=level, + ddof=ddof, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) @@ -3328,8 +3967,9 @@ def sub(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the subtraciont applied. """ - return self._operator_helper(pd.DataFrame.sub, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.sub, other, axis, level, fill_value + ) def subtract(self, other, axis='columns', level=None, fill_value=None): """Alias for sub. @@ -3348,12 +3988,14 @@ def subtract(self, other, axis='columns', level=None, fill_value=None): def swapaxes(self, axis1, axis2, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def swaplevel(self, i=-2, j=-1, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def tail(self, n=5): """Get the last n rows of the dataframe. @@ -3367,138 +4009,231 @@ def tail(self, n=5): if n >= len(self._row_metadata): return self - new_dfs = _map_partitions(lambda df: df.tail(n), - self._col_partitions) + new_dfs = _map_partitions(lambda df: df.tail(n), self._col_partitions) index = self._row_metadata.index[-n:] - return DataFrame(col_partitions=new_dfs, - columns=self.columns, - index=index) + return DataFrame( + col_partitions=new_dfs, columns=self.columns, index=index + ) def take(self, indices, axis=0, convert=None, is_copy=True, **kwargs): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_clipboard(self, excel=None, sep=None, **kwargs): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_clipboard(excel, sep, **kwargs) - def to_csv(self, path_or_buf=None, sep=',', na_rep='', float_format=None, - columns=None, header=True, index=True, index_label=None, - mode='w', encoding=None, compression=None, quoting=None, - quotechar='"', line_terminator='\n', chunksize=None, - tupleize_cols=None, date_format=None, doublequote=True, - escapechar=None, decimal='.'): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + def to_csv( + self, + path_or_buf=None, + sep=',', + na_rep='', + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + mode='w', + encoding=None, + compression=None, + quoting=None, + quotechar='"', + line_terminator='\n', + chunksize=None, + tupleize_cols=None, + date_format=None, + doublequote=True, + escapechar=None, + decimal='.' + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_csv(path_or_buf, sep, na_rep, float_format, - columns, header, index, index_label, - mode, encoding, compression, quoting, - quotechar, line_terminator, chunksize, - tupleize_cols, date_format, doublequote, - escapechar, decimal) + port_frame.to_csv( + path_or_buf, sep, na_rep, float_format, columns, header, index, + index_label, mode, encoding, compression, quoting, quotechar, + line_terminator, chunksize, tupleize_cols, date_format, doublequote, + escapechar, decimal + ) def to_dense(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_dict(self, orient='dict', into=dict): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='', - float_format=None, columns=None, header=True, index=True, - index_label=None, startrow=0, startcol=0, engine=None, - merge_cells=True, encoding=None, inf_rep='inf', verbose=True, - freeze_panes=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + "github.com/ray-project/ray." + ) + + def to_excel( + self, + excel_writer, + sheet_name='Sheet1', + na_rep='', + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + startrow=0, + startcol=0, + engine=None, + merge_cells=True, + encoding=None, + inf_rep='inf', + verbose=True, + freeze_panes=None + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_excel(excel_writer, sheet_name, na_rep, - float_format, columns, header, index, - index_label, startrow, startcol, engine, - merge_cells, encoding, inf_rep, verbose, - freeze_panes) + port_frame.to_excel( + excel_writer, sheet_name, na_rep, float_format, columns, header, + index, index_label, startrow, startcol, engine, merge_cells, + encoding, inf_rep, verbose, freeze_panes + ) def to_feather(self, fname): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_feather(fname) - def to_gbq(self, destination_table, project_id, chunksize=10000, - verbose=True, reauth=False, if_exists='fail', - private_key=None): + def to_gbq( + self, + destination_table, + project_id, + chunksize=10000, + verbose=True, + reauth=False, + if_exists='fail', + private_key=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_hdf(self, path_or_buf, key, **kwargs): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_hdf(path_or_buf, key, **kwargs) - def to_html(self, buf=None, columns=None, col_space=None, header=True, - index=True, na_rep='np.NaN', formatters=None, - float_format=None, sparsify=None, index_names=True, - justify=None, bold_rows=True, classes=None, escape=True, - max_rows=None, max_cols=None, show_dimensions=False, - notebook=False, decimal='.', border=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + def to_html( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep='np.NaN', + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + bold_rows=True, + classes=None, + escape=True, + max_rows=None, + max_cols=None, + show_dimensions=False, + notebook=False, + decimal='.', + border=None + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_html(buf, columns, col_space, header, - index, na_rep, formatters, - float_format, sparsify, index_names, - justify, bold_rows, classes, escape, - max_rows, max_cols, show_dimensions, - notebook, decimal, border) - - def to_json(self, path_or_buf=None, orient=None, date_format=None, - double_precision=10, force_ascii=True, date_unit='ms', - default_handler=None, lines=False, compression=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + port_frame.to_html( + buf, columns, col_space, header, index, na_rep, formatters, + float_format, sparsify, index_names, justify, bold_rows, classes, + escape, max_rows, max_cols, show_dimensions, notebook, decimal, + border + ) + + def to_json( + self, + path_or_buf=None, + orient=None, + date_format=None, + double_precision=10, + force_ascii=True, + date_unit='ms', + default_handler=None, + lines=False, + compression=None + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_json(path_or_buf, orient, date_format, - double_precision, force_ascii, date_unit, - default_handler, lines, compression) - - def to_latex(self, buf=None, columns=None, col_space=None, header=True, - index=True, na_rep='np.NaN', formatters=None, - float_format=None, sparsify=None, index_names=True, - bold_rows=False, column_format=None, longtable=None, - escape=None, encoding=None, decimal='.', multicolumn=None, - multicolumn_format=None, multirow=None): + port_frame.to_json( + path_or_buf, orient, date_format, double_precision, force_ascii, + date_unit, default_handler, lines, compression + ) + + def to_latex( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep='np.NaN', + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + bold_rows=False, + column_format=None, + longtable=None, + escape=None, + encoding=None, + decimal='.', + multicolumn=None, + multicolumn_format=None, + multirow=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_msgpack(path_or_buf, encoding, **kwargs) @@ -3506,13 +4241,14 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): def to_panel(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def to_parquet(self, fname, engine='auto', compression='snappy', - **kwargs): + def to_parquet(self, fname, engine='auto', compression='snappy', **kwargs): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_parquet(fname, engine, compression, **kwargs) @@ -3520,13 +4256,16 @@ def to_parquet(self, fname, engine='auto', compression='snappy', def to_period(self, freq=None, axis=0, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def to_pickle(self, path, compression='infer', - protocol=pkl.HIGHEST_PROTOCOL): + def to_pickle( + self, path, compression='infer', protocol=pkl.HIGHEST_PROTOCOL + ): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) port_frame.to_pickle(path, compression, protocol) @@ -3534,53 +4273,94 @@ def to_pickle(self, path, compression='infer', def to_records(self, index=True, convert_datetime64=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_sparse(self, fill_value=None, kind='block'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + "github.com/ray-project/ray." + ) + + def to_sql( + self, + name, + con, + flavor=None, + schema=None, + if_exists='fail', + index=True, + index_label=None, + chunksize=None, + dtype=None + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_sql(name, con, flavor, schema, if_exists, - index, index_label, chunksize, dtype) - - def to_stata(self, fname, convert_dates=None, write_index=True, - encoding='latin-1', byteorder=None, time_stamp=None, - data_label=None, variable_labels=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + port_frame.to_sql( + name, con, flavor, schema, if_exists, index, index_label, chunksize, + dtype + ) + + def to_stata( + self, + fname, + convert_dates=None, + write_index=True, + encoding='latin-1', + byteorder=None, + time_stamp=None, + data_label=None, + variable_labels=None + ): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = to_pandas(self) - port_frame.to_stata(fname, convert_dates, write_index, - encoding, byteorder, time_stamp, - data_label, variable_labels) - - def to_string(self, buf=None, columns=None, col_space=None, header=True, - index=True, na_rep='np.NaN', formatters=None, - float_format=None, sparsify=None, index_names=True, - justify=None, line_width=None, max_rows=None, max_cols=None, - show_dimensions=False): + port_frame.to_stata( + fname, convert_dates, write_index, encoding, byteorder, time_stamp, + data_label, variable_labels + ) + + def to_string( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep='np.NaN', + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + line_width=None, + max_rows=None, + max_cols=None, + show_dimensions=False + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_timestamp(self, freq=None, how='start', axis=0, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def to_xarray(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def transform(self, func, *args, **kwargs): kwargs["is_transform"] = True @@ -3604,43 +4384,62 @@ def truediv(self, other, axis='columns', level=None, fill_value=None): Returns: A new DataFrame with the Divide applied. """ - return self._operator_helper(pd.DataFrame.truediv, other, axis, level, - fill_value) + return self._operator_helper( + pd.DataFrame.truediv, other, axis, level, fill_value + ) def truncate(self, before=None, after=None, axis=None, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def tshift(self, periods=1, freq=None, axis=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def tz_convert(self, tz, axis=0, level=None, copy=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) - def tz_localize(self, tz, axis=0, level=None, copy=True, - ambiguous='raise'): + def tz_localize(self, tz, axis=0, level=None, copy=True, ambiguous='raise'): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def unstack(self, level=-1, fill_value=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def update(self, other, join='left', overwrite=True, filter_func=None, - raise_conflict=False): + "github.com/ray-project/ray." + ) + + def update( + self, + other, + join='left', + overwrite=True, + filter_func=None, + raise_conflict=False + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") - - def var(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + "github.com/ray-project/ray." + ) + + def var( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): """Computes variance across the DataFrame. Args: @@ -3651,22 +4450,40 @@ def var(self, axis=None, skipna=None, level=None, ddof=1, Returns: The variance of the DataFrame. """ + def remote_func(df): - return df.var(axis=axis, skipna=skipna, level=level, ddof=ddof, - numeric_only=numeric_only, **kwargs) + return df.var( + axis=axis, + skipna=skipna, + level=level, + ddof=ddof, + numeric_only=numeric_only, + **kwargs + ) return self._arithmetic_helper(remote_func, axis, level) - def where(self, cond, other=np.nan, inplace=False, axis=None, level=None, - errors='raise', try_cast=False, raise_on_error=None): + def where( + self, + cond, + other=np.nan, + inplace=False, + axis=None, + level=None, + errors='raise', + try_cast=False, + raise_on_error=None + ): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def xs(self, key, axis=0, level=None, drop_level=True): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __getitem__(self, key): """Get the column specified by key for this DataFrame. @@ -3695,12 +4512,16 @@ def __getitem__(self, key): if isinstance(key, (pd.Series, np.ndarray, pd.Index, list)): return self._getitem_array(key) elif isinstance(key, DataFrame): - raise NotImplementedError("To contribute to Pandas on Ray, please" - "visit github.com/ray-project/ray.") + raise NotImplementedError( + "To contribute to Pandas on Ray, please" + "visit github.com/ray-project/ray." + ) # return self._getitem_frame(key) elif is_mi_columns: - raise NotImplementedError("To contribute to Pandas on Ray, please" - "visit github.com/ray-project/ray.") + raise NotImplementedError( + "To contribute to Pandas on Ray, please" + "visit github.com/ray-project/ray." + ) # return self._getitem_multilevel(key) else: return self._getitem_column(key) @@ -3717,36 +4538,47 @@ def _getitem_array(self, key): if com.is_bool_indexer(key): if isinstance(key, pd.Series) and \ not key.index.equals(self.index): - warnings.warn("Boolean Series key will be reindexed to match " - "DataFrame index.", UserWarning, stacklevel=3) + warnings.warn( + "Boolean Series key will be reindexed to match " + "DataFrame index.", + UserWarning, + stacklevel=3 + ) elif len(key) != len(self.index): - raise ValueError('Item wrong length {} instead of {}.'.format( - len(key), len(self.index))) + raise ValueError( + 'Item wrong length {} instead of {}.'.format( + len(key), len(self.index) + ) + ) key = check_bool_indexer(self.index, key) - new_parts = _map_partitions(lambda df: df[key], - self._col_partitions) + new_parts = _map_partitions( + lambda df: df[key], self._col_partitions + ) columns = self.columns index = self.index[key] - return DataFrame(col_partitions=new_parts, - columns=columns, - index=index) + return DataFrame( + col_partitions=new_parts, columns=columns, index=index + ) else: columns = self._col_metadata[key].index - indices_for_rows = [self.columns.index(new_col) - for new_col in columns] + indices_for_rows = [ + self.columns.index(new_col) for new_col in columns + ] - new_parts = [_deploy_func.remote( - lambda df: df.__getitem__(indices_for_rows), - part) for part in self._row_partitions] + new_parts = [ + _deploy_func.remote( + lambda df: df.__getitem__(indices_for_rows), part + ) for part in self._row_partitions + ] index = self.index - return DataFrame(row_partitions=new_parts, - columns=columns, - index=index) + return DataFrame( + row_partitions=new_parts, columns=columns, index=index + ) def _getitem_indiv_col(self, key, part): loc = self._col_metadata[key] @@ -3755,17 +4587,16 @@ def _getitem_indiv_col(self, key, part): else: index = loc[loc['partition'] == part]['index_within_partition'] return _deploy_func.remote( - lambda df: df.__getitem__(index), - self._col_partitions[part]) + lambda df: df.__getitem__(index), self._col_partitions[part] + ) def _getitem_slice(self, key): - new_cols = _map_partitions(lambda df: df[key], - self._col_partitions) + new_cols = _map_partitions(lambda df: df[key], self._col_partitions) index = self.index[key] - return DataFrame(col_partitions=new_cols, - index=index, - columns=self.columns) + return DataFrame( + col_partitions=new_cols, index=index, columns=self.columns + ) def __getattr__(self, key): """After regular attribute access, looks up the name in the columns @@ -3786,7 +4617,8 @@ def __getattr__(self, key): def __setitem__(self, key, value): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __len__(self): """Gets the length of the dataframe. @@ -3799,17 +4631,20 @@ def __len__(self): def __unicode__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __invert__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __hash__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __iter__(self): """Iterate over the columns @@ -3833,12 +4668,14 @@ def __contains__(self, key): def __nonzero__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __bool__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __abs__(self): """Creates a modified DataFrame by taking the absolute value. @@ -3851,7 +4688,8 @@ def __abs__(self): def __round__(self, decimals=0): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __array__(self, dtype=None): # TODO: This is very inefficient and needs fix @@ -3860,17 +4698,20 @@ def __array__(self, dtype=None): def __array_wrap__(self, result, context=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __getstate__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __setstate__(self, state): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __delitem__(self, key): """Delete a column by key. `del a[key]` for example. @@ -3881,6 +4722,7 @@ def __delitem__(self, key): Args: key: key to delete """ + # Create helper method for deleting column(s) in row partition. def del_helper(df, to_delete): cols = df.columns[to_delete] # either int or an array of ints @@ -3897,7 +4739,8 @@ def del_helper(df, to_delete): to_delete = self.columns.get_loc(key) self._row_partitions = _map_partitions( - del_helper, self._row_partitions, to_delete) + del_helper, self._row_partitions, to_delete + ) # This structure is used to get the correct index inside the partition. del_df = self._col_metadata[key] @@ -3919,14 +4762,16 @@ def del_helper(df, to_delete): del_df[del_df['partition'] == i]['index_within_partition'] self._col_partitions[i] = _deploy_func.remote( - del_helper, self._col_partitions[i], to_delete_in_partition) + del_helper, self._col_partitions[i], to_delete_in_partition + ) self._col_metadata.reset_partition_coords(col_parts_to_del) def __finalize__(self, other, method=None, **kwargs): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __copy__(self, deep=True): """Make a copy using Ray.DataFrame.copy method @@ -3955,17 +4800,20 @@ def __deepcopy__(self, memo=None): def __and__(self, other): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __or__(self, other): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __xor__(self, other): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def __lt__(self, other): return self.lt(other) @@ -4027,8 +4875,7 @@ def __floordiv__(self, other): def __ifloordiv__(self, other): return self.floordiv(other) - def __rfloordiv__(self, other, axis="columns", level=None, - fill_value=None): + def __rfloordiv__(self, other, axis="columns", level=None, fill_value=None): return self.rfloordiv(other, axis, level, fill_value) def __truediv__(self, other): @@ -4062,47 +4909,57 @@ def __neg__(self): A modified DataFrame where every element is the negation of before """ for t in self.dtypes: - if not (is_bool_dtype(t) - or is_numeric_dtype(t) - or is_timedelta64_dtype(t)): - raise TypeError("Unary negative expects numeric dtype, not {}" - .format(t)) + if not ( + is_bool_dtype(t) or is_numeric_dtype(t) + or is_timedelta64_dtype(t) + ): + raise TypeError( + "Unary negative expects numeric dtype, not {}".format(t) + ) - new_block_partitions = np.array([_map_partitions( - lambda df: df.__neg__(), block) - for block in self._block_partitions]) + new_block_partitions = np.array([ + _map_partitions(lambda df: df.__neg__(), block) + for block in self._block_partitions + ]) - return DataFrame(block_partitions=new_block_partitions, - columns=self.columns, - index=self.index) + return DataFrame( + block_partitions=new_block_partitions, + columns=self.columns, + index=self.index + ) def __sizeof__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def __doc__(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def blocks(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def style(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def iat(self, axis=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def loc(self): @@ -4113,23 +4970,27 @@ def loc(self): """ raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def is_copy(self): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def at(self, axis=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def ix(self, axis=None): raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) @property def iloc(self): @@ -4140,7 +5001,8 @@ def iloc(self): """ raise NotImplementedError( "To contribute to Pandas on Ray, please visit " - "github.com/ray-project/ray.") + "github.com/ray-project/ray." + ) def _copartition(self, other, new_index): """Colocates the values of other with this for certain operations. @@ -4160,8 +5022,9 @@ def _copartition(self, other, new_index): new_index = ray.put(new_index) old_other_index = ray.put(other.index) - new_num_partitions = max(len(self._block_partitions.T), - len(other._block_partitions.T)) + new_num_partitions = max( + len(self._block_partitions.T), len(other._block_partitions.T) + ) new_partitions_self = \ np.array([_reindex_helper._submit( @@ -4183,17 +5046,20 @@ def _operator_helper(self, func, other, axis, level, *args): """Helper method for inter-dataframe and scalar operations""" if isinstance(other, DataFrame): return self._inter_df_op_helper( - lambda x, y: func(x, y, axis, level, *args), - other, axis, level) + lambda x, y: func(x, y, axis, level, *args), other, axis, level + ) else: return self._single_df_op_helper( - lambda df: func(df, other, axis, level, *args), - other, axis, level) + lambda df: func(df, other, axis, level, *args), other, axis, + level + ) def _inter_df_op_helper(self, func, other, axis, level): if level is not None: - raise NotImplementedError("Mutlilevel index not yet supported " - "in Pandas on Ray") + raise NotImplementedError( + "Mutlilevel index not yet supported " + "in Pandas on Ray" + ) axis = pd.DataFrame()._get_axis_number(axis) # Adding two DataFrames causes an outer join. @@ -4211,14 +5077,18 @@ def _inter_df_op_helper(self, func, other, axis, level): for part in copartitions]) # TODO join the Index Metadata objects together for performance. - return DataFrame(block_partitions=new_blocks, - columns=new_column_index, - index=new_index) + return DataFrame( + block_partitions=new_blocks, + columns=new_column_index, + index=new_index + ) def _single_df_op_helper(self, func, other, axis, level): if level is not None: - raise NotImplementedError("Multilevel index not yet supported " - "in Pandas on Ray") + raise NotImplementedError( + "Multilevel index not yet supported " + "in Pandas on Ray" + ) axis = pd.DataFrame()._get_axis_number(axis) if is_list_like(other): @@ -4232,20 +5102,23 @@ def _single_df_op_helper(self, func, other, axis, level): if len(other) != len(self.index): raise ValueError( "Unable to coerce to Series, length must be {0}: " - "given {1}".format(len(self.index), len(other))) + "given {1}".format(len(self.index), len(other)) + ) new_columns = _map_partitions(func, self._col_partitions) new_rows = None else: if len(other) != len(self.columns): raise ValueError( "Unable to coerce to Series, length must be {0}: " - "given {1}".format(len(self.columns), len(other))) + "given {1}".format(len(self.columns), len(other)) + ) new_rows = _map_partitions(func, self._row_partitions) new_columns = None else: - new_blocks = np.array([_map_partitions(func, block) - for block in self._block_partitions]) + new_blocks = np.array([ + _map_partitions(func, block) for block in self._block_partitions + ]) new_columns = None new_rows = None new_index = self.index @@ -4253,10 +5126,12 @@ def _single_df_op_helper(self, func, other, axis, level): new_col_metadata = self._col_metadata new_row_metadata = self._row_metadata - return DataFrame(col_partitions=new_columns, - row_partitions=new_rows, - block_partitions=new_blocks, - index=new_index, - columns=new_column_index, - col_metadata=new_col_metadata, - row_metadata=new_row_metadata) + return DataFrame( + col_partitions=new_columns, + row_partitions=new_rows, + block_partitions=new_blocks, + index=new_index, + columns=new_column_index, + col_metadata=new_col_metadata, + row_metadata=new_row_metadata + ) diff --git a/python/ray/dataframe/groupby.py b/python/ray/dataframe/groupby.py index 892bc8f74e19..528a154c7db1 100644 --- a/python/ray/dataframe/groupby.py +++ b/python/ray/dataframe/groupby.py @@ -14,9 +14,9 @@ @_inherit_docstrings(pandas.core.groupby.DataFrameGroupBy) class DataFrameGroupBy(object): - - def __init__(self, df, by, axis, level, as_index, sort, group_keys, - squeeze, **kwargs): + def __init__( + self, df, by, axis, level, as_index, sort, group_keys, squeeze, **kwargs + ): self._columns = df.columns self._index = df.index @@ -53,23 +53,29 @@ def _iter(self): from .dataframe import DataFrame if self._axis == 0: - return [(self._keys_and_values[i][0], - DataFrame(col_partitions=part, - columns=self._columns, - index=self._keys_and_values[i][1].index, - row_metadata=self._row_metadata[ - self._keys_and_values[i][1].index], - col_metadata=self._col_metadata)) - for i, part in enumerate(self._grouped_partitions)] + return [( + self._keys_and_values[i][0], + DataFrame( + col_partitions=part, + columns=self._columns, + index=self._keys_and_values[i][1].index, + row_metadata=self._row_metadata[self._keys_and_values[i][1] + .index], + col_metadata=self._col_metadata + ) + ) for i, part in enumerate(self._grouped_partitions)] else: - return [(self._keys_and_values[i][0], - DataFrame(row_partitions=part, - columns=self._keys_and_values[i][1].index, - index=self._index, - row_metadata=self._row_metadata, - col_metadata=self._col_metadata[ - self._keys_and_values[i][1].index])) - for i, part in enumerate(self._grouped_partitions)] + return [( + self._keys_and_values[i][0], + DataFrame( + row_partitions=part, + columns=self._keys_and_values[i][1].index, + index=self._index, + row_metadata=self._row_metadata, + col_metadata=self._col_metadata[self._keys_and_values[i][1] + .index] + ) + ) for i, part in enumerate(self._grouped_partitions)] @property def ngroups(self): @@ -125,9 +131,9 @@ def nth(self, n, dropna=None): raise NotImplementedError("Not Yet implemented.") def cumsum(self, axis=0, *args, **kwargs): - return self._apply_agg_function(lambda df: df.cumsum(axis, - *args, - **kwargs)) + return self._apply_agg_function( + lambda df: df.cumsum(axis, *args, **kwargs) + ) @property def indices(self): @@ -140,8 +146,9 @@ def filter(self, func, dropna=True, *args, **kwargs): raise NotImplementedError("Not Yet implemented.") def cummax(self, axis=0, **kwargs): - return self._apply_agg_function(lambda df: df.cummax(axis=axis, - **kwargs)) + return self._apply_agg_function( + lambda df: df.cummax(axis=axis, **kwargs) + ) def apply(self, func, *args, **kwargs): return self._apply_df_function(lambda df: df.apply(func, @@ -157,8 +164,7 @@ def dtypes(self): return self._apply_agg_function(lambda df: df.dtypes) def first(self, **kwargs): - return self._apply_agg_function(lambda df: df.first(offset=0, - **kwargs)) + return self._apply_agg_function(lambda df: df.first(offset=0, **kwargs)) def backfill(self, limit=None): return self.bfill(limit) @@ -168,8 +174,9 @@ def __getitem__(self, key): raise NotImplementedError("Not Yet implemented.") def cummin(self, axis=0, **kwargs): - return self._apply_agg_function(lambda df: df.cummin(axis=axis, - **kwargs)) + return self._apply_agg_function( + lambda df: df.cummin(axis=axis, **kwargs) + ) def bfill(self, limit=None): return self._apply_agg_function(lambda df: df.bfill(limit=limit)) @@ -181,8 +188,9 @@ def prod(self, **kwargs): return self._apply_agg_function(lambda df: df.prod(**kwargs)) def std(self, ddof=1, *args, **kwargs): - return self._apply_agg_function(lambda df: df.std(ddof=ddof, - *args, **kwargs)) + return self._apply_agg_function( + lambda df: df.std(ddof=ddof, *args, **kwargs) + ) def aggregate(self, arg, *args, **kwargs): return self._apply_df_function(lambda df: df.agg(arg, @@ -213,9 +221,9 @@ def max(self, **kwargs): return self._apply_agg_function(lambda df: df.max(**kwargs)) def var(self, ddof=1, *args, **kwargs): - return self._apply_agg_function(lambda df: df.var(ddof, - *args, - **kwargs)) + return self._apply_agg_function( + lambda df: df.var(ddof, *args, **kwargs) + ) def get_group(self, name, obj=None): raise NotImplementedError("Not Yet implemented.") @@ -230,8 +238,9 @@ def size(self): return self._apply_agg_function(lambda df: df.size) def sum(self, **kwargs): - return self._apply_agg_function(lambda df: - df.sum(axis=self._axis, **kwargs)) + return self._apply_agg_function( + lambda df: df.sum(axis=self._axis, **kwargs) + ) def __unicode__(self): raise NotImplementedError("Not Yet implemented.") @@ -239,8 +248,19 @@ def __unicode__(self): def describe(self, **kwargs): raise NotImplementedError("Not Yet implemented.") - def boxplot(self, grouped, subplots=True, column=None, fontsize=None, - rot=0, grid=True, ax=None, figsize=None, layout=None, **kwds): + def boxplot( + self, + grouped, + subplots=True, + column=None, + fontsize=None, + rot=0, + grid=True, + ax=None, + figsize=None, + layout=None, + **kwds + ): raise NotImplementedError("Not Yet implemented.") def ngroup(self, ascending=True): @@ -259,9 +279,9 @@ def head(self, n=5): return self._apply_df_function(lambda df: df.head(n)) def cumprod(self, axis=0, *args, **kwargs): - return self._apply_df_function(lambda df: df.cumprod(axis, - *args, - **kwargs)) + return self._apply_df_function( + lambda df: df.cumprod(axis, *args, **kwargs) + ) def __iter__(self): return self._iter.__iter__() @@ -272,8 +292,10 @@ def agg_help(df): return pd.DataFrame(df).T else: return df - x = [v.agg(arg, axis=self._axis, *args, **kwargs) - for k, v in self._iter] + + x = [ + v.agg(arg, axis=self._axis, *args, **kwargs) for k, v in self._iter + ] new_parts = _map_partitions(lambda df: agg_help(df), x) @@ -288,8 +310,9 @@ def cov(self): def transform(self, func, *args, **kwargs): from .concat import concat - new_parts = concat([v.transform(func, *args, **kwargs) - for k, v in self._iter]) + new_parts = concat([ + v.transform(func, *args, **kwargs) for k, v in self._iter + ]) return new_parts def corr(self, **kwargs): @@ -302,9 +325,9 @@ def count(self, **kwargs): return self._apply_agg_function(lambda df: df.count(**kwargs)) def pipe(self, func, *args, **kwargs): - return self._apply_df_function(lambda df: df.pipe(func, - *args, - **kwargs)) + return self._apply_df_function( + lambda df: df.pipe(func, *args, **kwargs) + ) def cumcount(self, ascending=True): raise NotImplementedError("Not Yet implemented.") @@ -365,10 +388,14 @@ def groupby(by, axis, level, as_index, sort, group_keys, squeeze, *df): df = pd.concat(df, axis=axis) - return [v for k, v in df.groupby(by=by, - axis=axis, - level=level, - as_index=as_index, - sort=sort, - group_keys=group_keys, - squeeze=squeeze)] + return [ + v for k, v in df.groupby( + by=by, + axis=axis, + level=level, + as_index=as_index, + sort=sort, + group_keys=group_keys, + squeeze=squeeze + ) + ] diff --git a/python/ray/dataframe/index_metadata.py b/python/ray/dataframe/index_metadata.py index 235809ec7a35..34cf196bec48 100644 --- a/python/ray/dataframe/index_metadata.py +++ b/python/ray/dataframe/index_metadata.py @@ -2,9 +2,7 @@ import numpy as np import ray -from .utils import ( - _build_index, - _build_columns) +from .utils import (_build_index, _build_columns) from pandas.core.indexing import convert_to_index_sliceable @@ -57,8 +55,17 @@ def coords_of(self, key): def __getitem__(self, key): return self.coords_of(key) - def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, - group_keys=True, squeeze=False, **kwargs): + def groupby( + self, + by=None, + axis=0, + level=None, + as_index=True, + sort=True, + group_keys=True, + squeeze=False, + **kwargs + ): raise NotImplementedError() def __len__(self): @@ -70,8 +77,9 @@ def first_valid_index(self): def last_valid_index(self): return self._coord_df.last_valid_index() - def insert(self, key, loc=None, partition=None, - index_within_partition=None): + def insert( + self, key, loc=None, partition=None, index_within_partition=None + ): raise NotImplementedError() def drop(self, labels, errors='raise'): @@ -114,8 +122,9 @@ class _IndexMetadata(_IndexMetadataBase): partitions. """ - def __init__(self, dfs=None, index=None, axis=0, lengths_oid=None, - coord_df_oid=None): + def __init__( + self, dfs=None, index=None, axis=0, lengths_oid=None, coord_df_oid=None + ): """Inits a IndexMetadata from Ray DataFrame partitions Args: @@ -165,8 +174,17 @@ def coords_of(self, key): """ return self._coord_df.loc[key] - def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, - group_keys=True, squeeze=False, **kwargs): + def groupby( + self, + by=None, + axis=0, + level=None, + as_index=True, + sort=True, + group_keys=True, + squeeze=False, + **kwargs + ): # TODO: Find out what this does, and write a docstring assignments_df = self._coord_df.groupby(by=by, axis=axis, level=level, as_index=as_index, sort=sort, @@ -193,19 +211,20 @@ def reset_partition_coords(self, partitions=None): # partition, we have to make sure that our reference to it is # updated as well. try: - self._coord_df.loc[partition_mask, - 'index_within_partition'] = [ - p for p in range(sum(partition_mask))] + self._coord_df.loc[partition_mask, 'index_within_partition'] = [ + p for p in range(sum(partition_mask)) + ] except ValueError: # Copy the arrow sealed dataframe so we can mutate it. # We only do this the first time we try to mutate the sealed. self._coord_df = self._coord_df.copy() - self._coord_df.loc[partition_mask, - 'index_within_partition'] = [ - p for p in range(sum(partition_mask))] + self._coord_df.loc[partition_mask, 'index_within_partition'] = [ + p for p in range(sum(partition_mask)) + ] - def insert(self, key, loc=None, partition=None, - index_within_partition=None): + def insert( + self, key, loc=None, partition=None, index_within_partition=None + ): """Inserts a key at a certain location in the index, or a certain coord in a partition. Called with either `loc` or `partition` and `index_within_partition`. If called with both, `loc` will be used. @@ -251,10 +270,13 @@ def insert(self, key, loc=None, partition=None, # TODO: Determine if there's a better way to do a row-index insert in # pandas, because this is very annoying/unsure of efficiency # Create new coord entry to insert - coord_to_insert = pd.DataFrame( - {'partition': partition, - 'index_within_partition': index_within_partition}, - index=[key]) + coord_to_insert = pd.DataFrame({ + 'partition': + partition, + 'index_within_partition': + index_within_partition + }, + index=[key]) # Insert into cached RangeIndex, and order by new column index self._coord_df = _coord_df_copy.append(coord_to_insert).loc[new_index] @@ -272,8 +294,9 @@ def squeeze(self, partition, index_within_partition): 'index_within_partition'] -= 1 def copy(self): - return _IndexMetadata(coord_df_oid=self._coord_df, - lengths_oid=self._lengths) + return _IndexMetadata( + coord_df_oid=self._coord_df, lengths_oid=self._lengths + ) class _WrappingIndexMetadata(_IndexMetadata): @@ -317,12 +340,22 @@ def coords_of(self, key): ret_obj['index_within_partition'] = loc_idxs return ret_obj - def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, - group_keys=True, squeeze=False, **kwargs): + def groupby( + self, + by=None, + axis=0, + level=None, + as_index=True, + sort=True, + group_keys=True, + squeeze=False, + **kwargs + ): raise NotImplementedError() - def insert(self, key, loc=None, partition=None, - index_within_partition=None): + def insert( + self, key, loc=None, partition=None, index_within_partition=None + ): """Inserts a key at a certain location in the index, or a certain coord in a partition. Called with either `loc` or `partition` and `index_within_partition`. If called with both, `loc` will be used. @@ -343,5 +376,8 @@ def insert(self, key, loc=None, partition=None, self._coord_df = pd.DataFrame(index=new_index) # Shouldn't really need this, but here to maintain API consistency - return pd.DataFrame({'partition': 0, 'index_within_partition': loc}, + return pd.DataFrame({ + 'partition': 0, + 'index_within_partition': loc + }, index=[key]) diff --git a/python/ray/dataframe/indexing.py b/python/ray/dataframe/indexing.py index cba4ff8728fc..ec2e64c60449 100644 --- a/python/ray/dataframe/indexing.py +++ b/python/ray/dataframe/indexing.py @@ -20,8 +20,10 @@ def __getitem__(self, key): def _get_lookup_dict(self, ray_partition_idx): if ray_partition_idx.ndim == 1: # Single row matched - position = (ray_partition_idx['partition'], - ray_partition_idx['index_within_partition']) + position = ( + ray_partition_idx['partition'], + ray_partition_idx['index_within_partition'] + ) rows_to_lookup = {position[0]: [position[1]]} if ray_partition_idx.ndim == 2: # Multiple rows matched # We copy ray_partition_idx because it allows us to @@ -29,7 +31,8 @@ def _get_lookup_dict(self, ray_partition_idx): # And have room to optimize. ray_partition_idx = ray_partition_idx.copy() rows_to_lookup = ray_partition_idx.groupby('partition').aggregate( - lambda x: list(x)).to_dict()['index_within_partition'] + lambda x: list(x) + ).to_dict()['index_within_partition'] return rows_to_lookup def locate_2d(self, row_label, col_label): @@ -54,10 +57,10 @@ def retrieve_func(df, idx_lst, col_idx): return df.iloc[idx_lst, col_idx] retrieved_rows_remote = [ - _deploy_func.remote(retrieve_func, - self.df._row_partitions[partition], - idx_to_lookup, col_lst) - for partition, idx_to_lookup in lookup_dict.items() + _deploy_func.remote( + retrieve_func, self.df._row_partitions[partition], + idx_to_lookup, col_lst + ) for partition, idx_to_lookup in lookup_dict.items() ] return retrieved_rows_remote @@ -69,7 +72,8 @@ def locate_2d(self, row_label, col_label): index_loc = self.df._row_index.loc[row_label] lookup_dict = self._get_lookup_dict(index_loc) retrieved_rows_remote = self._map_partition( - lookup_dict, col_label, indexer='loc') + lookup_dict, col_label, indexer='loc' + ) joined_df = pd.concat(ray.get(retrieved_rows_remote)) if index_loc.ndim == 2: @@ -90,7 +94,8 @@ def locate_2d(self, row_idx, col_idx): index_loc = self.df._row_index.iloc[row_idx] lookup_dict = self._get_lookup_dict(index_loc) retrieved_rows_remote = self._map_partition( - lookup_dict, col_idx, indexer='iloc') + lookup_dict, col_idx, indexer='iloc' + ) joined_df = pd.concat(ray.get(retrieved_rows_remote)) if index_loc.ndim == 2: diff --git a/python/ray/dataframe/io.py b/python/ray/dataframe/io.py index c1abc0ec474c..295ec2c57492 100644 --- a/python/ray/dataframe/io.py +++ b/python/ray/dataframe/io.py @@ -43,8 +43,9 @@ def read_parquet(path, engine='auto', columns=None, **kwargs): _read_parquet_row_group.remote(path, columns, i, kwargs) for i in range(n_row_groups) ] - splited_dfs = ray.get( - [_split_df.remote(df, chunksize) for df in df_from_row_groups]) + splited_dfs = ray.get([ + _split_df.remote(df, chunksize) for df in df_from_row_groups + ]) df_remotes = list(chain.from_iterable(splited_dfs)) return DataFrame(row_partitions=df_remotes, columns=columns) @@ -129,60 +130,62 @@ def _read_csv_with_offset(fn, start, end, header=b'', kwargs={}): return pd.read_csv(BytesIO(to_read), **kwargs) -def read_csv(filepath, - sep=',', - delimiter=None, - header='infer', - names=None, - index_col=None, - usecols=None, - squeeze=False, - prefix=None, - mangle_dupe_cols=True, - dtype=None, - engine=None, - converters=None, - true_values=None, - false_values=None, - skipinitialspace=False, - skiprows=None, - nrows=None, - na_values=None, - keep_default_na=True, - na_filter=True, - verbose=False, - skip_blank_lines=True, - parse_dates=False, - infer_datetime_format=False, - keep_date_col=False, - date_parser=None, - dayfirst=False, - iterator=False, - chunksize=None, - compression='infer', - thousands=None, - decimal=b'.', - lineterminator=None, - quotechar='"', - quoting=0, - escapechar=None, - comment=None, - encoding=None, - dialect=None, - tupleize_cols=None, - error_bad_lines=True, - warn_bad_lines=True, - skipfooter=0, - skip_footer=0, - doublequote=True, - delim_whitespace=False, - as_recarray=None, - compact_ints=None, - use_unsigned=None, - low_memory=True, - buffer_lines=None, - memory_map=False, - float_precision=None): +def read_csv( + filepath, + sep=',', + delimiter=None, + header='infer', + names=None, + index_col=None, + usecols=None, + squeeze=False, + prefix=None, + mangle_dupe_cols=True, + dtype=None, + engine=None, + converters=None, + true_values=None, + false_values=None, + skipinitialspace=False, + skiprows=None, + nrows=None, + na_values=None, + keep_default_na=True, + na_filter=True, + verbose=False, + skip_blank_lines=True, + parse_dates=False, + infer_datetime_format=False, + keep_date_col=False, + date_parser=None, + dayfirst=False, + iterator=False, + chunksize=None, + compression='infer', + thousands=None, + decimal=b'.', + lineterminator=None, + quotechar='"', + quoting=0, + escapechar=None, + comment=None, + encoding=None, + dialect=None, + tupleize_cols=None, + error_bad_lines=True, + warn_bad_lines=True, + skipfooter=0, + skip_footer=0, + doublequote=True, + delim_whitespace=False, + as_recarray=None, + compact_ints=None, + use_unsigned=None, + low_memory=True, + buffer_lines=None, + memory_map=False, + float_precision=None +): """Read csv file from local disk. Args: @@ -245,7 +248,8 @@ def read_csv(filepath, low_memory=low_memory, buffer_lines=buffer_lines, memory_map=memory_map, - float_precision=float_precision) + float_precision=float_precision + ) offsets = _compute_offset(filepath, get_npartitions()) @@ -256,65 +260,75 @@ def read_csv(filepath, for start, end in offsets: if start != 0: df = _read_csv_with_offset.remote( - filepath, start, end, header=first_line, kwargs=kwargs) + filepath, start, end, header=first_line, kwargs=kwargs + ) else: df = _read_csv_with_offset.remote( - filepath, start, end, kwargs=kwargs) + filepath, start, end, kwargs=kwargs + ) df_obj_ids.append(df) return DataFrame(row_partitions=df_obj_ids, columns=columns) -def read_json(path_or_buf=None, - orient=None, - typ='frame', - dtype=True, - convert_axes=True, - convert_dates=True, - keep_default_dates=True, - numpy=False, - precise_float=False, - date_unit=None, - encoding=None, - lines=False, - chunksize=None, - compression='infer'): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) - - port_frame = pd.read_json(path_or_buf, orient, typ, dtype, - convert_axes, convert_dates, keep_default_dates, - numpy, precise_float, date_unit, encoding, - lines, chunksize, compression) +def read_json( + path_or_buf=None, + orient=None, + typ='frame', + dtype=True, + convert_axes=True, + convert_dates=True, + keep_default_dates=True, + numpy=False, + precise_float=False, + date_unit=None, + encoding=None, + lines=False, + chunksize=None, + compression='infer' +): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) + + port_frame = pd.read_json( + path_or_buf, orient, typ, dtype, convert_axes, convert_dates, + keep_default_dates, numpy, precise_float, date_unit, encoding, lines, + chunksize, compression + ) ray_frame = from_pandas(port_frame, get_npartitions()) return ray_frame -def read_html(io, - match='.+', - flavor=None, - header=None, - index_col=None, - skiprows=None, - attrs=None, - parse_dates=False, - tupleize_cols=None, - thousands=',', - encoding=None, - decimal='.', - converters=None, - na_values=None, - keep_default_na=True): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) - - port_frame = pd.read_html(io, match, flavor, header, index_col, - skiprows, attrs, parse_dates, tupleize_cols, - thousands, encoding, decimal, converters, - na_values, keep_default_na) +def read_html( + io, + match='.+', + flavor=None, + header=None, + index_col=None, + skiprows=None, + attrs=None, + parse_dates=False, + tupleize_cols=None, + thousands=',', + encoding=None, + decimal='.', + converters=None, + na_values=None, + keep_default_na=True +): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) + + port_frame = pd.read_html( + io, match, flavor, header, index_col, skiprows, attrs, parse_dates, + tupleize_cols, thousands, encoding, decimal, converters, na_values, + keep_default_na + ) ray_frame = from_pandas(port_frame[0], get_npartitions()) return ray_frame @@ -322,8 +336,9 @@ def read_html(io, def read_clipboard(sep=r'\s+'): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = pd.read_clipboard(sep) ray_frame = from_pandas(port_frame, get_npartitions()) @@ -331,45 +346,47 @@ def read_clipboard(sep=r'\s+'): return ray_frame -def read_excel(io, - sheet_name=0, - header=0, - skiprows=None, - skip_footer=0, - index_col=None, - names=None, - usecols=None, - parse_dates=False, - date_parser=None, - na_values=None, - thousands=None, - convert_float=True, - converters=None, - dtype=None, - true_values=None, - false_values=None, - engine=None, - squeeze=False): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) - - port_frame = pd.read_excel(io, sheet_name, header, skiprows, skip_footer, - index_col, names, usecols, parse_dates, - date_parser, na_values, thousands, - convert_float, converters, dtype, true_values, - false_values, engine, squeeze) +def read_excel( + io, + sheet_name=0, + header=0, + skiprows=None, + skip_footer=0, + index_col=None, + names=None, + usecols=None, + parse_dates=False, + date_parser=None, + na_values=None, + thousands=None, + convert_float=True, + converters=None, + dtype=None, + true_values=None, + false_values=None, + engine=None, + squeeze=False +): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) + + port_frame = pd.read_excel( + io, sheet_name, header, skiprows, skip_footer, index_col, names, + usecols, parse_dates, date_parser, na_values, thousands, convert_float, + converters, dtype, true_values, false_values, engine, squeeze + ) ray_frame = from_pandas(port_frame, get_npartitions()) return ray_frame -def read_hdf(path_or_buf, - key=None, - mode='r'): +def read_hdf(path_or_buf, key=None, mode='r'): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = pd.read_hdf(path_or_buf, key, mode) ray_frame = from_pandas(port_frame, get_npartitions()) @@ -377,11 +394,11 @@ def read_hdf(path_or_buf, return ray_frame -def read_feather(path, - nthreads=1): +def read_feather(path, nthreads=1): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = pd.read_feather(path) ray_frame = from_pandas(port_frame, get_npartitions()) @@ -389,12 +406,11 @@ def read_feather(path, return ray_frame -def read_msgpack(path_or_buf, - encoding='utf-8', - iterator=False): +def read_msgpack(path_or_buf, encoding='utf-8', iterator=False): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = pd.read_msgpack(path_or_buf, encoding, iterator) ray_frame = from_pandas(port_frame, get_npartitions()) @@ -402,52 +418,60 @@ def read_msgpack(path_or_buf, return ray_frame -def read_stata(filepath_or_buffer, - convert_dates=True, - convert_categoricals=True, - encoding=None, - index_col=None, - convert_missing=False, - preserve_dtypes=True, - columns=None, - order_categoricals=True, - chunksize=None, - iterator=False): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) - - port_frame = pd.read_stata(filepath_or_buffer, convert_dates, - convert_categoricals, encoding, index_col, - convert_missing, preserve_dtypes, columns, - order_categoricals, chunksize, iterator) +def read_stata( + filepath_or_buffer, + convert_dates=True, + convert_categoricals=True, + encoding=None, + index_col=None, + convert_missing=False, + preserve_dtypes=True, + columns=None, + order_categoricals=True, + chunksize=None, + iterator=False +): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) + + port_frame = pd.read_stata( + filepath_or_buffer, convert_dates, convert_categoricals, encoding, + index_col, convert_missing, preserve_dtypes, columns, + order_categoricals, chunksize, iterator + ) ray_frame = from_pandas(port_frame, get_npartitions()) return ray_frame -def read_sas(filepath_or_buffer, - format=None, - index=None, - encoding=None, - chunksize=None, - iterator=False): +def read_sas( + filepath_or_buffer, + format=None, + index=None, + encoding=None, + chunksize=None, + iterator=False +): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) - port_frame = pd.read_sas(filepath_or_buffer, format, index, encoding, - chunksize, iterator) + port_frame = pd.read_sas( + filepath_or_buffer, format, index, encoding, chunksize, iterator + ) ray_frame = from_pandas(port_frame, get_npartitions()) return ray_frame -def read_pickle(path, - compression='infer'): +def read_pickle(path, compression='infer'): - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) port_frame = pd.read_pickle(path, compression) ray_frame = from_pandas(port_frame, get_npartitions()) @@ -455,20 +479,25 @@ def read_pickle(path, return ray_frame -def read_sql(sql, - con, - index_col=None, - coerce_float=True, - params=None, - parse_dates=None, - columns=None, - chunksize=None): - - warnings.warn("Defaulting to Pandas implementation", - PendingDeprecationWarning) - - port_frame = pd.read_sql(sql, con, index_col, coerce_float, params, - parse_dates, columns, chunksize) +def read_sql( + sql, + con, + index_col=None, + coerce_float=True, + params=None, + parse_dates=None, + columns=None, + chunksize=None +): + + warnings.warn( + "Defaulting to Pandas implementation", PendingDeprecationWarning + ) + + port_frame = pd.read_sql( + sql, con, index_col, coerce_float, params, parse_dates, columns, + chunksize + ) ray_frame = from_pandas(port_frame, get_npartitions()) return ray_frame diff --git a/python/ray/dataframe/pandas_code_gen.py b/python/ray/dataframe/pandas_code_gen.py index b5969281cfdb..f973dcbe223b 100644 --- a/python/ray/dataframe/pandas_code_gen.py +++ b/python/ray/dataframe/pandas_code_gen.py @@ -24,16 +24,18 @@ def code_gen(pd_obj, ray_obj, path): # let's not mess with these continue try: - outfile.write("\ndef " + func + - str(inspect.signature(getattr(pd_obj, func))) + - ":\n") + outfile.write( + "\ndef " + func + + str(inspect.signature(getattr(pd_obj, func))) + ":\n" + ) except TypeError: outfile.write("\n@property") outfile.write("\ndef " + func + "(self):\n") except ValueError: continue outfile.write( - " raise NotImplementedError(\"Not Yet implemented.\")\n") + " raise NotImplementedError(\"Not Yet implemented.\")\n" + ) def code_gen_test(ray_obj, path, name): @@ -50,7 +52,8 @@ def code_gen_test(ray_obj, path, name): outfile.write( " ray_" + name + " = create_test_" + name + "()\n\n" + " with pytest.raises(NotImplementedError):\n" + - " ray_" + name + "." + func) + " ray_" + name + "." + func + ) try: first = True param_num = \ @@ -86,8 +89,7 @@ def pandas_ray_diff(pd_obj, ray_obj): pd_funcs = dir(pd_obj) ray_funcs = dir(ray_obj) - pd_funcs = set(filter(lambda f: f[0] != "_" or f[1] == "_", - pd_funcs)) + pd_funcs = set(filter(lambda f: f[0] != "_" or f[1] == "_", pd_funcs)) diff = [x for x in pd_funcs if x not in set(ray_funcs)] return diff diff --git a/python/ray/dataframe/series.py b/python/ray/dataframe/series.py index dbbac7993c8c..ba5a7ad86e4e 100644 --- a/python/ray/dataframe/series.py +++ b/python/ray/dataframe/series.py @@ -16,7 +16,6 @@ def na_op(): @_inherit_docstrings(pd.Series) class Series(object): - def __init__(self, series_oids): """Constructor for a Series object. @@ -57,8 +56,15 @@ def __bool__(self): def __bytes__(self): raise NotImplementedError("Not Yet implemented.") - def __class__(self, data=None, index=None, dtype=None, name=None, - copy=False, fastpath=False): + def __class__( + self, + data=None, + index=None, + dtype=None, + name=None, + copy=False, + fastpath=False + ): raise NotImplementedError("Not Yet implemented.") def __contains__(self, key): @@ -212,17 +218,25 @@ def agg(self, func, axis=0, *args, **kwargs): def aggregate(self, func, axis=0, *args, **kwargs): raise NotImplementedError("Not Yet implemented.") - def align(self, other, join='outer', axis=None, level=None, copy=True, - fill_value=None, method=None, limit=None, fill_axis=0, - broadcast_axis=None): + def align( + self, + other, + join='outer', + axis=None, + level=None, + copy=True, + fill_value=None, + method=None, + limit=None, + fill_axis=0, + broadcast_axis=None + ): raise NotImplementedError("Not Yet implemented.") - def all(self, axis=None, bool_only=None, skipna=None, level=None, - **kwargs): + def all(self, axis=None, bool_only=None, skipna=None, level=None, **kwargs): raise NotImplementedError("Not Yet implemented.") - def any(self, axis=None, bool_only=None, skipna=None, level=None, - **kwargs): + def any(self, axis=None, bool_only=None, skipna=None, level=None, **kwargs): raise NotImplementedError("Not Yet implemented.") def append(self, to_append, ignore_index=False, verify_integrity=False): @@ -246,8 +260,9 @@ def as_blocks(self, copy=True): def as_matrix(self, columns=None): raise NotImplementedError("Not Yet implemented.") - def asfreq(self, freq, method=None, how=None, normalize=False, - fill_value=None): + def asfreq( + self, freq, method=None, how=None, normalize=False, fill_value=None + ): raise NotImplementedError("Not Yet implemented.") def asof(self, where, subset=None): @@ -268,8 +283,9 @@ def autocorr(self, lag=1): def between(self, left, right, inclusive=True): raise NotImplementedError("Not Yet implemented.") - def between_time(self, start_time, end_time, include_start=True, - include_end=True): + def between_time( + self, start_time, end_time, include_start=True, include_end=True + ): raise NotImplementedError("Not Yet implemented.") def bfill(self, axis=None, inplace=False, limit=None, downcast=None): @@ -302,8 +318,13 @@ def compress(self, condition, *args, **kwargs): def consolidate(self, inplace=False): raise NotImplementedError("Not Yet implemented.") - def convert_objects(self, convert_dates=True, convert_numeric=False, - convert_timedeltas=True, copy=True): + def convert_objects( + self, + convert_dates=True, + convert_numeric=False, + convert_timedeltas=True, + copy=True + ): raise NotImplementedError("Not Yet implemented.") def copy(self, deep=True): @@ -363,8 +384,18 @@ def eq(self, other, level=None, fill_value=None, axis=0): def equals(self, other): raise NotImplementedError("Not Yet implemented.") - def ewm(self, com=None, span=None, halflife=None, alpha=None, - min_periods=0, freq=None, adjust=True, ignore_na=False, axis=0): + def ewm( + self, + com=None, + span=None, + halflife=None, + alpha=None, + min_periods=0, + freq=None, + adjust=True, + ignore_na=False, + axis=0 + ): raise NotImplementedError("Not Yet implemented.") def expanding(self, min_periods=1, freq=None, center=False, axis=0): @@ -376,8 +407,16 @@ def factorize(self, sort=False, na_sentinel=-1): def ffill(self, axis=None, inplace=False, limit=None, downcast=None): raise NotImplementedError("Not Yet implemented.") - def fillna(self, value=None, method=None, axis=None, inplace=False, - limit=None, downcast=None, **kwargs): + def fillna( + self, + value=None, + method=None, + axis=None, + inplace=False, + limit=None, + downcast=None, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def filter(self, items=None, like=None, regex=None, axis=None): @@ -392,12 +431,27 @@ def first_valid_index(self): def floordiv(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def from_array(self, arr, index=None, name=None, dtype=None, copy=False, - fastpath=False): + def from_array( + self, + arr, + index=None, + name=None, + dtype=None, + copy=False, + fastpath=False + ): raise NotImplementedError("Not Yet implemented.") - def from_csv(self, path, sep=',', parse_dates=True, header=None, - index_col=0, encoding=None, infer_datetime_format=False): + def from_csv( + self, + path, + sep=',', + parse_dates=True, + header=None, + index_col=0, + encoding=None, + infer_datetime_format=False + ): raise NotImplementedError("Not Yet implemented.") def ge(self, other, level=None, fill_value=None, axis=0): @@ -418,8 +472,17 @@ def get_value(self, label, takeable=False): def get_values(self): raise NotImplementedError("Not Yet implemented.") - def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True, - group_keys=True, squeeze=False, **kwargs): + def groupby( + self, + by=None, + axis=0, + level=None, + as_index=True, + sort=True, + group_keys=True, + squeeze=False, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def gt(self, other, level=None, fill_value=None, axis=0): @@ -428,8 +491,19 @@ def gt(self, other, level=None, fill_value=None, axis=0): def head(self, n=5): raise NotImplementedError("Not Yet implemented.") - def hist(self, by=None, ax=None, grid=True, xlabelsize=None, xrot=None, - ylabelsize=None, yrot=None, figsize=None, bins=10, **kwds): + def hist( + self, + by=None, + ax=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + figsize=None, + bins=10, + **kwds + ): raise NotImplementedError("Not Yet implemented.") def iat(self, axis=None): @@ -444,8 +518,16 @@ def idxmin(self, axis=None, skipna=True, *args, **kwargs): def iloc(self, axis=None): raise NotImplementedError("Not Yet implemented.") - def interpolate(self, method='linear', axis=0, limit=None, inplace=False, - limit_direction='forward', downcast=None, **kwargs): + def interpolate( + self, + method='linear', + axis=0, + limit=None, + inplace=False, + limit_direction='forward', + downcast=None, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def isin(self, values): @@ -469,12 +551,14 @@ def ix(self, axis=None): def keys(self): raise NotImplementedError("Not Yet implemented.") - def kurt(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def kurt( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") - def kurtosis(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def kurtosis( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def last(self, offset): @@ -498,27 +582,39 @@ def mad(self, axis=None, skipna=None, level=None): def map(self, arg, na_action=None): raise NotImplementedError("Not Yet implemented.") - def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None, - try_cast=False, raise_on_error=True): + def mask( + self, + cond, + other=np.nan, + inplace=False, + axis=None, + level=None, + try_cast=False, + raise_on_error=True + ): raise NotImplementedError("Not Yet implemented.") - def max(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def max( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") - def mean(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def mean( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") - def median(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def median( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def memory_usage(self, index=True, deep=False): raise NotImplementedError("Not Yet implemented.") - def min(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def min( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def mod(self, other, level=None, fill_value=None, axis=0): @@ -551,18 +647,41 @@ def nsmallest(self, n=5, keep='first'): def nunique(self, dropna=True): raise NotImplementedError("Not Yet implemented.") - def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None, - **kwargs): + def pct_change( + self, periods=1, fill_method='pad', limit=None, freq=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def pipe(self, func, *args, **kwargs): raise NotImplementedError("Not Yet implemented.") - def plot(self, kind='line', ax=None, figsize=None, use_index=True, - title=None, grid=None, legend=False, style=None, logx=False, - logy=False, loglog=False, xticks=None, yticks=None, xlim=None, - ylim=None, rot=None, fontsize=None, colormap=None, table=False, - yerr=None, xerr=None, label=None, secondary_y=False, **kwds): + def plot( + self, + kind='line', + ax=None, + figsize=None, + use_index=True, + title=None, + grid=None, + legend=False, + style=None, + logx=False, + logy=False, + loglog=False, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + rot=None, + fontsize=None, + colormap=None, + table=False, + yerr=None, + xerr=None, + label=None, + secondary_y=False, + **kwds + ): raise NotImplementedError("Not Yet implemented.") def pop(self, item): @@ -571,16 +690,19 @@ def pop(self, item): def pow(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def prod(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def prod( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") - def product(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def product( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") - def ptp(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def ptp( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def put(self, *args, **kwargs): @@ -592,8 +714,15 @@ def quantile(self, q=0.5, interpolation='linear'): def radd(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def rank(self, axis=0, method='average', numeric_only=None, - na_option='keep', ascending=True, pct=False): + def rank( + self, + axis=0, + method='average', + numeric_only=None, + na_option='keep', + ascending=True, + pct=False + ): raise NotImplementedError("Not Yet implemented.") def ravel(self, order='C'): @@ -608,8 +737,9 @@ def reindex(self, index=None, **kwargs): def reindex_axis(self, labels, axis=0, **kwargs): raise NotImplementedError("Not Yet implemented.") - def reindex_like(self, other, method=None, copy=True, limit=None, - tolerance=None): + def reindex_like( + self, other, method=None, copy=True, limit=None, tolerance=None + ): raise NotImplementedError("Not Yet implemented.") def rename(self, index=None, **kwargs): @@ -624,13 +754,34 @@ def reorder_levels(self, order): def repeat(self, repeats, *args, **kwargs): raise NotImplementedError("Not Yet implemented.") - def replace(self, to_replace=None, value=None, inplace=False, limit=None, - regex=False, method='pad', axis=None): - raise NotImplementedError("Not Yet implemented.") - - def resample(self, rule, how=None, axis=0, fill_method=None, closed=None, - label=None, convention='start', kind=None, loffset=None, - limit=None, base=0, on=None, level=None): + def replace( + self, + to_replace=None, + value=None, + inplace=False, + limit=None, + regex=False, + method='pad', + axis=None + ): + raise NotImplementedError("Not Yet implemented.") + + def resample( + self, + rule, + how=None, + axis=0, + fill_method=None, + closed=None, + label=None, + convention='start', + kind=None, + loffset=None, + limit=None, + base=0, + on=None, + level=None + ): raise NotImplementedError("Not Yet implemented.") def reset_index(self, level=None, drop=False, name=None, inplace=False): @@ -648,8 +799,17 @@ def rmod(self, other, level=None, fill_value=None, axis=0): def rmul(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def rolling(self, window, min_periods=None, freq=None, center=False, - win_type=None, on=None, axis=0, closed=None): + def rolling( + self, + window, + min_periods=None, + freq=None, + center=False, + win_type=None, + on=None, + axis=0, + closed=None + ): raise NotImplementedError("Not Yet implemented.") def round(self, decimals=0, *args, **kwargs): @@ -664,8 +824,15 @@ def rsub(self, other, level=None, fill_value=None, axis=0): def rtruediv(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def sample(self, n=None, frac=None, replace=False, weights=None, - random_state=None, axis=None): + def sample( + self, + n=None, + frac=None, + replace=False, + weights=None, + random_state=None, + axis=None + ): raise NotImplementedError("Not Yet implemented.") def searchsorted(self, value, side='left', sorter=None): @@ -674,8 +841,15 @@ def searchsorted(self, value, side='left', sorter=None): def select(self, crit, axis=0): raise NotImplementedError("Not Yet implemented.") - def sem(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + def sem( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def set_axis(self, axis, labels): @@ -687,19 +861,34 @@ def set_value(self, label, value, takeable=False): def shift(self, periods=1, freq=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def skew(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def skew( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def slice_shift(self, periods=1, axis=0): raise NotImplementedError("Not Yet implemented.") - def sort_index(self, axis=0, level=None, ascending=True, inplace=False, - kind='quicksort', na_position='last', sort_remaining=True): + def sort_index( + self, + axis=0, + level=None, + ascending=True, + inplace=False, + kind='quicksort', + na_position='last', + sort_remaining=True + ): raise NotImplementedError("Not Yet implemented.") - def sort_values(self, axis=0, ascending=True, inplace=False, - kind='quicksort', na_position='last'): + def sort_values( + self, + axis=0, + ascending=True, + inplace=False, + kind='quicksort', + na_position='last' + ): raise NotImplementedError("Not Yet implemented.") def sortlevel(self, level=0, ascending=True, sort_remaining=True): @@ -708,8 +897,15 @@ def sortlevel(self, level=0, ascending=True, sort_remaining=True): def squeeze(self, axis=None): raise NotImplementedError("Not Yet implemented.") - def std(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + def std( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def sub(self, other, level=None, fill_value=None, axis=0): @@ -718,8 +914,9 @@ def sub(self, other, level=None, fill_value=None, axis=0): def subtract(self, other, level=None, fill_value=None, axis=0): raise NotImplementedError("Not Yet implemented.") - def sum(self, axis=None, skipna=None, level=None, numeric_only=None, - **kwargs): + def sum( + self, axis=None, skipna=None, level=None, numeric_only=None, **kwargs + ): raise NotImplementedError("Not Yet implemented.") def swapaxes(self, axis1, axis2, copy=True): @@ -737,9 +934,20 @@ def take(self, indices, axis=0, convert=True, is_copy=False, **kwargs): def to_clipboard(self, excel=None, sep=None, **kwargs): raise NotImplementedError("Not Yet implemented.") - def to_csv(self, path=None, index=True, sep=',', na_rep='', - float_format=None, header=False, index_label=None, mode='w', - encoding=None, date_format=None, decimal='.'): + def to_csv( + self, + path=None, + index=True, + sep=',', + na_rep='', + float_format=None, + header=False, + index_label=None, + mode='w', + encoding=None, + date_format=None, + decimal='.' + ): raise NotImplementedError("Not Yet implemented.") def to_dense(self): @@ -748,11 +956,24 @@ def to_dense(self): def to_dict(self): raise NotImplementedError("Not Yet implemented.") - def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='', - float_format=None, columns=None, header=True, index=True, - index_label=None, startrow=0, startcol=0, engine=None, - merge_cells=True, encoding=None, inf_rep='inf', - verbose=True): + def to_excel( + self, + excel_writer, + sheet_name='Sheet1', + na_rep='', + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + startrow=0, + startcol=0, + engine=None, + merge_cells=True, + encoding=None, + inf_rep='inf', + verbose=True + ): raise NotImplementedError("Not Yet implemented.") def to_frame(self, name=None): @@ -761,17 +982,41 @@ def to_frame(self, name=None): def to_hdf(self, path_or_buf, key, **kwargs): raise NotImplementedError("Not Yet implemented.") - def to_json(self, path_or_buf=None, orient=None, date_format=None, - double_precision=10, force_ascii=True, date_unit='ms', - default_handler=None, lines=False): - raise NotImplementedError("Not Yet implemented.") - - def to_latex(self, buf=None, columns=None, col_space=None, header=True, - index=True, na_rep='NaN', formatters=None, float_format=None, - sparsify=None, index_names=True, bold_rows=False, - column_format=None, longtable=None, escape=None, - encoding=None, decimal='.', multicolumn=None, - multicolumn_format=None, multirow=None): + def to_json( + self, + path_or_buf=None, + orient=None, + date_format=None, + double_precision=10, + force_ascii=True, + date_unit='ms', + default_handler=None, + lines=False + ): + raise NotImplementedError("Not Yet implemented.") + + def to_latex( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep='NaN', + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + bold_rows=False, + column_format=None, + longtable=None, + escape=None, + encoding=None, + decimal='.', + multicolumn=None, + multicolumn_format=None, + multirow=None + ): raise NotImplementedError("Not Yet implemented.") def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): @@ -786,13 +1031,32 @@ def to_pickle(self, path, compression='infer'): def to_sparse(self, kind='block', fill_value=None): raise NotImplementedError("Not Yet implemented.") - def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): - raise NotImplementedError("Not Yet implemented.") - - def to_string(self, buf=None, na_rep='NaN', float_format=None, - header=True, index=True, length=False, dtype=False, - name=False, max_rows=None): + def to_sql( + self, + name, + con, + flavor=None, + schema=None, + if_exists='fail', + index=True, + index_label=None, + chunksize=None, + dtype=None + ): + raise NotImplementedError("Not Yet implemented.") + + def to_string( + self, + buf=None, + na_rep='NaN', + float_format=None, + header=True, + index=True, + length=False, + dtype=False, + name=False, + max_rows=None + ): raise NotImplementedError("Not Yet implemented.") def to_timestamp(self, freq=None, how='start', copy=True): @@ -822,8 +1086,7 @@ def tshift(self, periods=1, freq=None, axis=0): def tz_convert(self, tz, axis=0, level=None, copy=True): raise NotImplementedError("Not Yet implemented.") - def tz_localize(self, tz, axis=0, level=None, copy=True, - ambiguous='raise'): + def tz_localize(self, tz, axis=0, level=None, copy=True, ambiguous='raise'): raise NotImplementedError("Not Yet implemented.") def unique(self): @@ -838,19 +1101,40 @@ def update(self, other): def valid(self, inplace=False, **kwargs): raise NotImplementedError("Not Yet implemented.") - def value_counts(self, normalize=False, sort=True, ascending=False, - bins=None, dropna=True): + def value_counts( + self, + normalize=False, + sort=True, + ascending=False, + bins=None, + dropna=True + ): raise NotImplementedError("Not Yet implemented.") - def var(self, axis=None, skipna=None, level=None, ddof=1, - numeric_only=None, **kwargs): + def var( + self, + axis=None, + skipna=None, + level=None, + ddof=1, + numeric_only=None, + **kwargs + ): raise NotImplementedError("Not Yet implemented.") def view(self, dtype=None): raise NotImplementedError("Not Yet implemented.") - def where(self, cond, other=np.nan, inplace=False, axis=None, level=None, - try_cast=False, raise_on_error=True): + def where( + self, + cond, + other=np.nan, + inplace=False, + axis=None, + level=None, + try_cast=False, + raise_on_error=True + ): raise NotImplementedError("Not Yet implemented.") def xs(key, axis=0, level=None, drop_level=True): diff --git a/python/ray/dataframe/test/test_concat.py b/python/ray/dataframe/test/test_concat.py index 62e881d05b73..18dc0052d2a3 100644 --- a/python/ray/dataframe/test/test_concat.py +++ b/python/ray/dataframe/test/test_concat.py @@ -5,10 +5,7 @@ import pytest import pandas import ray.dataframe as pd -from ray.dataframe.utils import ( - to_pandas, - from_pandas -) +from ray.dataframe.utils import (to_pandas, from_pandas) @pytest.fixture @@ -18,33 +15,41 @@ def ray_df_equals_pandas(ray_df, pandas_df): @pytest.fixture def generate_dfs(): - df = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) - - df2 = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col6': [12, 13, 14, 15], - 'col7': [0, 0, 0, 0]}) + df = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) + + df2 = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col6': [12, 13, 14, 15], + 'col7': [0, 0, 0, 0] + }) return df, df2 @pytest.fixture def generate_none_dfs(): - df = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, None, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [None, None, None, None]}) - - df2 = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col6': [12, 13, 14, 15], - 'col7': [0, 0, 0, 0]}) + df = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, None, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [None, None, None, None] + }) + + df2 = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col6': [12, 13, 14, 15], + 'col7': [0, 0, 0, 0] + }) return df, df2 @@ -52,41 +57,51 @@ def generate_none_dfs(): def test_df_concat(): df, df2 = generate_dfs() - assert(ray_df_equals_pandas(pd.concat([df, df2]), - pandas.concat([df, df2]))) + assert ( + ray_df_equals_pandas(pd.concat([df, df2]), pandas.concat([df, df2])) + ) def test_ray_concat(): df, df2 = generate_dfs() ray_df, ray_df2 = from_pandas(df, 2), from_pandas(df2, 2) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2]), - pandas.concat([df, df2])) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2]), pandas.concat([df, df2]) + ) def test_ray_concat_on_index(): df, df2 = generate_dfs() ray_df, ray_df2 = from_pandas(df, 2), from_pandas(df2, 2) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2], axis='index'), - pandas.concat([df, df2], axis='index')) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2], axis='index'), + pandas.concat([df, df2], axis='index') + ) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2], axis='rows'), - pandas.concat([df, df2], axis='rows')) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2], axis='rows'), + pandas.concat([df, df2], axis='rows') + ) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2], axis=0), - pandas.concat([df, df2], axis=0)) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2], axis=0), pandas.concat([df, df2], axis=0) + ) def test_ray_concat_on_column(): df, df2 = generate_dfs() ray_df, ray_df2 = from_pandas(df, 2), from_pandas(df2, 2) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2], axis=1), - pandas.concat([df, df2], axis=1)) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2], axis=1), pandas.concat([df, df2], axis=1) + ) - assert ray_df_equals_pandas(pd.concat([ray_df, ray_df2], axis="columns"), - pandas.concat([df, df2], axis="columns")) + assert ray_df_equals_pandas( + pd.concat([ray_df, ray_df2], axis="columns"), + pandas.concat([df, df2], axis="columns") + ) def test_invalid_axis_errors(): @@ -103,8 +118,11 @@ def test_mixed_concat(): mixed_dfs = [from_pandas(df, 2), from_pandas(df2, 2), df3] - assert(ray_df_equals_pandas(pd.concat(mixed_dfs), - pandas.concat([df, df2, df3]))) + assert ( + ray_df_equals_pandas( + pd.concat(mixed_dfs), pandas.concat([df, df2, df3]) + ) + ) def test_mixed_inner_concat(): @@ -113,8 +131,12 @@ def test_mixed_inner_concat(): mixed_dfs = [from_pandas(df, 2), from_pandas(df2, 2), df3] - assert(ray_df_equals_pandas(pd.concat(mixed_dfs, join='inner'), - pandas.concat([df, df2, df3], join='inner'))) + assert ( + ray_df_equals_pandas( + pd.concat(mixed_dfs, join='inner'), + pandas.concat([df, df2, df3], join='inner') + ) + ) def test_mixed_none_concat(): @@ -123,5 +145,8 @@ def test_mixed_none_concat(): mixed_dfs = [from_pandas(df, 2), from_pandas(df2, 2), df3] - assert(ray_df_equals_pandas(pd.concat(mixed_dfs), - pandas.concat([df, df2, df3]))) + assert ( + ray_df_equals_pandas( + pd.concat(mixed_dfs), pandas.concat([df, df2, df3]) + ) + ) diff --git a/python/ray/dataframe/test/test_dataframe.py b/python/ray/dataframe/test/test_dataframe.py index 60d2862d9cf9..05a4edbaef48 100644 --- a/python/ray/dataframe/test/test_dataframe.py +++ b/python/ray/dataframe/test/test_dataframe.py @@ -7,9 +7,7 @@ import pandas as pd import pandas.util.testing as tm import ray.dataframe as rdf -from ray.dataframe.utils import ( - from_pandas, - to_pandas) +from ray.dataframe.utils import (from_pandas, to_pandas) from pandas.tests.frame.common import TestData @@ -33,38 +31,38 @@ def ray_df_equals(ray_df1, ray_df2): @pytest.fixture def test_roundtrip(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df, pandas_df)) + assert (ray_df_equals_pandas(ray_df, pandas_df)) @pytest.fixture def test_index(ray_df, pandas_df): - assert(ray_df.index.equals(pandas_df.index)) + assert (ray_df.index.equals(pandas_df.index)) ray_df_cp = ray_df.copy() pandas_df_cp = pandas_df.copy() ray_df_cp.index = [str(i) for i in ray_df_cp.index] pandas_df_cp.index = [str(i) for i in pandas_df_cp.index] - assert(ray_df_cp.index.sort_values().equals(pandas_df_cp.index)) + assert (ray_df_cp.index.sort_values().equals(pandas_df_cp.index)) @pytest.fixture def test_size(ray_df, pandas_df): - assert(ray_df.size == pandas_df.size) + assert (ray_df.size == pandas_df.size) @pytest.fixture def test_ndim(ray_df, pandas_df): - assert(ray_df.ndim == pandas_df.ndim) + assert (ray_df.ndim == pandas_df.ndim) @pytest.fixture def test_ftypes(ray_df, pandas_df): - assert(ray_df.ftypes.equals(pandas_df.ftypes)) + assert (ray_df.ftypes.equals(pandas_df.ftypes)) @pytest.fixture def test_dtypes(ray_df, pandas_df): - assert(ray_df.dtypes.equals(pandas_df.dtypes)) + assert (ray_df.dtypes.equals(pandas_df.dtypes)) @pytest.fixture @@ -80,7 +78,7 @@ def test_axes(ray_df, pandas_df): @pytest.fixture def test_shape(ray_df, pandas_df): - assert(ray_df.shape == pandas_df.shape) + assert (ray_df.shape == pandas_df.shape) @pytest.fixture @@ -88,7 +86,7 @@ def test_add_prefix(ray_df, pandas_df): test_prefix = "TEST" new_ray_df = ray_df.add_prefix(test_prefix) new_pandas_df = pandas_df.add_prefix(test_prefix) - assert(new_ray_df.columns.equals(new_pandas_df.columns)) + assert (new_ray_df.columns.equals(new_pandas_df.columns)) @pytest.fixture @@ -97,7 +95,7 @@ def test_add_suffix(ray_df, pandas_df): new_ray_df = ray_df.add_suffix(test_suffix) new_pandas_df = pandas_df.add_suffix(test_suffix) - assert(new_ray_df.columns.equals(new_pandas_df.columns)) + assert (new_ray_df.columns.equals(new_pandas_df.columns)) @pytest.fixture @@ -105,7 +103,7 @@ def test_applymap(ray_df, pandas_df, testfunc): new_ray_df = ray_df.applymap(testfunc) new_pandas_df = pandas_df.applymap(testfunc) - assert(ray_df_equals_pandas(new_ray_df, new_pandas_df)) + assert (ray_df_equals_pandas(new_ray_df, new_pandas_df)) @pytest.fixture @@ -113,82 +111,85 @@ def test_copy(ray_df): new_ray_df = ray_df.copy() assert new_ray_df is not ray_df - assert np.array_equal(new_ray_df._block_partitions, - ray_df._block_partitions) + assert np.array_equal( + new_ray_df._block_partitions, ray_df._block_partitions + ) @pytest.fixture def test_sum(ray_df, pandas_df): - assert(ray_df.sum().sort_index().equals(pandas_df.sum().sort_index())) + assert (ray_df.sum().sort_index().equals(pandas_df.sum().sort_index())) @pytest.fixture def test_abs(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.abs(), pandas_df.abs())) + assert (ray_df_equals_pandas(ray_df.abs(), pandas_df.abs())) @pytest.fixture def test_keys(ray_df, pandas_df): - assert(ray_df.keys().equals(pandas_df.keys())) + assert (ray_df.keys().equals(pandas_df.keys())) @pytest.fixture def test_transpose(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.T, pandas_df.T)) - assert(ray_df_equals_pandas(ray_df.transpose(), pandas_df.transpose())) + assert (ray_df_equals_pandas(ray_df.T, pandas_df.T)) + assert (ray_df_equals_pandas(ray_df.transpose(), pandas_df.transpose())) @pytest.fixture def test_get(ray_df, pandas_df, key): - assert(ray_df.get(key).equals(pandas_df.get(key))) + assert (ray_df.get(key).equals(pandas_df.get(key))) assert ray_df.get( - key, default='default').equals( - pandas_df.get(key, default='default')) + key, default='default' + ).equals(pandas_df.get(key, default='default')) @pytest.fixture def test_get_dtype_counts(ray_df, pandas_df): - assert(ray_df.get_dtype_counts().equals(pandas_df.get_dtype_counts())) + assert (ray_df.get_dtype_counts().equals(pandas_df.get_dtype_counts())) @pytest.fixture def test_get_ftype_counts(ray_df, pandas_df): - assert(ray_df.get_ftype_counts().equals(pandas_df.get_ftype_counts())) + assert (ray_df.get_ftype_counts().equals(pandas_df.get_ftype_counts())) @pytest.fixture def create_test_dataframe(): - df = pd.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) + df = pd.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) return from_pandas(df, 2) def test_int_dataframe(): - pandas_df = pd.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) + pandas_df = pd.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) ray_df = from_pandas(pandas_df, 2) - testfuncs = [lambda x: x + 1, - lambda x: str(x), - lambda x: x * x, - lambda x: x, - lambda x: False] + testfuncs = [ + lambda x: x + 1, lambda x: str(x), lambda x: x * x, lambda x: x, + lambda x: False + ] - query_funcs = ['col1 < col2', 'col3 > col4', 'col1 == col2', - '(col2 > col1) and (col1 < col3)'] + query_funcs = [ + 'col1 < col2', 'col3 > col4', 'col1 == col2', + '(col2 > col1) and (col1 < col3)' + ] - keys = ['col1', - 'col2', - 'col3', - 'col4'] + keys = ['col1', 'col2', 'col3', 'col4'] test_roundtrip(ray_df, pandas_df) test_index(ray_df, pandas_df) @@ -291,8 +292,9 @@ def test_int_dataframe(): test___array__(ray_df, pandas_df) - apply_agg_functions = ['sum', lambda df: df.sum(), ['sum', 'mean'], - ['sum', 'sum']] + apply_agg_functions = [ + 'sum', lambda df: df.sum(), ['sum', 'mean'], ['sum', 'sum'] + ] for func in apply_agg_functions: test_apply(ray_df, pandas_df, func, 0) test_aggregate(ray_df, pandas_df, func, 0) @@ -328,27 +330,27 @@ def test_int_dataframe(): def test_float_dataframe(): - pandas_df = pd.DataFrame({'col1': [0.0, 1.0, 2.0, 3.0], - 'col2': [4.0, 5.0, 6.0, 7.0], - 'col3': [8.0, 9.0, 10.0, 11.0], - 'col4': [12.0, 13.0, 14.0, 15.0], - 'col5': [0.0, 0.0, 0.0, 0.0]}) + pandas_df = pd.DataFrame({ + 'col1': [0.0, 1.0, 2.0, 3.0], + 'col2': [4.0, 5.0, 6.0, 7.0], + 'col3': [8.0, 9.0, 10.0, 11.0], + 'col4': [12.0, 13.0, 14.0, 15.0], + 'col5': [0.0, 0.0, 0.0, 0.0] + }) ray_df = from_pandas(pandas_df, 3) - testfuncs = [lambda x: x + 1, - lambda x: str(x), - lambda x: x * x, - lambda x: x, - lambda x: False] + testfuncs = [ + lambda x: x + 1, lambda x: str(x), lambda x: x * x, lambda x: x, + lambda x: False + ] - query_funcs = ['col1 < col2', 'col3 > col4', 'col1 == col2', - '(col2 > col1) and (col1 < col3)'] + query_funcs = [ + 'col1 < col2', 'col3 > col4', 'col1 == col2', + '(col2 > col1) and (col1 < col3)' + ] - keys = ['col1', - 'col2', - 'col3', - 'col4'] + keys = ['col1', 'col2', 'col3', 'col4'] test_roundtrip(ray_df, pandas_df) test_index(ray_df, pandas_df) @@ -453,8 +455,9 @@ def test_float_dataframe(): # TODO Nans are always not equal to each other, fix it # test___array__(ray_df, pandas_df) - apply_agg_functions = ['sum', lambda df: df.sum(), ['sum', 'mean'], - ['sum', 'sum']] + apply_agg_functions = [ + 'sum', lambda df: df.sum(), ['sum', 'mean'], ['sum', 'sum'] + ] for func in apply_agg_functions: test_apply(ray_df, pandas_df, func, 0) test_aggregate(ray_df, pandas_df, func, 0) @@ -493,22 +496,20 @@ def test_mixed_dtype_dataframe(): 'col1': [1, 2, 3, 4], 'col2': [4, 5, 6, 7], 'col3': [8.0, 9.4, 10.1, 11.3], - 'col4': ['a', 'b', 'c', 'd']}) + 'col4': ['a', 'b', 'c', 'd'] + }) ray_df = from_pandas(pandas_df, 2) - testfuncs = [lambda x: x + x, - lambda x: str(x), - lambda x: x, - lambda x: False] + testfuncs = [ + lambda x: x + x, lambda x: str(x), lambda x: x, lambda x: False + ] - query_funcs = ['col1 < col2', 'col1 == col2', - '(col2 > col1) and (col1 < col3)'] + query_funcs = [ + 'col1 < col2', 'col1 == col2', '(col2 > col1) and (col1 < col3)' + ] - keys = ['col1', - 'col2', - 'col3', - 'col4'] + keys = ['col1', 'col2', 'col3', 'col4'] test_roundtrip(ray_df, pandas_df) test_index(ray_df, pandas_df) @@ -653,22 +654,21 @@ def test_nan_dataframe(): 'col1': [1, 2, 3, np.nan], 'col2': [4, 5, np.nan, 7], 'col3': [8, np.nan, 10, 11], - 'col4': [np.nan, 13, 14, 15]}) + 'col4': [np.nan, 13, 14, 15] + }) ray_df = from_pandas(pandas_df, 2) - testfuncs = [lambda x: x + x, - lambda x: str(x), - lambda x: x, - lambda x: False] + testfuncs = [ + lambda x: x + x, lambda x: str(x), lambda x: x, lambda x: False + ] - query_funcs = ['col1 < col2', 'col3 > col4', 'col1 == col2', - '(col2 > col1) and (col1 < col3)'] + query_funcs = [ + 'col1 < col2', 'col3 > col4', 'col1 == col2', + '(col2 > col1) and (col1 < col3)' + ] - keys = ['col1', - 'col2', - 'col3', - 'col4'] + keys = ['col1', 'col2', 'col3', 'col4'] test_roundtrip(ray_df, pandas_df) test_index(ray_df, pandas_df) @@ -771,8 +771,9 @@ def test_nan_dataframe(): # TODO Nans are always not equal to each other, fix it # test___array__(ray_df, pandas_df) - apply_agg_functions = ['sum', lambda df: df.sum(), ['sum', 'mean'], - ['sum', 'sum']] + apply_agg_functions = [ + 'sum', lambda df: df.sum(), ['sum', 'mean'], ['sum', 'sum'] + ] for func in apply_agg_functions: test_apply(ray_df, pandas_df, func, 0) test_aggregate(ray_df, pandas_df, func, 0) @@ -808,69 +809,99 @@ def test_nan_dataframe(): @pytest.fixture def test_inter_df_math(op, simple=False): - ray_df = rdf.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + ray_df = rdf.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - pandas_df = pd.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + pandas_df = pd.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - ray_df_equals_pandas(getattr(ray_df, op)(ray_df), - getattr(pandas_df, op)(pandas_df)) - ray_df_equals_pandas(getattr(ray_df, op)(4), - getattr(pandas_df, op)(4)) - ray_df_equals_pandas(getattr(ray_df, op)(4.0), - getattr(pandas_df, op)(4.0)) + ray_df_equals_pandas( + getattr(ray_df, op)(ray_df), + getattr(pandas_df, op)(pandas_df) + ) + ray_df_equals_pandas(getattr(ray_df, op)(4), getattr(pandas_df, op)(4)) + ray_df_equals_pandas(getattr(ray_df, op)(4.0), getattr(pandas_df, op)(4.0)) ray_df2 = rdf.DataFrame({"A": [0, 2], "col1": [0, 19], "col2": [1, 1]}) pandas_df2 = pd.DataFrame({"A": [0, 2], "col1": [0, 19], "col2": [1, 1]}) - ray_df_equals_pandas(getattr(ray_df, op)(ray_df2), - getattr(pandas_df, op)(pandas_df2)) + ray_df_equals_pandas( + getattr(ray_df, op)(ray_df2), + getattr(pandas_df, op)(pandas_df2) + ) list_test = [0, 1, 2, 4] if not simple: - ray_df_equals_pandas(getattr(ray_df, op)(list_test, axis=1), - getattr(pandas_df, op)(list_test, axis=1)) + ray_df_equals_pandas( + getattr(ray_df, op)(list_test, axis=1), + getattr(pandas_df, op)(list_test, axis=1) + ) - ray_df_equals_pandas(getattr(ray_df, op)(list_test, axis=0), - getattr(pandas_df, op)(list_test, axis=0)) + ray_df_equals_pandas( + getattr(ray_df, op)(list_test, axis=0), + getattr(pandas_df, op)(list_test, axis=0) + ) @pytest.fixture def test_comparison_inter_ops(op): - ray_df = rdf.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + ray_df = rdf.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - pandas_df = pd.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + pandas_df = pd.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - ray_df_equals_pandas(getattr(ray_df, op)(ray_df), - getattr(pandas_df, op)(pandas_df)) - ray_df_equals_pandas(getattr(ray_df, op)(4), - getattr(pandas_df, op)(4)) - ray_df_equals_pandas(getattr(ray_df, op)(4.0), - getattr(pandas_df, op)(4.0)) + ray_df_equals_pandas( + getattr(ray_df, op)(ray_df), + getattr(pandas_df, op)(pandas_df) + ) + ray_df_equals_pandas(getattr(ray_df, op)(4), getattr(pandas_df, op)(4)) + ray_df_equals_pandas(getattr(ray_df, op)(4.0), getattr(pandas_df, op)(4.0)) ray_df2 = rdf.DataFrame({"A": [0, 2], "col1": [0, 19], "col2": [1, 1]}) pandas_df2 = pd.DataFrame({"A": [0, 2], "col1": [0, 19], "col2": [1, 1]}) - ray_df_equals_pandas(getattr(ray_df, op)(ray_df2), - getattr(pandas_df, op)(pandas_df2)) + ray_df_equals_pandas( + getattr(ray_df, op)(ray_df2), + getattr(pandas_df, op)(pandas_df2) + ) @pytest.fixture def test_inter_df_math_right_ops(op): - ray_df = rdf.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + ray_df = rdf.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - pandas_df = pd.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + pandas_df = pd.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - ray_df_equals_pandas(getattr(ray_df, op)(4), - getattr(pandas_df, op)(4)) - ray_df_equals_pandas(getattr(ray_df, op)(4.0), - getattr(pandas_df, op)(4.0)) + ray_df_equals_pandas(getattr(ray_df, op)(4), getattr(pandas_df, op)(4)) + ray_df_equals_pandas(getattr(ray_df, op)(4.0), getattr(pandas_df, op)(4.0)) def test_add(): @@ -917,11 +948,19 @@ def test_any(ray_df, pd_df): def test_append(): - ray_df = rdf.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + ray_df = rdf.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - pandas_df = pd.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + pandas_df = pd.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) ray_df2 = rdf.DataFrame({"col5": [0], "col6": [1]}) @@ -929,8 +968,9 @@ def test_append(): print(ray_df.append(ray_df2)) - assert ray_df_equals_pandas(ray_df.append(ray_df2), - pandas_df.append(pandas_df2)) + assert ray_df_equals_pandas( + ray_df.append(ray_df2), pandas_df.append(pandas_df2) + ) with pytest.raises(ValueError): ray_df.append(ray_df2, verify_integrity=True) @@ -1008,10 +1048,7 @@ def test_bfill(num_partitions=2): test_data.tsframe['A'][:5] = np.nan test_data.tsframe['A'][-5:] = np.nan ray_df = from_pandas(test_data.tsframe, num_partitions) - assert ray_df_equals_pandas( - ray_df.bfill(), - test_data.tsframe.bfill() - ) + assert ray_df_equals_pandas(ray_df.bfill(), test_data.tsframe.bfill()) @pytest.fixture @@ -1118,27 +1155,27 @@ def test_cov(): @pytest.fixture def test_cummax(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.cummax(), pandas_df.cummax())) + assert (ray_df_equals_pandas(ray_df.cummax(), pandas_df.cummax())) @pytest.fixture def test_cummin(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.cummin(), pandas_df.cummin())) + assert (ray_df_equals_pandas(ray_df.cummin(), pandas_df.cummin())) @pytest.fixture def test_cumprod(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.cumprod(), pandas_df.cumprod())) + assert (ray_df_equals_pandas(ray_df.cumprod(), pandas_df.cumprod())) @pytest.fixture def test_cumsum(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.cumsum(), pandas_df.cumsum())) + assert (ray_df_equals_pandas(ray_df.cumsum(), pandas_df.cumsum())) @pytest.fixture def test_describe(ray_df, pandas_df): - assert(ray_df.describe().equals(pandas_df.describe())) + assert (ray_df.describe().equals(pandas_df.describe())) def test_diff(): @@ -1168,12 +1205,15 @@ def test_drop(): simple = pd.DataFrame({"A": [1, 2, 3, 4], "B": [0, 1, 2, 3]}) ray_simple = from_pandas(simple, 2) assert ray_df_equals_pandas(ray_simple.drop("A", axis=1), simple[['B']]) - assert ray_df_equals_pandas(ray_simple.drop(["A", "B"], axis='columns'), - simple[[]]) - assert ray_df_equals_pandas(ray_simple.drop([0, 1, 3], axis=0), - simple.loc[[2], :]) - assert ray_df_equals_pandas(ray_simple.drop([0, 3], axis='index'), - simple.loc[[1, 2], :]) + assert ray_df_equals_pandas( + ray_simple.drop(["A", "B"], axis='columns'), simple[[]] + ) + assert ray_df_equals_pandas( + ray_simple.drop([0, 1, 3], axis=0), simple.loc[[2], :] + ) + assert ray_df_equals_pandas( + ray_simple.drop([0, 3], axis='index'), simple.loc[[1, 2], :] + ) pytest.raises(ValueError, ray_simple.drop, 5) pytest.raises(ValueError, ray_simple.drop, 'C', 1) @@ -1182,30 +1222,35 @@ def test_drop(): # errors = 'ignore' assert ray_df_equals_pandas(ray_simple.drop(5, errors='ignore'), simple) - assert ray_df_equals_pandas(ray_simple.drop([0, 5], errors='ignore'), - simple.loc[[1, 2, 3], :]) - assert ray_df_equals_pandas(ray_simple.drop('C', axis=1, errors='ignore'), - simple) - assert ray_df_equals_pandas(ray_simple.drop(['A', 'C'], axis=1, - errors='ignore'), - simple[['B']]) + assert ray_df_equals_pandas( + ray_simple.drop([0, 5], errors='ignore'), simple.loc[[1, 2, 3], :] + ) + assert ray_df_equals_pandas( + ray_simple.drop('C', axis=1, errors='ignore'), simple + ) + assert ray_df_equals_pandas( + ray_simple.drop(['A', 'C'], axis=1, errors='ignore'), simple[['B']] + ) # non-unique - wheee! - nu_df = pd.DataFrame(pd.compat.lzip(range(3), range(-3, 1), list('abc')), - columns=['a', 'a', 'b']) + nu_df = pd.DataFrame( + pd.compat.lzip(range(3), range(-3, 1), list('abc')), + columns=['a', 'a', 'b'] + ) ray_nu_df = from_pandas(nu_df, 3) assert ray_df_equals_pandas(ray_nu_df.drop('a', axis=1), nu_df[['b']]) - assert ray_df_equals_pandas(ray_nu_df.drop('b', axis='columns'), - nu_df['a']) + assert ray_df_equals_pandas(ray_nu_df.drop('b', axis='columns'), nu_df['a']) assert ray_df_equals_pandas(ray_nu_df.drop([]), nu_df) # GH 16398 nu_df = nu_df.set_index(pd.Index(['X', 'Y', 'X'])) nu_df.columns = list('abc') ray_nu_df = from_pandas(nu_df, 3) - assert ray_df_equals_pandas(ray_nu_df.drop('X', axis='rows'), - nu_df.loc[["Y"], :]) - assert ray_df_equals_pandas(ray_nu_df.drop(['X', 'Y'], axis=0), - nu_df.loc[[], :]) + assert ray_df_equals_pandas( + ray_nu_df.drop('X', axis='rows'), nu_df.loc[["Y"], :] + ) + assert ray_df_equals_pandas( + ray_nu_df.drop(['X', 'Y'], axis=0), nu_df.loc[[], :] + ) # inplace cache issue # GH 5628 @@ -1272,15 +1317,13 @@ def test_eq(): def test_equals(): - pandas_df1 = pd.DataFrame({'col1': [2.9, 3, 3, 3], - 'col2': [2, 3, 4, 1]}) + pandas_df1 = pd.DataFrame({'col1': [2.9, 3, 3, 3], 'col2': [2, 3, 4, 1]}) ray_df1 = from_pandas(pandas_df1, 2) ray_df2 = from_pandas(pandas_df1, 3) assert ray_df1.equals(ray_df2) - pandas_df2 = pd.DataFrame({'col1': [2.9, 3, 3, 3], - 'col2': [2, 3, 5, 1]}) + pandas_df2 = pd.DataFrame({'col1': [2.9, 3, 3, 3], 'col2': [2, 3, 5, 1]}) ray_df3 = from_pandas(pandas_df2, 4) assert not ray_df3.equals(ray_df1) @@ -1288,29 +1331,33 @@ def test_equals(): def test_eval_df_use_case(): - df = pd.DataFrame({'a': np.random.randn(10), - 'b': np.random.randn(10)}) + df = pd.DataFrame({'a': np.random.randn(10), 'b': np.random.randn(10)}) ray_df = from_pandas(df, 2) - df.eval("e = arctan2(sin(a), b)", - engine='python', - parser='pandas', inplace=True) - ray_df.eval("e = arctan2(sin(a), b)", - engine='python', - parser='pandas', inplace=True) + df.eval( + "e = arctan2(sin(a), b)", + engine='python', + parser='pandas', + inplace=True + ) + ray_df.eval( + "e = arctan2(sin(a), b)", + engine='python', + parser='pandas', + inplace=True + ) # TODO: Use a series equality validator. assert ray_df_equals_pandas(ray_df, df) def test_eval_df_arithmetic_subexpression(): - df = pd.DataFrame({'a': np.random.randn(10), - 'b': np.random.randn(10)}) + df = pd.DataFrame({'a': np.random.randn(10), 'b': np.random.randn(10)}) ray_df = from_pandas(df, 2) - df.eval("not_e = sin(a + b)", - engine='python', - parser='pandas', inplace=True) - ray_df.eval("not_e = sin(a + b)", - engine='python', - parser='pandas', inplace=True) + df.eval( + "not_e = sin(a + b)", engine='python', parser='pandas', inplace=True + ) + ray_df.eval( + "not_e = sin(a + b)", engine='python', parser='pandas', inplace=True + ) # TODO: Use a series equality validator. assert ray_df_equals_pandas(ray_df, df) @@ -1336,10 +1383,7 @@ def test_ffill(num_partitions=2): test_data.tsframe['A'][-5:] = np.nan ray_df = from_pandas(test_data.tsframe, num_partitions) - assert ray_df_equals_pandas( - ray_df.ffill(), - test_data.tsframe.ffill() - ) + assert ray_df_equals_pandas(ray_df.ffill(), test_data.tsframe.ffill()) def test_fillna(): @@ -1377,8 +1421,7 @@ def test_fillna_sanity(num_partitions=2): assert ray_df_equals_pandas(ray_df, zero_filled) padded = test_data.tsframe.fillna(method='pad') - ray_df = from_pandas(test_data.tsframe, - num_partitions).fillna(method='pad') + ray_df = from_pandas(test_data.tsframe, num_partitions).fillna(method='pad') assert ray_df_equals_pandas(ray_df, padded) # mixed type @@ -1387,8 +1430,7 @@ def test_fillna_sanity(num_partitions=2): mf.loc[mf.index[-10:], 'A'] = np.nan result = test_data.mixed_frame.fillna(value=0) - ray_df = from_pandas(test_data.mixed_frame, - num_partitions).fillna(value=0) + ray_df = from_pandas(test_data.mixed_frame, num_partitions).fillna(value=0) assert ray_df_equals_pandas(ray_df, result) result = test_data.mixed_frame.fillna(method='pad') @@ -1397,12 +1439,12 @@ def test_fillna_sanity(num_partitions=2): assert ray_df_equals_pandas(ray_df, result) pytest.raises(ValueError, test_data.tsframe.fillna) - pytest.raises(ValueError, from_pandas(test_data.tsframe, - num_partitions).fillna) + pytest.raises( + ValueError, + from_pandas(test_data.tsframe, num_partitions).fillna + ) with pytest.raises(ValueError): - from_pandas(test_data.tsframe, num_partitions).fillna( - 5, method='ffill' - ) + from_pandas(test_data.tsframe, num_partitions).fillna(5, method='ffill') # mixed numeric (but no float16) mf = test_data.mixed_float.reindex(columns=['A', 'B', 'D']) @@ -1424,8 +1466,8 @@ def test_fillna_sanity(num_partitions=2): # df.x.fillna(method=m) # with different dtype (GH3386) - df = pd.DataFrame([['a', 'a', np.nan, 'a'], [ - 'b', 'b', np.nan, 'b'], ['c', 'c', np.nan, 'c']]) + df = pd.DataFrame([['a', 'a', np.nan, 'a'], ['b', 'b', np.nan, 'b'], + ['c', 'c', np.nan, 'c']]) result = df.fillna({2: 'foo'}) ray_df = from_pandas(df, num_partitions).fillna({2: 'foo'}) @@ -1454,9 +1496,7 @@ def test_fillna_sanity(num_partitions=2): 'Date2': [pd.Timestamp("2013-1-1"), pd.NaT] }) result = df.fillna(value={'Date': df['Date2']}) - ray_df = from_pandas(df, num_partitions).fillna( - value={'Date': df['Date2']} - ) + ray_df = from_pandas(df, num_partitions).fillna(value={'Date': df['Date2']}) assert ray_df_equals_pandas(ray_df, result) # TODO: Use this when Arrow issue resolves: @@ -1489,9 +1529,7 @@ def test_fillna_downcast(num_partitions=2): # infer int64 from float64 when fillna value is a dict df = pd.DataFrame({'a': [1., np.nan]}) result = df.fillna({'a': 0}, downcast='infer') - ray_df = from_pandas(df, num_partitions).fillna( - {'a': 0}, downcast='infer' - ) + ray_df = from_pandas(df, num_partitions).fillna({'a': 0}, downcast='infer') assert ray_df_equals_pandas(ray_df, result) @@ -1502,8 +1540,7 @@ def test_ffill2(num_partitions=2): test_data.tsframe['A'][-5:] = np.nan ray_df = from_pandas(test_data.tsframe, num_partitions) assert ray_df_equals_pandas( - ray_df.fillna(method='ffill'), - test_data.tsframe.fillna(method='ffill') + ray_df.fillna(method='ffill'), test_data.tsframe.fillna(method='ffill') ) @@ -1514,8 +1551,7 @@ def test_bfill2(num_partitions=2): test_data.tsframe['A'][-5:] = np.nan ray_df = from_pandas(test_data.tsframe, num_partitions) assert ray_df_equals_pandas( - ray_df.fillna(method='bfill'), - test_data.tsframe.fillna(method='bfill') + ray_df.fillna(method='bfill'), test_data.tsframe.fillna(method='bfill') ) @@ -1532,8 +1568,7 @@ def test_fillna_inplace(num_partitions=2): ray_df.fillna(value=0, inplace=True) assert ray_df_equals_pandas(ray_df, df) - ray_df = from_pandas(df, num_partitions).fillna(value={0: 0}, - inplace=True) + ray_df = from_pandas(df, num_partitions).fillna(value={0: 0}, inplace=True) assert ray_df is None df[1][:4] = np.nan @@ -1596,19 +1631,13 @@ def test_fillna_dtype_conversion(num_partitions=2): # empty block df = pd.DataFrame(index=range(3), columns=['A', 'B'], dtype='float64') ray_df = from_pandas(df, num_partitions) - assert ray_df_equals_pandas( - ray_df.fillna('nan'), - df.fillna('nan') - ) + assert ray_df_equals_pandas(ray_df.fillna('nan'), df.fillna('nan')) # equiv of replace df = pd.DataFrame(dict(A=[1, np.nan], B=[1., 2.])) ray_df = from_pandas(df, num_partitions) for v in ['', 1, np.nan, 1.0]: - assert ray_df_equals_pandas( - ray_df.fillna(v), - df.fillna(v) - ) + assert ray_df_equals_pandas(ray_df.fillna(v), df.fillna(v)) @pytest.fixture @@ -1619,57 +1648,66 @@ def test_fillna_skip_certain_blocks(num_partitions=2): ray_df = from_pandas(df, num_partitions) # it works! - assert ray_df_equals_pandas( - ray_df.fillna(np.nan), - df.fillna(np.nan) - ) + assert ray_df_equals_pandas(ray_df.fillna(np.nan), df.fillna(np.nan)) @pytest.fixture def test_fillna_dict_series(num_partitions=2): - df = pd.DataFrame({'a': [np.nan, 1, 2, np.nan, np.nan], - 'b': [1, 2, 3, np.nan, np.nan], - 'c': [np.nan, 1, 2, 3, 4]}) + df = pd.DataFrame({ + 'a': [np.nan, 1, 2, np.nan, np.nan], + 'b': [1, 2, 3, np.nan, np.nan], + 'c': [np.nan, 1, 2, 3, 4] + }) ray_df = from_pandas(df, num_partitions) assert ray_df_equals_pandas( - ray_df.fillna({'a': 0, 'b': 5}), - df.fillna({'a': 0, 'b': 5}) + ray_df.fillna({ + 'a': 0, + 'b': 5 + }), df.fillna({ + 'a': 0, + 'b': 5 + }) ) # it works assert ray_df_equals_pandas( - ray_df.fillna({'a': 0, 'b': 5, 'd': 7}), - df.fillna({'a': 0, 'b': 5, 'd': 7}) + ray_df.fillna({ + 'a': 0, + 'b': 5, + 'd': 7 + }), df.fillna({ + 'a': 0, + 'b': 5, + 'd': 7 + }) ) # Series treated same as dict - assert ray_df_equals_pandas( - ray_df.fillna(df.max()), - df.fillna(df.max()) - ) + assert ray_df_equals_pandas(ray_df.fillna(df.max()), df.fillna(df.max())) @pytest.fixture def test_fillna_dataframe(num_partitions=2): # GH 8377 - df = pd.DataFrame({'a': [np.nan, 1, 2, np.nan, np.nan], - 'b': [1, 2, 3, np.nan, np.nan], - 'c': [np.nan, 1, 2, 3, 4]}, + df = pd.DataFrame({ + 'a': [np.nan, 1, 2, np.nan, np.nan], + 'b': [1, 2, 3, np.nan, np.nan], + 'c': [np.nan, 1, 2, 3, 4] + }, index=list('VWXYZ')) ray_df = from_pandas(df, num_partitions) # df2 may have different index and columns - df2 = pd.DataFrame({'a': [np.nan, 10, 20, 30, 40], - 'b': [50, 60, 70, 80, 90], - 'foo': ['bar'] * 5}, + df2 = pd.DataFrame({ + 'a': [np.nan, 10, 20, 30, 40], + 'b': [50, 60, 70, 80, 90], + 'foo': ['bar'] * 5 + }, index=list('VWXuZ')) # only those columns and indices which are shared get filled - assert ray_df_equals_pandas( - ray_df.fillna(df2), - df.fillna(df2) - ) + assert ray_df_equals_pandas(ray_df.fillna(df2), df.fillna(df2)) @pytest.fixture @@ -1719,8 +1757,7 @@ def test_fillna_col_reordering(num_partitions=2): df = pd.DataFrame(index=range(20), columns=cols, data=data) ray_df = from_pandas(df, num_partitions) assert ray_df_equals_pandas( - ray_df.fillna(method='ffill'), - df.fillna(method='ffill') + ray_df.fillna(method='ffill'), df.fillna(method='ffill') ) @@ -1771,7 +1808,7 @@ def test_first(): @pytest.fixture def test_first_valid_index(ray_df, pandas_df): - assert(ray_df.first_valid_index() == (pandas_df.first_valid_index())) + assert (ray_df.first_valid_index() == (pandas_df.first_valid_index())) def test_floordiv(): @@ -1929,11 +1966,19 @@ def test_itertuples(ray_df, pandas_df): def test_join(): - ray_df = rdf.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + ray_df = rdf.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) - pandas_df = pd.DataFrame({"col1": [0, 1, 2, 3], "col2": [4, 5, 6, 7], - "col3": [8, 9, 0, 1], "col4": [2, 4, 5, 6]}) + pandas_df = pd.DataFrame({ + "col1": [0, 1, 2, 3], + "col2": [4, 5, 6, 7], + "col3": [8, 9, 0, 1], + "col4": [2, 4, 5, 6] + }) ray_df2 = rdf.DataFrame({"col5": [0], "col6": [1]}) @@ -1979,7 +2024,7 @@ def test_last(): @pytest.fixture def test_last_valid_index(ray_df, pandas_df): - assert(ray_df.last_valid_index() == (pandas_df.last_valid_index())) + assert (ray_df.last_valid_index() == (pandas_df.last_valid_index())) def test_le(): @@ -2013,8 +2058,8 @@ def test_mask(): @pytest.fixture def test_max(ray_df, pandas_df): - assert(ray_series_equals_pandas(ray_df.max(), pandas_df.max())) - assert(ray_series_equals_pandas(ray_df.max(axis=1), pandas_df.max(axis=1))) + assert (ray_series_equals_pandas(ray_df.max(), pandas_df.max())) + assert (ray_series_equals_pandas(ray_df.max(axis=1), pandas_df.max(axis=1))) @pytest.fixture @@ -2024,7 +2069,7 @@ def test_mean(ray_df, pandas_df): @pytest.fixture def test_median(ray_df, pandas_df): - assert(ray_df.median().equals(pandas_df.median())) + assert (ray_df.median().equals(pandas_df.median())) def test_melt(): @@ -2051,8 +2096,8 @@ def test_merge(): @pytest.fixture def test_min(ray_df, pandas_df): - assert(ray_series_equals_pandas(ray_df.min(), pandas_df.min())) - assert(ray_series_equals_pandas(ray_df.min(axis=1), pandas_df.min(axis=1))) + assert (ray_series_equals_pandas(ray_df.min(), pandas_df.min())) + assert (ray_series_equals_pandas(ray_df.min(axis=1), pandas_df.min(axis=1))) def test_mod(): @@ -2087,12 +2132,12 @@ def test_nlargest(): @pytest.fixture def test_notna(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.notna(), pandas_df.notna())) + assert (ray_df_equals_pandas(ray_df.notna(), pandas_df.notna())) @pytest.fixture def test_notnull(ray_df, pandas_df): - assert(ray_df_equals_pandas(ray_df.notnull(), pandas_df.notnull())) + assert (ray_df_equals_pandas(ray_df.notnull(), pandas_df.notnull())) def test_nsmallest(): @@ -2174,7 +2219,7 @@ def test_product(): @pytest.fixture def test_quantile(ray_df, pandas_df, q): - assert(ray_df.quantile(q).equals(pandas_df.quantile(q))) + assert (ray_df.quantile(q).equals(pandas_df.quantile(q))) @pytest.fixture @@ -2222,6 +2267,7 @@ def test_reindex_like(): # Renaming + def test_rename(): test_rename_sanity() test_rename_multiindex() @@ -2234,47 +2280,40 @@ def test_rename(): @pytest.fixture def test_rename_sanity(num_partitions=2): test_data = TestData() - mapping = { - 'A': 'a', - 'B': 'b', - 'C': 'c', - 'D': 'd' - } + mapping = {'A': 'a', 'B': 'b', 'C': 'c', 'D': 'd'} ray_df = from_pandas(test_data.frame, num_partitions) assert ray_df_equals_pandas( - ray_df.rename(columns=mapping), - test_data.frame.rename(columns=mapping) + ray_df.rename(columns=mapping), test_data.frame.rename(columns=mapping) ) renamed2 = test_data.frame.rename(columns=str.lower) - assert ray_df_equals_pandas( - ray_df.rename(columns=str.lower), - renamed2 - ) + assert ray_df_equals_pandas(ray_df.rename(columns=str.lower), renamed2) ray_df = from_pandas(renamed2, num_partitions) assert ray_df_equals_pandas( - ray_df.rename(columns=str.upper), - renamed2.rename(columns=str.upper) + ray_df.rename(columns=str.upper), renamed2.rename(columns=str.upper) ) # index - data = { - 'A': {'foo': 0, 'bar': 1} - } + data = {'A': {'foo': 0, 'bar': 1}} # gets sorted alphabetical df = pd.DataFrame(data) ray_df = from_pandas(df, num_partitions) tm.assert_index_equal( - ray_df.rename(index={'foo': 'bar', 'bar': 'foo'}).index, - df.rename(index={'foo': 'bar', 'bar': 'foo'}).index + ray_df.rename(index={ + 'foo': 'bar', + 'bar': 'foo' + }).index, + df.rename(index={ + 'foo': 'bar', + 'bar': 'foo' + }).index ) tm.assert_index_equal( - ray_df.rename(index=str.upper).index, - df.rename(index=str.upper).index + ray_df.rename(index=str.upper).index, df.rename(index=str.upper).index ) # have to pass something @@ -2284,8 +2323,14 @@ def test_rename_sanity(num_partitions=2): renamed = test_data.frame.rename(columns={'C': 'foo', 'D': 'bar'}) ray_df = from_pandas(test_data.frame, num_partitions) tm.assert_index_equal( - ray_df.rename(columns={'C': 'foo', 'D': 'bar'}).index, - test_data.frame.rename(columns={'C': 'foo', 'D': 'bar'}).index + ray_df.rename(columns={ + 'C': 'foo', + 'D': 'bar' + }).index, + test_data.frame.rename(columns={ + 'C': 'foo', + 'D': 'bar' + }).index ) # TODO: Uncomment when transpose works @@ -2303,9 +2348,7 @@ def test_rename_sanity(num_partitions=2): ray_df = from_pandas(renamer, num_partitions) renamed = renamer.rename(index={'foo': 'bar', 'bar': 'foo'}) ray_renamed = ray_df.rename(index={'foo': 'bar', 'bar': 'foo'}) - tm.assert_index_equal( - renamed.index, ray_renamed.index - ) + tm.assert_index_equal(renamed.index, ray_renamed.index) assert renamed.index.name == ray_renamed.index.name @@ -2315,23 +2358,44 @@ def test_rename_multiindex(num_partitions=2): tuples_index = [('foo1', 'bar1'), ('foo2', 'bar2')] tuples_columns = [('fizz1', 'buzz1'), ('fizz2', 'buzz2')] index = pd.MultiIndex.from_tuples(tuples_index, names=['foo', 'bar']) - columns = pd.MultiIndex.from_tuples( - tuples_columns, names=['fizz', 'buzz']) + columns = pd.MultiIndex.from_tuples(tuples_columns, names=['fizz', 'buzz']) df = pd.DataFrame([(0, 0), (1, 1)], index=index, columns=columns) ray_df = from_pandas(df, num_partitions) # # without specifying level -> accross all levels - renamed = df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, - columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}) - ray_renamed = ray_df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, - columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}) - tm.assert_index_equal( - renamed.index, ray_renamed.index + renamed = df.rename( + index={ + 'foo1': 'foo3', + 'bar2': 'bar3' + }, + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + } + ) + ray_renamed = ray_df.rename( + index={ + 'foo1': 'foo3', + 'bar2': 'bar3' + }, + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + } + ) + tm.assert_index_equal(renamed.index, ray_renamed.index) + + renamed = df.rename( + index={ + 'foo1': 'foo3', + 'bar2': 'bar3' + }, + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + } ) - - renamed = df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, - columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}) tm.assert_index_equal(renamed.columns, ray_renamed.columns) assert renamed.index.names == ray_renamed.index.names assert renamed.columns.names == ray_renamed.columns.names @@ -2340,26 +2404,48 @@ def test_rename_multiindex(num_partitions=2): # with specifying a level (GH13766) # dict - renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level=0) - ray_renamed = ray_df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level=0) + renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, level=0) + ray_renamed = ray_df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level=0 + ) tm.assert_index_equal(renamed.columns, ray_renamed.columns) - renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level='fizz') - ray_renamed = ray_df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level='fizz') + renamed = df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level='fizz' + ) + ray_renamed = ray_df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level='fizz' + ) tm.assert_index_equal(renamed.columns, ray_renamed.columns) - renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level=1) - ray_renamed = ray_df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level=1) + renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, level=1) + ray_renamed = ray_df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level=1 + ) tm.assert_index_equal(renamed.columns, ray_renamed.columns) - renamed = df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level='buzz') - ray_renamed = ray_df.rename(columns={'fizz1': 'fizz3', 'buzz2': 'buzz3'}, - level='buzz') + renamed = df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level='buzz' + ) + ray_renamed = ray_df.rename( + columns={ + 'fizz1': 'fizz3', + 'buzz2': 'buzz3' + }, level='buzz' + ) tm.assert_index_equal(renamed.columns, ray_renamed.columns) # function @@ -2379,10 +2465,8 @@ def test_rename_multiindex(num_partitions=2): tm.assert_index_equal(renamed.columns, ray_renamed.columns) # index - renamed = df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, - level=0) - ray_renamed = ray_df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, - level=0) + renamed = df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, level=0) + ray_renamed = ray_df.rename(index={'foo1': 'foo3', 'bar2': 'bar3'}, level=0) tm.assert_index_equal(ray_renamed.index, renamed.index) @@ -2410,10 +2494,7 @@ def test_rename_inplace(num_partitions=2): frame.rename(columns={'C': 'foo'}, inplace=True) ray_frame.rename(columns={'C': 'foo'}, inplace=True) - assert ray_df_equals_pandas( - ray_frame, - frame - ) + assert ray_df_equals_pandas(ray_frame, frame) @pytest.fixture @@ -2434,10 +2515,7 @@ def test_rename_bug(num_partitions=2): # ray_df = ray_df.set_index(['a', 'b']) # ray_df.columns = ['2001-01-01'] - assert ray_df_equals_pandas( - ray_df, - df - ) + assert ray_df_equals_pandas(ray_df, df) def test_rename_axis(): @@ -2456,10 +2534,7 @@ def test_rename_axis_inplace(num_partitions=2): ray_no_return = ray_result.rename_axis('foo', inplace=True) assert no_return is ray_no_return - assert ray_df_equals_pandas( - ray_result, - result - ) + assert ray_df_equals_pandas(ray_result, result) result = test_frame.copy() ray_result = ray_df.copy() @@ -2467,10 +2542,7 @@ def test_rename_axis_inplace(num_partitions=2): ray_no_return = ray_result.rename_axis('bar', axis=1, inplace=True) assert no_return is ray_no_return - assert ray_df_equals_pandas( - ray_result, - result - ) + assert ray_df_equals_pandas(ray_result, result) def test_reorder_levels(): @@ -2498,7 +2570,8 @@ def test_resample(): def test_reset_index(ray_df, pandas_df, inplace=False): if not inplace: assert to_pandas(ray_df.reset_index(inplace=inplace)).equals( - pandas_df.reset_index(inplace=inplace)) + pandas_df.reset_index(inplace=inplace) + ) else: ray_df_cp = ray_df.copy() pd_df_cp = pandas_df.copy() @@ -2574,14 +2647,16 @@ def test_sem(): @pytest.fixture def test_set_axis(ray_df, pandas_df, label, axis): assert to_pandas(ray_df.set_axis(label, axis, inplace=False)).equals( - pandas_df.set_axis(label, axis, inplace=False)) + pandas_df.set_axis(label, axis, inplace=False) + ) @pytest.fixture def test_set_index(ray_df, pandas_df, keys, inplace=False): if not inplace: assert to_pandas(ray_df.set_index(keys)).equals( - pandas_df.set_index(keys)) + pandas_df.set_index(keys) + ) else: ray_df_cp = ray_df.copy() pd_df_cp = pandas_df.copy() @@ -2655,7 +2730,7 @@ def test_stack(): @pytest.fixture def test_std(ray_df, pandas_df): - assert(ray_df.std().equals(pandas_df.std())) + assert (ray_df.std().equals(pandas_df.std())) def test_sub(): @@ -2729,10 +2804,11 @@ def test_to_xarray(): @pytest.fixture def test_transform(ray_df, pandas_df): - ray_df_equals_pandas(ray_df.transform(lambda df: df.isna()), - pandas_df.transform(lambda df: df.isna())) - ray_df_equals_pandas(ray_df.transform('isna'), - pandas_df.transform('isna')) + ray_df_equals_pandas( + ray_df.transform(lambda df: df.isna()), + pandas_df.transform(lambda df: df.isna()) + ) + ray_df_equals_pandas(ray_df.transform('isna'), pandas_df.transform('isna')) def test_truediv(): @@ -2783,7 +2859,7 @@ def test_update(): @pytest.fixture def test_var(ray_df, pandas_df): - assert(ray_df.var().equals(pandas_df.var())) + assert (ray_df.var().equals(pandas_df.var())) def test_where(): @@ -2835,7 +2911,7 @@ def test___setitem__(): @pytest.fixture def test___len__(ray_df, pandas_df): - assert((len(ray_df) == len(pandas_df))) + assert ((len(ray_df) == len(pandas_df))) def test___unicode__(): @@ -2899,7 +2975,7 @@ def test___bool__(): @pytest.fixture def test___abs__(ray_df, pandas_df): - assert(ray_df_equals_pandas(abs(ray_df), abs(pandas_df))) + assert (ray_df_equals_pandas(abs(ray_df), abs(pandas_df))) def test___round__(): diff --git a/python/ray/dataframe/test/test_io.py b/python/ray/dataframe/test/test_io.py index c2ab544beefe..59f8c6ac53bd 100644 --- a/python/ray/dataframe/test/test_io.py +++ b/python/ray/dataframe/test/test_io.py @@ -45,22 +45,26 @@ def setup_parquet_file(row_size, force=False): @pytest.fixture def create_test_ray_dataframe(): - df = pd.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) + df = pd.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) return df @pytest.fixture def create_test_pandas_dataframe(): - df = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) + df = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) return df @@ -265,11 +269,13 @@ def setup_sql_file(conn, force=False): if os.path.exists(TEST_SQL_FILENAME) and not force: pass else: - df = pandas.DataFrame({'col1': [0, 1, 2, 3], - 'col2': [4, 5, 6, 7], - 'col3': [8, 9, 10, 11], - 'col4': [12, 13, 14, 15], - 'col5': [0, 0, 0, 0]}) + df = pandas.DataFrame({ + 'col1': [0, 1, 2, 3], + 'col2': [4, 5, 6, 7], + 'col3': [8, 9, 10, 11], + 'col4': [12, 13, 14, 15], + 'col5': [0, 0, 0, 0] + }) df.to_sql(TEST_SQL_FILENAME.split(".")[0], conn) @@ -453,7 +459,7 @@ def test_to_clipboard(): pandas_df.to_clipboard() pandas_as_clip = pandas.read_clipboard() - assert(ray_as_clip.equals(pandas_as_clip)) + assert (ray_as_clip.equals(pandas_as_clip)) def test_to_csv(): @@ -466,8 +472,7 @@ def test_to_csv(): ray_df.to_csv(TEST_CSV_DF_FILENAME) pandas_df.to_csv(TEST_CSV_pandas_FILENAME) - assert(test_files_eq(TEST_CSV_DF_FILENAME, - TEST_CSV_pandas_FILENAME)) + assert (test_files_eq(TEST_CSV_DF_FILENAME, TEST_CSV_pandas_FILENAME)) teardown_test_file(TEST_CSV_pandas_FILENAME) teardown_test_file(TEST_CSV_DF_FILENAME) @@ -503,8 +508,7 @@ def test_to_excel(): ray_writer.save() pandas_writer.save() - assert(test_files_eq(TEST_EXCEL_DF_FILENAME, - TEST_EXCEL_pandas_FILENAME)) + assert (test_files_eq(TEST_EXCEL_DF_FILENAME, TEST_EXCEL_pandas_FILENAME)) teardown_test_file(TEST_EXCEL_DF_FILENAME) teardown_test_file(TEST_EXCEL_pandas_FILENAME) @@ -520,8 +524,9 @@ def test_to_feather(): ray_df.to_feather(TEST_FEATHER_DF_FILENAME) pandas_df.to_feather(TEST_FEATHER_pandas_FILENAME) - assert(test_files_eq(TEST_FEATHER_DF_FILENAME, - TEST_FEATHER_pandas_FILENAME)) + assert ( + test_files_eq(TEST_FEATHER_DF_FILENAME, TEST_FEATHER_pandas_FILENAME) + ) teardown_test_file(TEST_FEATHER_pandas_FILENAME) teardown_test_file(TEST_FEATHER_DF_FILENAME) @@ -545,8 +550,7 @@ def test_to_html(): ray_df.to_html(TEST_HTML_DF_FILENAME) pandas_df.to_html(TEST_HTML_pandas_FILENAME) - assert(test_files_eq(TEST_HTML_DF_FILENAME, - TEST_HTML_pandas_FILENAME)) + assert (test_files_eq(TEST_HTML_DF_FILENAME, TEST_HTML_pandas_FILENAME)) teardown_test_file(TEST_HTML_pandas_FILENAME) teardown_test_file(TEST_HTML_DF_FILENAME) @@ -562,8 +566,7 @@ def test_to_json(): ray_df.to_json(TEST_JSON_DF_FILENAME) pandas_df.to_json(TEST_JSON_pandas_FILENAME) - assert(test_files_eq(TEST_JSON_DF_FILENAME, - TEST_JSON_pandas_FILENAME)) + assert (test_files_eq(TEST_JSON_DF_FILENAME, TEST_JSON_pandas_FILENAME)) teardown_test_file(TEST_JSON_pandas_FILENAME) teardown_test_file(TEST_JSON_DF_FILENAME) @@ -586,8 +589,9 @@ def test_to_msgpack(): ray_df.to_msgpack(TEST_MSGPACK_DF_FILENAME) pandas_df.to_msgpack(TEST_MSGPACK_pandas_FILENAME) - assert(test_files_eq(TEST_MSGPACK_DF_FILENAME, - TEST_MSGPACK_pandas_FILENAME)) + assert ( + test_files_eq(TEST_MSGPACK_DF_FILENAME, TEST_MSGPACK_pandas_FILENAME) + ) teardown_test_file(TEST_MSGPACK_pandas_FILENAME) teardown_test_file(TEST_MSGPACK_DF_FILENAME) @@ -610,8 +614,9 @@ def test_to_parquet(): ray_df.to_parquet(TEST_PARQUET_DF_FILENAME) pandas_df.to_parquet(TEST_PARQUET_pandas_FILENAME) - assert(test_files_eq(TEST_PARQUET_DF_FILENAME, - TEST_PARQUET_pandas_FILENAME)) + assert ( + test_files_eq(TEST_PARQUET_DF_FILENAME, TEST_PARQUET_pandas_FILENAME) + ) teardown_test_file(TEST_PARQUET_pandas_FILENAME) teardown_test_file(TEST_PARQUET_DF_FILENAME) @@ -634,8 +639,7 @@ def test_to_pickle(): ray_df.to_pickle(TEST_PICKLE_DF_FILENAME) pandas_df.to_pickle(TEST_PICKLE_pandas_FILENAME) - assert(test_files_eq(TEST_PICKLE_DF_FILENAME, - TEST_PICKLE_pandas_FILENAME)) + assert (test_files_eq(TEST_PICKLE_DF_FILENAME, TEST_PICKLE_pandas_FILENAME)) teardown_test_file(TEST_PICKLE_pandas_FILENAME) teardown_test_file(TEST_PICKLE_DF_FILENAME) @@ -651,8 +655,7 @@ def test_to_sql(): ray_df.to_pickle(TEST_SQL_DF_FILENAME) pandas_df.to_pickle(TEST_SQL_pandas_FILENAME) - assert(test_files_eq(TEST_SQL_DF_FILENAME, - TEST_SQL_pandas_FILENAME)) + assert (test_files_eq(TEST_SQL_DF_FILENAME, TEST_SQL_pandas_FILENAME)) teardown_test_file(TEST_SQL_DF_FILENAME) teardown_test_file(TEST_SQL_pandas_FILENAME) @@ -668,8 +671,7 @@ def test_to_stata(): ray_df.to_stata(TEST_STATA_DF_FILENAME) pandas_df.to_stata(TEST_STATA_pandas_FILENAME) - assert(test_files_eq(TEST_STATA_DF_FILENAME, - TEST_STATA_pandas_FILENAME)) + assert (test_files_eq(TEST_STATA_DF_FILENAME, TEST_STATA_pandas_FILENAME)) teardown_test_file(TEST_STATA_pandas_FILENAME) teardown_test_file(TEST_STATA_DF_FILENAME) diff --git a/python/ray/dataframe/test/test_series.py b/python/ray/dataframe/test/test_series.py index ba8b50061bca..c0db4e92b55c 100644 --- a/python/ray/dataframe/test/test_series.py +++ b/python/ray/dataframe/test/test_series.py @@ -1352,9 +1352,11 @@ def test_plot(): ray_series = create_test_series() with pytest.raises(NotImplementedError): - ray_series.plot(None, None, None, None, None, None, None, None, None, - None, None, None, None, None, None, None, None, None, - None, None, None, None, None) + ray_series.plot( + None, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + None + ) def test_pop(): @@ -1501,8 +1503,10 @@ def test_resample(): ray_series = create_test_series() with pytest.raises(NotImplementedError): - ray_series.resample(None, None, None, None, None, None, None, None, - None, None, None, None) + ray_series.resample( + None, None, None, None, None, None, None, None, None, None, None, + None + ) def test_reset_index(): @@ -1754,8 +1758,9 @@ def test_to_csv(): ray_series = create_test_series() with pytest.raises(NotImplementedError): - ray_series.to_csv(None, None, None, None, None, None, None, None, - None, None) + ray_series.to_csv( + None, None, None, None, None, None, None, None, None, None + ) def test_to_dense(): @@ -1776,8 +1781,10 @@ def test_to_excel(): ray_series = create_test_series() with pytest.raises(NotImplementedError): - ray_series.to_excel(None, None, None, None, None, None, None, None, - None, None, None, None, None, None) + ray_series.to_excel( + None, None, None, None, None, None, None, None, None, None, None, + None, None, None + ) def test_to_frame(): @@ -1805,9 +1812,10 @@ def test_to_latex(): ray_series = create_test_series() with pytest.raises(NotImplementedError): - ray_series.to_latex(None, None, None, None, None, None, None, None, - None, None, None, None, None, None, None, None, - None, None) + ray_series.to_latex( + None, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None + ) def test_to_msgpack(): diff --git a/python/ray/dataframe/utils.py b/python/ray/dataframe/utils.py index 97c166d09413..83c2f8f5ef0b 100644 --- a/python/ray/dataframe/utils.py +++ b/python/ray/dataframe/utils.py @@ -95,9 +95,9 @@ def from_pandas(df, num_partitions=None, chunksize=None): row_partitions = \ _partition_pandas_dataframe(df, num_partitions, chunksize) - return DataFrame(row_partitions=row_partitions, - columns=df.columns, - index=df.index) + return DataFrame( + row_partitions=row_partitions, columns=df.columns, index=df.index + ) def to_pandas(df): @@ -110,8 +110,7 @@ def to_pandas(df): if df._row_partitions is not None: pd_df = pd.concat(ray.get(df._row_partitions)) else: - pd_df = pd.concat(ray.get(df._col_partitions), - axis=1) + pd_df = pd.concat(ray.get(df._col_partitions), axis=1) pd_df.index = df.index pd_df.columns = df.columns return pd_df @@ -145,25 +144,30 @@ def _map_partitions(func, partitions, *argslists): if partitions is None: return None - assert(callable(func)) + assert (callable(func)) if len(argslists) == 0: return [_deploy_func.remote(func, part) for part in partitions] elif len(argslists) == 1: - return [_deploy_func.remote(func, part, argslists[0]) - for part in partitions] + return [ + _deploy_func.remote(func, part, argslists[0]) for part in partitions + ] else: - assert(all([len(args) == len(partitions) for args in argslists])) - return [_deploy_func.remote(func, part, *args) - for part, args in zip(partitions, *argslists)] + assert (all([len(args) == len(partitions) for args in argslists])) + return [ + _deploy_func.remote(func, part, *args) + for part, args in zip(partitions, *argslists) + ] @ray.remote(num_return_vals=2) def _build_columns(df_col, columns): """Build columns and compute lengths for each partition.""" # Columns and width - widths = np.array(ray.get([_deploy_func.remote(_get_widths, d) - for d in df_col])) - dest_indices = [(p_idx, p_sub_idx) for p_idx in range(len(widths)) + widths = np.array( + ray.get([_deploy_func.remote(_get_widths, d) for d in df_col]) + ) + dest_indices = [(p_idx, p_sub_idx) + for p_idx in range(len(widths)) for p_sub_idx in range(widths[p_idx])] col_names = ("partition", "index_within_partition") @@ -176,10 +180,12 @@ def _build_columns(df_col, columns): def _build_index(df_row, index): """Build index and compute lengths for each partition.""" # Rows and length - lengths = np.array(ray.get([_deploy_func.remote(_get_lengths, d) - for d in df_row])) + lengths = np.array( + ray.get([_deploy_func.remote(_get_lengths, d) for d in df_row]) + ) - dest_indices = [(p_idx, p_sub_idx) for p_idx in range(len(lengths)) + dest_indices = [(p_idx, p_sub_idx) + for p_idx in range(len(lengths)) for p_sub_idx in range(lengths[p_idx])] col_names = ("partition", "index_within_partition") index_df = pd.DataFrame(dest_indices, index=index, columns=col_names) @@ -194,9 +200,11 @@ def _create_block_partitions(partitions, axis=0, length=None): else: npartitions = get_npartitions() - x = [create_blocks._submit(args=(partition, npartitions, axis), - num_return_vals=npartitions) - for partition in partitions] + x = [ + create_blocks._submit( + args=(partition, npartitions, axis), num_return_vals=npartitions + ) for partition in partitions + ] # In the case that axis is 1 we have to transpose because we build the # columns into rows. Fortunately numpy is efficient at this. @@ -221,10 +229,11 @@ def create_blocks_helper(df, npartitions, axis): # if not isinstance(df.columns, pd.RangeIndex): # df.columns = pd.RangeIndex(0, len(df.columns)) - blocks = [df.iloc[:, i * block_size: (i + 1) * block_size] - if axis == 0 - else df.iloc[i * block_size: (i + 1) * block_size, :] - for i in range(npartitions)] + blocks = [ + df.iloc[:, i * block_size:(i + 1) * block_size] + if axis == 0 else df.iloc[i * block_size:(i + 1) * block_size, :] + for i in range(npartitions) + ] for block in blocks: block.columns = pd.RangeIndex(0, len(block.columns)) @@ -267,6 +276,7 @@ def _inherit_docstrings(parent): function: decorator which replaces the decorated class' documentation parent's documentation. """ + def decorator(cls): # cls.__doc__ = parent.__doc__ for attr, obj in cls.__dict__.items(): diff --git a/python/ray/experimental/array/distributed/__init__.py b/python/ray/experimental/array/distributed/__init__.py index df0f8f59939b..c316cc789c29 100644 --- a/python/ray/experimental/array/distributed/__init__.py +++ b/python/ray/experimental/array/distributed/__init__.py @@ -4,9 +4,10 @@ from . import random from . import linalg -from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye, - triu, tril, blockwise_dot, dot, transpose, add, subtract, - numpy_to_dist, subblocks) +from .core import ( + BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye, triu, tril, + blockwise_dot, dot, transpose, add, subtract, numpy_to_dist, subblocks +) __all__ = [ "random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", diff --git a/python/ray/experimental/array/distributed/core.py b/python/ray/experimental/array/distributed/core.py index 59a5e024c8f0..3caf5f4a8666 100644 --- a/python/ray/experimental/array/distributed/core.py +++ b/python/ray/experimental/array/distributed/core.py @@ -21,25 +21,32 @@ def __init__(self, shape, objectids=None): else: self.objectids = np.empty(self.num_blocks, dtype=object) if self.num_blocks != list(self.objectids.shape): - raise Exception("The fields `num_blocks` and `objectids` are " - "inconsistent, `num_blocks` is {} and `objectids` " - "has shape {}".format(self.num_blocks, - list(self.objectids.shape))) + raise Exception( + "The fields `num_blocks` and `objectids` are " + "inconsistent, `num_blocks` is {} and `objectids` " + "has shape {}".format( + self.num_blocks, list(self.objectids.shape) + ) + ) @staticmethod def compute_block_lower(index, shape): if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the " - "same length, but `index` is {} and `shape` is " - "{}.".format(index, shape)) + raise Exception( + "The fields `index` and `shape` must have the " + "same length, but `index` is {} and `shape` is " + "{}.".format(index, shape) + ) return [elem * BLOCK_SIZE for elem in index] @staticmethod def compute_block_upper(index, shape): if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the " - "same length, but `index` is {} and `shape` is " - "{}.".format(index, shape)) + raise Exception( + "The fields `index` and `shape` must have the " + "same length, but `index` is {} and `shape` is " + "{}.".format(index, shape) + ) upper = [] for i in range(len(shape)): upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i])) @@ -64,7 +71,8 @@ def assemble(self): lower = DistArray.compute_block_lower(index, self.shape) upper = DistArray.compute_block_upper(index, self.shape) result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get( - self.objectids[index]) + self.objectids[index] + ) return result def __getitem__(self, sliced): @@ -87,7 +95,8 @@ def numpy_to_dist(a): lower = DistArray.compute_block_lower(index, a.shape) upper = DistArray.compute_block_upper(index, a.shape) result.objectids[index] = ray.put( - a[[slice(l, u) for (l, u) in zip(lower, upper)]]) + a[[slice(l, u) for (l, u) in zip(lower, upper)]] + ) return result @@ -96,7 +105,8 @@ def zeros(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.zeros.remote( - DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name + ) return result @@ -105,7 +115,8 @@ def ones(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.ones.remote( - DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name + ) return result @@ -128,18 +139,22 @@ def eye(dim1, dim2=-1, dtype_name="float"): block_shape = DistArray.compute_block_shape([i, j], shape) if i == j: result.objectids[i, j] = ra.eye.remote( - block_shape[0], block_shape[1], dtype_name=dtype_name) + block_shape[0], block_shape[1], dtype_name=dtype_name + ) else: result.objectids[i, j] = ra.zeros.remote( - block_shape, dtype_name=dtype_name) + block_shape, dtype_name=dtype_name + ) return result @ray.remote def triu(a): if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " - "{}.".format(a.ndim)) + raise Exception( + "Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim) + ) result = DistArray(a.shape) for (i, j) in np.ndindex(*result.num_blocks): if i < j: @@ -154,8 +169,10 @@ def triu(a): @ray.remote def tril(a): if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " - "{}.".format(a.ndim)) + raise Exception( + "Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim) + ) result = DistArray(a.shape) for (i, j) in np.ndindex(*result.num_blocks): if i > j: @@ -171,8 +188,10 @@ def tril(a): def blockwise_dot(*matrices): n = len(matrices) if n % 2 != 0: - raise Exception("blockwise_dot expects an even number of arguments, " - "but len(matrices) is {}.".format(n)) + raise Exception( + "blockwise_dot expects an even number of arguments, " + "but len(matrices) is {}.".format(n) + ) shape = (matrices[0].shape[0], matrices[n // 2].shape[1]) result = np.zeros(shape) for i in range(n // 2): @@ -183,15 +202,20 @@ def blockwise_dot(*matrices): @ray.remote def dot(a, b): if a.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but " - "a.ndim = {}.".format(a.ndim)) + raise Exception( + "dot expects its arguments to be 2-dimensional, but " + "a.ndim = {}.".format(a.ndim) + ) if b.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but " - "b.ndim = {}.".format(b.ndim)) + raise Exception( + "dot expects its arguments to be 2-dimensional, but " + "b.ndim = {}.".format(b.ndim) + ) if a.shape[1] != b.shape[0]: - raise Exception("dot expects a.shape[1] to equal b.shape[0], but " - "a.shape = {} and b.shape = {}.".format( - a.shape, b.shape)) + raise Exception( + "dot expects a.shape[1] to equal b.shape[0], but " + "a.shape = {} and b.shape = {}.".format(a.shape, b.shape) + ) shape = [a.shape[0], b.shape[1]] result = DistArray(shape) for (i, j) in np.ndindex(*result.num_blocks): @@ -214,43 +238,52 @@ def subblocks(a, *ranges): """ ranges = list(ranges) if len(ranges) != a.ndim: - raise Exception("sub_blocks expects to receive a number of ranges " - "equal to a.ndim, but it received {} ranges and " - "a.ndim = {}.".format(len(ranges), a.ndim)) + raise Exception( + "sub_blocks expects to receive a number of ranges " + "equal to a.ndim, but it received {} ranges and " + "a.ndim = {}.".format(len(ranges), a.ndim) + ) for i in range(len(ranges)): # We allow the user to pass in an empty list to indicate the full # range. if ranges[i] == []: ranges[i] = range(a.num_blocks[i]) if not np.alltrue(ranges[i] == np.sort(ranges[i])): - raise Exception("Ranges passed to sub_blocks must be sorted, but " - "the {}th range is {}.".format(i, ranges[i])) + raise Exception( + "Ranges passed to sub_blocks must be sorted, but " + "the {}th range is {}.".format(i, ranges[i]) + ) if ranges[i][0] < 0: - raise Exception("Values in the ranges passed to sub_blocks must " - "be at least 0, but the {}th range is {}.".format( - i, ranges[i])) + raise Exception( + "Values in the ranges passed to sub_blocks must " + "be at least 0, but the {}th range is {}.".format(i, ranges[i]) + ) if ranges[i][-1] >= a.num_blocks[i]: - raise Exception("Values in the ranges passed to sub_blocks must " - "be less than the relevant number of blocks, but " - "the {}th range is {}, and a.num_blocks = {}." - .format(i, ranges[i], a.num_blocks)) + raise Exception( + "Values in the ranges passed to sub_blocks must " + "be less than the relevant number of blocks, but " + "the {}th range is {}, and a.num_blocks = {}." + .format(i, ranges[i], a.num_blocks) + ) last_index = [r[-1] for r in ranges] last_block_shape = DistArray.compute_block_shape(last_index, a.shape) shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] for i in range(a.ndim)] result = DistArray(shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = a.objectids[tuple( - [ranges[i][index[i]] for i in range(a.ndim)])] + result.objectids[index] = a.objectids[tuple([ + ranges[i][index[i]] for i in range(a.ndim) + ])] return result @ray.remote def transpose(a): if a.ndim != 2: - raise Exception("transpose expects its argument to be 2-dimensional, " - "but a.ndim = {}, a.shape = {}.".format( - a.ndim, a.shape)) + raise Exception( + "transpose expects its argument to be 2-dimensional, " + "but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape) + ) result = DistArray([a.shape[1], a.shape[0]]) for i in range(result.num_blocks[0]): for j in range(result.num_blocks[1]): @@ -262,13 +295,17 @@ def transpose(a): @ray.remote def add(x1, x2): if x1.shape != x2.shape: - raise Exception("add expects arguments `x1` and `x2` to have the same " - "shape, but x1.shape = {}, and x2.shape = {}.".format( - x1.shape, x2.shape)) + raise Exception( + "add expects arguments `x1` and `x2` to have the same " + "shape, but x1.shape = {}, and x2.shape = {}.".format( + x1.shape, x2.shape + ) + ) result = DistArray(x1.shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.add.remote(x1.objectids[index], - x2.objectids[index]) + result.objectids[index] = ra.add.remote( + x1.objectids[index], x2.objectids[index] + ) return result @@ -276,11 +313,14 @@ def add(x1, x2): @ray.remote def subtract(x1, x2): if x1.shape != x2.shape: - raise Exception("subtract expects arguments `x1` and `x2` to have the " - "same shape, but x1.shape = {}, and x2.shape = {}." - .format(x1.shape, x2.shape)) + raise Exception( + "subtract expects arguments `x1` and `x2` to have the " + "same shape, but x1.shape = {}, and x2.shape = {}." + .format(x1.shape, x2.shape) + ) result = DistArray(x1.shape) for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.subtract.remote(x1.objectids[index], - x2.objectids[index]) + result.objectids[index] = ra.subtract.remote( + x1.objectids[index], x2.objectids[index] + ) return result diff --git a/python/ray/experimental/array/distributed/linalg.py b/python/ray/experimental/array/distributed/linalg.py index 7f473715d022..8ab5a49683bd 100644 --- a/python/ray/experimental/array/distributed/linalg.py +++ b/python/ray/experimental/array/distributed/linalg.py @@ -28,11 +28,15 @@ def tsqr(a): - np.allclose(r, np.triu(r)) == True. """ if len(a.shape) != 2: - raise Exception("tsqr requires len(a.shape) == 2, but a.shape is " - "{}".format(a.shape)) + raise Exception( + "tsqr requires len(a.shape) == 2, but a.shape is " + "{}".format(a.shape) + ) if a.num_blocks[1] != 1: - raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks " - "is {}".format(a.num_blocks)) + raise Exception( + "tsqr requires a.num_blocks[1] == 1, but a.num_blocks " + "is {}".format(a.num_blocks) + ) num_blocks = a.num_blocks[0] K = int(np.ceil(np.log2(num_blocks))) + 1 @@ -76,10 +80,10 @@ def tsqr(a): lower = [a.shape[1], 0] upper = [2 * a.shape[1], core.BLOCK_SIZE] ith_index //= 2 - q_block_current = ra.dot.remote(q_block_current, - ra.subarray.remote( - q_tree[ith_index, j], lower, - upper)) + q_block_current = ra.dot.remote( + q_block_current, + ra.subarray.remote(q_tree[ith_index, j], lower, upper) + ) q_result.objectids[i] = q_block_current r = current_rs[0] return q_result, ray.get(r) @@ -115,8 +119,9 @@ def modified_lu(q): # Scale ith column of L by diagonal element. q_work[(i + 1):m, i] /= q_work[i, i] # Perform Schur complement update. - q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], - q_work[i, (i + 1):b]) + q_work[(i + 1):m, (i + 1):b] -= np.outer( + q_work[(i + 1):m, i], q_work[i, (i + 1):b] + ) L = np.tril(q_work) for i in range(b): @@ -147,8 +152,9 @@ def tsqr_hr(a): q, r_temp = tsqr.remote(a) y, u, s = modified_lu.remote(q) y_blocked = ray.get(y) - t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], - a.shape[1]) + t, y_top = tsqr_hr_helper1.remote( + u, s, y_blocked.objectids[0, 0], a.shape[1] + ) r = tsqr_hr_helper2.remote(s, r_temp) return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r) @@ -188,7 +194,9 @@ def qr(a): # sense when a.num_blocks[1] > a.num_blocks[0]. for i in range(min(a.num_blocks[0], a.num_blocks[1])): sub_dist_array = core.subblocks.remote( - a_work, list(range(i, a_work.num_blocks[0])), [i]) + a_work, list(range(i, a_work.num_blocks[0])), + [i] + ) y, t, _, R = tsqr_hr.remote(sub_dist_array) y_val = ray.get(y) @@ -198,7 +206,8 @@ def qr(a): # in this case, R needs to be square R_shape = ray.get(ra.shape.remote(R)) eye_temp = ra.eye.remote( - R_shape[1], R_shape[0], dtype_name=result_dtype) + R_shape[1], R_shape[0], dtype_name=result_dtype + ) r_res.objectids[i, i] = ra.dot.remote(eye_temp, R) else: r_res.objectids[i, i] = R @@ -222,10 +231,13 @@ def qr(a): y_col_block = core.subblocks.remote(y_res, [], [i]) q = core.subtract.remote( q, - core.dot.remote(y_col_block, - core.dot.remote( - Ts[i], - core.dot.remote( - core.transpose.remote(y_col_block), q)))) + core.dot.remote( + y_col_block, + core.dot.remote( + Ts[i], + core.dot.remote(core.transpose.remote(y_col_block), q) + ) + ) + ) return ray.get(q), r_res diff --git a/python/ray/experimental/array/distributed/random.py b/python/ray/experimental/array/distributed/random.py index a946df90a8d8..bed3df584662 100644 --- a/python/ray/experimental/array/distributed/random.py +++ b/python/ray/experimental/array/distributed/random.py @@ -15,6 +15,7 @@ def normal(shape): objectids = np.empty(num_blocks, dtype=object) for index in np.ndindex(*num_blocks): objectids[index] = ra.random.normal.remote( - DistArray.compute_block_shape(index, shape)) + DistArray.compute_block_shape(index, shape) + ) result = DistArray(shape, objectids) return result diff --git a/python/ray/experimental/array/remote/__init__.py b/python/ray/experimental/array/remote/__init__.py index 1f2e4429fe74..31bde6f46f57 100644 --- a/python/ray/experimental/array/remote/__init__.py +++ b/python/ray/experimental/array/remote/__init__.py @@ -4,9 +4,10 @@ from . import random from . import linalg -from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray, - copy, tril, triu, diag, transpose, add, subtract, sum, - shape, sum_list) +from .core import ( + zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray, copy, tril, + triu, diag, transpose, add, subtract, sum, shape, sum_list +) __all__ = [ "random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", "vstack", diff --git a/python/ray/experimental/features.py b/python/ray/experimental/features.py index 7db9d611b9b0..1ad455344a20 100644 --- a/python/ray/experimental/features.py +++ b/python/ray/experimental/features.py @@ -19,8 +19,10 @@ def flush_redis_unsafe(): flushed. """ if not hasattr(ray.worker.global_worker, "redis_client"): - raise Exception("ray.experimental.flush_redis_unsafe cannot be called " - "before ray.init() has been called.") + raise Exception( + "ray.experimental.flush_redis_unsafe cannot be called " + "before ray.init() has been called." + ) redis_client = ray.worker.global_worker.redis_client @@ -53,8 +55,10 @@ def flush_task_and_object_metadata_unsafe(): likely not work. """ if not hasattr(ray.worker.global_worker, "redis_client"): - raise Exception("ray.experimental.flush_redis_unsafe cannot be called " - "before ray.init() has been called.") + raise Exception( + "ray.experimental.flush_redis_unsafe cannot be called " + "before ray.init() has been called." + ) def flush_shard(redis_client): # Flush the task table. Note that this also flushes the driver tasks @@ -68,15 +72,19 @@ def flush_shard(redis_client): num_object_keys_deleted = 0 for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): num_object_keys_deleted += redis_client.delete(key) - print("Deleted {} object info keys from Redis.".format( - num_object_keys_deleted)) + print( + "Deleted {} object info keys from Redis.". + format(num_object_keys_deleted) + ) # Flush the object locations. num_object_location_keys_deleted = 0 for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"): num_object_location_keys_deleted += redis_client.delete(key) - print("Deleted {} object location keys from Redis.".format( - num_object_location_keys_deleted)) + print( + "Deleted {} object location keys from Redis.". + format(num_object_location_keys_deleted) + ) # Loop over the shards and flush all of them. for redis_client in ray.worker.global_state.redis_clients: diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 964b0f71e290..fd98afc5cb50 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -11,8 +11,9 @@ import time import ray -from ray.utils import (decode, binary_to_object_id, binary_to_hex, - hex_to_binary) +from ray.utils import ( + decode, binary_to_object_id, binary_to_hex, hex_to_binary +) # Import flatbuffer bindings. from ray.core.generated.TaskReply import TaskReply @@ -76,17 +77,20 @@ def _check_connected(self): yet. """ if self.redis_client is None: - raise Exception("The ray.global_state API cannot be used before " - "ray.init has been called.") + raise Exception( + "The ray.global_state API cannot be used before " + "ray.init has been called." + ) if self.redis_clients is None: - raise Exception("The ray.global_state API cannot be used before " - "ray.init has been called.") - - def _initialize_global_state(self, - redis_ip_address, - redis_port, - timeout=20): + raise Exception( + "The ray.global_state API cannot be used before " + "ray.init has been called." + ) + + def _initialize_global_state( + self, redis_ip_address, redis_port, timeout=20 + ): """Initialize the GlobalState object by connecting to Redis. It's possible that certain keys in Redis may not have been fully @@ -101,7 +105,8 @@ def _initialize_global_state(self, wait for the keys in Redis to be populated. """ self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port + ) start_time = time.time() @@ -117,12 +122,15 @@ def _initialize_global_state(self, continue num_redis_shards = int(num_redis_shards) if (num_redis_shards < 1): - raise Exception("Expected at least one Redis shard, found " - "{}.".format(num_redis_shards)) + raise Exception( + "Expected at least one Redis shard, found " + "{}.".format(num_redis_shards) + ) # Attempt to get all of the Redis shards. ip_address_ports = self.redis_client.lrange( - "RedisShards", start=0, end=-1) + "RedisShards", start=0, end=-1 + ) if len(ip_address_ports) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) @@ -133,17 +141,21 @@ def _initialize_global_state(self, # Check to see if we timed out. if time.time() - start_time >= timeout: - raise Exception("Timed out while attempting to initialize the " - "global state. num_redis_shards = {}, " - "ip_address_ports = {}".format( - num_redis_shards, ip_address_ports)) + raise Exception( + "Timed out while attempting to initialize the " + "global state. num_redis_shards = {}, " + "ip_address_ports = {}".format( + num_redis_shards, ip_address_ports + ) + ) # Get the rest of the information. self.redis_clients = [] for ip_address_port in ip_address_ports: shard_address, shard_port = ip_address_port.split(b":") self.redis_clients.append( - redis.StrictRedis(host=shard_address, port=shard_port)) + redis.StrictRedis(host=shard_address, port=shard_port) + ) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. @@ -155,8 +167,8 @@ def _execute_command(self, key, *args): Returns: The value returned by the Redis command. """ - client = self.redis_clients[key.redis_shard_hash() % len( - self.redis_clients)] + client = self.redis_clients[key.redis_shard_hash() % + len(self.redis_clients)] return client.execute_command(*args) def _keys(self, pattern): @@ -188,9 +200,9 @@ def _object_table(self, object_id): object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. - object_locations = self._execute_command(object_id, - "RAY.OBJECT_TABLE_LOOKUP", - object_id.id()) + object_locations = self._execute_command( + object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id() + ) if object_locations is not None: manager_ids = [ binary_to_hex(manager_id) for manager_id in object_locations @@ -199,9 +211,11 @@ def _object_table(self, object_id): manager_ids = None result_table_response = self._execute_command( - object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) + object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id() + ) result_table_message = ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0) + result_table_response, 0 + ) result = { "ManagerIDs": manager_ids, @@ -236,11 +250,13 @@ def object_table(self, object_id=None): [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [ key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys - ]) + ] + ) results = {} for object_id_binary in object_ids_binary: results[binary_to_object_id(object_id_binary)] = ( - self._object_table(binary_to_object_id(object_id_binary))) + self._object_table(binary_to_object_id(object_id_binary)) + ) return results def _task_table(self, task_id): @@ -255,14 +271,17 @@ def _task_table(self, task_id): TASK_STATUS_MAPPING should be used to parse the "State" field into a human-readable string. """ - task_table_response = self._execute_command(task_id, - "RAY.TASK_TABLE_GET", - task_id.id()) + task_table_response = self._execute_command( + task_id, "RAY.TASK_TABLE_GET", task_id.id() + ) if task_table_response is None: - raise Exception("There is no entry for task ID {} in the task " - "table.".format(binary_to_hex(task_id.id()))) + raise Exception( + "There is no entry for task ID {} in the task " + "table.".format(binary_to_hex(task_id.id())) + ) task_table_message = TaskReply.GetRootAsTaskReply( - task_table_response, 0) + task_table_response, 0 + ) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) @@ -295,12 +314,14 @@ def _task_table(self, task_id): execution_dependencies_message = ( TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( - task_table_message.ExecutionDependencies(), 0)) + task_table_message.ExecutionDependencies(), 0 + ) + ) execution_dependencies = [ ray.local_scheduler.ObjectID( - execution_dependencies_message.ExecutionDependencies(i)) - for i in range( - execution_dependencies_message.ExecutionDependenciesLength()) + execution_dependencies_message.ExecutionDependencies(i) + ) for i in + range(execution_dependencies_message.ExecutionDependenciesLength()) ] # TODO(rkn): The return fields ExecutionDependenciesString and @@ -343,7 +364,8 @@ def task_table(self, task_id=None): for key in task_table_keys: task_id_binary = key[len(TASK_PREFIX):] results[binary_to_hex(task_id_binary)] = self._task_table( - ray.local_scheduler.ObjectID(task_id_binary)) + ray.local_scheduler.ObjectID(task_id_binary) + ) return results def function_table(self, function_id=None): @@ -397,11 +419,13 @@ def client_table(self): client_info_parsed["AuxAddress"] = decode(value) elif field == b"local_scheduler_socket_name": client_info_parsed["LocalSchedulerSocketName"] = ( - decode(value)) + decode(value) + ) elif client_info[b"client_type"] == b"local_scheduler": # The remaining fields are resource types. client_info_parsed[field.decode("ascii")] = float( - decode(value)) + decode(value) + ) else: client_info_parsed[field.decode("ascii")] = decode(value) @@ -494,17 +518,21 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): if start is None and end is None: if fwd: event_list = self.redis_client.zrange( - event_log_set, **params) + event_log_set, **params + ) else: event_list = self.redis_client.zrevrange( - event_log_set, **params) + event_log_set, **params + ) else: if fwd: event_list = self.redis_client.zrangebyscore( - event_log_set, **params) + event_log_set, **params + ) else: event_list = self.redis_client.zrevrangebyscore( - event_log_set, **params) + event_log_set, **params + ) for (event, score) in event_list: event_dict = json.loads(event.decode()) @@ -524,11 +552,15 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): task_info[task_id]["get_task_start"] = event[0] if event[1] == "ray:get_task" and event[2] == 2: task_info[task_id]["get_task_end"] = event[0] - if (event[1] == "ray:import_remote_function" - and event[2] == 1): + if ( + event[1] == "ray:import_remote_function" + and event[2] == 1 + ): task_info[task_id]["import_remote_start"] = event[0] - if (event[1] == "ray:import_remote_function" - and event[2] == 2): + if ( + event[1] == "ray:import_remote_function" + and event[2] == 2 + ): task_info[task_id]["import_remote_end"] = event[0] if event[1] == "ray:acquire_lock" and event[2] == 1: task_info[task_id]["acquire_lock_start"] = event[0] @@ -550,7 +582,8 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): task_info[task_id]["worker_id"] = event[3]["worker_id"] if "function_name" in event[3]: task_info[task_id]["function_name"] = ( - event[3]["function_name"]) + event[3]["function_name"] + ) if heap_size > num_tasks: min_task, task_id_hex = heapq.heappop(heap) @@ -562,12 +595,9 @@ def task_profiles(self, num_tasks, start=None, end=None, fwd=True): return task_info - def dump_catapult_trace(self, - path, - task_info, - breakdowns=True, - task_dep=True, - obj_dep=True): + def dump_catapult_trace( + self, path, task_info, breakdowns=True, task_dep=True, obj_dep=True + ): """Dump task profiling information to a file. This information can be viewed as a timeline of profiling information @@ -595,8 +625,7 @@ def dump_catapult_trace(self, # slider should be correct to begin with, though. task_table[task_id] = self.task_table(task_id) task_table[task_id]["TaskSpec"]["Args"] = [ - repr(arg) - for arg in task_table[task_id]["TaskSpec"]["Args"] + repr(arg) for arg in task_table[task_id]["TaskSpec"]["Args"] ] except Exception as e: print("Could not find task {}".format(task_id)) @@ -637,16 +666,20 @@ def micros_rel(ts): ] total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"] total_info["get_arguments"] = ( - info["get_arguments_end"] - info["get_arguments_start"]) + info["get_arguments_end"] - info["get_arguments_start"] + ) total_info["execute"] = ( - info["execute_end"] - info["execute_start"]) + info["execute_end"] - info["execute_start"] + ) total_info["store_outputs"] = ( - info["store_outputs_end"] - info["store_outputs_start"]) + info["store_outputs_end"] - info["store_outputs_start"] + ) total_info["function_name"] = info["function_name"] total_info["worker_id"] = info["worker_id"] parent_info = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) + task_table[task_id]["TaskSpec"]["ParentTaskID"] + ) worker = workers[info["worker_id"]] # The catapult trace format documentation can be found here: # https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview # noqa: E501 @@ -670,8 +703,10 @@ def micros_rel(ts): "args": total_info, "dur": - micros(info["get_arguments_end"] - - info["get_arguments_start"]), + micros( + info["get_arguments_end"] - + info["get_arguments_start"] + ), "cname": "rail_idle" } @@ -696,8 +731,10 @@ def micros_rel(ts): "args": total_info, "dur": - micros(info["store_outputs_end"] - - info["store_outputs_start"]), + micros( + info["store_outputs_end"] - + info["store_outputs_start"] + ), "cname": "thread_state_runnable" } @@ -705,26 +742,17 @@ def micros_rel(ts): if "execute_end" in info: execute_trace = { - "cat": - "execute", - "pid": - "Node " + worker["node_ip_address"], - "tid": - info["worker_id"], - "id": - task_id, - "ts": - micros_rel(info["execute_start"]), - "ph": - "X", - "name": - info["function_name"] + ":execute", - "args": - total_info, + "cat": "execute", + "pid": "Node " + worker["node_ip_address"], + "tid": info["worker_id"], + "id": task_id, + "ts": micros_rel(info["execute_start"]), + "ph": "X", + "name": info["function_name"] + ":execute", + "args": total_info, "dur": micros(info["execute_end"] - info["execute_start"]), - "cname": - "rail_animation" + "cname": "rail_animation" } full_trace.append(execute_trace) @@ -733,7 +761,8 @@ def micros_rel(ts): parent_worker = workers[parent_info["worker_id"]] parent_times = self._get_times(parent_info) parent_profile = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) + task_table[task_id]["TaskSpec"]["ParentTaskID"] + ) parent = { "cat": "submit_task", @@ -742,39 +771,35 @@ def micros_rel(ts): "tid": parent_info["worker_id"], "ts": - micros_rel(parent_profile - and parent_profile["get_arguments_start"] - or start_time), + micros_rel( + parent_profile + and parent_profile["get_arguments_start"] + or start_time + ), "ph": "s", "name": "SubmitTask", "args": {}, - "id": (parent_info["worker_id"] + - str(micros(min(parent_times)))) + "id": ( + parent_info["worker_id"] + + str(micros(min(parent_times))) + ) } full_trace.append(parent) task_trace = { - "cat": - "submit_task", - "pid": - "Node " + worker["node_ip_address"], - "tid": - info["worker_id"], - "ts": - micros_rel(info["get_arguments_start"]), - "ph": - "f", - "name": - "SubmitTask", + "cat": "submit_task", + "pid": "Node " + worker["node_ip_address"], + "tid": info["worker_id"], + "ts": micros_rel(info["get_arguments_start"]), + "ph": "f", + "name": "SubmitTask", "args": {}, "id": (info["worker_id"] + str(micros(min(parent_times)))), - "bp": - "e", - "cname": - "olive" + "bp": "e", + "cname": "olive" } full_trace.append(task_trace) @@ -796,8 +821,9 @@ def micros_rel(ts): "args": total_info, "dur": - micros(info["store_outputs_end"] - - info["get_arguments_start"]), + micros( + info["store_outputs_end"] - info["get_arguments_start"] + ), "cname": "thread_state_runnable" } @@ -808,7 +834,8 @@ def micros_rel(ts): parent_worker = workers[parent_info["worker_id"]] parent_times = self._get_times(parent_info) parent_profile = task_info.get( - task_table[task_id]["TaskSpec"]["ParentTaskID"]) + task_table[task_id]["TaskSpec"]["ParentTaskID"] + ) parent = { "cat": "submit_task", @@ -817,37 +844,34 @@ def micros_rel(ts): "tid": parent_info["worker_id"], "ts": - micros_rel(parent_profile - and parent_profile["get_arguments_start"] - or start_time), + micros_rel( + parent_profile + and parent_profile["get_arguments_start"] + or start_time + ), "ph": "s", "name": "SubmitTask", "args": {}, - "id": (parent_info["worker_id"] + - str(micros(min(parent_times)))) + "id": ( + parent_info["worker_id"] + + str(micros(min(parent_times))) + ) } full_trace.append(parent) task_trace = { - "cat": - "submit_task", - "pid": - "Node " + worker["node_ip_address"], - "tid": - info["worker_id"], - "ts": - micros_rel(info["get_arguments_start"]), - "ph": - "f", - "name": - "SubmitTask", + "cat": "submit_task", + "pid": "Node " + worker["node_ip_address"], + "tid": info["worker_id"], + "ts": micros_rel(info["get_arguments_start"]), + "ph": "f", + "name": "SubmitTask", "args": {}, "id": (info["worker_id"] + str(micros(min(parent_times)))), - "bp": - "e" + "bp": "e" } full_trace.append(task_trace) @@ -865,8 +889,9 @@ def micros_rel(ts): seen_obj[arg] += 1 owner_task = self._object_table(arg)["TaskID"] if owner_task in task_info: - owner_worker = (workers[task_info[owner_task][ - "worker_id"]]) + owner_worker = ( + workers[task_info[owner_task]["worker_id"]] + ) # Adding/subtracting 2 to the time associated # with the beginning/ending of the flow event # is necessary to make the flow events show up @@ -882,13 +907,14 @@ def micros_rel(ts): owner = { "cat": "obj_dependency", - "pid": ("Node " + - owner_worker["node_ip_address"]), + "pid": + ("Node " + owner_worker["node_ip_address"]), "tid": task_info[owner_task]["worker_id"], "ts": - micros_rel(task_info[owner_task] - ["store_outputs_end"]) - 2, + micros_rel( + task_info[owner_task]["store_outputs_end"] + ) - 2, "ph": "s", "name": @@ -971,8 +997,10 @@ def local_schedulers(self): local_schedulers = [] for ip_address, client_list in clients.items(): for client in client_list: - if (client["ClientType"] == "local_scheduler" - and not client["Deleted"]): + if ( + client["ClientType"] == "local_scheduler" + and not client["Deleted"] + ): local_schedulers.append(client) return local_schedulers @@ -988,19 +1016,21 @@ def workers(self): workers_data[worker_id] = { "local_scheduler_socket": (worker_info[b"local_scheduler_socket"].decode("ascii")), - "node_ip_address": (worker_info[b"node_ip_address"] - .decode("ascii")), - "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] - .decode("ascii")), - "plasma_store_socket": (worker_info[b"plasma_store_socket"] - .decode("ascii")) + "node_ip_address": + (worker_info[b"node_ip_address"].decode("ascii")), + "plasma_manager_socket": + (worker_info[b"plasma_manager_socket"].decode("ascii")), + "plasma_store_socket": + (worker_info[b"plasma_store_socket"].decode("ascii")) } if b"stderr_file" in worker_info: workers_data[worker_id]["stderr_file"] = ( - worker_info[b"stderr_file"].decode("ascii")) + worker_info[b"stderr_file"].decode("ascii") + ) if b"stdout_file" in worker_info: workers_data[worker_id]["stdout_file"] = ( - worker_info[b"stdout_file"].decode("ascii")) + worker_info[b"stdout_file"].decode("ascii") + ) return workers_data def actors(self): @@ -1027,15 +1057,18 @@ def _job_length(self): num_tasks = 0 for event_log_set in event_log_sets: fwd_range = self.redis_client.zrange( - event_log_set, start=0, end=0, withscores=True) + event_log_set, start=0, end=0, withscores=True + ) overall_smallest = min(overall_smallest, fwd_range[0][1]) rev_range = self.redis_client.zrevrange( - event_log_set, start=0, end=0, withscores=True) + event_log_set, start=0, end=0, withscores=True + ) overall_largest = max(overall_largest, rev_range[0][1]) num_tasks += self.redis_client.zcount( - event_log_set, min=0, max=time.time()) + event_log_set, min=0, max=time.time() + ) if num_tasks is 0: return 0, 0, 0 return overall_smallest, overall_largest, num_tasks @@ -1056,8 +1089,8 @@ def cluster_resources(self): for local_scheduler in local_schedulers: for key, value in local_scheduler.items(): if key not in [ - "ClientType", "Deleted", "DBClientID", "AuxAddress", - "LocalSchedulerSocketName" + "ClientType", "Deleted", "DBClientID", "AuxAddress", + "LocalSchedulerSocketName" ]: resources[key] += value diff --git a/python/ray/experimental/tfutils.py b/python/ray/experimental/tfutils.py index 10d5fb4bc308..9ce2388147cd 100644 --- a/python/ray/experimental/tfutils.py +++ b/python/ray/experimental/tfutils.py @@ -92,7 +92,8 @@ def __init__(self, loss, sess=None, input_variables=None): self.placeholders[k] = tf.placeholder( var.value().dtype, var.get_shape().as_list(), - name="Placeholder_" + k) + name="Placeholder_" + k + ) self.assignment_nodes[k] = var.assign(self.placeholders[k]) def set_session(self, sess): @@ -115,10 +116,12 @@ def get_flat_size(self): def _check_sess(self): """Checks if the session is set, and if not throw an error message.""" - assert self.sess is not None, ("The session is not set. Set the " - "session either by passing it into the " - "TensorFlowVariables constructor or by " - "calling set_session(sess).") + assert self.sess is not None, ( + "The session is not set. Set the " + "session either by passing it into the " + "TensorFlowVariables constructor or by " + "calling set_session(sess)." + ) def get_flat(self): """Gets the weights and returns them as a flat array. @@ -128,8 +131,7 @@ def get_flat(self): """ self._check_sess() return np.concatenate([ - v.eval(session=self.sess).flatten() - for v in self.variables.values() + v.eval(session=self.sess).flatten() for v in self.variables.values() ]) def set_flat(self, new_weights): @@ -145,12 +147,11 @@ def set_flat(self, new_weights): self._check_sess() shapes = [v.get_shape().as_list() for v in self.variables.values()] arrays = unflatten(new_weights, shapes) - placeholders = [ - self.placeholders[k] for k, v in self.variables.items() - ] + placeholders = [self.placeholders[k] for k, v in self.variables.items()] self.sess.run( list(self.assignment_nodes.values()), - feed_dict=dict(zip(placeholders, arrays))) + feed_dict=dict(zip(placeholders, arrays)) + ) def get_weights(self): """Returns a dictionary containing the weights of the network. @@ -159,10 +160,7 @@ def get_weights(self): Dictionary mapping variable names to their weights. """ self._check_sess() - return { - k: v.eval(session=self.sess) - for k, v in self.variables.items() - } + return {k: v.eval(session=self.sess) for k, v in self.variables.items()} def set_weights(self, new_weights): """Sets the weights to new_weights. @@ -177,18 +175,22 @@ def set_weights(self, new_weights): """ self._check_sess() assign_list = [ - self.assignment_nodes[name] for name in new_weights.keys() + self.assignment_nodes[name] + for name in new_weights.keys() if name in self.assignment_nodes ] - assert assign_list, ("No variables in the input matched those in the " - "network. Possible cause: Two networks were " - "defined in the same TensorFlow graph. To fix " - "this, place each network definition in its own " - "tf.Graph.") + assert assign_list, ( + "No variables in the input matched those in the " + "network. Possible cause: Two networks were " + "defined in the same TensorFlow graph. To fix " + "this, place each network definition in its own " + "tf.Graph." + ) self.sess.run( assign_list, feed_dict={ self.placeholders[name]: value for (name, value) in new_weights.items() if name in self.placeholders - }) + } + ) diff --git a/python/ray/experimental/ui.py b/python/ray/experimental/ui.py index adae70692265..d42c8a2d8a35 100644 --- a/python/ray/experimental/ui.py +++ b/python/ray/experimental/ui.py @@ -67,7 +67,8 @@ def get_sliders(update): breakdown_opt = widgets.Dropdown( options=[total_time_value, total_tasks_value], value=total_tasks_value, - description="Selection Options:") + description="Selection Options:" + ) # Display box for layout. total_time_box = widgets.VBox([start_box, end_box]) @@ -101,9 +102,11 @@ def update_wrapper(event): if event == INIT_EVENT: if breakdown_opt.value == total_tasks_value: num_tasks_box.value = -min(10000, num_tasks) - range_slider.value = (int( - 100 - (100. * -num_tasks_box.value) / num_tasks), - 100) + range_slider.value = ( + int( + 100 - (100. * -num_tasks_box.value) / num_tasks + ), 100 + ) else: low, high = map(lambda x: x / 100., range_slider.value) start_box.value = round(diff * low, 2) @@ -116,8 +119,9 @@ def update_wrapper(event): elif start_box.value < 0: start_box.value = 0 low, high = range_slider.value - range_slider.value = (int((start_box.value * 100.) / diff), - high) + range_slider.value = ( + int((start_box.value * 100.) / diff), high + ) # Event was triggered by a change in the end_box value. elif event["owner"] == end_box: @@ -126,8 +130,9 @@ def update_wrapper(event): elif end_box.value > diff: end_box.value = diff low, high = range_slider.value - range_slider.value = (low, - int((end_box.value * 100.) / diff)) + range_slider.value = ( + low, int((end_box.value * 100.) / diff) + ) # Event was triggered by a change in the breakdown options # toggle. @@ -141,9 +146,10 @@ def update_wrapper(event): # Make CSS display go back to the default settings. num_tasks_box.layout.display = None num_tasks_box.value = min(10000, num_tasks) - range_slider.value = (int( - 100 - (100. * num_tasks_box.value) / num_tasks), - 100) + range_slider.value = ( + int(100 - (100. * num_tasks_box.value) / num_tasks), + 100 + ) else: start_box.disabled = False end_box.disabled = False @@ -154,7 +160,8 @@ def update_wrapper(event): num_tasks_box.layout.display = 'none' range_slider.value = ( int((start_box.value * 100.) / diff), - int((end_box.value * 100.) / diff)) + int((end_box.value * 100.) / diff) + ) # Event was triggered by a change in the range_slider # value. @@ -166,7 +173,8 @@ def update_wrapper(event): if old_low != new_low: range_slider.value = (new_low, 100) num_tasks_box.value = ( - -(100. - new_low) / 100. * num_tasks) + -(100. - new_low) / 100. * num_tasks + ) else: range_slider.value = (0, new_high) num_tasks_box.value = new_high / 100. * num_tasks @@ -179,11 +187,15 @@ def update_wrapper(event): elif event["owner"] == num_tasks_box: if num_tasks_box.value > 0: range_slider.value = ( - 0, int( - 100 * float(num_tasks_box.value) / num_tasks)) + 0, + int(100 * float(num_tasks_box.value) / num_tasks) + ) elif num_tasks_box.value < 0: - range_slider.value = (100 + int( - 100 * float(num_tasks_box.value) / num_tasks), 100) + range_slider.value = ( + 100 + + int(100 * float(num_tasks_box.value) / num_tasks), + 100 + ) if not update: return @@ -200,18 +212,21 @@ def update_wrapper(event): if breakdown_opt.value == total_time_value: tasks = _truncated_task_profiles( start=(smallest + diff * low), - end=(smallest + diff * high)) + end=(smallest + diff * high) + ) # (Querying based on % of total number of tasks that were # run.) elif breakdown_opt.value == total_tasks_value: if range_slider.value[0] == 0: tasks = _truncated_task_profiles( - num_tasks=(int(num_tasks * high)), fwd=True) + num_tasks=(int(num_tasks * high)), fwd=True + ) else: tasks = _truncated_task_profiles( num_tasks=(int(num_tasks * (high - low))), - fwd=False) + fwd=False + ) update(smallest, largest, num_tasks, tasks) @@ -227,8 +242,10 @@ def update_wrapper(event): update_wrapper(INIT_EVENT) # Display sliders and search boxes - display(breakdown_opt, - widgets.HBox([range_slider, total_time_box, num_tasks_box])) + display( + breakdown_opt, + widgets.HBox([range_slider, total_time_box, num_tasks_box]) + ) # Return the sliders and text boxes return start_box, end_box, range_slider, breakdown_opt @@ -239,7 +256,8 @@ def object_search_bar(): value="", placeholder="Object ID", description="Search for an object:", - disabled=False) + disabled=False + ) display(object_search) def handle_submit(sender): @@ -254,7 +272,8 @@ def task_search_bar(): value="", placeholder="Task ID", description="Search for a task:", - disabled=False) + disabled=False + ) display(task_search) def handle_submit(sender): @@ -272,12 +291,17 @@ def handle_submit(sender): def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True): if num_tasks is None: num_tasks = MAX_TASKS_TO_VISUALIZE - print("Warning: at most {} tasks will be fetched within this " - "time range.".format(MAX_TASKS_TO_VISUALIZE)) + print( + "Warning: at most {} tasks will be fetched within this " + "time range.".format(MAX_TASKS_TO_VISUALIZE) + ) elif num_tasks > MAX_TASKS_TO_VISUALIZE: - print("Warning: too many tasks to visualize, " - "fetching only the first {} of {}.".format( - MAX_TASKS_TO_VISUALIZE, num_tasks)) + print( + "Warning: too many tasks to visualize, " + "fetching only the first {} of {}.".format( + MAX_TASKS_TO_VISUALIZE, num_tasks + ) + ) num_tasks = MAX_TASKS_TO_VISUALIZE return ray.global_state.task_profiles(num_tasks, start, end, fwd) @@ -286,7 +310,8 @@ def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True): # Prevents clashes in task trace files when multiple notebooks are running. def _get_temp_file_path(**kwargs): temp_file = tempfile.NamedTemporaryFile( - delete=False, dir=os.getcwd(), **kwargs) + delete=False, dir=os.getcwd(), **kwargs + ) temp_file_path = temp_file.name temp_file.close() return os.path.relpath(temp_file_path) @@ -304,16 +329,21 @@ def task_timeline(): disabled=False, ) obj_dep = widgets.Checkbox( - value=True, disabled=False, layout=widgets.Layout(width='20px')) + value=True, disabled=False, layout=widgets.Layout(width='20px') + ) task_dep = widgets.Checkbox( - value=True, disabled=False, layout=widgets.Layout(width='20px')) + value=True, disabled=False, layout=widgets.Layout(width='20px') + ) # Labels to bypass width limitation for descriptions. label_tasks = widgets.Label( - value='Task submissions', layout=widgets.Layout(width='110px')) + value='Task submissions', layout=widgets.Layout(width='110px') + ) label_objects = widgets.Label( - value='Object dependencies', layout=widgets.Layout(width='130px')) + value='Object dependencies', layout=widgets.Layout(width='130px') + ) label_options = widgets.Label( - value='View options:', layout=widgets.Layout(width='100px')) + value='View options:', layout=widgets.Layout(width='100px') + ) start_box, end_box, range_slider, time_opt = get_sliders(False) display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects])) display(widgets.HBox([label_options, breakdown_opt])) @@ -325,8 +355,9 @@ def task_timeline(): shutil.copy( os.path.join( os.path.dirname(os.path.abspath(__file__)), - "../core/src/catapult_files/trace_viewer_full.html"), - "trace_viewer_full.html") + "../core/src/catapult_files/trace_viewer_full.html" + ), "trace_viewer_full.html" + ) def handle_submit(sender): json_tmp = tempfile.mktemp() + ".json" @@ -337,8 +368,9 @@ def handle_submit(sender): elif breakdown_opt.value == breakdown_task: breakdown = True else: - raise ValueError("Unexpected breakdown value '{}'".format( - breakdown_opt.value)) + raise ValueError( + "Unexpected breakdown value '{}'".format(breakdown_opt.value) + ) low, high = map(lambda x: x / 100., range_slider.value) @@ -347,32 +379,40 @@ def handle_submit(sender): if time_opt.value == total_time_value: tasks = _truncated_task_profiles( - start=smallest + diff * low, end=smallest + diff * high) + start=smallest + diff * low, end=smallest + diff * high + ) elif time_opt.value == total_tasks_value: if range_slider.value[0] == 0: tasks = _truncated_task_profiles( - num_tasks=int(num_tasks * high), fwd=True) + num_tasks=int(num_tasks * high), fwd=True + ) else: tasks = _truncated_task_profiles( - num_tasks=int(num_tasks * (high - low)), fwd=False) + num_tasks=int(num_tasks * (high - low)), fwd=False + ) else: - raise ValueError("Unexpected time value '{}'".format( - time_opt.value)) + raise ValueError( + "Unexpected time value '{}'".format(time_opt.value) + ) # Write trace to a JSON file print("Collected profiles for {} tasks.".format(len(tasks))) - print("Dumping task profile data to {}, " - "this might take a while...".format(json_tmp)) + print( + "Dumping task profile data to {}, " + "this might take a while...".format(json_tmp) + ) ray.global_state.dump_catapult_trace( json_tmp, tasks, breakdowns=breakdown, obj_dep=obj_dep.value, - task_dep=task_dep.value) + task_dep=task_dep.value + ) print("Opening html file in browser...") trace_viewer_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), - "../core/src/catapult_files/index.html") + "../core/src/catapult_files/index.html" + ) html_file_path = _get_temp_file_path(suffix=".html") json_file_path = _get_temp_file_path(suffix=".json") @@ -393,8 +433,10 @@ def handle_submit(sender): # Display the task trace within the Jupyter notebook clear_output(wait=True) - print("To view fullscreen, open chrome://tracing in Google Chrome " - "and load `{}`".format(json_tmp)) + print( + "To view fullscreen, open chrome://tracing in Google Chrome " + "and load `{}`".format(json_tmp) + ) display(IFrame(html_file_path, 900, 800)) path_input.on_click(handle_submit) @@ -414,7 +456,8 @@ def task_completion_time_distribution(): tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"], background_fill_color="#FFFFFF", x_range=(0, 1), - y_range=(0, 1)) + y_range=(0, 1) + ) # Create the data source that the plot pulls from source = ColumnDataSource(data={"top": [], "left": [], "right": []}) @@ -427,7 +470,8 @@ def task_completion_time_distribution(): right="right", source=source, fill_color="#B3B3B3", - line_color="#033649") + line_color="#033649" + ) # Label the plot axes p.xaxis.axis_label = "Duration in seconds" @@ -439,12 +483,15 @@ def task_completion_time_distribution(): ncols=1, plot_width=500, plot_height=500, - toolbar_location="below"), - notebook_handle=True) + toolbar_location="below" + ), + notebook_handle=True + ) # Function to update the plot - def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks, - tasks): + def task_completion_time_update( + abs_earliest, abs_latest, abs_num_tasks, tasks + ): if len(tasks) == 0: return @@ -452,7 +499,8 @@ def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks, distr = [] for task_id, data in tasks.items(): distr.append( - data["store_outputs_end"] - data["get_arguments_start"]) + data["store_outputs_end"] - data["get_arguments_start"] + ) # Create a histogram from the distribution top, bin_edges = np.histogram(distr, bins="auto") @@ -462,8 +510,9 @@ def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks, source.data = {"top": top, "left": left, "right": right} # Set the x and y ranges - x_range = (min(left) if len(left) else 0, max(right) - if len(right) else 1) + x_range = ( + min(left) if len(left) else 0, max(right) if len(right) else 1 + ) y_range = (0, max(top) + 1 if len(top) else 1) x_range = helpers._get_range(x_range) @@ -480,12 +529,14 @@ def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks, get_sliders(task_completion_time_update) -def compute_utilizations(abs_earliest, - abs_latest, - num_tasks, - tasks, - num_buckets, - use_abs_times=False): +def compute_utilizations( + abs_earliest, + abs_latest, + num_tasks, + tasks, + num_buckets, + use_abs_times=False +): if len(tasks) == 0: return [], [], [] @@ -515,21 +566,24 @@ def compute_utilizations(abs_earliest, task_end_time = data["store_outputs_end"] start_bucket = int( - (task_start_time - earliest_time) / bucket_time_length) + (task_start_time - earliest_time) / bucket_time_length + ) end_bucket = int((task_end_time - earliest_time) / bucket_time_length) # Walk over each time bucket that this task intersects, adding the # amount of time that the task intersects within each bucket for bucket_idx in range(start_bucket, end_bucket + 1): - bucket_start_time = (( - earliest_time + bucket_idx) * bucket_time_length) + bucket_start_time = ((earliest_time + bucket_idx) * + bucket_time_length) bucket_end_time = ((earliest_time + (bucket_idx + 1)) * bucket_time_length) - task_start_time_within_bucket = max(task_start_time, - bucket_start_time) + task_start_time_within_bucket = max( + task_start_time, bucket_start_time + ) task_end_time_within_bucket = min(task_end_time, bucket_end_time) task_cpu_time_within_bucket = ( - task_end_time_within_bucket - task_start_time_within_bucket) + task_end_time_within_bucket - task_start_time_within_bucket + ) if bucket_idx > -1 and bucket_idx < num_buckets: cpu_time[bucket_idx] += task_cpu_time_within_bucket @@ -537,7 +591,8 @@ def compute_utilizations(abs_earliest, # Cpu_utilization is the average cpu utilization of the bucket, which # is just cpu_time divided by bucket_time_length. cpu_utilization = list( - map(lambda x: x / float(bucket_time_length), cpu_time)) + map(lambda x: x / float(bucket_time_length), cpu_time) + ) # Generate histogram bucket edges. Subtract out abs_earliest to get # relative time. @@ -577,11 +632,13 @@ def plot_utilization(): tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"], background_fill_color="#FFFFFF", x_range=[0, 1], - y_range=[0, 1]) + y_range=[0, 1] + ) # Create the data source that the plot will pull from time_series_source = ColumnDataSource( - data=dict(left=[], right=[], top=[])) + data=dict(left=[], right=[], top=[]) + ) # Plot the rectangles representing the distribution time_series_fig.quad( @@ -591,7 +648,8 @@ def plot_utilization(): bottom=0, source=time_series_source, fill_color="#B3B3B3", - line_color="#033649") + line_color="#033649" + ) # Label the plot axes time_series_fig.xaxis.axis_label = "Time in seconds" @@ -603,22 +661,23 @@ def plot_utilization(): ncols=1, plot_width=500, plot_height=500, - toolbar_location="below"), - notebook_handle=True) + toolbar_location="below" + ), + notebook_handle=True + ) def update_plot(abs_earliest, abs_latest, abs_num_tasks, tasks): num_buckets = 100 left, right, top = compute_utilizations( - abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets) + abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets + ) - time_series_source.data = { - "left": left, - "right": right, - "top": top - } + time_series_source.data = {"left": left, "right": right, "top": top} - x_range = (max(0, min(left)) if len(left) else 0, max(right) - if len(right) else 1) + x_range = ( + max(0, min(left)) if len(left) else 0, max(right) + if len(right) else 1 + ) y_range = (0, max(top) + 1 if len(top) else 1) # Define the axis ranges @@ -659,7 +718,8 @@ def cluster_usage(): "time": ['0.5'], "num_tasks": ['1'], "length": [1] - }) + } + ) # Define the color schema colors = [ @@ -678,7 +738,8 @@ def cluster_usage(): plot_width=900, plot_height=500, tools=TOOLS, - toolbar_location='below') + toolbar_location='below' + ) # Format the plot axes p.grid.grid_line_color = None @@ -699,7 +760,8 @@ def cluster_usage(): "field": "num_tasks", "transform": mapper }, - line_color=None) + line_color=None + ) # Add legend to the side of the plot color_bar = ColorBar( @@ -708,14 +770,15 @@ def cluster_usage(): ticker=BasicTicker(desired_num_ticks=len(colors)), label_standoff=6, border_line_color=None, - location=(0, 0)) + location=(0, 0) + ) p.add_layout(color_bar, "right") # Define hover tool - p.select_one(HoverTool).tooltips = [("Node IP Address", - "@node_ip_address"), - ("Number of tasks running", - "@num_tasks"), ("Time", "@time")] + p.select_one(HoverTool).tooltips = [ + ("Node IP Address", "@node_ip_address"), + ("Number of tasks running", "@num_tasks"), ("Time", "@time") + ] # Define the axis labels p.xaxis.axis_label = "Time in seconds" @@ -752,7 +815,8 @@ def heat_map_update(abs_earliest, abs_latest, abs_num_tasks, tasks): for node_ip, task_dict in node_to_tasks.items(): left, right, top = compute_utilizations( - earliest, latest, abs_num_tasks, task_dict, 100, True) + earliest, latest, abs_num_tasks, task_dict, 100, True + ) for (l, r, t) in zip(left, right, top): nodes.append(node_ip) times.append((l + r) / 2) diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py index 7e3d019ffa98..f30d97cf3ae8 100644 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ b/python/ray/global_scheduler/global_scheduler_services.py @@ -7,12 +7,14 @@ import time -def start_global_scheduler(redis_address, - node_ip_address, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): +def start_global_scheduler( + redis_address, + node_ip_address, + use_valgrind=False, + use_profiler=False, + stdout_file=None, + stderr_file=None +): """Start a global scheduler process. Args: @@ -35,25 +37,24 @@ def start_global_scheduler(redis_address, raise Exception("Cannot use valgrind and profiler at the same time.") global_scheduler_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "../core/src/global_scheduler/global_scheduler") + "../core/src/global_scheduler/global_scheduler" + ) command = [ global_scheduler_executable, "-r", redis_address, "-h", node_ip_address ] if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen([ + "valgrind", "--track-origins=yes", "--leak-check=full", + "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", + "--error-exitcode=1" + ] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) else: pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index a05b7a9c7a3c..5a66ce8defbd 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -59,8 +59,7 @@ class TestGlobalScheduler(unittest.TestCase): def setUp(self): # Start one Redis server and N pairs of (plasma, local_scheduler) self.node_ip_address = "127.0.0.1" - redis_address, redis_shards = services.start_redis( - self.node_ip_address) + redis_address, redis_shards = services.start_redis(self.node_ip_address) redis_port = services.get_port(redis_address) time.sleep(0.1) # Create a client for the global state store. @@ -69,7 +68,8 @@ def setUp(self): # Start one global scheduler. self.p1 = global_scheduler.start_global_scheduler( - redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) + redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND + ) self.plasma_store_pids = [] self.plasma_manager_pids = [] self.local_scheduler_pids = [] @@ -83,14 +83,17 @@ def setUp(self): # Start the Plasma manager. # Assumption: Plasma manager name and port are randomly generated # by the plasma module. - manager_info = plasma.start_plasma_manager(plasma_store_name, - redis_address) + manager_info = plasma.start_plasma_manager( + plasma_store_name, redis_address + ) plasma_manager_name, p3, plasma_manager_port = manager_info self.plasma_manager_pids.append(p3) - plasma_address = "{}:{}".format(self.node_ip_address, - plasma_manager_port) - plasma_client = pa.plasma.connect(plasma_store_name, - plasma_manager_name, 64) + plasma_address = "{}:{}".format( + self.node_ip_address, plasma_manager_port + ) + plasma_client = pa.plasma.connect( + plasma_store_name, plasma_manager_name, 64 + ) self.plasma_clients.append(plasma_client) # Start the local scheduler. local_scheduler_name, p4 = local_scheduler.start_local_scheduler( @@ -98,10 +101,12 @@ def setUp(self): plasma_manager_name=plasma_manager_name, plasma_address=plasma_address, redis_address=redis_address, - static_resources={"CPU": 10}) + static_resources={"CPU": 10} + ) # Connect to the scheduler. local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_WORKER_ID, False) + local_scheduler_name, NIL_WORKER_ID, False + ) self.local_scheduler_clients.append(local_scheduler_client) self.local_scheduler_pids.append(p4) @@ -115,8 +120,8 @@ def tearDown(self): for p4 in self.local_scheduler_pids: self.assertEqual(p4.poll(), None) - redis_processes = services.all_processes[ - services.PROCESS_TYPE_REDIS_SERVER] + redis_processes = services.all_processes[services.PROCESS_TYPE_REDIS_SERVER + ] for redis_process in redis_processes: self.assertEqual(redis_process.poll(), None) @@ -165,7 +170,8 @@ def get_plasma_manager_id(self): def test_task_default_resources(self): task1 = local_scheduler.Task( random_driver_id(), random_function_id(), [random_object_id()], 0, - random_task_id(), 0) + random_task_id(), 0 + ) self.assertEqual(task1.required_resources(), {"CPU": 1}) task2 = local_scheduler.Task( random_driver_id(), random_function_id(), [random_object_id()], 0, @@ -175,7 +181,8 @@ def test_task_default_resources(self): local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], { "CPU": 1, "GPU": 2 - }) + } + ) self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2}) def test_redis_only_single_task(self): @@ -188,19 +195,22 @@ def test_redis_only_single_task(self): # scheduler and one plasma per node. self.assertEqual( len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) + 2 * NUM_CLUSTER_NODES + 1 + ) db_client_id = self.get_plasma_manager_id() assert (db_client_id is not None) @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "New GCS API doesn't have a Python API yet.") + "New GCS API doesn't have a Python API yet." + ) def test_integration_single_task(self): # There should be three db clients, the global scheduler, the local # scheduler, and the plasma manager. self.assertEqual( len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) + 2 * NUM_CLUSTER_NODES + 1 + ) num_return_vals = [0, 1, 2, 3, 5, 10] # Insert the object into Redis. @@ -208,15 +218,17 @@ def test_integration_single_task(self): metadata_size = 0x40 plasma_client = self.plasma_clients[0] object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) + plasma_client, data_size, metadata_size, seal=True + ) # Sleep before submitting task to local scheduler. time.sleep(0.1) # Submit a task to Redis. task = local_scheduler.Task( random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) + [local_scheduler.ObjectID(object_dep.binary())], num_return_vals[0], + random_task_id(), 0 + ) self.local_scheduler_clients[0].submit(task) time.sleep(0.1) # There should now be a task in Redis, and it should get assigned to @@ -228,10 +240,12 @@ def test_integration_single_task(self): if len(task_entries) == 1: task_id, task = task_entries.popitem() task_status = task["State"] - self.assertTrue(task_status in [ - state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED - ]) + self.assertTrue( + task_status in [ + state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, + state.TASK_STATUS_QUEUED + ] + ) if task_status == state.TASK_STATUS_QUEUED: break else: @@ -250,7 +264,8 @@ def integration_many_tasks_helper(self, timesync=True): # scheduler, and the plasma manager. self.assertEqual( len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) + 2 * NUM_CLUSTER_NODES + 1 + ) num_return_vals = [0, 1, 2, 3, 5, 10] # Submit a bunch of tasks to Redis. @@ -261,7 +276,8 @@ def integration_many_tasks_helper(self, timesync=True): metadata_size = np.random.randint(1 << 9) plasma_client = self.plasma_clients[0] object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) + plasma_client, data_size, metadata_size, seal=True + ) if timesync: # Give 10ms for object info handler to fire (long enough to # yield CPU). @@ -269,7 +285,8 @@ def integration_many_tasks_helper(self, timesync=True): task = local_scheduler.Task( random_driver_id(), random_function_id(), [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) + num_return_vals[0], random_task_id(), 0 + ) self.local_scheduler_clients[0].submit(task) # Check that there are the correct number of tasks in Redis and that # they all get assigned to the local scheduler. @@ -281,8 +298,7 @@ def integration_many_tasks_helper(self, timesync=True): # First, check if all tasks made it to Redis. if len(task_entries) == num_tasks: task_statuses = [ - task_entry["State"] - for task_entry in task_entries.values() + task_entry["State"] for task_entry in task_entries.values() ] self.assertTrue( all([ @@ -291,20 +307,26 @@ def integration_many_tasks_helper(self, timesync=True): state.TASK_STATUS_SCHEDULED, state.TASK_STATUS_QUEUED ] for status in task_statuses - ])) + ]) + ) num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) num_tasks_scheduled = task_statuses.count( - state.TASK_STATUS_SCHEDULED) + state.TASK_STATUS_SCHEDULED + ) num_tasks_waiting = task_statuses.count( - state.TASK_STATUS_WAITING) - print("tasks in Redis = {}, tasks waiting = {}, " - "tasks scheduled = {}, " - "tasks queued = {}, retries left = {}".format( - len(task_entries), num_tasks_waiting, - num_tasks_scheduled, num_tasks_done, num_retries)) + state.TASK_STATUS_WAITING + ) + print( + "tasks in Redis = {}, tasks waiting = {}, " + "tasks scheduled = {}, " + "tasks queued = {}, retries left = {}".format( + len(task_entries), num_tasks_waiting, + num_tasks_scheduled, num_tasks_done, num_retries + ) + ) if all([ - status == state.TASK_STATUS_QUEUED - for status in task_statuses + status == state.TASK_STATUS_QUEUED + for status in task_statuses ]): # We're done, so pass. break @@ -317,13 +339,15 @@ def integration_many_tasks_helper(self, timesync=True): @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "New GCS API doesn't have a Python API yet.") + "New GCS API doesn't have a Python API yet." + ) def test_integration_many_tasks_handler_sync(self): self.integration_many_tasks_helper(timesync=True) @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "New GCS API doesn't have a Python API yet.") + "New GCS API doesn't have a Python API yet." + ) def test_integration_many_tasks(self): # More realistic case: should handle out of order object and task # notifications. diff --git a/python/ray/local_scheduler/__init__.py b/python/ray/local_scheduler/__init__.py index d3018dd5bb88..dfaf7e222789 100644 --- a/python/ray/local_scheduler/__init__.py +++ b/python/ray/local_scheduler/__init__.py @@ -4,7 +4,8 @@ from ray.core.src.local_scheduler.liblocal_scheduler_library import ( Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string, - task_to_string, _config, common_error) + task_to_string, _config, common_error +) from .local_scheduler_services import start_local_scheduler __all__ = [ diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py index 1f6b79a2279b..b267cc42b275 100644 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ b/python/ray/local_scheduler/local_scheduler_services.py @@ -14,18 +14,20 @@ def random_name(): return str(random.randint(0, 99999999)) -def start_local_scheduler(plasma_store_name, - plasma_manager_name=None, - worker_path=None, - plasma_address=None, - node_ip_address="127.0.0.1", - redis_address=None, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None, - static_resources=None, - num_workers=0): +def start_local_scheduler( + plasma_store_name, + plasma_manager_name=None, + worker_path=None, + plasma_address=None, + node_ip_address="127.0.0.1", + redis_address=None, + use_valgrind=False, + use_profiler=False, + stdout_file=None, + stderr_file=None, + static_resources=None, + num_workers=0 +): """Start a local scheduler process. Args: @@ -63,14 +65,17 @@ def start_local_scheduler(plasma_store_name, the local scheduler process. """ if (plasma_manager_name is None) != (redis_address is None): - raise Exception("If one of the plasma_manager_name and the " - "redis_address is provided, then both must be " - "provided.") + raise Exception( + "If one of the plasma_manager_name and the " + "redis_address is provided, then both must be " + "provided." + ) if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") local_scheduler_executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), - "../core/src/local_scheduler/local_scheduler") + "../core/src/local_scheduler/local_scheduler" + ) local_scheduler_name = "/tmp/scheduler{}".format(random_name()) command = [ local_scheduler_executable, "-s", local_scheduler_name, "-p", @@ -83,16 +88,17 @@ def start_local_scheduler(plasma_store_name, assert plasma_store_name is not None assert plasma_manager_name is not None assert redis_address is not None - start_worker_command = ("{} {} " - "--node-ip-address={} " - "--object-store-name={} " - "--object-store-manager-name={} " - "--local-scheduler-name={} " - "--redis-address={}".format( - sys.executable, worker_path, - node_ip_address, plasma_store_name, - plasma_manager_name, local_scheduler_name, - redis_address)) + start_worker_command = ( + "{} {} " + "--node-ip-address={} " + "--object-store-name={} " + "--object-store-manager-name={} " + "--local-scheduler-name={} " + "--redis-address={}".format( + sys.executable, worker_path, node_ip_address, plasma_store_name, + plasma_manager_name, local_scheduler_name, redis_address + ) + ) command += ["-w", start_worker_command] if redis_address is not None: command += ["-r", redis_address] @@ -101,8 +107,10 @@ def start_local_scheduler(plasma_store_name, if static_resources is not None: resource_argument = "" for resource_name, resource_quantity in static_resources.items(): - assert (isinstance(resource_quantity, int) - or isinstance(resource_quantity, float)) + assert ( + isinstance(resource_quantity, int) + or isinstance(resource_quantity, float) + ) resource_argument = ",".join([ resource_name + "," + str(resource_quantity) for resource_name, resource_quantity in static_resources.items() @@ -112,20 +120,18 @@ def start_local_scheduler(plasma_store_name, command += ["-c", resource_argument] if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen([ + "valgrind", "--track-origins=yes", "--leak-check=full", + "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", + "--error-exitcode=1" + ] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) else: pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py index b990676c8edc..90766536f390 100644 --- a/python/ray/local_scheduler/test/test.py +++ b/python/ray/local_scheduler/test/test.py @@ -43,10 +43,12 @@ def setUp(self): self.plasma_client = pa.plasma.connect(plasma_store_name, "", 0) # Start a local scheduler. scheduler_name, self.p2 = local_scheduler.start_local_scheduler( - plasma_store_name, use_valgrind=USE_VALGRIND) + plasma_store_name, use_valgrind=USE_VALGRIND + ) # Connect to the scheduler. self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_WORKER_ID, False) + scheduler_name, NIL_WORKER_ID, False + ) def tearDown(self): # Check that the processes are still alive. @@ -87,15 +89,18 @@ def test_submit_and_get_task(self): for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) + task = local_scheduler.Task( + random_driver_id(), function_id, args, num_return_vals, + random_task_id(), 0 + ) # Submit a task. self.local_scheduler_client.submit(task) # Get the task. new_task = self.local_scheduler_client.get_task() - self.assertEqual(task.function_id().id(), - new_task.function_id().id()) + self.assertEqual( + task.function_id().id(), + new_task.function_id().id() + ) retrieved_args = new_task.arguments() returns = new_task.returns() self.assertEqual(len(args), len(retrieved_args)) @@ -109,9 +114,10 @@ def test_submit_and_get_task(self): # Submit all of the tasks. for args in args_list: for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) + task = local_scheduler.Task( + random_driver_id(), function_id, args, num_return_vals, + random_task_id(), 0 + ) self.local_scheduler_client.submit(task) # Get all of the tasks. for args in args_list: @@ -121,8 +127,10 @@ def test_submit_and_get_task(self): def test_scheduling_when_objects_ready(self): # Create a task and submit it. object_id = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id], 0, random_task_id(), 0) + task = local_scheduler.Task( + random_driver_id(), random_function_id(), + [object_id], 0, random_task_id(), 0 + ) self.local_scheduler_client.submit(task) # Launch a thread to get the task. @@ -145,9 +153,10 @@ def test_scheduling_when_objects_evicted(self): # Create a task with two dependencies and submit it. object_id1 = random_object_id() object_id2 = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id1, object_id2], 0, - random_task_id(), 0) + task = local_scheduler.Task( + random_driver_id(), random_function_id(), + [object_id1, object_id2], 0, random_task_id(), 0 + ) self.local_scheduler_client.submit(task) # Launch a thread to get the task. @@ -171,8 +180,10 @@ def get_task(): time.sleep(0.1) self.assertTrue(t.is_alive()) # Check that the first object dependency was evicted. - object1 = self.plasma_client.get_buffers( - [pa.plasma.ObjectID(object_id1.id())], timeout_ms=0) + object1 = self.plasma_client.get_buffers([ + pa.plasma.ObjectID(object_id1.id()) + ], + timeout_ms=0) self.assertEqual(object1, [None]) # Check that the thread is still waiting for a task. time.sleep(0.1) diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 465d87ef3e31..2afd99435e58 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -31,7 +31,8 @@ def __init__(self, redis_ip_address, redis_port, node_ip_address): """Initialize the log monitor object.""" self.node_ip_address = node_ip_address self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port + ) self.log_files = {} self.log_file_handles = {} self.files_to_ignore = set() @@ -39,8 +40,10 @@ def __init__(self, redis_ip_address, redis_port, node_ip_address): def update_log_filenames(self): """Get the most up-to-date list of log files to monitor from Redis.""" num_current_log_files = len(self.log_files) - new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format( - self.node_ip_address), num_current_log_files, -1) + new_log_filenames = self.redis_client.lrange( + "LOG_FILENAMES:{}".format(self.node_ip_address), + num_current_log_files, -1 + ) for log_filename in new_log_filenames: print("Beginning to track file {}".format(log_filename)) assert log_filename not in self.log_files @@ -54,13 +57,14 @@ def check_log_files_and_push_updates(self): new_lines = [] while True: current_position = ( - self.log_file_handles[log_filename].tell()) + self.log_file_handles[log_filename].tell() + ) next_line = self.log_file_handles[log_filename].readline() if next_line != "": new_lines.append(next_line) else: - self.log_file_handles[log_filename].seek( - current_position) + self.log_file_handles[log_filename + ].seek(current_position) break # If there are any new lines, cache them and also push them to @@ -68,7 +72,8 @@ def check_log_files_and_push_updates(self): if len(new_lines) > 0: self.log_files[log_filename] += new_lines redis_key = "LOGFILE:{}:{}".format( - self.node_ip_address, log_filename.decode("ascii")) + self.node_ip_address, log_filename.decode("ascii") + ) self.redis_client.rpush(redis_key, *new_lines) # Pass if we already failed to open the log file. @@ -79,14 +84,19 @@ def check_log_files_and_push_updates(self): else: try: self.log_file_handles[log_filename] = open( - log_filename, "r") + log_filename, "r" + ) except IOError as e: if e.errno == os.errno.EMFILE: - print("Warning: Ignoring {} because there are too " - "many open files.".format(log_filename)) + print( + "Warning: Ignoring {} because there are too " + "many open files.".format(log_filename) + ) elif e.errno == os.errno.ENOENT: - print("Warning: The file {} was not " - "found.".format(log_filename)) + print( + "Warning: The file {} was not " + "found.".format(log_filename) + ) else: raise e @@ -107,24 +117,28 @@ def run(self): if __name__ == "__main__": parser = argparse.ArgumentParser( - description=("Parse Redis server for the " - "log monitor to connect " - "to.")) + description=( + "Parse Redis server for the " + "log monitor to connect " + "to." + ) + ) parser.add_argument( "--redis-address", required=True, type=str, - help="The address to use for Redis.") + help="The address to use for Redis." + ) parser.add_argument( "--node-ip-address", required=True, type=str, - help="The IP address of the node this process is on.") + help="The IP address of the node this process is on." + ) args = parser.parse_args() redis_ip_address = get_ip_address(args.redis_address) redis_port = get_port(args.redis_address) - log_monitor = LogMonitor(redis_ip_address, redis_port, - args.node_ip_address) + log_monitor = LogMonitor(redis_ip_address, redis_port, args.node_ip_address) log_monitor.run() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 20c5ce17b4cb..de1e84b6b9d5 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -85,7 +85,8 @@ def __init__(self, redis_address, redis_port, autoscaling_config): self.state = ray.experimental.state.GlobalState() self.state._initialize_global_state(redis_address, redis_port) self.redis = redis.StrictRedis( - host=redis_address, port=redis_port, db=0) + host=redis_address, port=redis_port, db=0 + ) # TODO(swang): Update pubsub client to use ray.experimental.state once # subscriptions are implemented there. self.subscribe_client = self.redis.pubsub() @@ -100,8 +101,9 @@ def __init__(self, redis_address, redis_port, autoscaling_config): self.local_scheduler_id_to_ip_map = dict() self.load_metrics = LoadMetrics() if autoscaling_config: - self.autoscaler = StandardAutoscaler(autoscaling_config, - self.load_metrics) + self.autoscaler = StandardAutoscaler( + autoscaling_config, self.load_metrics + ) else: self.autoscaler = None @@ -151,10 +153,13 @@ def cleanup_task_table(self): for manager in manager_ids: ok = self.state._execute_command( dummy_object_id, "RAY.OBJECT_TABLE_REMOVE", - dummy_object_id.id(), hex_to_binary(manager)) + dummy_object_id.id(), hex_to_binary(manager) + ) if ok != b"OK": - log.warn("Failed to remove object location for " - "dead plasma manager.") + log.warn( + "Failed to remove object location for " + "dead plasma manager." + ) # If the task is scheduled on a dead local scheduler, mark the # task as lost. @@ -162,7 +167,8 @@ def cleanup_task_table(self): ok = self.state._execute_command( key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), ray.experimental.state.TASK_STATUS_LOST, NIL_ID, - task["ExecutionDependenciesString"], task["SpillbackCount"]) + task["ExecutionDependenciesString"], task["SpillbackCount"] + ) if ok != b"OK": log.warn("Failed to update lost task for dead scheduler.") num_tasks_updated += 1 @@ -189,13 +195,15 @@ def cleanup_object_table(self): if manager in self.dead_plasma_managers: # If the object was on a dead plasma manager, remove that # location entry. - ok = self.state._execute_command(object_id, - "RAY.OBJECT_TABLE_REMOVE", - object_id.id(), - hex_to_binary(manager)) + ok = self.state._execute_command( + object_id, "RAY.OBJECT_TABLE_REMOVE", object_id.id(), + hex_to_binary(manager) + ) if ok != b"OK": - log.warn("Failed to remove object location for dead " - "plasma manager.") + log.warn( + "Failed to remove object location for dead " + "plasma manager." + ) num_objects_removed += 1 if num_objects_removed > 0: log.warn("Marked {} objects as lost.".format(num_objects_removed)) @@ -233,8 +241,11 @@ def db_client_notification_handler(self, unused_channel, data): the associated state in the state tables should be handled by the caller. """ - notification_object = (SubscribeToDBClientTableReply. - GetRootAsSubscribeToDBClientTableReply(data, 0)) + notification_object = ( + SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply( + data, 0 + ) + ) db_client_id = binary_to_hex(notification_object.DbClientId()) client_type = notification_object.ClientType() is_insertion = notification_object.IsInsertion() @@ -260,7 +271,8 @@ def local_scheduler_info_handler(self, unused_channel, data): """Handle a local scheduler heartbeat from Redis.""" message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage( - data, 0) + data, 0 + ) num_resources = message.DynamicResourcesLength() static_resources = {} dynamic_resources = {} @@ -276,8 +288,7 @@ def local_scheduler_info_handler(self, unused_channel, data): if ip: self.load_metrics.update(ip, static_resources, dynamic_resources) else: - print("Warning: could not find ip for client {}." - .format(client_id)) + print("Warning: could not find ip for client {}.".format(client_id)) def plasma_manager_heartbeat_handler(self, unused_channel, data): """Handle a plasma manager heartbeat from Redis. @@ -364,11 +375,13 @@ def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index): num_deleted = redis.delete(*keys) log.info( "Removed {} dead redis entries of the driver from redis shard {}.". - format(num_deleted, shard_index)) + format(num_deleted, shard_index) + ) if num_deleted != len(keys): log.warning( "Failed to remove {} relevant redis entries" - " from redis shard {}.".format(len(keys) - num_deleted)) + " from redis shard {}.".format(len(keys) - num_deleted) + ) def _clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from all redis shards. @@ -405,7 +418,8 @@ def _clean_up_entries_for_driver(self, driver_id): def ToShardIndex(index): return binary_to_object_id(index).redis_shard_hash() % len( - self.state.redis_clients) + self.state.redis_clients + ) for object_id in driver_object_ids: object_ids_per_shard[ToShardIndex(object_id)].append(object_id) @@ -416,7 +430,8 @@ def ToShardIndex(index): for shard_index in range(len(self.state.redis_clients)): self._clean_up_entries_from_shard( object_ids_per_shard[shard_index], - task_ids_per_shard[shard_index], shard_index) + task_ids_per_shard[shard_index], shard_index + ) def driver_removed_handler(self, unused_channel, data): """Handle a notification that a driver has been removed. @@ -426,8 +441,7 @@ def driver_removed_handler(self, unused_channel, data): """ message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) driver_id = message.DriverId() - log.info("Driver {} has been removed.".format( - binary_to_hex(driver_id))) + log.info("Driver {} has been removed.".format(binary_to_hex(driver_id))) self._clean_up_entries_for_driver(driver_id) @@ -504,12 +518,15 @@ def run(self): self.cleanup_task_table() if len(self.dead_plasma_managers) > 0: self.cleanup_object_table() - log.debug("{} dead local schedulers, {} plasma managers total, {} " - "dead plasma managers".format( - len(self.dead_local_schedulers), - (len(self.live_plasma_managers) + - len(self.dead_plasma_managers)), - len(self.dead_plasma_managers))) + log.debug( + "{} dead local schedulers, {} plasma managers total, {} " + "dead plasma managers".format( + len(self.dead_local_schedulers), ( + len(self.live_plasma_managers) + + len(self.dead_plasma_managers) + ), len(self.dead_plasma_managers) + ) + ) # Handle messages from the subscription channels. while True: @@ -542,7 +559,7 @@ def run(self): plasma_manager_ids = list(self.live_plasma_managers.keys()) for plasma_manager_id in plasma_manager_ids: if ((self.live_plasma_managers[plasma_manager_id]) >= - ray._config.num_heartbeats_timeout()): + ray._config.num_heartbeats_timeout()): log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) # Remove the plasma manager from the managers whose # heartbeats we're tracking. @@ -551,8 +568,9 @@ def run(self): # corresponding state in the object table will be cleaned # up once we receive the notification for this db_client # deletion. - self.redis.execute_command("RAY.DISCONNECT", - plasma_manager_id) + self.redis.execute_command( + "RAY.DISCONNECT", plasma_manager_id + ) # Increment the number of heartbeats that we've missed from each # plasma manager. @@ -571,17 +589,20 @@ def run(self): if __name__ == "__main__": parser = argparse.ArgumentParser( description=("Parse Redis server for the " - "monitor to connect to.")) + "monitor to connect to.") + ) parser.add_argument( "--redis-address", required=True, type=str, - help="the address to use for Redis") + help="the address to use for Redis" + ) parser.add_argument( "--autoscaling-config", required=False, type=str, - help="the path to the autoscaling config file") + help="the path to the autoscaling config file" + ) args = parser.parse_args() redis_ip_address = get_ip_address(args.redis_address) diff --git a/python/ray/plasma/__init__.py b/python/ray/plasma/__init__.py index 1ecd0c2af2dc..1db225a1c800 100644 --- a/python/ray/plasma/__init__.py +++ b/python/ray/plasma/__init__.py @@ -2,8 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.plasma.plasma import (start_plasma_store, start_plasma_manager, - DEFAULT_PLASMA_STORE_MEMORY) +from ray.plasma.plasma import ( + start_plasma_store, start_plasma_manager, DEFAULT_PLASMA_STORE_MEMORY +) __all__ = [ "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 36498ea4c251..c22efa36a410 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -21,13 +21,15 @@ def random_name(): return str(random.randint(0, 99999999)) -def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None, - plasma_directory=None, - huge_pages=False): +def start_plasma_store( + plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, + use_valgrind=False, + use_profiler=False, + stdout_file=None, + stderr_file=None, + plasma_directory=None, + huge_pages=False +): """Start a plasma store process. Args: @@ -51,18 +53,22 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - if huge_pages and not (sys.platform == "linux" - or sys.platform == "linux2"): - raise Exception("The huge_pages argument is only supported on " - "Linux.") + if huge_pages and not (sys.platform == "linux" or sys.platform == "linux2"): + raise Exception( + "The huge_pages argument is only supported on " + "Linux." + ) if huge_pages and plasma_directory is None: - raise Exception("If huge_pages is True, then the " - "plasma_directory argument must be provided.") + raise Exception( + "If huge_pages is True, then the " + "plasma_directory argument must be provided." + ) plasma_store_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "../core/src/plasma/plasma_store") + "../core/src/plasma/plasma_store" + ) plasma_store_name = "/tmp/plasma_store{}".format(random_name()) command = [ plasma_store_executable, "-s", plasma_store_name, "-m", @@ -73,20 +79,18 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, if huge_pages: command += ["-h"] if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen([ + "valgrind", "--track-origins=yes", "--leak-check=full", + "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", + "--error-exitcode=1" + ] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) else: pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) @@ -98,15 +102,17 @@ def new_port(): return random.randint(10000, 65535) -def start_plasma_manager(store_name, - redis_address, - node_ip_address="127.0.0.1", - plasma_manager_port=None, - num_retries=20, - use_valgrind=False, - run_profiler=False, - stdout_file=None, - stderr_file=None): +def start_plasma_manager( + store_name, + redis_address, + node_ip_address="127.0.0.1", + plasma_manager_port=None, + num_retries=20, + use_valgrind=False, + run_profiler=False, + stdout_file=None, + stderr_file=None +): """Start a plasma manager and return the ports it listens on. Args: @@ -132,7 +138,8 @@ def start_plasma_manager(store_name, """ plasma_manager_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "../core/src/plasma/plasma_manager") + "../core/src/plasma/plasma_manager" + ) plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) if plasma_manager_port is not None: if num_retries != 1: @@ -158,21 +165,21 @@ def start_plasma_manager(store_name, redis_address, ] if use_valgrind: - process = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) + process = subprocess.Popen([ + "valgrind", "--track-origins=yes", "--leak-check=full", + "--show-leak-kinds=all", "--error-exitcode=1" + ] + command, + stdout=stdout_file, + stderr=stderr_file) elif run_profiler: - process = subprocess.Popen( - (["valgrind", "--tool=callgrind"] + command), - stdout=stdout_file, - stderr=stderr_file) + process = subprocess.Popen((["valgrind", "--tool=callgrind"] + + command), + stdout=stdout_file, + stderr=stderr_file) else: process = subprocess.Popen( - command, stdout=stdout_file, stderr=stderr_file) + command, stdout=stdout_file, stderr=stderr_file + ) # This sleep is critical. If the plasma_manager fails to start because # the port is already in use, then we need it to fail within 0.1 # seconds. diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index 8b0d62fe1d2a..5316d0608001 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -16,8 +16,9 @@ # The ray import must come before the pyarrow import because ray modifies the # python path so that the right version of pyarrow is found. import ray -from ray.plasma.utils import (random_object_id, create_object_with_id, - create_object) +from ray.plasma.utils import ( + random_object_id, create_object_with_id, create_object +) from ray import services import pyarrow as pa import pyarrow.plasma as plasma @@ -30,12 +31,9 @@ def random_name(): return str(random.randint(0, 99999999)) -def assert_get_object_equal(unit_test, - client1, - client2, - object_id, - memory_buffer=None, - metadata=None): +def assert_get_object_equal( + unit_test, client1, client2, object_id, memory_buffer=None, metadata=None +): client1_buff = client1.get_buffers([object_id])[0] client2_buff = client2.get_buffers([object_id])[0] client1_metadata = client1.get_metadata([object_id])[0] @@ -45,31 +43,37 @@ def assert_get_object_equal(unit_test, # Check that the buffers from the two clients are the same. assert_equal( np.frombuffer(client1_buff, dtype="uint8"), - np.frombuffer(client2_buff, dtype="uint8")) + np.frombuffer(client2_buff, dtype="uint8") + ) # Check that the metadata buffers from the two clients are the same. assert_equal( np.frombuffer(client1_metadata, dtype="uint8"), - np.frombuffer(client2_metadata, dtype="uint8")) + np.frombuffer(client2_metadata, dtype="uint8") + ) # If a reference buffer was provided, check that it is the same as well. if memory_buffer is not None: assert_equal( np.frombuffer(memory_buffer, dtype="uint8"), - np.frombuffer(client1_buff, dtype="uint8")) + np.frombuffer(client1_buff, dtype="uint8") + ) # If reference metadata was provided, check that it is the same as well. if metadata is not None: assert_equal( np.frombuffer(metadata, dtype="uint8"), - np.frombuffer(client1_metadata, dtype="uint8")) + np.frombuffer(client1_metadata, dtype="uint8") + ) DEFAULT_PLASMA_STORE_MEMORY = 10**9 -def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): +def start_plasma_store( + plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, + use_valgrind=False, + use_profiler=False, + stdout_file=None, + stderr_file=None +): """Start a plasma store process. Args: use_valgrind (bool): True if the plasma store should be started inside @@ -93,20 +97,18 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, str(plasma_store_memory) ] if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen([ + "valgrind", "--track-origins=yes", "--leak-check=full", + "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", + "--error-exitcode=1" + ] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, + stderr=stderr_file) time.sleep(1.0) else: pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) @@ -126,9 +128,11 @@ def setUp(self): redis_address, _ = services.start_redis("127.0.0.1") # Start two PlasmaManagers. manager_name1, self.p4, self.port1 = ray.plasma.start_plasma_manager( - store_name1, redis_address, use_valgrind=USE_VALGRIND) + store_name1, redis_address, use_valgrind=USE_VALGRIND + ) manager_name2, self.p5, self.port2 = ray.plasma.start_plasma_manager( - store_name2, redis_address, use_valgrind=USE_VALGRIND) + store_name2, redis_address, use_valgrind=USE_VALGRIND + ) # Connect two PlasmaClients. self.client1 = plasma.connect(store_name1, manager_name1, 64) self.client2 = plasma.connect(store_name2, manager_name2, 64) @@ -164,7 +168,8 @@ def test_fetch(self): for _ in range(10): # Create an object. object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) + self.client1, 2000, 2000 + ) self.client1.fetch([object_id1]) self.assertEqual(self.client1.contains(object_id1), True) self.assertEqual(self.client2.contains(object_id1), False) @@ -180,14 +185,16 @@ def test_fetch(self): self.client2, object_id1, memory_buffer=memory_buffer1, - metadata=metadata1) + metadata=metadata1 + ) # Test that we can call fetch on object IDs that don't exist yet. object_id2 = random_object_id() self.client1.fetch([object_id2]) self.assertEqual(self.client1.contains(object_id2), False) memory_buffer2, metadata2 = create_object_with_id( - self.client2, object_id2, 2000, 2000) + self.client2, object_id2, 2000, 2000 + ) # # Check that the object has been fetched. # self.assertEqual(self.client1.contains(object_id2), True) # Compare the two buffers. @@ -203,7 +210,8 @@ def test_fetch(self): self.client1.fetch([object_id3]) self.client2.fetch([object_id3]) memory_buffer3, metadata3 = create_object_with_id( - self.client1, object_id3, 2000, 2000) + self.client1, object_id3, 2000, 2000 + ) for _ in range(10): self.client1.fetch([object_id3]) self.client2.fetch([object_id3]) @@ -216,16 +224,19 @@ def test_fetch(self): self.client2, object_id3, memory_buffer=memory_buffer3, - metadata=metadata3) + metadata=metadata3 + ) def test_fetch_multiple(self): for _ in range(20): # Create two objects and a third fake one that doesn't exist. object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) + self.client1, 2000, 2000 + ) missing_object_id = random_object_id() object_id2, memory_buffer2, metadata2 = create_object( - self.client1, 2000, 2000) + self.client1, 2000, 2000 + ) object_ids = [object_id1, missing_object_id, object_id2] # Fetch the objects from the other plasma store. The second object # ID should timeout since it does not exist. @@ -241,14 +252,16 @@ def test_fetch_multiple(self): self.client2, object_id1, memory_buffer=memory_buffer1, - metadata=metadata1) + metadata=metadata1 + ) assert_get_object_equal( self, self.client1, self.client2, object_id2, memory_buffer=memory_buffer2, - metadata=metadata2) + metadata=metadata2 + ) # Fetch in the other direction. The fake object still does not # exist. self.client1.fetch(object_ids) @@ -258,32 +271,35 @@ def test_fetch_multiple(self): self.client1, object_id1, memory_buffer=memory_buffer1, - metadata=metadata1) + metadata=metadata1 + ) assert_get_object_equal( self, self.client2, self.client1, object_id2, memory_buffer=memory_buffer2, - metadata=metadata2) + metadata=metadata2 + ) # Check that we can call fetch with duplicated object IDs. object_id3 = random_object_id() self.client1.fetch([object_id3, object_id3]) object_id4, memory_buffer4, metadata4 = create_object( - self.client1, 2000, 2000) + self.client1, 2000, 2000 + ) time.sleep(0.1) # TODO(rkn): Right now we must wait for the object table to be updated. while not self.client2.contains(object_id4): - self.client2.fetch( - [object_id3, object_id3, object_id4, object_id4]) + self.client2.fetch([object_id3, object_id3, object_id4, object_id4]) assert_get_object_equal( self, self.client2, self.client1, object_id4, memory_buffer=memory_buffer4, - metadata=metadata4) + metadata=metadata4 + ) def test_wait(self): # Test timeout. @@ -295,8 +311,9 @@ def test_wait(self): obj_id1 = random_object_id() self.client1.create(obj_id1, 1000) self.client1.seal(obj_id1) - ready, waiting = self.client1.wait( - [obj_id1], timeout=100, num_returns=1) + ready, waiting = self.client1.wait([obj_id1], + timeout=100, + num_returns=1) self.assertEqual(set(ready), set([obj_id1])) self.assertEqual(waiting, []) @@ -305,8 +322,9 @@ def test_wait(self): obj_id2 = random_object_id() self.client1.create(obj_id2, 1000) # Don't seal. - ready, waiting = self.client1.wait( - [obj_id2, obj_id1], timeout=100, num_returns=1) + ready, waiting = self.client1.wait([obj_id2, obj_id1], + timeout=100, + num_returns=1) self.assertEqual(set(ready), set([obj_id1])) self.assertEqual(set(waiting), set([obj_id2])) @@ -319,8 +337,9 @@ def finish(): t = threading.Timer(0.1, finish) t.start() - ready, waiting = self.client1.wait( - [obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2) + ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], + timeout=1000, + num_returns=2) self.assertEqual(set(ready), set([obj_id1, obj_id3])) self.assertEqual(set(waiting), set([obj_id2])) @@ -352,12 +371,14 @@ def finish(): retrieved = [] for i in range(1, n + 1): ready, waiting = self.client1.wait( - waiting, timeout=1000, num_returns=i) + waiting, timeout=1000, num_returns=i + ) self.assertEqual(len(ready), i) retrieved += ready self.assertEqual(set(retrieved), set(object_ids)) ready, waiting = self.client1.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) + object_ids, timeout=1000, num_returns=len(object_ids) + ) self.assertEqual(set(ready), set(object_ids)) self.assertEqual(waiting, []) # Try waiting for all of the object IDs on the second client. @@ -365,12 +386,14 @@ def finish(): retrieved = [] for i in range(1, n + 1): ready, waiting = self.client2.wait( - waiting, timeout=1000, num_returns=i) + waiting, timeout=1000, num_returns=i + ) self.assertEqual(len(ready), i) retrieved += ready self.assertEqual(set(retrieved), set(object_ids)) ready, waiting = self.client2.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) + object_ids, timeout=1000, num_returns=len(object_ids) + ) self.assertEqual(set(ready), set(object_ids)) self.assertEqual(waiting, []) @@ -382,11 +405,13 @@ def finish(): random.shuffle(object_ids_perm) for i in range(10): if i % 2 == 0: - create_object_with_id(self.client1, object_ids_perm[i], 2000, - 2000) + create_object_with_id( + self.client1, object_ids_perm[i], 2000, 2000 + ) else: - create_object_with_id(self.client2, object_ids_perm[i], 2000, - 2000) + create_object_with_id( + self.client2, object_ids_perm[i], 2000, 2000 + ) ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1)) self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)])) self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) @@ -396,14 +421,14 @@ def test_transfer(self): for _ in range(100): # Create an object. object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) + self.client1, 2000, 2000 + ) # Transfer the buffer to the the other Plasma store. There is a # race condition on the create and transfer of the object, so keep # trying until the object appears on the second Plasma store. for i in range(num_attempts): self.client1.transfer("127.0.0.1", self.port2, object_id1) - buff = self.client2.get_buffers( - [object_id1], timeout_ms=100)[0] + buff = self.client2.get_buffers([object_id1], timeout_ms=100)[0] if buff is not None: break self.assertNotEqual(buff, None) @@ -416,7 +441,8 @@ def test_transfer(self): self.client2, object_id1, memory_buffer=memory_buffer1, - metadata=metadata1) + metadata=metadata1 + ) # # Transfer the buffer again. # self.client1.transfer("127.0.0.1", self.port2, object_id1) # # Compare the two buffers. @@ -427,14 +453,14 @@ def test_transfer(self): # Create an object. object_id2, memory_buffer2, metadata2 = create_object( - self.client2, 20000, 20000) + self.client2, 20000, 20000 + ) # Transfer the buffer to the the other Plasma store. There is a # race condition on the create and transfer of the object, so keep # trying until the object appears on the second Plasma store. for i in range(num_attempts): self.client2.transfer("127.0.0.1", self.port1, object_id2) - buff = self.client1.get_buffers( - [object_id2], timeout_ms=100)[0] + buff = self.client1.get_buffers([object_id2], timeout_ms=100)[0] if buff is not None: break self.assertNotEqual(buff, None) @@ -447,7 +473,8 @@ def test_transfer(self): self.client2, object_id2, memory_buffer=memory_buffer2, - metadata=metadata2) + metadata=metadata2 + ) def test_illegal_functionality(self): # Create an object id string. @@ -478,13 +505,13 @@ def test_stresstest(self): class TestPlasmaManagerRecovery(unittest.TestCase): def setUp(self): # Start a Plasma store. - self.store_name, self.p2 = start_plasma_store( - use_valgrind=USE_VALGRIND) + self.store_name, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND) # Start a Redis server. self.redis_address, _ = services.start_redis("127.0.0.1") # Start a PlasmaManagers. manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) + self.store_name, self.redis_address, use_valgrind=USE_VALGRIND + ) # Connect a PlasmaClient. self.client = plasma.connect(self.store_name, manager_name, 64) @@ -530,7 +557,8 @@ def test_delayed_start(self): # Start a second plasma manager attached to the same store. manager_name, self.p5, self.port2 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) + self.store_name, self.redis_address, use_valgrind=USE_VALGRIND + ) self.processes_to_kill = [self.p5] + self.processes_to_kill # Check that the second manager knows about existing objects. @@ -538,7 +566,8 @@ def test_delayed_start(self): ready, waiting = [], object_ids while True: ready, waiting = client2.wait( - object_ids, num_returns=num_objects, timeout=0) + object_ids, num_returns=num_objects, timeout=0 + ) if len(ready) == len(object_ids): break diff --git a/python/ray/plasma/utils.py b/python/ray/plasma/utils.py index 956502e75a27..c96c0b591936 100644 --- a/python/ray/plasma/utils.py +++ b/python/ray/plasma/utils.py @@ -18,8 +18,9 @@ def generate_metadata(length): metadata_buffer[0] = random.randint(0, 255) metadata_buffer[-1] = random.randint(0, 255) for _ in range(100): - metadata_buffer[random.randint(0, length - 1)] = (random.randint( - 0, 255)) + metadata_buffer[random.randint(0, length - 1)] = ( + random.randint(0, 255) + ) return metadata_buffer @@ -32,11 +33,9 @@ def write_to_data_buffer(buff, length): array[random.randint(0, length - 1)] = random.randint(0, 255) -def create_object_with_id(client, - object_id, - data_size, - metadata_size, - seal=True): +def create_object_with_id( + client, object_id, data_size, metadata_size, seal=True +): metadata = generate_metadata(metadata_size) memory_buffer = client.create(object_id, data_size, metadata) write_to_data_buffer(memory_buffer, data_size) @@ -48,5 +47,6 @@ def create_object_with_id(client, def create_object(client, data_size, metadata_size, seal=True): object_id = random_object_id() memory_buffer, metadata = create_object_with_id( - client, object_id, data_size, metadata_size, seal=seal) + client, object_id, data_size, metadata_size, seal=seal + ) return object_id, memory_buffer, metadata diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index a2441f0b5bf6..dacc3e6c2064 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -8,9 +8,10 @@ def _register_all(): - for key in ["PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", - "DDPG2", "APEX_DDPG", "__fake", "__sigmoid_fake_data", - "__parameter_tuning"]: + for key in [ + "PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "DDPG2", + "APEX_DDPG", "__fake", "__sigmoid_fake_data", "__parameter_tuning" + ]: from ray.rllib.agent import get_agent_class register_trainable(key, get_agent_class(key)) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index fa10db9c918c..61c3c7c48c2b 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -15,7 +15,6 @@ from ray.tune.result import TrainingResult from ray.tune.trial import Resources - DEFAULT_CONFIG = { # Number of workers (excluding master) "num_workers": 4, @@ -73,53 +72,67 @@ class A3CAgent(Agent): def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( - cpu=1, gpu=0, + cpu=1, + gpu=0, extra_cpu=cf["num_workers"], - extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) + extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0 + ) def _init(self): self.local_evaluator = A3CEvaluator( - self.registry, self.env_creator, self.config, self.logdir, - start_sampler=False) + self.registry, + self.env_creator, + self.config, + self.logdir, + start_sampler=False + ) if self.config["use_gpu_for_workers"]: remote_cls = GPURemoteA3CEvaluator else: remote_cls = RemoteA3CEvaluator self.remote_evaluators = [ remote_cls.remote( - self.registry, self.env_creator, self.config, self.logdir) - for i in range(self.config["num_workers"])] + self.registry, self.env_creator, self.config, self.logdir + ) for i in range(self.config["num_workers"]) + ] self.optimizer = AsyncOptimizer( self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.remote_evaluators + ) def _train(self): self.optimizer.step() FilterManager.synchronize( - self.local_evaluator.filters, self.remote_evaluators) + self.local_evaluator.filters, self.remote_evaluators + ) res = self._fetch_metrics_from_remote_evaluators() return res def _fetch_metrics_from_remote_evaluators(self): episode_rewards = [] episode_lengths = [] - metric_lists = [a.get_completed_rollout_metrics.remote() - for a in self.remote_evaluators] + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators + ] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) avg_reward = ( - np.mean(episode_rewards) if episode_rewards else float('nan')) + np.mean(episode_rewards) if episode_rewards else float('nan') + ) avg_length = ( - np.mean(episode_lengths) if episode_lengths else float('nan')) + np.mean(episode_lengths) if episode_lengths else float('nan') + ) timesteps = np.sum(episode_lengths) if episode_lengths else 0 result = TrainingResult( episode_reward_mean=avg_reward, episode_len_mean=avg_length, timesteps_this_iter=timesteps, - info={}) + info={} + ) return result @@ -130,20 +143,22 @@ def _stop(self): def _save(self, checkpoint_dir): checkpoint_path = os.path.join( - checkpoint_dir, "checkpoint-{}".format(self.iteration)) - agent_state = ray.get( - [a.save.remote() for a in self.remote_evaluators]) + checkpoint_dir, "checkpoint-{}".format(self.iteration) + ) + agent_state = ray.get([a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, - "local_state": self.local_evaluator.save()} + "local_state": self.local_evaluator.save() + } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - ray.get( - [a.restore.remote(o) for a, o in zip( - self.remote_evaluators, extra_data["remote_state"])]) + ray.get([ + a.restore.remote(o) + for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) + ]) self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation): diff --git a/python/ray/rllib/a3c/a3c_evaluator.py b/python/ray/rllib/a3c/a3c_evaluator.py index 9b0522dfdf56..c02a0c9e7afb 100644 --- a/python/ray/rllib/a3c/a3c_evaluator.py +++ b/python/ray/rllib/a3c/a3c_evaluator.py @@ -26,25 +26,33 @@ class A3CEvaluator(PolicyEvaluator): rollouts. logdir: Directory for logging. """ + def __init__( - self, registry, env_creator, config, logdir, start_sampler=True): + self, registry, env_creator, config, logdir, start_sampler=True + ): env = ModelCatalog.get_preprocessor_as_wrapper( - registry, env_creator(config["env_config"]), config["model"]) + registry, env_creator(config["env_config"]), config["model"] + ) self.env = env policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls( - registry, env.observation_space.shape, env.action_space, config) + registry, env.observation_space.shape, env.action_space, config + ) self.config = config # Technically not needed when not remote self.obs_filter = get_filter( - config["observation_filter"], env.observation_space.shape) + config["observation_filter"], env.observation_space.shape + ) self.rew_filter = get_filter(config["reward_filter"], ()) - self.filters = {"obs_filter": self.obs_filter, - "rew_filter": self.rew_filter} - self.sampler = AsyncSampler(env, self.policy, self.obs_filter, - config["batch_size"]) + self.filters = { + "obs_filter": self.obs_filter, + "rew_filter": self.rew_filter + } + self.sampler = AsyncSampler( + env, self.policy, self.obs_filter, config["batch_size"] + ) if start_sampler and self.sampler.async: self.sampler.start() self.logdir = logdir @@ -52,8 +60,12 @@ def __init__( def sample(self): rollout = self.sampler.get_data() samples = process_rollout( - rollout, self.rew_filter, gamma=self.config["gamma"], - lambda_=self.config["lambda"], use_gae=True) + rollout, + self.rew_filter, + gamma=self.config["gamma"], + lambda_=self.config["lambda"], + use_gae=True + ) return samples def get_completed_rollout_metrics(self): @@ -79,9 +91,7 @@ def set_weights(self, params): def save(self): filters = self.get_filters(flush_after=True) weights = self.get_weights() - return pickle.dumps({ - "filters": filters, - "weights": weights}) + return pickle.dumps({"filters": filters, "weights": weights}) def restore(self, objs): objs = pickle.loads(objs) diff --git a/python/ray/rllib/a3c/policy.py b/python/ray/rllib/a3c/policy.py index 1e9639fd71af..2772a9b006f2 100644 --- a/python/ray/rllib/a3c/policy.py +++ b/python/ray/rllib/a3c/policy.py @@ -5,6 +5,7 @@ class Policy(object): """The policy base class.""" + def __init__(self, ob_space, action_space, name="local", summarize=True): pass diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index 8209be159ed4..67300ce7e151 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -14,26 +14,33 @@ class SharedModel(TFPolicy): is_recurrent = False def __init__(self, registry, ob_space, ac_space, config, **kwargs): - super(SharedModel, self).__init__( - registry, ob_space, ac_space, config, **kwargs) + super(SharedModel, + self).__init__(registry, ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self._model = ModelCatalog.get_model( - self.registry, self.x, self.logit_dim, self.config["model"]) + self.registry, self.x, self.logit_dim, self.config["model"] + ) self.logits = self._model.outputs self.curr_dist = dist_class(self.logits) - self.vf = tf.reshape(linear(self._model.last_layer, 1, "value", - normc_initializer(1.0)), [-1]) + self.vf = tf.reshape( + linear(self._model.last_layer, 1, "value", normc_initializer(1.0)), + [-1] + ) self.sample = self.curr_dist.sample() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) + self.var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) self.global_step = tf.get_variable( - "global_step", [], tf.int32, + "global_step", [], + tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), - trainable=False) + trainable=False + ) def compute_gradients(self, samples): info = {} @@ -54,8 +61,7 @@ def compute_gradients(self, samples): return grad, info def compute(self, ob, *args): - action, vf = self.sess.run([self.sample, self.vf], - {self.x: [ob]}) + action, vf = self.sess.run([self.sample, self.vf], {self.x: [ob]}) return action[0], {"vf_preds": vf[0]} def value(self, ob, *args): diff --git a/python/ray/rllib/a3c/shared_model_lstm.py b/python/ray/rllib/a3c/shared_model_lstm.py index 37f71e490467..26f9017ffe8c 100644 --- a/python/ray/rllib/a3c/shared_model_lstm.py +++ b/python/ray/rllib/a3c/shared_model_lstm.py @@ -22,8 +22,8 @@ class SharedModelLSTM(TFPolicy): is_recurrent = True def __init__(self, registry, ob_space, ac_space, config, **kwargs): - super(SharedModelLSTM, self).__init__( - registry, ob_space, ac_space, config, **kwargs) + super(SharedModelLSTM, + self).__init__(registry, ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) @@ -38,16 +38,22 @@ def _setup_graph(self, ob_space, ac_space): self.curr_dist = dist_class(self.logits) # with tf.variable_scope("vf"): # vf_model = ModelCatalog.get_model(self.x, 1) - self.vf = tf.reshape(linear(self._model.last_layer, 1, "value", - normc_initializer(1.0)), [-1]) + self.vf = tf.reshape( + linear(self._model.last_layer, 1, "value", normc_initializer(1.0)), + [-1] + ) self.sample = self.curr_dist.sample() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) + self.var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) self.global_step = tf.get_variable( - "global_step", [], tf.int32, + "global_step", [], + tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), - trainable=False) + trainable=False + ) def compute_gradients(self, samples): """Computing the gradient is actually model-dependent. @@ -75,15 +81,23 @@ def compute_gradients(self, samples): return grad, info def compute(self, ob, c, h): - action, vf, c, h = self.sess.run( - [self.sample, self.vf] + self.state_out, - {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) + action, vf, c, h = self.sess.run([self.sample, self.vf] + self.state_out, + { + self.x: [ob], + self.state_in[0]: c, + self.state_in[1]: h + }) return action[0], {"vf_preds": vf[0], "features": (c, h)} def value(self, ob, c, h): - vf = self.sess.run(self.vf, {self.x: [ob], - self.state_in[0]: c, - self.state_in[1]: h}) + vf = self.sess.run( + self.vf, + { + self.x: [ob], + self.state_in[0]: c, + self.state_in[1]: h + } + ) return vf[0] def get_initial_features(self): diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index 59b7a2577008..0030122bd87a 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -18,15 +18,17 @@ class SharedTorchPolicy(TorchPolicy): is_recurrent = False def __init__(self, ob_space, ac_space, config, **kwargs): - super(SharedTorchPolicy, self).__init__( - ob_space, ac_space, config, **kwargs) + super(SharedTorchPolicy, + self).__init__(ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self._model = ModelCatalog.get_torch_model( - self.registry, ob_space, self.logit_dim, self.config["model"]) + self.registry, ob_space, self.logit_dim, self.config["model"] + ) self.optimizer = torch.optim.Adam( - self._model.parameters(), lr=self.config["lr"]) + self._model.parameters(), lr=self.config["lr"] + ) def compute(self, ob, *args): """Should take in a SINGLE ob""" @@ -70,9 +72,11 @@ def _backward(self, batch): value_err = 0.5 * (values - rs).pow(2).sum() self.optimizer.zero_grad() - overall_err = (pi_err + - value_err * self.config["vf_loss_coeff"] + - entropy * self.config["entropy_coeff"]) + overall_err = ( + pi_err + value_err * self.config["vf_loss_coeff"] + + entropy * self.config["entropy_coeff"] + ) overall_err.backward() torch.nn.utils.clip_grad_norm( - self._model.parameters(), self.config["grad_clip"]) + self._model.parameters(), self.config["grad_clip"] + ) diff --git a/python/ray/rllib/a3c/tfpolicy.py b/python/ray/rllib/a3c/tfpolicy.py index 4816a7fefb54..f3ed1f02c345 100644 --- a/python/ray/rllib/a3c/tfpolicy.py +++ b/python/ray/rllib/a3c/tfpolicy.py @@ -10,8 +10,16 @@ class TFPolicy(Policy): """The policy base class.""" - def __init__(self, registry, ob_space, action_space, config, - name="local", summarize=True): + + def __init__( + self, + registry, + ob_space, + action_space, + config, + name="local", + summarize=True + ): self.registry = registry self.local_steps = 0 self.config = config @@ -21,8 +29,10 @@ def __init__(self, registry, ob_space, action_space, config, with self.g.as_default(), tf.device(worker_device): with tf.variable_scope(name): self._setup_graph(ob_space, action_space) - assert all([hasattr(self, attr) - for attr in ["vf", "logits", "x", "var_list"]]) + assert all([ + hasattr(self, attr) + for attr in ["vf", "logits", "x", "var_list"] + ]) print("Setting up loss") self.setup_loss(action_space) self.setup_gradients() @@ -40,7 +50,8 @@ def setup_loss(self, action_space): else: raise NotImplementedError( "action space" + str(type(action_space)) + - "currently not supported") + "currently not supported" + ) self.adv = tf.placeholder(tf.float32, [None], name="adv") self.r = tf.placeholder(tf.float32, [None], name="r") @@ -50,14 +61,15 @@ def setup_loss(self, action_space): # gradient. Notice that self.ac is a placeholder that is provided # externally. adv will contain the advantages, as calculated in # process_rollout. - self.pi_loss = - tf.reduce_sum(log_prob * self.adv) + self.pi_loss = -tf.reduce_sum(log_prob * self.adv) delta = self.vf - self.r self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) self.entropy = tf.reduce_sum(self.curr_dist.entropy()) - self.loss = (self.pi_loss + - self.vf_loss * self.config["vf_loss_coeff"] + - self.entropy * self.config["entropy_coeff"]) + self.loss = ( + self.pi_loss + self.vf_loss * self.config["vf_loss_coeff"] + + self.entropy * self.config["entropy_coeff"] + ) def setup_gradients(self): grads = tf.gradients(self.loss, self.var_list) @@ -77,16 +89,21 @@ def initialize(self): self.summary_op = tf.summary.merge_all() # TODO(rliaw): Can consider exposing these parameters - self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=2, - gpu_options=tf.GPUOptions(allow_growth=True))) - self.variables = ray.experimental.TensorFlowVariables(self.loss, - self.sess) + self.sess = tf.Session( + graph=self.g, + config=tf.ConfigProto( + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=2, + gpu_options=tf.GPUOptions(allow_growth=True) + ) + ) + self.variables = ray.experimental.TensorFlowVariables( + self.loss, self.sess + ) self.sess.run(tf.global_variables_initializer()) def apply_gradients(self, grads): - feed_dict = {self.grads[i]: grads[i] - for i in range(len(grads))} + feed_dict = {self.grads[i]: grads[i] for i in range(len(grads))} self.sess.run(self._apply_gradients, feed_dict=feed_dict) def get_weights(self): diff --git a/python/ray/rllib/a3c/torchpolicy.py b/python/ray/rllib/a3c/torchpolicy.py index 8c7d86a086c3..295dac660ced 100644 --- a/python/ray/rllib/a3c/torchpolicy.py +++ b/python/ray/rllib/a3c/torchpolicy.py @@ -15,8 +15,15 @@ class TorchPolicy(Policy): The model is a separate object than the policy. This could be changed in the future.""" - def __init__(self, registry, ob_space, action_space, config, - name="local", summarize=True): + def __init__( + self, + registry, + ob_space, + action_space, + config, + name="local", + summarize=True + ): self.registry = registry self.local_steps = 0 self.config = config diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 5699022b2a8e..b390de433619 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -34,8 +34,7 @@ def _deep_update(original, new_dict, new_keys_allowed, whitelist): for k, value in new_dict.items(): if k not in original and k != "env": if not new_keys_allowed: - raise Exception( - "Unknown config parameter `{}` ".format(k)) + raise Exception("Unknown config parameter `{}` ".format(k)) if type(original.get(k)) is dict: if k in whitelist: _deep_update(original[k], value, True, []) @@ -69,11 +68,12 @@ def resource_help(cls, config): "\n\nYou can adjust the resource requests of RLlib agents by " "setting `num_workers` and other configs. See the " "DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: " + json.dumps(config)) + "The config of this agent is: " + json.dumps(config) + ) def __init__( - self, config=None, env=None, registry=None, - logger_creator=None): + self, config=None, env=None, registry=None, logger_creator=None + ): """Initialize an RLLib agent. Args: @@ -106,9 +106,10 @@ def _setup(self): # Merge the supplied config with the class default merged_config = self._default_config.copy() - merged_config = _deep_update(merged_config, self.config, - self._allow_unknown_configs, - self._allow_unknown_subkeys) + merged_config = _deep_update( + merged_config, self.config, self._allow_unknown_configs, + self._allow_unknown_subkeys + ) self.config = merged_config # TODO(ekl) setting the graph is unnecessary for PyTorch agents @@ -162,8 +163,11 @@ def _train(self): and (self.config["persistent_error"] or not self.restored): raise Exception("mock error") return TrainingResult( - episode_reward_mean=10, episode_len_mean=10, - timesteps_this_iter=10, info={}) + episode_reward_mean=10, + episode_len_mean=10, + timesteps_this_iter=10, + info={} + ) def _save(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "mock_agent.pkl") @@ -204,9 +208,12 @@ def _train(self): v = np.tanh(float(i) / self.config["width"]) v *= self.config["height"] return TrainingResult( - episode_reward_mean=v, episode_len_mean=v, + episode_reward_mean=v, + episode_len_mean=v, timesteps_this_iter=self.config["iter_timesteps"], - time_this_iter_s=self.config["iter_time"], info={}) + time_this_iter_s=self.config["iter_time"], + info={} + ) class _ParameterTuningAgent(_MockAgent): @@ -225,7 +232,9 @@ def _train(self): episode_reward_mean=self.config["reward_amt"] * self.iteration, episode_len_mean=self.config["reward_amt"], timesteps_this_iter=self.config["iter_timesteps"], - time_this_iter_s=self.config["iter_time"], info={}) + time_this_iter_s=self.config["iter_time"], + info={} + ) def get_agent_class(alg): @@ -271,5 +280,4 @@ def get_agent_class(alg): elif alg == "__parameter_tuning": return _ParameterTuningAgent else: - raise Exception( - ("Unknown algorithm {}.").format(alg)) + raise Exception(("Unknown algorithm {}.").format(alg)) diff --git a/python/ray/rllib/bc/bc.py b/python/ray/rllib/bc/bc.py index cdfc7ab98878..c8ddb3252639 100644 --- a/python/ray/rllib/bc/bc.py +++ b/python/ray/rllib/bc/bc.py @@ -57,29 +57,35 @@ def default_resource_request(cls, config): else: num_gpus_per_worker = 0 return Resources( - cpu=1, gpu=cf["gpu"] and 1 or 0, + cpu=1, + gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_workers"], - extra_gpu=num_gpus_per_worker * cf["num_workers"]) + extra_gpu=num_gpus_per_worker * cf["num_workers"] + ) def _init(self): self.local_evaluator = BCEvaluator( - self.registry, self.env_creator, self.config, self.logdir) + self.registry, self.env_creator, self.config, self.logdir + ) if self.config["use_gpu_for_workers"]: remote_cls = GPURemoteBCEvaluator else: remote_cls = RemoteBCEvaluator self.remote_evaluators = [ remote_cls.remote( - self.registry, self.env_creator, self.config, self.logdir) - for _ in range(self.config["num_workers"])] + self.registry, self.env_creator, self.config, self.logdir + ) for _ in range(self.config["num_workers"]) + ] self.optimizer = AsyncOptimizer( self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.remote_evaluators + ) def _train(self): self.optimizer.step() - metric_lists = [re.get_metrics.remote() for re in - self.remote_evaluators] + metric_lists = [ + re.get_metrics.remote() for re in self.remote_evaluators + ] total_samples = 0 total_loss = 0 for metrics in metric_lists: diff --git a/python/ray/rllib/bc/bc_evaluator.py b/python/ray/rllib/bc/bc_evaluator.py index 8499ba1e023e..ec6fa9ce1126 100644 --- a/python/ray/rllib/bc/bc_evaluator.py +++ b/python/ray/rllib/bc/bc_evaluator.py @@ -14,12 +14,14 @@ class BCEvaluator(PolicyEvaluator): def __init__(self, registry, env_creator, config, logdir): - env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator( - config["env_config"]), config["model"]) + env = ModelCatalog.get_preprocessor_as_wrapper( + registry, env_creator(config["env_config"]), config["model"] + ) self.dataset = ExperienceDataset(config["dataset_path"]) # TODO(rliaw): should change this to be just env.observation_space - self.policy = BCPolicy(registry, env.observation_space.shape, - env.action_space, config) + self.policy = BCPolicy( + registry, env.observation_space.shape, env.action_space, config + ) self.config = config self.logdir = logdir self.metrics_queue = queue.Queue() @@ -29,8 +31,10 @@ def sample(self): def compute_gradients(self, samples): gradient, info = self.policy.compute_gradients(samples) - self.metrics_queue.put( - {"num_samples": info["num_samples"], "loss": info["loss"]}) + self.metrics_queue.put({ + "num_samples": info["num_samples"], + "loss": info["loss"] + }) return gradient, {} def apply_gradients(self, grads): @@ -44,8 +48,7 @@ def set_weights(self, params): def save(self): weights = self.get_weights() - return pickle.dumps({ - "weights": weights}) + return pickle.dumps({"weights": weights}) def restore(self, objs): objs = pickle.loads(objs) diff --git a/python/ray/rllib/bc/experience_dataset.py b/python/ray/rllib/bc/experience_dataset.py index ccf47bc31ee2..3600fdf6be4d 100644 --- a/python/ray/rllib/bc/experience_dataset.py +++ b/python/ray/rllib/bc/experience_dataset.py @@ -21,8 +21,11 @@ def __init__(self, dataset_path): elements. The file must be available on each machine used by a BCEvaluator. """ - self._dataset = list(itertools.chain.from_iterable( - pickle.load(open(dataset_path, "rb")))) + self._dataset = list( + itertools.chain.from_iterable( + pickle.load(open(dataset_path, "rb")) + ) + ) def sample(self, batch_size): indexes = np.random.choice(len(self._dataset), batch_size) diff --git a/python/ray/rllib/bc/policy.py b/python/ray/rllib/bc/policy.py index 7566422fa154..03c87029e2cb 100644 --- a/python/ray/rllib/bc/policy.py +++ b/python/ray/rllib/bc/policy.py @@ -9,8 +9,15 @@ class BCPolicy(Policy): - def __init__(self, registry, ob_space, action_space, config, name="local", - summarize=True): + def __init__( + self, + registry, + ob_space, + action_space, + config, + name="local", + summarize=True + ): super(BCPolicy, self).__init__(ob_space, action_space, name, summarize) self.registry = registry self.local_steps = 0 @@ -30,17 +37,20 @@ def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self._model = ModelCatalog.get_model( - self.registry, self.x, self.logit_dim, self.config["model"]) + self.registry, self.x, self.logit_dim, self.config["model"] + ) self.logits = self._model.outputs self.curr_dist = dist_class(self.logits) self.sample = self.curr_dist.sample() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) + self.var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) def setup_loss(self, action_space): self.ac = tf.placeholder(tf.int64, [None], name="ac") log_prob = self.curr_dist.logp(self.ac) - self.pi_loss = - tf.reduce_sum(log_prob) + self.pi_loss = -tf.reduce_sum(log_prob) self.loss = self.pi_loss def setup_gradients(self): @@ -59,11 +69,17 @@ def initialize(self): self.summary_op = tf.summary.merge_all() # TODO(rliaw): Can consider exposing these parameters - self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=2, - gpu_options=tf.GPUOptions(allow_growth=True))) - self.variables = ray.experimental.TensorFlowVariables(self.loss, - self.sess) + self.sess = tf.Session( + graph=self.g, + config=tf.ConfigProto( + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=2, + gpu_options=tf.GPUOptions(allow_growth=True) + ) + ) + self.variables = ray.experimental.TensorFlowVariables( + self.loss, self.sess + ) self.sess.run(tf.global_variables_initializer()) def compute_gradients(self, samples): @@ -75,8 +91,10 @@ def compute_gradients(self, samples): self.grads = [g for g in self.grads if g is not None] self.local_steps += 1 if self.summarize: - loss, grad, summ = self.sess.run( - [self.loss, self.grads, self.summary_op], feed_dict=feed_dict) + loss, grad, summ = self.sess.run([ + self.loss, self.grads, self.summary_op + ], + feed_dict=feed_dict) info["summary"] = summ else: loss, grad = self.sess.run([self.loss, self.grads], @@ -86,8 +104,7 @@ def compute_gradients(self, samples): return grad, info def apply_gradients(self, grads): - feed_dict = {self.grads[i]: grads[i] - for i in range(len(grads))} + feed_dict = {self.grads[i]: grads[i] for i in range(len(grads))} self.sess.run(self._apply_gradients, feed_dict=feed_dict) def get_weights(self): diff --git a/python/ray/rllib/ddpg/apex.py b/python/ray/rllib/ddpg/apex.py index c670198c3571..04d060abd9d4 100644 --- a/python/ray/rllib/ddpg/apex.py +++ b/python/ray/rllib/ddpg/apex.py @@ -4,28 +4,31 @@ from ray.rllib.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG -APEX_DDPG_DEFAULT_CONFIG = dict(DDPG_CONFIG, - **dict( - optimizer_class="ApexOptimizer", - optimizer_config=dict( - DDPG_CONFIG["optimizer_config"], - **dict( - max_weight_sync_delay=400, - num_replay_buffer_shards=4, - debug=False, - )), - n_step=3, - num_workers=32, - buffer_size=2000000, - learning_starts=50000, - train_batch_size=512, - sample_batch_size=50, - max_weight_sync_delay=400, - target_network_update_freq=500000, - timesteps_per_iteration=25000, - per_worker_exploration=True, - worker_side_prioritization=True, - )) +APEX_DDPG_DEFAULT_CONFIG = dict( + DDPG_CONFIG, + **dict( + optimizer_class="ApexOptimizer", + optimizer_config=dict( + DDPG_CONFIG["optimizer_config"], + **dict( + max_weight_sync_delay=400, + num_replay_buffer_shards=4, + debug=False, + ) + ), + n_step=3, + num_workers=32, + buffer_size=2000000, + learning_starts=50000, + train_batch_size=512, + sample_batch_size=50, + max_weight_sync_delay=400, + target_network_update_freq=500000, + timesteps_per_iteration=25000, + per_worker_exploration=True, + worker_side_prioritization=True, + ) +) class ApexDDPGAgent(DDPGAgent): diff --git a/python/ray/rllib/ddpg/ddpg.py b/python/ray/rllib/ddpg/ddpg.py index 343b323948b3..b662b44157b2 100644 --- a/python/ray/rllib/ddpg/ddpg.py +++ b/python/ray/rllib/ddpg/ddpg.py @@ -129,7 +129,8 @@ # Whether to use a distribution of epsilons across workers for exploration. per_worker_exploration=False, # Whether to compute priorities on workers. - worker_side_prioritization=False) + worker_side_prioritization=False +) class DDPGAgent(Agent): @@ -140,15 +141,18 @@ class DDPGAgent(Agent): _default_config = DEFAULT_CONFIG def _init(self): - self.local_evaluator = DDPGEvaluator(self.registry, self.env_creator, - self.config, self.logdir, 0) + self.local_evaluator = DDPGEvaluator( + self.registry, self.env_creator, self.config, self.logdir, 0 + ) remote_cls = ray.remote( - num_cpus=1, - num_gpus=self.config["num_gpus_per_worker"])(DDPGEvaluator) + num_cpus=1, num_gpus=self.config["num_gpus_per_worker"] + )( + DDPGEvaluator + ) self.remote_evaluators = [ - remote_cls.remote(self.registry, self.env_creator, self.config, - self.logdir, i) - for i in range(self.config["num_workers"]) + remote_cls.remote( + self.registry, self.env_creator, self.config, self.logdir, i + ) for i in range(self.config["num_workers"]) ] for k in OPTIMIZER_SHARED_CONFIGS: @@ -157,7 +161,8 @@ def _init(self): self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, - self.remote_evaluators) + self.remote_evaluators + ) self.saver = tf.train.Saver(max_to_keep=None) self.last_target_update_ts = 0 @@ -177,8 +182,10 @@ def update_target_if_needed(self): def _train(self): start_timestep = self.global_timestep - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"]): + while ( + self.global_timestep - start_timestep < + self.config["timesteps_per_iteration"] + ): self.optimizer.step() self.update_target_if_needed() @@ -227,7 +234,8 @@ def _train_stats(self, start_timestep): "min_exploration": min(explorations), "max_exploration": max(explorations), "num_target_updates": self.num_target_updates, - }, **opt_stats)) + }, **opt_stats) + ) return result @@ -240,7 +248,8 @@ def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, os.path.join(checkpoint_dir, "checkpoint"), - global_step=self.iteration) + global_step=self.iteration + ) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), @@ -263,6 +272,7 @@ def _restore(self, checkpoint_path): self.last_target_update_ts = extra_data[4] def compute_action(self, observation): - return self.local_evaluator.ddpg_graph.act(self.local_evaluator.sess, - np.array(observation)[None], - 0.0)[0] + return self.local_evaluator.ddpg_graph.act( + self.local_evaluator.sess, + np.array(observation)[None], 0.0 + )[0] diff --git a/python/ray/rllib/ddpg/ddpg_evaluator.py b/python/ray/rllib/ddpg/ddpg_evaluator.py index 5a68c4b583ee..76cf7ed06cab 100644 --- a/python/ray/rllib/ddpg/ddpg_evaluator.py +++ b/python/ray/rllib/ddpg/ddpg_evaluator.py @@ -31,7 +31,9 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): if not isinstance(env.action_space, Box): raise UnsupportedSpaceException( "Action space {} is not supported for DDPG.".format( - env.action_space)) + env.action_space + ) + ) tf_config = tf.ConfigProto(**config["tf_session_args"]) self.sess = tf.Session(config=tf_config) @@ -41,15 +43,18 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): if config["per_worker_exploration"]: assert config["num_workers"] > 1, "This requires multiple workers" self.exploration = ConstantSchedule( - config["noise_scale"] * 0.4 ** - (1 + worker_index / float(config["num_workers"] - 1) * 7)) + config["noise_scale"] * 0.4** + (1 + worker_index / float(config["num_workers"] - 1) * 7) + ) else: self.exploration = LinearSchedule( - schedule_timesteps=int(config["exploration_fraction"] * - config["schedule_max_timesteps"]), + schedule_timesteps=int( + config["exploration_fraction"] * + config["schedule_max_timesteps"] + ), initial_p=config["noise_scale"] * 1.0, - final_p=config["noise_scale"] * - config["exploration_final_eps"]) + final_p=config["noise_scale"] * config["exploration_final_eps"] + ) # Initialize the parameters and copy them to the target network. self.sess.run(tf.global_variables_initializer()) @@ -61,7 +66,8 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): # Note that this encompasses both the policy and Q-value networks and # their corresponding target networks self.variables = ray.experimental.TensorFlowVariables( - tf.group(self.ddpg_graph.q_tp0, self.ddpg_graph.q_tp1), self.sess) + tf.group(self.ddpg_graph.q_tp0, self.ddpg_graph.q_tp1), self.sess + ) self.episode_rewards = [0.0] self.episode_lengths = [0.0] @@ -78,7 +84,8 @@ def update_target(self): def sample(self): obs, actions, rewards, new_obs, dones = [], [], [], [], [] for _ in range( - self.config["sample_batch_size"] + self.config["n_step"] - 1): + self.config["sample_batch_size"] + self.config["n_step"] - 1 + ): ob, act, rew, ob1, done = self._step(self.global_timestep) obs.append(ob) actions.append(act) @@ -90,8 +97,10 @@ def sample(self): if self.config["n_step"] > 1: # Adjust for steps lost from truncation self.local_timestep -= (self.config["n_step"] - 1) - adjust_nstep(self.config["n_step"], self.config["gamma"], obs, - actions, rewards, new_obs, dones) + adjust_nstep( + self.config["n_step"], self.config["gamma"], obs, actions, + rewards, new_obs, dones + ) batch = SampleBatch({ "obs": [pack(np.array(o)) for o in obs], @@ -107,9 +116,11 @@ def sample(self): if self.config["worker_side_prioritization"]: td_errors = self.ddpg_graph.compute_td_error( self.sess, obs, batch["actions"], batch["rewards"], new_obs, - batch["dones"], batch["weights"]) + batch["dones"], batch["weights"] + ) new_priorities = ( - np.abs(td_errors) + self.config["prioritized_replay_eps"]) + np.abs(td_errors) + self.config["prioritized_replay_eps"] + ) batch.data["weights"] = new_priorities return batch @@ -117,7 +128,8 @@ def sample(self): def compute_gradients(self, samples): td_err, grads = self.ddpg_graph.compute_gradients( self.sess, samples["obs"], samples["actions"], samples["rewards"], - samples["new_obs"], samples["dones"], samples["weights"]) + samples["new_obs"], samples["dones"], samples["weights"] + ) return grads, {"td_error": td_err} def apply_gradients(self, grads): @@ -126,7 +138,8 @@ def apply_gradients(self, grads): def compute_apply(self, samples): td_error = self.ddpg_graph.compute_apply( self.sess, samples["obs"], samples["actions"], samples["rewards"], - samples["new_obs"], samples["dones"], samples["weights"]) + samples["new_obs"], samples["dones"], samples["weights"] + ) return {"td_error": td_error} def get_weights(self): @@ -139,8 +152,8 @@ def _step(self, global_timestep): """Takes a single step, and returns the result of the step.""" action = self.ddpg_graph.act( self.sess, - np.array(self.obs)[None], - self.exploration.value(global_timestep))[0] + np.array(self.obs)[None], self.exploration.value(global_timestep) + )[0] new_obs, rew, done, _ = self.env.step(action) ret = (self.obs, action, rew, new_obs, float(done)) self.obs = new_obs diff --git a/python/ray/rllib/ddpg/models.py b/python/ray/rllib/ddpg/models.py index d58f37dc6417..147b4ab9c66b 100644 --- a/python/ray/rllib/ddpg/models.py +++ b/python/ray/rllib/ddpg/models.py @@ -21,19 +21,22 @@ def _build_p_network(registry, inputs, dim_actions, config): action_out = frontend.last_layer for hidden in hiddens: action_out = layers.fully_connected( - action_out, num_outputs=hidden, activation_fn=tf.nn.relu) + action_out, num_outputs=hidden, activation_fn=tf.nn.relu + ) # Use sigmoid layer to bound values within (0, 1) # shape of action_scores is [batch_size, dim_actions] action_scores = layers.fully_connected( - action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid) + action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid + ) return action_scores # As a stochastic policy for inference, but a deterministic policy for training # thus ignore batch_size issue when constructing a stochastic action -def _build_action_network(p_values, low_action, high_action, stochastic, eps, - theta, sigma): +def _build_action_network( + p_values, low_action, high_action, stochastic, eps, theta, sigma +): # shape is [None, dim_action] deterministic_actions = (high_action - low_action) * p_values + low_action @@ -41,17 +44,22 @@ def _build_action_network(p_values, low_action, high_action, stochastic, eps, name="ornstein_uhlenbeck", dtype=tf.float32, initializer=low_action.size * [.0], - trainable=False) + trainable=False + ) normal_sample = tf.random_normal( - shape=[low_action.size], mean=0.0, stddev=1.0) + shape=[low_action.size], mean=0.0, stddev=1.0 + ) exploration_value = tf.assign_add( exploration_sample, - theta * (.0 - exploration_sample) + sigma * normal_sample) + theta * (.0 - exploration_sample) + sigma * normal_sample + ) stochastic_actions = deterministic_actions + eps * ( - high_action - low_action) * exploration_value + high_action - low_action + ) * exploration_value - return tf.cond(stochastic, lambda: stochastic_actions, - lambda: deterministic_actions) + return tf.cond( + stochastic, lambda: stochastic_actions, lambda: deterministic_actions + ) def _build_q_network(registry, inputs, action_inputs, config): @@ -62,7 +70,8 @@ def _build_q_network(registry, inputs, action_inputs, config): q_out = tf.concat([frontend.last_layer, action_inputs], axis=1) for hidden in hiddens: q_out = layers.fully_connected( - q_out, num_outputs=hidden, activation_fn=tf.nn.relu) + q_out, num_outputs=hidden, activation_fn=tf.nn.relu + ) q_scores = layers.fully_connected(q_out, num_outputs=1, activation_fn=None) return q_scores @@ -72,7 +81,8 @@ def _huber_loss(x, delta=1.0): """Reference: https://en.wikipedia.org/wiki/Huber_loss""" return tf.where( tf.abs(x) < delta, - tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta)) + tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta) + ) def _minimize_and_clip(optimizer, objective, var_list, clip_val=10): @@ -108,7 +118,8 @@ def _scope_vars(scope, trainable_only=False): return tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES, - scope=scope if isinstance(scope, str) else scope.name) + scope=scope if isinstance(scope, str) else scope.name + ) class ModelAndLoss(object): @@ -118,16 +129,19 @@ class ModelAndLoss(object): to create towers on each device. """ - def __init__(self, registry, dim_actions, low_action, high_action, config, - obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): + def __init__( + self, registry, dim_actions, low_action, high_action, config, obs_t, + act_t, rew_t, obs_tp1, done_mask, importance_weights + ): # p network evaluation with tf.variable_scope("p_func", reuse=True) as scope: self.p_t = _build_p_network(registry, obs_t, dim_actions, config) # target p network evaluation with tf.variable_scope("target_p_func") as scope: - self.p_tp1 = _build_p_network(registry, obs_tp1, dim_actions, - config) + self.p_tp1 = _build_p_network( + registry, obs_tp1, dim_actions, config + ) self.target_p_func_vars = _scope_vars(scope.name) # Action outputs @@ -135,38 +149,43 @@ def __init__(self, registry, dim_actions, low_action, high_action, config, deterministic_flag = tf.constant(value=False, dtype=tf.bool) zero_eps = tf.constant(value=.0, dtype=tf.float32) output_actions = _build_action_network( - self.p_t, low_action, high_action, deterministic_flag, - zero_eps, config["exploration_theta"], - config["exploration_sigma"]) + self.p_t, low_action, high_action, deterministic_flag, zero_eps, + config["exploration_theta"], config["exploration_sigma"] + ) output_actions_estimated = _build_action_network( self.p_tp1, low_action, high_action, deterministic_flag, zero_eps, config["exploration_theta"], - config["exploration_sigma"]) + config["exploration_sigma"] + ) # q network evaluation with tf.variable_scope("q_func") as scope: self.q_t = _build_q_network(registry, obs_t, act_t, config) self.q_func_vars = _scope_vars(scope.name) with tf.variable_scope("q_func", reuse=True): - self.q_tp0 = _build_q_network(registry, obs_t, output_actions, - config) + self.q_tp0 = _build_q_network( + registry, obs_t, output_actions, config + ) # target q network evalution with tf.variable_scope("target_q_func") as scope: - self.q_tp1 = _build_q_network(registry, obs_tp1, - output_actions_estimated, config) + self.q_tp1 = _build_q_network( + registry, obs_tp1, output_actions_estimated, config + ) self.target_q_func_vars = _scope_vars(scope.name) q_t_selected = tf.squeeze(self.q_t, axis=len(self.q_t.shape) - 1) q_tp1_best = tf.squeeze( - input=self.q_tp1, axis=len(self.q_tp1.shape) - 1) + input=self.q_tp1, axis=len(self.q_tp1.shape) - 1 + ) q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = ( - rew_t + config["gamma"]**config["n_step"] * q_tp1_best_masked) + rew_t + config["gamma"]**config["n_step"] * q_tp1_best_masked + ) # compute the error (potentially clipped) self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) @@ -190,21 +209,25 @@ def __init__(self, registry, env, config, logdir): low_action = env.action_space.low high_action = env.action_space.high actor_optimizer = tf.train.AdamOptimizer( - learning_rate=config["actor_lr"]) + learning_rate=config["actor_lr"] + ) critic_optimizer = tf.train.AdamOptimizer( - learning_rate=config["critic_lr"]) + learning_rate=config["critic_lr"] + ) # Action inputs self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") self.cur_observations = tf.placeholder( - tf.float32, shape=(None, ) + env.observation_space.shape) + tf.float32, shape=(None, ) + env.observation_space.shape + ) # Actor: P (policy) network p_scope_name = "p_func" with tf.variable_scope(p_scope_name) as scope: - p_values = _build_p_network(registry, self.cur_observations, - dim_actions, config) + p_values = _build_p_network( + registry, self.cur_observations, dim_actions, config + ) p_func_vars = _scope_vars(scope.name) # Action outputs @@ -212,32 +235,40 @@ def __init__(self, registry, env, config, logdir): with tf.variable_scope(a_scope_name): self.output_actions = _build_action_network( p_values, low_action, high_action, self.stochastic, self.eps, - config["exploration_theta"], config["exploration_sigma"]) + config["exploration_theta"], config["exploration_sigma"] + ) with tf.variable_scope(a_scope_name, reuse=True): exploration_sample = tf.get_variable(name="ornstein_uhlenbeck") - self.reset_noise_op = tf.assign(exploration_sample, - dim_actions * [.0]) + self.reset_noise_op = tf.assign( + exploration_sample, dim_actions * [.0] + ) # Replay inputs self.obs_t = tf.placeholder( tf.float32, shape=(None, ) + env.observation_space.shape, - name="observation") + name="observation" + ) self.act_t = tf.placeholder( - tf.float32, shape=(None, ) + env.action_space.shape, name="action") + tf.float32, shape=(None, ) + env.action_space.shape, name="action" + ) self.rew_t = tf.placeholder(tf.float32, [None], name="reward") self.obs_tp1 = tf.placeholder( - tf.float32, shape=(None, ) + env.observation_space.shape) + tf.float32, shape=(None, ) + env.observation_space.shape + ) self.done_mask = tf.placeholder(tf.float32, [None], name="done") self.importance_weights = tf.placeholder( - tf.float32, [None], name="weight") + tf.float32, [None], name="weight" + ) - def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - return ModelAndLoss(registry, dim_actions, low_action, high_action, - config, obs_t, act_t, rew_t, obs_tp1, - done_mask, importance_weights) + def build_loss( + obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): + return ModelAndLoss( + registry, dim_actions, low_action, high_action, config, obs_t, + act_t, rew_t, obs_tp1, done_mask, importance_weights + ) self.loss_inputs = [ ("obs", self.obs_t), @@ -248,8 +279,10 @@ def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, ("weights", self.importance_weights), ] - loss_obj = build_loss(self.obs_t, self.act_t, self.rew_t, self.obs_tp1, - self.done_mask, self.importance_weights) + loss_obj = build_loss( + self.obs_t, self.act_t, self.rew_t, self.obs_tp1, self.done_mask, + self.importance_weights + ) self.build_loss = build_loss @@ -270,8 +303,8 @@ def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, actor_loss += config["l2_reg"] * 0.5 * tf.nn.l2_loss(var) for var in q_func_vars: if "bias" not in var.name: - weighted_error += config["l2_reg"] * 0.5 * tf.nn.l2_loss( - var) + weighted_error += config["l2_reg" + ] * 0.5 * tf.nn.l2_loss(var) # compute optimization op (potentially with gradient clipping) if config["grad_norm_clipping"] is not None: @@ -279,17 +312,21 @@ def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, actor_optimizer, actor_loss, var_list=p_func_vars, - clip_val=config["grad_norm_clipping"]) + clip_val=config["grad_norm_clipping"] + ) self.critic_grads_and_vars = _minimize_and_clip( critic_optimizer, weighted_error, var_list=q_func_vars, - clip_val=config["grad_norm_clipping"]) + clip_val=config["grad_norm_clipping"] + ) else: self.actor_grads_and_vars = actor_optimizer.compute_gradients( - actor_loss, var_list=p_func_vars) + actor_loss, var_list=p_func_vars + ) self.critic_grads_and_vars = critic_optimizer.compute_gradients( - weighted_error, var_list=q_func_vars) + weighted_error, var_list=q_func_vars + ) self.actor_grads_and_vars = [(g, v) for (g, v) in self.actor_grads_and_vars if g is not None] @@ -297,12 +334,15 @@ def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, for (g, v) in self.critic_grads_and_vars if g is not None] self.grads_and_vars = ( - self.actor_grads_and_vars + self.critic_grads_and_vars) + self.actor_grads_and_vars + self.critic_grads_and_vars + ) self.grads = [g for (g, v) in self.grads_and_vars] self.actor_train_expr = actor_optimizer.apply_gradients( - self.actor_grads_and_vars) + self.actor_grads_and_vars + ) self.critic_train_expr = critic_optimizer.apply_gradients( - self.critic_grads_and_vars) + self.critic_grads_and_vars + ) # update_target_fn will be called periodically to copy Q network to # target Q network @@ -310,24 +350,29 @@ def build_loss(obs_t, act_t, rew_t, obs_tp1, done_mask, self.tau = tf.placeholder(tf.float32, (), name="tau") update_target_expr = [] for var, var_target in zip( - sorted(q_func_vars, key=lambda v: v.name), - sorted(target_q_func_vars, key=lambda v: v.name)): + sorted(q_func_vars, key=lambda v: v.name), + sorted(target_q_func_vars, key=lambda v: v.name) + ): update_target_expr.append( - var_target.assign(self.tau * var + - (1.0 - self.tau) * var_target)) + var_target. + assign(self.tau * var + (1.0 - self.tau) * var_target) + ) for var, var_target in zip( - sorted(p_func_vars, key=lambda v: v.name), - sorted(target_p_func_vars, key=lambda v: v.name)): + sorted(p_func_vars, key=lambda v: v.name), + sorted(target_p_func_vars, key=lambda v: v.name) + ): update_target_expr.append( - var_target.assign(self.tau * var + - (1.0 - self.tau) * var_target)) + var_target. + assign(self.tau * var + (1.0 - self.tau) * var_target) + ) self.update_target_expr = tf.group(*update_target_expr) # support both hard and soft sync def update_target(self, sess, tau=None): return sess.run( self.update_target_expr, - feed_dict={self.tau: tau or self.tau_value}) + feed_dict={self.tau: tau or self.tau_value} + ) def act(self, sess, obs, eps, stochastic=True): return sess.run( @@ -336,24 +381,26 @@ def act(self, sess, obs, eps, stochastic=True): self.cur_observations: obs, self.stochastic: stochastic, self.eps: eps - }) - - def compute_gradients(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - td_err, grads = sess.run( - [self.td_error, self.grads], - feed_dict={ - self.obs_t: obs_t, - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: obs_tp1, - self.done_mask: done_mask, - self.importance_weights: importance_weights - }) + } + ) + + def compute_gradients( + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): + td_err, grads = sess.run([self.td_error, self.grads], + feed_dict={ + self.obs_t: obs_t, + self.act_t: act_t, + self.rew_t: rew_t, + self.obs_tp1: obs_tp1, + self.done_mask: done_mask, + self.importance_weights: importance_weights + }) return td_err, grads - def compute_td_error(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + def compute_td_error( + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): td_err = sess.run( self.td_error, feed_dict={ @@ -363,28 +410,30 @@ def compute_td_error(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, self.obs_tp1: [np.array(ob) for ob in obs_tp1], self.done_mask: done_mask, self.importance_weights: importance_weights - }) + } + ) return td_err def apply_gradients(self, sess, grads): assert len(grads) == len(self.grads_and_vars) feed_dict = {ph: g for (g, ph) in zip(grads, self.grads)} - sess.run( - [self.critic_train_expr, self.actor_train_expr], - feed_dict=feed_dict) - - def compute_apply(self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - td_err, _, _ = sess.run( - [self.td_error, self.critic_train_expr, self.actor_train_expr], - feed_dict={ - self.obs_t: obs_t, - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: obs_tp1, - self.done_mask: done_mask, - self.importance_weights: importance_weights - }) + sess.run([self.critic_train_expr, self.actor_train_expr], + feed_dict=feed_dict) + + def compute_apply( + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): + td_err, _, _ = sess.run([ + self.td_error, self.critic_train_expr, self.actor_train_expr + ], + feed_dict={ + self.obs_t: obs_t, + self.act_t: act_t, + self.rew_t: rew_t, + self.obs_tp1: obs_tp1, + self.done_mask: done_mask, + self.importance_weights: importance_weights + }) return td_err def reset_noise(self, sess): diff --git a/python/ray/rllib/ddpg2/ddpg.py b/python/ray/rllib/ddpg2/ddpg.py index 0de2a865f8ea..4f58e3e85ca0 100644 --- a/python/ray/rllib/ddpg2/ddpg.py +++ b/python/ray/rllib/ddpg2/ddpg.py @@ -37,7 +37,6 @@ "num_local_steps": 1, # Number of workers (excluding master) "num_workers": 0, - "optimizer": { # Replay buffer size "buffer_size": 10000, @@ -64,14 +63,17 @@ class DDPG2Agent(Agent): def _init(self): self.local_evaluator = DDPGEvaluator( - self.registry, self.env_creator, self.config) + self.registry, self.env_creator, self.config + ) self.remote_evaluators = [ RemoteDDPGEvaluator.remote( - self.registry, self.env_creator, self.config) - for _ in range(self.config["num_workers"])] + self.registry, self.env_creator, self.config + ) for _ in range(self.config["num_workers"]) + ] self.optimizer = LocalSyncReplayOptimizer( self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.remote_evaluators + ) def _train(self): for _ in range(self.config["train_steps"]): @@ -87,8 +89,10 @@ def _fetch_metrics(self): episode_rewards = [] episode_lengths = [] if self.config["num_workers"] > 0: - metric_lists = [a.get_completed_rollout_metrics.remote() - for a in self.remote_evaluators] + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators + ] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) @@ -107,6 +111,7 @@ def _fetch_metrics(self): episode_reward_mean=avg_reward, episode_len_mean=avg_length, timesteps_this_iter=timesteps, - info={}) + info={} + ) return result diff --git a/python/ray/rllib/ddpg2/ddpg_evaluator.py b/python/ray/rllib/ddpg2/ddpg_evaluator.py index 8a5ab5ed3f3a..eab236d1b720 100644 --- a/python/ray/rllib/ddpg2/ddpg_evaluator.py +++ b/python/ray/rllib/ddpg2/ddpg_evaluator.py @@ -14,17 +14,21 @@ class DDPGEvaluator(PolicyEvaluator): - def __init__(self, registry, env_creator, config): self.env = ModelCatalog.get_preprocessor_as_wrapper( - registry, env_creator(config["env_config"])) + registry, env_creator(config["env_config"]) + ) # contains model, target_model self.model = DDPGModel(registry, self.env, config) self.sampler = SyncSampler( - self.env, self.model.model, NoFilter(), - config["num_local_steps"], horizon=config["horizon"]) + self.env, + self.model.model, + NoFilter(), + config["num_local_steps"], + horizon=config["horizon"] + ) def sample(self): """Returns a batch of samples.""" @@ -34,9 +38,7 @@ def sample(self): # since each sample is one step, no discounting needs to be applied; # this does not involve config["gamma"] - samples = process_rollout( - rollout, NoFilter(), - gamma=1.0, use_gae=False) + samples = process_rollout(rollout, NoFilter(), gamma=1.0, use_gae=False) return samples diff --git a/python/ray/rllib/ddpg2/models.py b/python/ray/rllib/ddpg2/models.py index e785f518f541..90d170233fea 100644 --- a/python/ray/rllib/ddpg2/models.py +++ b/python/ray/rllib/ddpg2/models.py @@ -16,11 +16,11 @@ def __init__(self, registry, env, config): self.sess = tf.Session() with tf.variable_scope("model"): - self.model = DDPGActorCritic( - registry, env, self.config, self.sess) + self.model = DDPGActorCritic(registry, env, self.config, self.sess) with tf.variable_scope("target_model"): self.target_model = DDPGActorCritic( - registry, env, self.config, self.sess) + registry, env, self.config, self.sess + ) self._setup_gradients() self._setup_target_updates() @@ -34,13 +34,15 @@ def _initialize_target_weights(self): """Set initial target weights to match model weights.""" a_updates = [] for var, target_var in zip( - self.model.actor_var_list, self.target_model.actor_var_list): + self.model.actor_var_list, self.target_model.actor_var_list + ): a_updates.append(tf.assign(target_var, var)) actor_updates = tf.group(*a_updates) c_updates = [] for var, target_var in zip( - self.model.critic_var_list, self.target_model.critic_var_list): + self.model.critic_var_list, self.target_model.critic_var_list + ): c_updates.append(tf.assign(target_var, var)) critic_updates = tf.group(*c_updates) self.sess.run([actor_updates, critic_updates]) @@ -48,16 +50,20 @@ def _initialize_target_weights(self): def _setup_gradients(self): """Setup critic and actor gradients.""" self.critic_grads = tf.gradients( - self.model.critic_loss, self.model.critic_var_list) - c_grads_and_vars = list(zip( - self.critic_grads, self.model.critic_var_list)) + self.model.critic_loss, self.model.critic_var_list + ) + c_grads_and_vars = list( + zip(self.critic_grads, self.model.critic_var_list) + ) c_opt = tf.train.AdamOptimizer(self.config["critic_lr"]) self._apply_c_gradients = c_opt.apply_gradients(c_grads_and_vars) self.actor_grads = tf.gradients( - -self.model.cn_for_loss, self.model.actor_var_list) - a_grads_and_vars = list(zip( - self.actor_grads, self.model.actor_var_list)) + -self.model.cn_for_loss, self.model.actor_var_list + ) + a_grads_and_vars = list( + zip(self.actor_grads, self.model.actor_var_list) + ) a_opt = tf.train.AdamOptimizer(self.config["actor_lr"]) self._apply_a_gradients = a_opt.apply_gradients(a_grads_and_vars) @@ -65,9 +71,9 @@ def compute_gradients(self, samples): """ Returns gradient w.r.t. samples.""" # actor gradients actor_actions = self.sess.run( - self.model.output_action, - feed_dict={self.model.obs: samples["obs"]} - ) + self.model.output_action, + feed_dict={self.model.obs: samples["obs"]} + ) actor_feed_dict = { self.model.obs: samples["obs"], @@ -78,16 +84,17 @@ def compute_gradients(self, samples): # feed samples into target actor target_Q_act = self.sess.run( - self.target_model.output_action, - feed_dict={self.target_model.obs: samples["new_obs"]} - ) + self.target_model.output_action, + feed_dict={self.target_model.obs: samples["new_obs"]} + ) target_Q_dict = { self.target_model.obs: samples["new_obs"], self.target_model.act: target_Q_act, } target_Q = self.sess.run( - self.target_model.critic_eval, feed_dict=target_Q_dict) + self.target_model.critic_eval, feed_dict=target_Q_dict + ) # critic gradients critic_feed_dict = { @@ -98,7 +105,8 @@ def compute_gradients(self, samples): } self.critic_grads = [g for g in self.critic_grads if g is not None] critic_grad = self.sess.run( - self.critic_grads, feed_dict=critic_feed_dict) + self.critic_grads, feed_dict=critic_feed_dict + ) return (critic_grad, actor_grad), {} def apply_gradients(self, grads): @@ -124,16 +132,20 @@ def _setup_target_updates(self): a_updates = [] tau = self.config["tau"] for var, target_var in zip( - self.model.actor_var_list, self.target_model.actor_var_list): - a_updates.append(tf.assign( - target_var, tau * var + (1. - tau) * target_var)) + self.model.actor_var_list, self.target_model.actor_var_list + ): + a_updates.append( + tf.assign(target_var, tau * var + (1. - tau) * target_var) + ) actor_updates = tf.group(*a_updates) c_updates = [] for var, target_var in zip( - self.model.critic_var_list, self.target_model.critic_var_list): - c_updates.append(tf.assign( - target_var, tau * var + (1. - tau) * target_var)) + self.model.critic_var_list, self.target_model.critic_var_list + ): + c_updates.append( + tf.assign(target_var, tau * var + (1. - tau) * target_var) + ) critic_updates = tf.group(*c_updates) self.target_updates = [actor_updates, critic_updates] @@ -166,26 +178,26 @@ def __init__(self, registry, env, config, sess): with tf.variable_scope("critic"): self.critic_var_list = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name - ) - self.critic_vars = TensorFlowVariables(self.critic_loss, - self.sess) + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) + self.critic_vars = TensorFlowVariables(self.critic_loss, self.sess) with tf.variable_scope("actor"): self.actor_var_list = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name - ) - self.actor_vars = TensorFlowVariables(self.output_action, - self.sess) + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) + self.actor_vars = TensorFlowVariables(self.output_action, self.sess) if (self.config["noise_add"]): params = self.config["noise_parameters"] - self.rand_process = OrnsteinUhlenbeckProcess(size=self.ac_size, - theta=params["theta"], - mu=params["mu"], - sigma=params["sigma"]) + self.rand_process = OrnsteinUhlenbeckProcess( + size=self.ac_size, + theta=params["theta"], + mu=params["mu"], + sigma=params["sigma"] + ) self.epsilon = 1.0 def _setup_critic_loss(self, action_space): @@ -196,8 +208,9 @@ def _setup_critic_loss(self, action_space): self.reward = tf.placeholder(tf.float32, [None], name="reward") self.critic_target = tf.expand_dims(self.reward, 1) + \ self.config['gamma'] * self.target_Q - self.critic_loss = tf.reduce_mean(tf.square( - self.critic_target - self.critic_eval)) + self.critic_loss = tf.reduce_mean( + tf.square(self.critic_target - self.critic_eval) + ) def _setup_critic_network(self, obs_space, ac_space): """Sets up Q network.""" @@ -206,15 +219,17 @@ def _setup_critic_network(self, obs_space, ac_space): self.critic_eval = self.critic_network.outputs with tf.variable_scope("critic", reuse=True): - self.cn_for_loss = DDPGCritic( - (self.obs, self.output_action), 1, {}).outputs + self.cn_for_loss = DDPGCritic((self.obs, self.output_action), 1, + {}).outputs def _setup_actor_network(self, obs_space, ac_space): """Sets up actor network.""" with tf.variable_scope("actor", reuse=tf.AUTO_REUSE): self.actor_network = DDPGActor( - self.obs, self.ac_size, - options={"action_bound": self.action_bound}) + self.obs, + self.ac_size, + options={"action_bound": self.action_bound} + ) self.output_action = self.actor_network.outputs def get_weights(self): diff --git a/python/ray/rllib/ddpg2/random_process.py b/python/ray/rllib/ddpg2/random_process.py index 0a969fd00303..57f9f66f2e19 100644 --- a/python/ray/rllib/ddpg2/random_process.py +++ b/python/ray/rllib/ddpg2/random_process.py @@ -37,13 +37,23 @@ def current_sigma(self): # Based on # http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): - def __init__(self, theta, mu=0., sigma=1., dt=1e-2, - x0=None, size=1, sigma_min=None, n_steps_annealing=1000): + def __init__( + self, + theta, + mu=0., + sigma=1., + dt=1e-2, + x0=None, + size=1, + sigma_min=None, + n_steps_annealing=1000 + ): super(OrnsteinUhlenbeckProcess, self).__init__( mu=mu, sigma=sigma, sigma_min=sigma_min, - n_steps_annealing=n_steps_annealing) + n_steps_annealing=n_steps_annealing + ) self.theta = theta self.mu = mu self.dt = dt diff --git a/python/ray/rllib/dqn/apex.py b/python/ray/rllib/dqn/apex.py index b44fb85b4663..b75f268ca1fc 100644 --- a/python/ray/rllib/dqn/apex.py +++ b/python/ray/rllib/dqn/apex.py @@ -5,26 +5,32 @@ from ray.rllib.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG from ray.tune.trial import Resources -APEX_DEFAULT_CONFIG = dict(DQN_CONFIG, **dict( - optimizer_class="ApexOptimizer", - optimizer_config=dict(DQN_CONFIG["optimizer_config"], **dict( +APEX_DEFAULT_CONFIG = dict( + DQN_CONFIG, + **dict( + optimizer_class="ApexOptimizer", + optimizer_config=dict( + DQN_CONFIG["optimizer_config"], + **dict( + max_weight_sync_delay=400, + num_replay_buffer_shards=4, + debug=False, + ) + ), + n_step=3, + gpu=True, + num_workers=32, + buffer_size=2000000, + learning_starts=50000, + train_batch_size=512, + sample_batch_size=50, max_weight_sync_delay=400, - num_replay_buffer_shards=4, - debug=False, - )), - n_step=3, - gpu=True, - num_workers=32, - buffer_size=2000000, - learning_starts=50000, - train_batch_size=512, - sample_batch_size=50, - max_weight_sync_delay=400, - target_network_update_freq=500000, - timesteps_per_iteration=25000, - per_worker_exploration=True, - worker_side_prioritization=True, -)) + target_network_update_freq=500000, + timesteps_per_iteration=25000, + per_worker_exploration=True, + worker_side_prioritization=True, + ) +) class ApexAgent(DQNAgent): @@ -44,7 +50,8 @@ def default_resource_request(cls, config): cpu=1 + cf["optimizer_config"]["num_replay_buffer_shards"], gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"] + ) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled diff --git a/python/ray/rllib/dqn/common/schedules.py b/python/ray/rllib/dqn/common/schedules.py index d9ceb2f76696..a1c899c88fe7 100644 --- a/python/ray/rllib/dqn/common/schedules.py +++ b/python/ray/rllib/dqn/common/schedules.py @@ -40,9 +40,8 @@ def linear_interpolation(l, r, alpha): class PiecewiseSchedule(object): def __init__( - self, endpoints, interpolation=linear_interpolation, - outside_value=None): - + self, endpoints, interpolation=linear_interpolation, outside_value=None + ): """Piecewise schedule. endpoints: [(int, int)] diff --git a/python/ray/rllib/dqn/common/wrappers.py b/python/ray/rllib/dqn/common/wrappers.py index a968888aab71..214ff4a2f4f6 100644 --- a/python/ray/rllib/dqn/common/wrappers.py +++ b/python/ray/rllib/dqn/common/wrappers.py @@ -15,6 +15,7 @@ def wrap_dqn(registry, env, options, random_starts): # TODO(ekl) this logic should be pushed to the catalog. if is_atari and "custom_preprocessor" not in options: return wrap_deepmind( - env, random_starts=random_starts, dim=options.get("dim", 80)) + env, random_starts=random_starts, dim=options.get("dim", 80) + ) return ModelCatalog.get_preprocessor_as_wrapper(registry, env, options) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index cd7e85847a93..6977bc59d607 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -15,11 +15,11 @@ from ray.tune.result import TrainingResult from ray.tune.trial import Resources - OPTIMIZER_SHARED_CONFIGS = [ "buffer_size", "prioritized_replay", "prioritized_replay_alpha", "prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size", - "train_batch_size", "learning_starts", "clip_rewards"] + "train_batch_size", "learning_starts", "clip_rewards" +] DEFAULT_CONFIG = dict( # === Model === @@ -90,7 +90,9 @@ # === Tensorflow === # Arguments to pass to tensorflow tf_session_args={ - "device_count": {"CPU": 2}, + "device_count": { + "CPU": 2 + }, "log_device_placement": False, "allow_soft_placement": True, "gpu_options": { @@ -118,35 +120,42 @@ # Whether to use a distribution of epsilons across workers for exploration. per_worker_exploration=False, # Whether to compute priorities on workers. - worker_side_prioritization=False) + worker_side_prioritization=False +) class DQNAgent(Agent): _agent_name = "DQN" _allow_unknown_subkeys = [ - "model", "optimizer", "tf_session_args", "env_config"] + "model", "optimizer", "tf_session_args", "env_config" + ] _default_config = DEFAULT_CONFIG @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( - cpu=1, gpu=cf["gpu"] and 1 or 0, + cpu=1, + gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"] + ) def _init(self): self.local_evaluator = DQNEvaluator( - self.registry, self.env_creator, self.config, self.logdir, 0) + self.registry, self.env_creator, self.config, self.logdir, 0 + ) remote_cls = ray.remote( num_cpus=self.config["num_cpus_per_worker"], - num_gpus=self.config["num_gpus_per_worker"])( - DQNEvaluator) + num_gpus=self.config["num_gpus_per_worker"] + )( + DQNEvaluator + ) self.remote_evaluators = [ remote_cls.remote( - self.registry, self.env_creator, self.config, self.logdir, - i) - for i in range(self.config["num_workers"])] + self.registry, self.env_creator, self.config, self.logdir, i + ) for i in range(self.config["num_workers"]) + ] for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: @@ -154,7 +163,8 @@ def _init(self): self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, - self.remote_evaluators) + self.remote_evaluators + ) self.saver = tf.train.Saver(max_to_keep=None) self.last_target_update_ts = 0 @@ -174,8 +184,10 @@ def update_target_if_needed(self): def _train(self): start_timestep = self.global_timestep - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"]): + while ( + self.global_timestep - start_timestep < + self.config["timesteps_per_iteration"] + ): self.optimizer.step() self.update_target_if_needed() @@ -188,8 +200,7 @@ def _train(self): def _train_stats(self, start_timestep): if self.remote_evaluators: - stats = ray.get([ - e.stats.remote() for e in self.remote_evaluators]) + stats = ray.get([e.stats.remote() for e in self.remote_evaluators]) else: stats = self.local_evaluator.stats() if not isinstance(stats, list): @@ -202,7 +213,7 @@ def _train_stats(self, start_timestep): if self.config["per_worker_exploration"]: # Return stats from workers with the lowest 20% of exploration - test_stats = stats[-int(max(1, len(stats)*0.2)):] + test_stats = stats[-int(max(1, len(stats) * 0.2)):] else: test_stats = stats @@ -225,7 +236,8 @@ def _train_stats(self, start_timestep): "min_exploration": min(explorations), "max_exploration": max(explorations), "num_target_updates": self.num_target_updates, - }, **opt_stats)) + }, **opt_stats) + ) return result @@ -238,13 +250,14 @@ def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, os.path.join(checkpoint_dir, "checkpoint"), - global_step=self.iteration) + global_step=self.iteration + ) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), - self.optimizer.save(), - self.num_target_updates, - self.last_target_update_ts] + self.optimizer.save(), self.num_target_updates, + self.last_target_update_ts + ] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path @@ -253,12 +266,15 @@ def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ - e.restore.remote(d) for (d, e) - in zip(extra_data[1], self.remote_evaluators)]) + e.restore.remote(d) + for (d, e) in zip(extra_data[1], self.remote_evaluators) + ]) self.optimizer.restore(extra_data[2]) self.num_target_updates = extra_data[3] self.last_target_update_ts = extra_data[4] def compute_action(self, observation): return self.local_evaluator.dqn_graph.act( - self.local_evaluator.sess, np.array(observation)[None], 0.0)[0] + self.local_evaluator.sess, + np.array(observation)[None], 0.0 + )[0] diff --git a/python/ray/rllib/dqn/dqn_evaluator.py b/python/ray/rllib/dqn/dqn_evaluator.py index 758dc5f819d4..f3f40ec0cb2d 100644 --- a/python/ray/rllib/dqn/dqn_evaluator.py +++ b/python/ray/rllib/dqn/dqn_evaluator.py @@ -34,7 +34,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): continue # episode end for j in range(1, n_step): new_obs[i] = new_obs[i + j] - rewards[i] += gamma ** j * rewards[i + j] + rewards[i] += gamma**j * rewards[i + j] if dones[i + j]: break # episode end # truncate ends of the trajectory @@ -57,7 +57,9 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): if not isinstance(env.action_space, Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format( - env.action_space)) + env.action_space + ) + ) tf_config = tf.ConfigProto(**config["tf_session_args"]) self.sess = tf.Session(config=tf_config) @@ -67,15 +69,17 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): if config["per_worker_exploration"]: assert config["num_workers"] > 1, "This requires multiple workers" self.exploration = ConstantSchedule( - 0.4 ** ( - 1 + worker_index / float(config["num_workers"] - 1) * 7)) + 0.4**(1 + worker_index / float(config["num_workers"] - 1) * 7) + ) else: self.exploration = LinearSchedule( schedule_timesteps=int( config["exploration_fraction"] * - config["schedule_max_timesteps"]), + config["schedule_max_timesteps"] + ), initial_p=1.0, - final_p=config["exploration_final_eps"]) + final_p=config["exploration_final_eps"] + ) # Initialize the parameters and copy them to the target network. self.sess.run(tf.global_variables_initializer()) @@ -85,7 +89,8 @@ def __init__(self, registry, env_creator, config, logdir, worker_index): # Note that this encompasses both the Q and target network self.variables = ray.experimental.TensorFlowVariables( - tf.group(self.dqn_graph.q_t, self.dqn_graph.q_tp1), self.sess) + tf.group(self.dqn_graph.q_t, self.dqn_graph.q_tp1), self.sess + ) self.episode_rewards = [0.0] self.episode_lengths = [0.0] @@ -102,7 +107,8 @@ def update_target(self): def sample(self): obs, actions, rewards, new_obs, dones = [], [], [], [], [] for _ in range( - self.config["sample_batch_size"] + self.config["n_step"] - 1): + self.config["sample_batch_size"] + self.config["n_step"] - 1 + ): ob, act, rew, ob1, done = self._step(self.global_timestep) obs.append(ob) actions.append(act) @@ -115,23 +121,29 @@ def sample(self): # Adjust for steps lost from truncation self.local_timestep -= (self.config["n_step"] - 1) adjust_nstep( - self.config["n_step"], self.config["gamma"], - obs, actions, rewards, new_obs, dones) + self.config["n_step"], self.config["gamma"], obs, actions, + rewards, new_obs, dones + ) batch = SampleBatch({ - "obs": [pack(np.array(o)) for o in obs], "actions": actions, + "obs": [pack(np.array(o)) for o in obs], + "actions": actions, "rewards": rewards, - "new_obs": [pack(np.array(o)) for o in new_obs], "dones": dones, - "weights": np.ones_like(rewards)}) + "new_obs": [pack(np.array(o)) for o in new_obs], + "dones": dones, + "weights": np.ones_like(rewards) + }) assert (batch.count == self.config["sample_batch_size"]) # Prioritize on the worker side if self.config["worker_side_prioritization"]: td_errors = self.dqn_graph.compute_td_error( - self.sess, obs, batch["actions"], batch["rewards"], - new_obs, batch["dones"], batch["weights"]) + self.sess, obs, batch["actions"], batch["rewards"], new_obs, + batch["dones"], batch["weights"] + ) new_priorities = ( - np.abs(td_errors) + self.config["prioritized_replay_eps"]) + np.abs(td_errors) + self.config["prioritized_replay_eps"] + ) batch.data["weights"] = new_priorities return batch @@ -139,7 +151,8 @@ def sample(self): def compute_gradients(self, samples): td_err, grads = self.dqn_graph.compute_gradients( self.sess, samples["obs"], samples["actions"], samples["rewards"], - samples["new_obs"], samples["dones"], samples["weights"]) + samples["new_obs"], samples["dones"], samples["weights"] + ) return grads, {"td_error": td_err} def apply_gradients(self, grads): @@ -148,7 +161,8 @@ def apply_gradients(self, grads): def compute_apply(self, samples): td_error = self.dqn_graph.compute_apply( self.sess, samples["obs"], samples["actions"], samples["rewards"], - samples["new_obs"], samples["dones"], samples["weights"]) + samples["new_obs"], samples["dones"], samples["weights"] + ) return {"td_error": td_error} def get_weights(self): @@ -160,8 +174,9 @@ def set_weights(self, weights): def _step(self, global_timestep): """Takes a single step, and returns the result of the step.""" action = self.dqn_graph.act( - self.sess, np.array(self.obs)[None], - self.exploration.value(global_timestep))[0] + self.sess, + np.array(self.obs)[None], self.exploration.value(global_timestep) + )[0] new_obs, rew, done, _ = self.env.step(action) ret = (self.obs, action, rew, new_obs, float(done)) self.obs = new_obs @@ -189,13 +204,10 @@ def stats(self): def save(self): return [ - self.exploration, - self.episode_rewards, - self.episode_lengths, - self.saved_mean_reward, - self.obs, - self.global_timestep, - self.local_timestep] + self.exploration, self.episode_rewards, self.episode_lengths, + self.saved_mean_reward, self.obs, self.global_timestep, + self.local_timestep + ] def restore(self, data): self.exploration = data[0] diff --git a/python/ray/rllib/dqn/models.py b/python/ray/rllib/dqn/models.py index 6629b6126acf..758ffac786ac 100644 --- a/python/ray/rllib/dqn/models.py +++ b/python/ray/rllib/dqn/models.py @@ -21,47 +21,54 @@ def _build_q_network(registry, inputs, num_actions, config): action_out = frontend_out for hidden in hiddens: action_out = layers.fully_connected( - action_out, num_outputs=hidden, activation_fn=tf.nn.relu) + action_out, num_outputs=hidden, activation_fn=tf.nn.relu + ) action_scores = layers.fully_connected( - action_out, num_outputs=num_actions, activation_fn=None) + action_out, num_outputs=num_actions, activation_fn=None + ) if dueling: with tf.variable_scope("state_value"): state_out = frontend_out for hidden in hiddens: state_out = layers.fully_connected( - state_out, num_outputs=hidden, activation_fn=tf.nn.relu) + state_out, num_outputs=hidden, activation_fn=tf.nn.relu + ) state_score = layers.fully_connected( - state_out, num_outputs=1, activation_fn=None) + state_out, num_outputs=1, activation_fn=None + ) action_scores_mean = tf.reduce_mean(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( - action_scores_mean, 1) + action_scores_mean, 1 + ) return state_score + action_scores_centered else: return action_scores -def _build_action_network( - q_values, observations, num_actions, stochastic, eps): +def _build_action_network(q_values, observations, num_actions, stochastic, eps): deterministic_actions = tf.argmax(q_values, axis=1) batch_size = tf.shape(observations)[0] random_actions = tf.random_uniform( - tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64) + tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64 + ) chose_random = tf.random_uniform( - tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps + tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32 + ) < eps stochastic_actions = tf.where( - chose_random, random_actions, deterministic_actions) + chose_random, random_actions, deterministic_actions + ) return tf.cond( - stochastic, lambda: stochastic_actions, - lambda: deterministic_actions) + stochastic, lambda: stochastic_actions, lambda: deterministic_actions + ) def _huber_loss(x, delta=1.0): """Reference: https://en.wikipedia.org/wiki/Huber_loss""" return tf.where( tf.abs(x) < delta, - tf.square(x) * 0.5, - delta * (tf.abs(x) - 0.5 * delta)) + tf.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta) + ) def _minimize_and_clip(optimizer, objective, var_list, clip_val=10): @@ -97,7 +104,8 @@ def _scope_vars(scope, trainable_only=False): return tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES, - scope=scope if isinstance(scope, str) else scope.name) + scope=scope if isinstance(scope, str) else scope.name + ) class ModelAndLoss(object): @@ -108,8 +116,9 @@ class ModelAndLoss(object): """ def __init__( - self, registry, num_actions, config, - obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): + self, registry, num_actions, config, obs_t, act_t, rew_t, obs_tp1, + done_mask, importance_weights + ): # q network evaluation with tf.variable_scope("q_func", reuse=True): self.q_t = _build_q_network(registry, obs_t, num_actions, config) @@ -117,29 +126,34 @@ def __init__( # target q network evalution with tf.variable_scope("target_q_func") as scope: self.q_tp1 = _build_q_network( - registry, obs_tp1, num_actions, config) + registry, obs_tp1, num_actions, config + ) self.target_q_func_vars = _scope_vars(scope.name) # q scores for actions which we know were selected in the given state. q_t_selected = tf.reduce_sum( - self.q_t * tf.one_hot(act_t, num_actions), 1) + self.q_t * tf.one_hot(act_t, num_actions), 1 + ) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: with tf.variable_scope("q_func", reuse=True): q_tp1_using_online_net = _build_q_network( - registry, obs_tp1, num_actions, config) + registry, obs_tp1, num_actions, config + ) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best = tf.reduce_sum( - self.q_tp1 * tf.one_hot( - q_tp1_best_using_online_net, num_actions), 1) + self.q_tp1 * + tf.one_hot(q_tp1_best_using_online_net, num_actions), 1 + ) else: q_tp1_best = tf.reduce_max(self.q_tp1, 1) q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = ( - rew_t + config["gamma"] ** config["n_step"] * q_tp1_best_masked) + rew_t + config["gamma"]**config["n_step"] * q_tp1_best_masked + ) # compute the error (potentially clipped) self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) @@ -160,40 +174,44 @@ def __init__(self, registry, env, config, logdir): self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") self.cur_observations = tf.placeholder( - tf.float32, shape=(None,) + env.observation_space.shape) + tf.float32, shape=(None, ) + env.observation_space.shape + ) # Action Q network q_scope_name = TOWER_SCOPE_NAME + "/q_func" with tf.variable_scope(q_scope_name) as scope: q_values = _build_q_network( - registry, self.cur_observations, num_actions, config) + registry, self.cur_observations, num_actions, config + ) q_func_vars = _scope_vars(scope.name) # Action outputs self.output_actions = _build_action_network( - q_values, - self.cur_observations, - num_actions, - self.stochastic, - self.eps) + q_values, self.cur_observations, num_actions, self.stochastic, + self.eps + ) # Replay inputs self.obs_t = tf.placeholder( - tf.float32, shape=(None,) + env.observation_space.shape) + tf.float32, shape=(None, ) + env.observation_space.shape + ) self.act_t = tf.placeholder(tf.int32, [None], name="action") self.rew_t = tf.placeholder(tf.float32, [None], name="reward") self.obs_tp1 = tf.placeholder( - tf.float32, shape=(None,) + env.observation_space.shape) + tf.float32, shape=(None, ) + env.observation_space.shape + ) self.done_mask = tf.placeholder(tf.float32, [None], name="done") self.importance_weights = tf.placeholder( - tf.float32, [None], name="weight") + tf.float32, [None], name="weight" + ) def build_loss( - obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): + obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): return ModelAndLoss( - registry, - num_actions, config, - obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights) + registry, num_actions, config, obs_t, act_t, rew_t, obs_tp1, + done_mask, importance_weights + ) self.loss_inputs = [ ("obs", self.obs_t), @@ -207,7 +225,8 @@ def build_loss( with tf.variable_scope(TOWER_SCOPE_NAME): loss_obj = build_loss( self.obs_t, self.act_t, self.rew_t, self.obs_tp1, - self.done_mask, self.importance_weights) + self.done_mask, self.importance_weights + ) self.build_loss = build_loss @@ -220,13 +239,18 @@ def build_loss( # compute optimization op (potentially with gradient clipping) if config["grad_norm_clipping"] is not None: self.grads_and_vars = _minimize_and_clip( - optimizer, weighted_error, var_list=q_func_vars, - clip_val=config["grad_norm_clipping"]) + optimizer, + weighted_error, + var_list=q_func_vars, + clip_val=config["grad_norm_clipping"] + ) else: self.grads_and_vars = optimizer.compute_gradients( - weighted_error, var_list=q_func_vars) - self.grads_and_vars = [ - (g, v) for (g, v) in self.grads_and_vars if g is not None] + weighted_error, var_list=q_func_vars + ) + self.grads_and_vars = [(g, v) + for (g, v) in self.grads_and_vars + if g is not None] self.grads = [g for (g, v) in self.grads_and_vars] self.train_expr = optimizer.apply_gradients(self.grads_and_vars) @@ -235,7 +259,8 @@ def build_loss( update_target_expr = [] for var, var_target in zip( sorted(q_func_vars, key=lambda v: v.name), - sorted(target_q_func_vars, key=lambda v: v.name)): + sorted(target_q_func_vars, key=lambda v: v.name) + ): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) @@ -249,26 +274,26 @@ def act(self, sess, obs, eps, stochastic=True): self.cur_observations: obs, self.stochastic: stochastic, self.eps: eps, - }) + } + ) def compute_gradients( - self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - td_err, grads = sess.run( - [self.td_error, self.grads], - feed_dict={ - self.obs_t: obs_t, - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: obs_tp1, - self.done_mask: done_mask, - self.importance_weights: importance_weights - }) + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): + td_err, grads = sess.run([self.td_error, self.grads], + feed_dict={ + self.obs_t: obs_t, + self.act_t: act_t, + self.rew_t: rew_t, + self.obs_tp1: obs_tp1, + self.done_mask: done_mask, + self.importance_weights: importance_weights + }) return td_err, grads def compute_td_error( - self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): td_err = sess.run( self.td_error, feed_dict={ @@ -278,7 +303,8 @@ def compute_td_error( self.obs_tp1: [np.array(ob) for ob in obs_tp1], self.done_mask: done_mask, self.importance_weights: importance_weights - }) + } + ) return td_err def apply_gradients(self, sess, grads): @@ -287,16 +313,15 @@ def apply_gradients(self, sess, grads): sess.run(self.train_expr, feed_dict=feed_dict) def compute_apply( - self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - td_err, _ = sess.run( - [self.td_error, self.train_expr], - feed_dict={ - self.obs_t: obs_t, - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: obs_tp1, - self.done_mask: done_mask, - self.importance_weights: importance_weights - }) + self, sess, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights + ): + td_err, _ = sess.run([self.td_error, self.train_expr], + feed_dict={ + self.obs_t: obs_t, + self.act_t: act_t, + self.rew_t: rew_t, + self.obs_tp1: obs_tp1, + self.done_mask: done_mask, + self.importance_weights: importance_weights + }) return td_err diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index ca1bf4da69fe..c93bccad8e5d 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -20,12 +20,12 @@ from ray.rllib.es import tabular_logger as tlogger from ray.rllib.es import utils - -Result = namedtuple("Result", [ - "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", - "eval_returns", "eval_lengths" -]) - +Result = namedtuple( + "Result", [ + "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", + "eval_returns", "eval_lengths" + ] +) DEFAULT_CONFIG = dict( l2_coeff=0.005, @@ -38,7 +38,8 @@ stepsize=0.01, observation_filter="MeanStdFilter", noise_size=250000000, - env_config={}) + env_config={} +) @ray.remote @@ -63,8 +64,15 @@ def sample_index(self, dim): @ray.remote class Worker(object): - def __init__(self, registry, config, policy_params, env_creator, noise, - min_task_runtime=0.2): + def __init__( + self, + registry, + config, + policy_params, + env_creator, + noise, + min_task_runtime=0.2 + ): self.min_task_runtime = min_task_runtime self.config = config self.policy_params = policy_params @@ -73,17 +81,22 @@ def __init__(self, registry, config, policy_params, env_creator, noise, self.env = env_creator(config["env_config"]) from ray.rllib import models self.preprocessor = models.ModelCatalog.get_preprocessor( - registry, self.env) + registry, self.env + ) self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( registry, self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], **policy_params) + config["observation_filter"], **policy_params + ) def rollout(self, timestep_limit, add_noise=True): rollout_rewards, rollout_length = policies.rollout( - self.policy, self.env, timestep_limit=timestep_limit, - add_noise=add_noise) + self.policy, + self.env, + timestep_limit=timestep_limit, + add_noise=add_noise + ) return rollout_rewards, rollout_length def do_rollouts(self, params, timestep_limit=None): @@ -95,8 +108,10 @@ def do_rollouts(self, params, timestep_limit=None): # Perform some rollouts with noise. task_tstart = time.time() - while (len(noise_indices) == 0 or - time.time() - task_tstart < self.min_task_runtime): + while ( + len(noise_indices) == 0 + or time.time() - task_tstart < self.min_task_runtime + ): if np.random.uniform() < self.config["eval_prob"]: # Do an evaluation run with no perturbation. @@ -109,7 +124,8 @@ def do_rollouts(self, params, timestep_limit=None): noise_index = self.noise.sample_index(self.policy.num_params) perturbation = self.config["noise_stdev"] * self.noise.get( - noise_index, self.policy.num_params) + noise_index, self.policy.num_params + ) # These two sampling steps could be done in parallel on # different actors letting us update twice as frequently. @@ -121,8 +137,10 @@ def do_rollouts(self, params, timestep_limit=None): noise_indices.append(noise_index) returns.append([rewards_pos.sum(), rewards_neg.sum()]) - sign_returns.append( - [np.sign(rewards_pos).sum(), np.sign(rewards_neg).sum()]) + sign_returns.append([ + np.sign(rewards_pos).sum(), + np.sign(rewards_neg).sum() + ]) lengths.append([lengths_pos, lengths_neg]) return Result( @@ -131,7 +149,8 @@ def do_rollouts(self, params, timestep_limit=None): sign_noisy_returns=sign_returns, noisy_lengths=lengths, eval_returns=eval_returns, - eval_lengths=eval_lengths) + eval_lengths=eval_lengths + ) class ESAgent(agent.Agent): @@ -145,19 +164,17 @@ def default_resource_request(cls, config): return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"]) def _init(self): - policy_params = { - "action_noise_std": 0.01 - } + policy_params = {"action_noise_std": 0.01} env = self.env_creator(self.config["env_config"]) from ray.rllib import models - preprocessor = models.ModelCatalog.get_preprocessor( - self.registry, env) + preprocessor = models.ModelCatalog.get_preprocessor(self.registry, env) self.sess = utils.make_session(single_threaded=False) self.policy = policies.GenericPolicy( self.registry, self.sess, env.action_space, preprocessor, - self.config["observation_filter"], **policy_params) + self.config["observation_filter"], **policy_params + ) self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"]) # Create the shared noise table. @@ -170,8 +187,9 @@ def _init(self): self.workers = [ Worker.remote( self.registry, self.config, policy_params, self.env_creator, - noise_id) - for _ in range(self.config["num_workers"])] + noise_id + ) for _ in range(self.config["num_workers"]) + ] self.episodes_so_far = 0 self.timesteps_so_far = 0 @@ -183,19 +201,24 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps): while num_episodes < min_episodes or num_timesteps < min_timesteps: print( "Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) - rollout_ids = [worker.do_rollouts.remote(theta_id) - for worker in self.workers] + num_episodes, num_timesteps + ) + ) + rollout_ids = [ + worker.do_rollouts.remote(theta_id) for worker in self.workers + ] # Get the results of the rollouts. for result in ray.get(rollout_ids): results.append(result) # Update the number of episodes and the number of timesteps # keeping in mind that result.noisy_lengths is a list of lists, # where the inner lists have length 2. - num_episodes += sum([len(pair) for pair - in result.noisy_lengths]) - num_timesteps += sum([sum(pair) for pair - in result.noisy_lengths]) + num_episodes += sum([ + len(pair) for pair in result.noisy_lengths + ]) + num_timesteps += sum([ + sum(pair) for pair in result.noisy_lengths + ]) return results, num_episodes, num_timesteps def _train(self): @@ -210,9 +233,9 @@ def _train(self): # Use the actors to do rollouts, note that we pass in the ID of the # policy weights. results, num_episodes, num_timesteps = self._collect_results( - theta_id, - config["episodes_per_batch"], - config["timesteps_per_batch"]) + theta_id, config["episodes_per_batch"], + config["timesteps_per_batch"] + ) all_noise_indices = [] all_training_returns = [] @@ -230,8 +253,10 @@ def _train(self): all_training_lengths += result.noisy_lengths assert len(all_eval_returns) == len(all_eval_lengths) - assert (len(all_noise_indices) == len(all_training_returns) == - len(all_training_lengths)) + assert ( + len(all_noise_indices) == len(all_training_returns) == + len(all_training_lengths) + ) self.episodes_so_far += num_episodes self.timesteps_so_far += num_timesteps @@ -251,18 +276,21 @@ def _train(self): # Compute and take a step. g, count = utils.batched_weighted_sum( - proc_noisy_returns[:, 0] - proc_noisy_returns[:, 1], - (self.noise.get(index, self.policy.num_params) - for index in noise_indices), - batch_size=500) + proc_noisy_returns[:, 0] - proc_noisy_returns[:, 1], ( + self.noise.get(index, self.policy.num_params) + for index in noise_indices + ), + batch_size=500 + ) g /= noisy_returns.size assert ( - g.shape == (self.policy.num_params,) and - g.dtype == np.float32 and - count == len(noise_indices)) + g.shape == (self.policy.num_params, ) and g.dtype == np.float32 + and count == len(noise_indices) + ) # Compute the new weights theta. theta, update_ratio = self.optimizer.update( - -g + config["l2_coeff"] * theta) + -g + config["l2_coeff"] * theta + ) # Set the new weights in the local copy of the policy. self.policy.set_weights(theta) @@ -304,7 +332,8 @@ def _train(self): episode_reward_mean=eval_returns.mean(), episode_len_mean=eval_lengths.mean(), timesteps_this_iter=noisy_lengths.sum(), - info=info) + info=info + ) return result @@ -315,12 +344,10 @@ def _stop(self): def _save(self, checkpoint_dir): checkpoint_path = os.path.join( - checkpoint_dir, "checkpoint-{}".format(self.iteration)) + checkpoint_dir, "checkpoint-{}".format(self.iteration) + ) weights = self.policy.get_weights() - objects = [ - weights, - self.episodes_so_far, - self.timesteps_so_far] + objects = [weights, self.episodes_so_far, self.timesteps_so_far] pickle.dump(objects, open(checkpoint_path, "wb")) return checkpoint_path diff --git a/python/ray/rllib/es/optimizers.py b/python/ray/rllib/es/optimizers.py index f5ef4e10925f..316714ed7135 100644 --- a/python/ray/rllib/es/optimizers.py +++ b/python/ray/rllib/es/optimizers.py @@ -48,8 +48,9 @@ def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08): self.v = np.zeros(self.dim, dtype=np.float32) def _compute_step(self, globalg): - a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) / - (1 - self.beta1 ** self.t)) + a = self.stepsize * ( + np.sqrt(1 - self.beta2**self.t) / (1 - self.beta1**self.t) + ) self.m = self.beta1 * self.m + (1 - self.beta1) * globalg self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) step = -a * self.m / (np.sqrt(self.v) + self.epsilon) diff --git a/python/ray/rllib/es/policies.py b/python/ray/rllib/es/policies.py index 36a404c4882c..80dfe515d514 100644 --- a/python/ray/rllib/es/policies.py +++ b/python/ray/rllib/es/policies.py @@ -21,8 +21,10 @@ def rollout(policy, env, timestep_limit=None, add_noise=False): noise drawn from that stream. Otherwise, no action noise will be added. """ env_timestep_limit = env.spec.max_episode_steps - timestep_limit = (env_timestep_limit if timestep_limit is None - else min(timestep_limit, env_timestep_limit)) + timestep_limit = ( + env_timestep_limit + if timestep_limit is None else min(timestep_limit, env_timestep_limit) + ) rews = [] t = 0 observation = env.reset() @@ -38,37 +40,45 @@ def rollout(policy, env, timestep_limit=None, add_noise=False): class GenericPolicy(object): - def __init__(self, registry, sess, action_space, preprocessor, - observation_filter, action_noise_std): + def __init__( + self, registry, sess, action_space, preprocessor, observation_filter, + action_noise_std + ): self.sess = sess self.action_space = action_space self.action_noise_std = action_noise_std self.preprocessor = preprocessor self.observation_filter = get_filter( - observation_filter, self.preprocessor.shape) + observation_filter, self.preprocessor.shape + ) self.inputs = tf.placeholder( - tf.float32, [None] + list(self.preprocessor.shape)) + tf.float32, [None] + list(self.preprocessor.shape) + ) # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( - self.action_space, dist_type="deterministic") + self.action_space, dist_type="deterministic" + ) model = ModelCatalog.get_model(registry, self.inputs, dist_dim) dist = dist_class(model.outputs) self.sampler = dist.sample() self.variables = ray.experimental.TensorFlowVariables( - model.outputs, self.sess) + model.outputs, self.sess + ) - self.num_params = sum([np.prod(variable.shape.as_list()) - for _, variable - in self.variables.variables.items()]) + self.num_params = sum([ + np.prod(variable.shape.as_list()) + for _, variable in self.variables.variables.items() + ]) self.sess.run(tf.global_variables_initializer()) def compute(self, observation, add_noise=False, update=True): observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) - action = self.sess.run(self.sampler, - feed_dict={self.inputs: observation}) + action = self.sess.run( + self.sampler, feed_dict={self.inputs: observation} + ) if add_noise and isinstance(self.action_space, gym.spaces.Box): action += np.random.randn(*action.shape) * self.action_noise_std return action diff --git a/python/ray/rllib/es/tabular_logger.py b/python/ray/rllib/es/tabular_logger.py index 80e7b5b37aec..196b2935e573 100644 --- a/python/ray/rllib/es/tabular_logger.py +++ b/python/ray/rllib/es/tabular_logger.py @@ -25,18 +25,23 @@ class TbWriter(object): """Based on SummaryWriter, but changed to allow for a different prefix.""" + def __init__(self, dir, prefix): self.dir = dir # Start at 1, because EvWriter automatically generates an object with # step = 0. self.step = 1 self.evwriter = pywrap_tensorflow.EventsWriter( - compat.as_bytes(os.path.join(dir, prefix))) + compat.as_bytes(os.path.join(dir, prefix)) + ) def write_values(self, key2val): - summary = tf.Summary(value=[tf.Summary.Value(tag=k, - simple_value=float(v)) - for (k, v) in key2val.items()]) + summary = tf.Summary( + value=[ + tf.Summary.Value(tag=k, simple_value=float(v)) + for (k, v) in key2val.items() + ] + ) event = event_pb2.Event(wall_time=time.time(), summary=summary) event.step = self.step self.evwriter.WriteEvent(event) @@ -46,22 +51,27 @@ def write_values(self, key2val): def close(self): self.evwriter.Close() + # API def start(dir): if _Logger.CURRENT is not _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to start logging (dir=%s), but " - "you never stopped the previous logger (dir=%s)." - "\n" % (dir, _Logger.CURRENT.dir)) + sys.stderr.write( + "WARNING: You asked to start logging (dir=%s), but " + "you never stopped the previous logger (dir=%s)." + "\n" % (dir, _Logger.CURRENT.dir) + ) _Logger.CURRENT = _Logger(dir=dir) def stop(): if _Logger.CURRENT is _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to stop logging, but you never " - "started any previous logger." - "\n" % (dir, _Logger.CURRENT.dir)) + sys.stderr.write( + "WARNING: You asked to stop logging, but you never " + "started any previous logger." + "\n" % (dir, _Logger.CURRENT.dir) + ) return _Logger.CURRENT.close() _Logger.CURRENT = _Logger.DEFAULT @@ -126,6 +136,7 @@ def get_expt_dir(): sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n") return get_dir() + # Backend @@ -167,8 +178,10 @@ def dump_tabular(self): # Write to all text outputs self._write_text("-" * (keywidth + valwidth + 7), "\n") for (key, val) in key2str.items(): - self._write_text("| ", key, " " * (keywidth - len(key)), - " | ", val, " " * (valwidth - len(val)), " |\n") + self._write_text( + "| ", key, " " * (keywidth - len(key)), " | ", val, + " " * (valwidth - len(val)), " |\n" + ) self._write_text("-" * (keywidth + valwidth + 7), "\n") for f in self.text_outputs: try: @@ -202,7 +215,7 @@ def close(self): # Misc def _do_log(self, *args): - self._write_text(*args + ('\n',)) + self._write_text(*args + ('\n', )) for f in self.text_outputs: try: f.flush() diff --git a/python/ray/rllib/es/utils.py b/python/ray/rllib/es/utils.py index 6ea5d31acd25..4a85a365f07a 100644 --- a/python/ray/rllib/es/utils.py +++ b/python/ray/rllib/es/utils.py @@ -31,8 +31,11 @@ def compute_centered_ranks(x): def make_session(single_threaded): if not single_threaded: return tf.Session() - return tf.Session(config=tf.ConfigProto(inter_op_parallelism_threads=1, - intra_op_parallelism_threads=1)) + return tf.Session( + config=tf.ConfigProto( + inter_op_parallelism_threads=1, intra_op_parallelism_threads=1 + ) + ) def itergroups(items, group_size): @@ -50,10 +53,13 @@ def itergroups(items, group_size): def batched_weighted_sum(weights, vecs, batch_size): total = 0 num_items_summed = 0 - for batch_weights, batch_vecs in zip(itergroups(weights, batch_size), - itergroups(vecs, batch_size)): + for batch_weights, batch_vecs in zip( + itergroups(weights, batch_size), itergroups(vecs, batch_size) + ): assert len(batch_weights) == len(batch_vecs) <= batch_size - total += np.dot(np.asarray(batch_weights, dtype=np.float32), - np.asarray(batch_vecs, dtype=np.float32)) + total += np.dot( + np.asarray(batch_weights, dtype=np.float32), + np.asarray(batch_vecs, dtype=np.float32) + ) num_items_summed += len(batch_weights) return total, num_items_summed diff --git a/python/ray/rllib/examples/multiagent_mountaincar.py b/python/ray/rllib/examples/multiagent_mountaincar.py index 74f818d7e552..5d6a13ce5421 100644 --- a/python/ray/rllib/examples/multiagent_mountaincar.py +++ b/python/ray/rllib/examples/multiagent_mountaincar.py @@ -20,10 +20,10 @@ def pass_params_to_gym(env_name): global env_version_num register( - id=env_name, - entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv", - max_episode_steps=200, - kwargs={} + id=env_name, + entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv", + max_episode_steps=200, + kwargs={} ) @@ -46,10 +46,12 @@ def create_env(env_config): config["horizon"] = horizon config["use_gae"] = False config["model"].update({"fcnet_hiddens": [256, 256]}) - options = {"multiagent_obs_shapes": [2, 2], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": False, - "multiagent_fcnet_hiddens": [[32, 32]] * 2} + options = { + "multiagent_obs_shapes": [2, 2], + "multiagent_act_shapes": [1, 1], + "multiagent_shared_model": False, + "multiagent_fcnet_hiddens": [[32, 32]] * 2 + } config["model"].update({"custom_options": options}) alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config) for i in range(1): diff --git a/python/ray/rllib/examples/multiagent_mountaincar_env.py b/python/ray/rllib/examples/multiagent_mountaincar_env.py index d454937acb04..7882af7b0aec 100644 --- a/python/ray/rllib/examples/multiagent_mountaincar_env.py +++ b/python/ray/rllib/examples/multiagent_mountaincar_env.py @@ -2,7 +2,6 @@ from gym.spaces import Box, Tuple, Discrete import numpy as np from gym.envs.classic_control.mountain_car import MountainCarEnv - """ Multiagent mountain car that sums and then averages its actions to produce the velocity @@ -23,7 +22,8 @@ def __init__(self): self.action_space = [Discrete(3) for _ in range(2)] self.observation_space = Tuple([ - Box(self.low, self.high, dtype=np.float32) for _ in range(2)]) + Box(self.low, self.high, dtype=np.float32) for _ in range(2) + ]) self.seed() self.reset() diff --git a/python/ray/rllib/examples/multiagent_pendulum.py b/python/ray/rllib/examples/multiagent_pendulum.py index 20cd5d7ace77..af042476942c 100644 --- a/python/ray/rllib/examples/multiagent_pendulum.py +++ b/python/ray/rllib/examples/multiagent_pendulum.py @@ -20,10 +20,10 @@ def pass_params_to_gym(env_name): global env_version_num register( - id=env_name, - entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv", - max_episode_steps=100, - kwargs={} + id=env_name, + entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv", + max_episode_steps=100, + kwargs={} ) @@ -46,10 +46,12 @@ def create_env(env_config): config["horizon"] = horizon config["use_gae"] = True config["model"].update({"fcnet_hiddens": [256, 256]}) - options = {"multiagent_obs_shapes": [3, 3], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": True, - "multiagent_fcnet_hiddens": [[32, 32]] * 2} + options = { + "multiagent_obs_shapes": [3, 3], + "multiagent_act_shapes": [1, 1], + "multiagent_shared_model": True, + "multiagent_fcnet_hiddens": [[32, 32]] * 2 + } config["model"].update({"custom_options": options}) alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config) for i in range(1): diff --git a/python/ray/rllib/examples/multiagent_pendulum_env.py b/python/ray/rllib/examples/multiagent_pendulum_env.py index 44c86f4e6d2c..b0e5d877c010 100644 --- a/python/ray/rllib/examples/multiagent_pendulum_env.py +++ b/python/ray/rllib/examples/multiagent_pendulum_env.py @@ -2,7 +2,6 @@ from gym.utils import seeding from gym.envs.classic_control.pendulum import PendulumEnv import numpy as np - """ Multiagent pendulum that sums its torques to generate an action """ @@ -10,8 +9,8 @@ class MultiAgentPendulumEnv(PendulumEnv): metadata = { - 'render.modes': ['human', 'rgb_array'], - 'video.frames_per_second': 30 + 'render.modes': ['human', 'rgb_array'], + 'video.frames_per_second': 30 } def __init__(self): @@ -21,13 +20,17 @@ def __init__(self): self.viewer = None high = np.array([1., 1., self.max_speed]) - self.action_space = [Box(low=-self.max_torque / 2, - high=self.max_torque / 2, - shape=(1,), - dtype=np.float32) - for _ in range(2)] + self.action_space = [ + Box( + low=-self.max_torque / 2, + high=self.max_torque / 2, + shape=(1, ), + dtype=np.float32 + ) for _ in range(2) + ] self.observation_space = Tuple([ - Box(low=-high, high=high, dtype=np.float32) for _ in range(2)]) + Box(low=-high, high=high, dtype=np.float32) for _ in range(2) + ]) self.seed() @@ -49,8 +52,10 @@ def step(self, u): costs = self.angle_normalize(th) ** 2 + .1 * thdot ** 2 + \ .001 * (summed_u ** 2) - newthdot = thdot + (-3 * g / (2 * length) * np.sin(th + np.pi) + - 3. / (m * length ** 2) * summed_u) * dt + newthdot = thdot + ( + -3 * g / (2 * length) * np.sin(th + np.pi) + 3. / + (m * length**2) * summed_u + ) * dt newth = th + newthdot * dt newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) @@ -65,8 +70,10 @@ def reset(self): def _get_obs(self): theta, thetadot = self.state - return [np.array([np.cos(theta), np.sin(theta), thetadot]) - for _ in range(2)] + return [ + np.array([np.cos(theta), np.sin(theta), thetadot]) + for _ in range(2) + ] def angle_normalize(self, x): return (((x + np.pi) % (2 * np.pi)) - np.pi) diff --git a/python/ray/rllib/models/__init__.py b/python/ray/rllib/models/__init__.py index af3ac81dcb98..017fff372130 100644 --- a/python/ray/rllib/models/__init__.py +++ b/python/ray/rllib/models/__init__.py @@ -1,14 +1,15 @@ from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.action_dist import (ActionDistribution, Categorical, - DiagGaussian, Deterministic) +from ray.rllib.models.action_dist import ( + ActionDistribution, Categorical, DiagGaussian, Deterministic +) from ray.rllib.models.model import Model from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.convnet import ConvolutionalNetwork from ray.rllib.models.lstm import LSTM from ray.rllib.models.multiagentfcnet import MultiAgentFullyConnectedNetwork - -__all__ = ["ActionDistribution", "ActionDistribution", "Categorical", - "DiagGaussian", "Deterministic", "ModelCatalog", "Model", - "FullyConnectedNetwork", "ConvolutionalNetwork", "LSTM", - "MultiAgentFullyConnectedNetwork"] +__all__ = [ + "ActionDistribution", "ActionDistribution", "Categorical", "DiagGaussian", + "Deterministic", "ModelCatalog", "Model", "FullyConnectedNetwork", + "ConvolutionalNetwork", "LSTM", "MultiAgentFullyConnectedNetwork" +] diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 03e88bd1fc5e..002707dbe7fd 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -39,28 +39,33 @@ class Categorical(ActionDistribution): def logp(self, x): return -tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=self.inputs, labels=x) + logits=self.inputs, labels=x + ) def entropy(self): - a0 = self.inputs - tf.reduce_max(self.inputs, reduction_indices=[1], - keep_dims=True) + a0 = self.inputs - tf.reduce_max( + self.inputs, reduction_indices=[1], keep_dims=True + ) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) def kl(self, other): - a0 = self.inputs - tf.reduce_max(self.inputs, reduction_indices=[1], - keep_dims=True) - a1 = other.inputs - tf.reduce_max(other.inputs, reduction_indices=[1], - keep_dims=True) + a0 = self.inputs - tf.reduce_max( + self.inputs, reduction_indices=[1], keep_dims=True + ) + a1 = other.inputs - tf.reduce_max( + other.inputs, reduction_indices=[1], keep_dims=True + ) ea0 = tf.exp(a0) ea1 = tf.exp(a1) z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True) p0 = ea0 / z0 - return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), - reduction_indices=[1]) + return tf.reduce_sum( + p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1] + ) def sample(self): return tf.multinomial(self.inputs, 1)[0] @@ -81,22 +86,27 @@ def __init__(self, inputs): self.std = tf.exp(log_std) def logp(self, x): - return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), - reduction_indices=[1]) - - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - - tf.reduce_sum(self.log_std, reduction_indices=[1])) + return ( + -0.5 * tf.reduce_sum( + tf.square((x - self.mean) / self.std), reduction_indices=[1] + ) - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - + tf.reduce_sum(self.log_std, reduction_indices=[1]) + ) def kl(self, other): assert isinstance(other, DiagGaussian) - return tf.reduce_sum(other.log_std - self.log_std + - (tf.square(self.std) + - tf.square(self.mean - other.mean)) / - (2.0 * tf.square(other.std)) - 0.5, - reduction_indices=[1]) + return tf.reduce_sum( + other.log_std - self.log_std + + (tf.square(self.std) + tf.square(self.mean - other.mean)) / + (2.0 * tf.square(other.std)) - 0.5, + reduction_indices=[1] + ) def entropy(self): - return tf.reduce_sum(self.log_std + .5 * np.log(2.0 * np.pi * np.e), - reduction_indices=[1]) + return tf.reduce_sum( + self.log_std + .5 * np.log(2.0 * np.pi * np.e), + reduction_indices=[1] + ) def sample(self): return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) @@ -118,6 +128,7 @@ class MultiActionDistribution(ActionDistribution): Args: inputs (Tensor list): A list of tensors from which to compute samples. """ + def __init__(self, inputs, action_space, child_distributions): # you actually have to instantiate the child distributions self.reshaper = Reshaper(action_space.spaces) @@ -134,23 +145,24 @@ def logp(self, x): # Remove extra categorical dimension if isinstance(distribution, Categorical): split_list[i] = tf.squeeze(split_list[i], axis=-1) - log_list = np.asarray([distribution.logp(split_x) for - distribution, split_x in - zip(self.child_distributions, split_list)]) + log_list = np.asarray([ + distribution.logp(split_x) for distribution, split_x in + zip(self.child_distributions, split_list) + ]) return np.sum(log_list) def kl(self, other): """The KL-divergence between two action distributions.""" - kl_list = np.asarray([distribution.kl(other_distribution) for - distribution, other_distribution in - zip(self.child_distributions, - other.child_distributions)]) + kl_list = np.asarray([ + distribution.kl(other_distribution) + for distribution, other_distribution in + zip(self.child_distributions, other.child_distributions) + ]) return np.sum(kl_list) def entropy(self): """The entropy of the action distribution.""" - entropy_list = np.array([s.entropy() for s in - self.child_distributions]) + entropy_list = np.array([s.entropy() for s in self.child_distributions]) return np.sum(entropy_list) def sample(self): diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 8a423d309850..0b057ebc14cf 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -11,13 +11,13 @@ _default_registry from ray.rllib.models.action_dist import ( - Categorical, Deterministic, DiagGaussian, MultiActionDistribution) + Categorical, Deterministic, DiagGaussian, MultiActionDistribution +) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.multiagentfcnet import MultiAgentFullyConnectedNetwork - MODEL_CONFIGS = [ # === Built-in options === "conv_filters", # Number of filters @@ -81,12 +81,15 @@ def get_action_dist(action_space, dist_type=None): dist, action_size = ModelCatalog.get_action_dist(action) child_dist.append(dist) size += action_size - return partial(MultiActionDistribution, - child_distributions=child_dist, - action_space=action_space), size + return partial( + MultiActionDistribution, + child_distributions=child_dist, + action_space=action_space + ), size raise NotImplementedError( - "Unsupported args: {} {}".format(action_space, dist_type)) + "Unsupported args: {} {}".format(action_space, dist_type) + ) @staticmethod def get_action_placeholder(action_space): @@ -104,9 +107,10 @@ def get_action_placeholder(action_space): if isinstance(action_space, gym.spaces.Box): return tf.placeholder( - tf.float32, shape=(None, action_space.shape[0])) + tf.float32, shape=(None, action_space.shape[0]) + ) elif isinstance(action_space, gym.spaces.Discrete): - return tf.placeholder(tf.int64, shape=(None,)) + return tf.placeholder(tf.int64, shape=(None, )) elif isinstance(action_space, gym.spaces.Tuple): size = 0 all_discrete = True @@ -117,10 +121,13 @@ def get_action_placeholder(action_space): all_discrete = False size += np.product(action_space.spaces[i].shape) return tf.placeholder( - tf.int64 if all_discrete else tf.float32, shape=(None, size)) + tf.int64 if all_discrete else tf.float32, shape=(None, size) + ) else: - raise NotImplementedError("action space {}" - " not supported".format(action_space)) + raise NotImplementedError( + "action space {}" + " not supported".format(action_space) + ) @staticmethod def get_model(registry, inputs, num_outputs, options=dict()): @@ -139,16 +146,17 @@ def get_model(registry, inputs, num_outputs, options=dict()): if "custom_model" in options: model = options["custom_model"] print("Using custom model {}".format(model)) - return registry.get(RLLIB_MODEL, model)( - inputs, num_outputs, options) + return registry.get(RLLIB_MODEL, + model)(inputs, num_outputs, options) obs_rank = len(inputs.shape) - 1 # num_outputs > 1 used to avoid hitting this with the value function - if isinstance(options.get("custom_options", {}).get( - "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1: - return MultiAgentFullyConnectedNetwork(inputs, - num_outputs, options) + if isinstance( + options.get("custom_options", + {}).get("multiagent_fcnet_hiddens", 1), list + ) and num_outputs > 1: + return MultiAgentFullyConnectedNetwork(inputs, num_outputs, options) if obs_rank > 1: return VisionNetwork(inputs, num_outputs, options) @@ -170,15 +178,17 @@ def get_torch_model(registry, input_shape, num_outputs, options=dict()): model (Model): Neural network model. """ from ray.rllib.models.pytorch.fcnet import ( - FullyConnectedNetwork as PyTorchFCNet) + FullyConnectedNetwork as PyTorchFCNet + ) from ray.rllib.models.pytorch.visionnet import ( - VisionNetwork as PyTorchVisionNet) + VisionNetwork as PyTorchVisionNet + ) if "custom_model" in options: model = options["custom_model"] print("Using custom torch model {}".format(model)) - return registry.get(RLLIB_MODEL, model)( - input_shape, num_outputs, options) + return registry.get(RLLIB_MODEL, + model)(input_shape, num_outputs, options) obs_rank = len(input_shape) - 1 @@ -203,13 +213,15 @@ def get_preprocessor(registry, env, options=dict()): if k not in MODEL_CONFIGS: raise Exception( "Unknown config key `{}`, all keys: {}".format( - k, MODEL_CONFIGS)) + k, MODEL_CONFIGS + ) + ) if "custom_preprocessor" in options: preprocessor = options["custom_preprocessor"] print("Using custom preprocessor {}".format(preprocessor)) - return registry.get(RLLIB_PREPROCESSOR, preprocessor)( - env.observation_space, options) + return registry.get(RLLIB_PREPROCESSOR, + preprocessor)(env.observation_space, options) preprocessor = get_preprocessor(env.observation_space) return preprocessor(env.observation_space, options) @@ -242,7 +254,8 @@ def register_custom_preprocessor(preprocessor_name, preprocessor_class): preprocessor_class (type): Python class of the preprocessor. """ _default_registry.register( - RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class) + RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class + ) @staticmethod def register_custom_model(model_name, model_class): @@ -267,7 +280,8 @@ def __init__(self, env, preprocessor): from gym.spaces.box import Box self.observation_space = Box( - -1.0, 1.0, preprocessor.shape, dtype=np.float32) + -1.0, 1.0, preprocessor.shape, dtype=np.float32 + ) def observation(self, observation): return self.preprocessor.transform(observation) diff --git a/python/ray/rllib/models/convnet.py b/python/ray/rllib/models/convnet.py index 4074e0ad3777..8cd42ca62464 100644 --- a/python/ray/rllib/models/convnet.py +++ b/python/ray/rllib/models/convnet.py @@ -10,14 +10,17 @@ class ConvolutionalNetwork(Model): """Generic convolutional network.""" + # TODO(rliaw): converge on one generic ConvNet model def _init(self, inputs, num_outputs, options): x = inputs with tf.name_scope("convnet"): for i in range(4): - x = tf.nn.elu(conv2d(x, 32, "l{}".format(i+1), [3, 3], [2, 2])) + x = tf.nn.elu( + conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]) + ) r, c = x.shape[1].value, x.shape[2].value - x = tf.reshape(x, [-1, r*c*32]) + x = tf.reshape(x, [-1, r * c * 32]) fc1 = linear(x, 256, "fc1") fc2 = linear(x, num_outputs, "fc2", normc_initializer(0.01)) return fc2, fc1 diff --git a/python/ray/rllib/models/ddpgnet.py b/python/ray/rllib/models/ddpgnet.py index b881e6013dad..d925a3db76c3 100644 --- a/python/ray/rllib/models/ddpgnet.py +++ b/python/ray/rllib/models/ddpgnet.py @@ -17,13 +17,17 @@ def _init(self, inputs, num_outputs, options): ac_bound = options["action_bound"] net = slim.fully_connected( - inputs, 400, activation_fn=tf.nn.relu, - weights_initializer=w_normal) + inputs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal + ) net = slim.fully_connected( - net, 300, activation_fn=tf.nn.relu, weights_initializer=w_normal) + net, 300, activation_fn=tf.nn.relu, weights_initializer=w_normal + ) out = slim.fully_connected( - net, num_outputs, activation_fn=tf.nn.tanh, - weights_initializer=w_init) + net, + num_outputs, + activation_fn=tf.nn.tanh, + weights_initializer=w_init + ) scaled_out = tf.multiply(out, ac_bound) return scaled_out, net @@ -36,14 +40,21 @@ def _init(self, inputs, num_outputs, options): w_normal = tf.truncated_normal_initializer() w_init = tf.random_uniform_initializer(minval=-0.0003, maxval=0.0003) net = slim.fully_connected( - obs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal) + obs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal + ) t1 = slim.fully_connected( - net, 300, activation_fn=None, biases_initializer=None, - weights_initializer=w_normal) + net, + 300, + activation_fn=None, + biases_initializer=None, + weights_initializer=w_normal + ) t2 = slim.fully_connected( - action, 300, activation_fn=None, weights_initializer=w_normal) + action, 300, activation_fn=None, weights_initializer=w_normal + ) net = tf.nn.relu(tf.add(t1, t2)) out = slim.fully_connected( - net, 1, activation_fn=None, weights_initializer=w_init) + net, 1, activation_fn=None, weights_initializer=w_init + ) return out, net diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index ab40a6c6b875..b4e380031da4 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -27,14 +27,19 @@ def _init(self, inputs, num_outputs, options): for size in hiddens: label = "fc{}".format(i) last_layer = slim.fully_connected( - last_layer, size, + last_layer, + size, weights_initializer=normc_initializer(1.0), activation_fn=activation, - scope=label) + scope=label + ) i += 1 label = "fc_out" output = slim.fully_connected( - last_layer, num_outputs, + last_layer, + num_outputs, weights_initializer=normc_initializer(0.01), - activation_fn=None, scope=label) + activation_fn=None, + scope=label + ) return output, last_layer diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 1d950506b0b7..a4bb111c8142 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -7,8 +7,7 @@ import tensorflow.contrib.rnn as rnn import distutils.version -from ray.rllib.models.misc import (conv2d, linear, flatten, - normc_initializer) +from ray.rllib.models.misc import (conv2d, linear, flatten, normc_initializer) from ray.rllib.models.model import Model @@ -18,8 +17,10 @@ class LSTM(Model): # TODO(rliaw): Add LSTM code for other algorithms def _init(self, inputs, num_outputs, options): - use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.0.0")) + use_tf100_api = ( + distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.0.0") + ) self.x = x = inputs for i in range(4): @@ -46,10 +47,13 @@ def _init(self, inputs, num_outputs, options): state_in = rnn.LSTMStateTuple(c_in, h_in) else: state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) - lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, x, - initial_state=state_in, - sequence_length=step_size, - time_major=False) + lstm_out, lstm_state = tf.nn.dynamic_rnn( + lstm, + x, + initial_state=state_in, + sequence_length=step_size, + time_major=False + ) lstm_c, lstm_h = lstm_state x = tf.reshape(lstm_out, [-1, size]) logits = linear(x, num_outputs, "action", normc_initializer(0.01)) diff --git a/python/ray/rllib/models/misc.py b/python/ray/rllib/models/misc.py index a531bc07bf91..0b61f5601c28 100644 --- a/python/ray/rllib/models/misc.py +++ b/python/ray/rllib/models/misc.py @@ -11,15 +11,26 @@ def _initializer(shape, dtype=None, partition_info=None): out = np.random.randn(*shape).astype(np.float32) out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) return tf.constant(out) + return _initializer -def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", - dtype=tf.float32, collections=None): +def conv2d( + x, + num_filters, + name, + filter_size=(3, 3), + stride=(1, 1), + pad="SAME", + dtype=tf.float32, + collections=None +): with tf.variable_scope(name): stride_shape = [1, stride[0], stride[1], 1] - filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), - num_filters] + filter_shape = [ + filter_size[0], filter_size[1], + int(x.get_shape()[3]), num_filters + ] # There are "num input feature maps * filter height * filter width" # inputs to each hidden unit. @@ -30,20 +41,28 @@ def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", # Initialize weights with random weights. w_bound = np.sqrt(6 / (fan_in + fan_out)) - w = tf.get_variable("W", filter_shape, dtype, - tf.random_uniform_initializer(-w_bound, w_bound), - collections=collections) - b = tf.get_variable("b", [1, 1, 1, num_filters], - initializer=tf.constant_initializer(0.0), - collections=collections) + w = tf.get_variable( + "W", + filter_shape, + dtype, + tf.random_uniform_initializer(-w_bound, w_bound), + collections=collections + ) + b = tf.get_variable( + "b", [1, 1, 1, num_filters], + initializer=tf.constant_initializer(0.0), + collections=collections + ) return tf.nn.conv2d(x, w, stride_shape, pad) + b def linear(x, size, name, initializer=None, bias_init=0): - w = tf.get_variable(name + "/w", [x.get_shape()[1], size], - initializer=initializer) - b = tf.get_variable(name + "/b", [size], - initializer=tf.constant_initializer(bias_init)) + w = tf.get_variable( + name + "/w", [x.get_shape()[1], size], initializer=initializer + ) + b = tf.get_variable( + name + "/b", [size], initializer=tf.constant_initializer(bias_init) + ) return tf.matmul(x, w) + b diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index b1c5145d8877..bedbb66de459 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -34,13 +34,16 @@ def __init__(self, inputs, num_outputs, options): if options.get("free_log_std", False): assert num_outputs % 2 == 0 num_outputs = num_outputs // 2 - self.outputs, self.last_layer = self._init( - inputs, num_outputs, options) + self.outputs, self.last_layer = self._init(inputs, num_outputs, options) if options.get("free_log_std", False): - log_std = tf.get_variable(name="log_std", shape=[num_outputs], - initializer=tf.zeros_initializer) - self.outputs = tf.concat( - [self.outputs, 0.0 * self.outputs + log_std], 1) + log_std = tf.get_variable( + name="log_std", + shape=[num_outputs], + initializer=tf.zeros_initializer + ) + self.outputs = tf.concat([ + self.outputs, 0.0 * self.outputs + log_std + ], 1) def _init(self): """Builds and returns the output and last layer of the network.""" diff --git a/python/ray/rllib/models/multiagentfcnet.py b/python/ray/rllib/models/multiagentfcnet.py index daf88a7f2f02..f30458974dfd 100644 --- a/python/ray/rllib/models/multiagentfcnet.py +++ b/python/ray/rllib/models/multiagentfcnet.py @@ -23,8 +23,9 @@ def _init(self, inputs, num_outputs, options): num_actions = output_reshaper.split_number(num_outputs) custom_options = options["custom_options"] - hiddens = custom_options.get("multiagent_fcnet_hiddens", - [[256, 256]]*1) + hiddens = custom_options.get( + "multiagent_fcnet_hiddens", [[256, 256]] * 1 + ) # check for a shared model shared_model = custom_options.get("multiagent_shared_model", 0) @@ -37,7 +38,8 @@ def _init(self, inputs, num_outputs, options): sub_options.update({"fcnet_hiddens": hiddens[i]}) # TODO(ev) make this support arbitrary networks fcnet = FullyConnectedNetwork( - split_inputs[i], int(num_actions[i]), sub_options) + split_inputs[i], int(num_actions[i]), sub_options + ) output = fcnet.outputs outputs.append(output) overall_output = tf.concat(outputs, axis=1) diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index f6ec1fa5e071..2da0f2697021 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -6,7 +6,7 @@ import gym ATARI_OBS_SHAPE = (210, 160, 3) -ATARI_RAM_OBS_SHAPE = (128,) +ATARI_RAM_OBS_SHAPE = (128, ) class Preprocessor(object): @@ -70,7 +70,7 @@ def transform(self, observation): class AtariRamPreprocessor(Preprocessor): def _init(self): - self.shape = (128,) + self.shape = (128, ) def transform(self, observation): return (observation - 128) / 128 @@ -78,7 +78,7 @@ def transform(self, observation): class OneHotPreprocessor(Preprocessor): def _init(self): - self.shape = (self._obs_space.n,) + self.shape = (self._obs_space.n, ) def transform(self, observation): arr = np.zeros(self._obs_space.n) @@ -111,13 +111,14 @@ def _init(self): preprocessor = get_preprocessor(space)(space, self._options) self.preprocessors.append(preprocessor) size += np.product(preprocessor.shape) - self.shape = (size,) + self.shape = (size, ) def transform(self, observation): assert len(observation) == len(self.preprocessors), observation return np.concatenate([ np.reshape(p.transform(o), [np.product(p.shape)]) - for (o, p) in zip(observation, self.preprocessors)]) + for (o, p) in zip(observation, self.preprocessors) + ]) def get_preprocessor(space): diff --git a/python/ray/rllib/models/pytorch/fcnet.py b/python/ray/rllib/models/pytorch/fcnet.py index b67f1365bc9c..7ce7e4f9ed16 100644 --- a/python/ray/rllib/models/pytorch/fcnet.py +++ b/python/ray/rllib/models/pytorch/fcnet.py @@ -9,6 +9,7 @@ class FullyConnectedNetwork(Model): """TODO(rliaw): Logits, Value should both be contained here""" + def _init(self, inputs, num_outputs, options): assert type(inputs) is int hiddens = options.get("fcnet_hiddens", [256, 256]) @@ -23,23 +24,31 @@ def _init(self, inputs, num_outputs, options): layers = [] last_layer_size = inputs for size in hiddens: - layers.append(SlimFC( - last_layer_size, size, - initializer=normc_initializer(1.0), - activation_fn=activation)) + layers.append( + SlimFC( + last_layer_size, + size, + initializer=normc_initializer(1.0), + activation_fn=activation + ) + ) last_layer_size = size self.hidden_layers = nn.Sequential(*layers) self.logits = SlimFC( - last_layer_size, num_outputs, + last_layer_size, + num_outputs, initializer=normc_initializer(0.01), - activation_fn=None) + activation_fn=None + ) self.probs = nn.Softmax() self.value_branch = SlimFC( - last_layer_size, 1, + last_layer_size, + 1, initializer=normc_initializer(1.0), - activation_fn=None) + activation_fn=None + ) def forward(self, obs): """ Internal method - pass in Variables, not numpy arrays diff --git a/python/ray/rllib/models/pytorch/misc.py b/python/ray/rllib/models/pytorch/misc.py index 5cb5a4718162..59235c810772 100644 --- a/python/ray/rllib/models/pytorch/misc.py +++ b/python/ray/rllib/models/pytorch/misc.py @@ -10,19 +10,16 @@ def convert_batch(trajectory, has_features=False): """Convert trajectory from numpy to PT variable""" - states = Variable(torch.from_numpy( - trajectory["observations"]).float()) - acs = Variable(torch.from_numpy( - trajectory["actions"])) - advs = Variable(torch.from_numpy( - trajectory["advantages"].copy()).float()) + states = Variable(torch.from_numpy(trajectory["observations"]).float()) + acs = Variable(torch.from_numpy(trajectory["actions"])) + advs = Variable(torch.from_numpy(trajectory["advantages"].copy()).float()) advs = advs.view(-1, 1) - rs = Variable(torch.from_numpy( - trajectory["value_targets"]).float()) + rs = Variable(torch.from_numpy(trajectory["value_targets"]).float()) rs = rs.view(-1, 1) if has_features: - features = [Variable(torch.from_numpy(f)) - for f in trajectory["features"]] + features = [ + Variable(torch.from_numpy(f)) for f in trajectory["features"] + ] else: features = trajectory["features"] return states, acs, advs, rs, features @@ -35,8 +32,8 @@ def var_to_np(var): def normc_initializer(std=1.0): def initializer(tensor): tensor.data.normal_(0, 1) - tensor.data *= std / torch.sqrt( - tensor.data.pow(2).sum(1, keepdim=True)) + tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True)) + return initializer @@ -61,9 +58,11 @@ def valid_padding(in_size, filter_size, stride_size): out_width = np.ceil(float(in_width) / float(stride_width)) pad_along_height = int( - ((out_height - 1) * stride_height + filter_height - in_height)) + ((out_height - 1) * stride_height + filter_height - in_height) + ) pad_along_width = int( - ((out_width - 1) * stride_width + filter_width - in_width)) + ((out_width - 1) * stride_width + filter_width - in_width) + ) pad_top = pad_along_height // 2 pad_bottom = pad_along_height - pad_top pad_left = pad_along_width // 2 diff --git a/python/ray/rllib/models/pytorch/model.py b/python/ray/rllib/models/pytorch/model.py index fd1577f332de..6768c1553205 100644 --- a/python/ray/rllib/models/pytorch/model.py +++ b/python/ray/rllib/models/pytorch/model.py @@ -29,9 +29,17 @@ def forward(self, obs): class SlimConv2d(nn.Module): """Simple mock of tf.slim Conv2d""" - def __init__(self, in_channels, out_channels, kernel, stride, padding, - initializer=nn.init.xavier_uniform, - activation_fn=nn.ReLU, bias_init=0): + def __init__( + self, + in_channels, + out_channels, + kernel, + stride, + padding, + initializer=nn.init.xavier_uniform, + activation_fn=nn.ReLU, + bias_init=0 + ): super(SlimConv2d, self).__init__() layers = [] if padding: @@ -53,8 +61,9 @@ def forward(self, x): class SlimFC(nn.Module): """Simple PyTorch of `linear` function""" - def __init__(self, in_size, size, initializer=None, - activation_fn=None, bias_init=0): + def __init__( + self, in_size, size, initializer=None, activation_fn=None, bias_init=0 + ): super(SlimFC, self).__init__() layers = [] linear = nn.Linear(in_size, size) diff --git a/python/ray/rllib/models/pytorch/visionnet.py b/python/ray/rllib/models/pytorch/visionnet.py index 99786a8d4287..9f5a116eb20c 100644 --- a/python/ray/rllib/models/pytorch/visionnet.py +++ b/python/ray/rllib/models/pytorch/visionnet.py @@ -18,32 +18,34 @@ def _init(self, inputs, num_outputs, options): inputs (tuple): (channels, rows/height, cols/width) num_outputs (int): logits size """ - filters = options.get("conv_filters", [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1] - ]) + filters = options.get( + "conv_filters", + [[16, [8, 8], 4], [32, [4, 4], 2], [512, [10, 10], 1]] + ) layers = [] in_channels, in_size = inputs[0], inputs[1:] for out_channels, kernel, stride in filters[:-1]: - padding, out_size = valid_padding( - in_size, kernel, [stride, stride]) - layers.append(SlimConv2d( - in_channels, out_channels, kernel, stride, padding)) + padding, out_size = valid_padding(in_size, kernel, [stride, stride]) + layers.append( + SlimConv2d(in_channels, out_channels, kernel, stride, padding) + ) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] - layers.append(SlimConv2d( - in_channels, out_channels, kernel, stride, None)) + layers.append( + SlimConv2d(in_channels, out_channels, kernel, stride, None) + ) self._convs = nn.Sequential(*layers) self.logits = SlimFC( - out_channels, num_outputs, initializer=nn.init.xavier_uniform) + out_channels, num_outputs, initializer=nn.init.xavier_uniform + ) self.probs = nn.Softmax() self.value_branch = SlimFC( - out_channels, 1, initializer=normc_initializer()) + out_channels, 1, initializer=normc_initializer() + ) def hidden_layers(self, obs): """ Internal method - pass in Variables, not numpy arrays diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 198f40762a15..9b99e0591d45 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -12,19 +12,27 @@ class VisionNetwork(Model): """Generic vision network.""" def _init(self, inputs, num_outputs, options): - filters = options.get("conv_filters", [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ]) + filters = options.get( + "conv_filters", [ + [16, [8, 8], 4], + [32, [4, 4], 2], + [512, [10, 10], 1], + ] + ) with tf.name_scope("vision_net"): for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): inputs = slim.conv2d( - inputs, out_size, kernel, stride, - scope="conv{}".format(i)) + inputs, out_size, kernel, stride, scope="conv{}".format(i) + ) out_size, kernel, stride = filters[-1] fc1 = slim.conv2d( - inputs, out_size, kernel, stride, padding="VALID", scope="fc1") - fc2 = slim.conv2d(fc1, num_outputs, [1, 1], activation_fn=None, - normalizer_fn=None, scope="fc2") + inputs, out_size, kernel, stride, padding="VALID", scope="fc1" + ) + fc2 = slim.conv2d( + fc1, + num_outputs, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope="fc2" + ) return tf.squeeze(fc2, [1, 2]), tf.squeeze(fc1, [1, 2]) diff --git a/python/ray/rllib/optimizers/__init__.py b/python/ray/rllib/optimizers/__init__.py index 95be536c0a67..0239fb7541db 100644 --- a/python/ray/rllib/optimizers/__init__.py +++ b/python/ray/rllib/optimizers/__init__.py @@ -7,8 +7,8 @@ from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \ TFMultiGPUSupport - __all__ = [ "ApexOptimizer", "AsyncOptimizer", "LocalSyncOptimizer", "LocalSyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch", - "PolicyEvaluator", "TFMultiGPUSupport"] + "PolicyEvaluator", "TFMultiGPUSupport" +] diff --git a/python/ray/rllib/optimizers/apex_optimizer.py b/python/ray/rllib/optimizers/apex_optimizer.py index ded738f622fc..40ac5761222f 100644 --- a/python/ray/rllib/optimizers/apex_optimizer.py +++ b/python/ray/rllib/optimizers/apex_optimizer.py @@ -35,9 +35,10 @@ class ReplayActor(object): may be created to increase parallelism.""" def __init__( - self, num_shards, learning_starts, buffer_size, train_batch_size, - prioritized_replay_alpha, prioritized_replay_beta, - prioritized_replay_eps, clip_rewards): + self, num_shards, learning_starts, buffer_size, train_batch_size, + prioritized_replay_alpha, prioritized_replay_beta, + prioritized_replay_eps, clip_rewards + ): self.replay_starts = learning_starts // num_shards self.buffer_size = buffer_size // num_shards self.train_batch_size = train_batch_size @@ -45,8 +46,10 @@ def __init__( self.prioritized_replay_eps = prioritized_replay_eps self.replay_buffer = PrioritizedReplayBuffer( - self.buffer_size, alpha=prioritized_replay_alpha, - clip_rewards=clip_rewards) + self.buffer_size, + alpha=prioritized_replay_alpha, + clip_rewards=clip_rewards + ) # Metrics self.add_batch_timer = TimerStat() @@ -61,38 +64,45 @@ def add_batch(self, batch): for row in batch.rows(): self.replay_buffer.add( row["obs"], row["actions"], row["rewards"], row["new_obs"], - row["dones"], row["weights"]) + row["dones"], row["weights"] + ) def replay(self): with self.replay_timer: if len(self.replay_buffer) < self.replay_starts: return None - (obses_t, actions, rewards, obses_tp1, - dones, weights, batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, - beta=self.prioritized_replay_beta) + ( + obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes + ) = self.replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta + ) batch = SampleBatch({ - "obs": obses_t, "actions": actions, "rewards": rewards, - "new_obs": obses_tp1, "dones": dones, "weights": weights, - "batch_indexes": batch_indexes}) + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) return batch def update_priorities(self, batch_indexes, td_errors): with self.update_priorities_timer: - new_priorities = ( - np.abs(td_errors) + self.prioritized_replay_eps) + new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) self.replay_buffer.update_priorities(batch_indexes, new_priorities) def stats(self): stat = { - "add_batch_time_ms": round( - 1000 * self.add_batch_timer.mean, 3), - "replay_time_ms": round( - 1000 * self.replay_timer.mean, 3), - "update_priorities_time_ms": round( - 1000 * self.update_priorities_timer.mean, 3), + "add_batch_time_ms": + round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": + round(1000 * self.replay_timer.mean, 3), + "update_priorities_time_ms": + round(1000 * self.update_priorities_timer.mean, 3), } stat.update(self.replay_buffer.stats()) return stat @@ -127,8 +137,7 @@ def step(self): ra, replay = self.inqueue.get() if replay is not None: with self.grad_timer: - td_error = self.local_evaluator.compute_apply(replay)[ - "td_error"] + td_error = self.local_evaluator.compute_apply(replay)["td_error"] self.outqueue.put((ra, replay, td_error)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True @@ -145,12 +154,20 @@ class ApexOptimizer(PolicyOptimizer): term will be used for sample prioritization.""" def _init( - self, learning_starts=1000, buffer_size=10000, - prioritized_replay=True, prioritized_replay_alpha=0.6, - prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, - train_batch_size=512, sample_batch_size=50, - num_replay_buffer_shards=1, max_weight_sync_delay=400, - clip_rewards=True, debug=False): + self, + learning_starts=1000, + buffer_size=10000, + prioritized_replay=True, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6, + train_batch_size=512, + sample_batch_size=50, + num_replay_buffer_shards=1, + max_weight_sync_delay=400, + clip_rewards=True, + debug=False + ): self.debug = debug self.replay_starts = learning_starts @@ -165,16 +182,22 @@ def _init( self.replay_actors = create_colocated( ReplayActor, - [num_replay_buffer_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps, clip_rewards], - num_replay_buffer_shards) + [ + num_replay_buffer_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps, clip_rewards + ], num_replay_buffer_shards + ) assert len(self.remote_evaluators) > 0 # Stats - self.timers = {k: TimerStat() for k in [ - "put_weights", "get_samples", "enqueue", "sample_processing", - "replay_processing", "update_priorities", "train", "sample"]} + self.timers = { + k: TimerStat() + for k in [ + "put_weights", "get_samples", "enqueue", "sample_processing", + "replay_processing", "update_priorities", "train", "sample" + ] + } self.num_weight_syncs = 0 self.learning_started = False @@ -219,8 +242,7 @@ def _step(self): sample_timesteps += self.sample_batch_size # Send the data to the replay buffer - random.choice(self.replay_actors).add_batch.remote( - sample_batch) + random.choice(self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += self.sample_batch_size @@ -231,7 +253,8 @@ def _step(self): self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( - self.local_evaluator.get_weights()) + self.local_evaluator.get_weights() + ) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 @@ -262,12 +285,14 @@ def stats(self): for k in self.timers } timing["learner_grad_time_ms"] = round( - 1000 * self.learner.grad_timer.mean, 3) + 1000 * self.learner.grad_timer.mean, 3 + ) timing["learner_dequeue_time_ms"] = round( - 1000 * self.learner.queue_timer.mean, 3) + 1000 * self.learner.queue_timer.mean, 3 + ) stats = { - "sample_throughput": round( - self.timers["sample"].mean_throughput, 3), + "sample_throughput": + round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, } diff --git a/python/ray/rllib/optimizers/async_optimizer.py b/python/ray/rllib/optimizers/async_optimizer.py index 93c363345695..e3c01b7c3575 100644 --- a/python/ray/rllib/optimizers/async_optimizer.py +++ b/python/ray/rllib/optimizers/async_optimizer.py @@ -14,6 +14,7 @@ class AsyncOptimizer(PolicyOptimizer): evaluators, sending updated weights back as needed. This pipelines the gradient computations on the remote workers. """ + def _init(self, grads_per_step=100, batch_size=10): self.apply_timer = TimerStat() self.wait_timer = TimerStat() @@ -54,8 +55,10 @@ def step(self): self.num_steps_trained += self.grads_per_step * self.batch_size def stats(self): - return dict(PolicyOptimizer.stats(), **{ - "wait_time_ms": round(1000 * self.wait_timer.mean, 3), - "apply_time_ms": round(1000 * self.apply_timer.mean, 3), - "dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3), - }) + return dict( + PolicyOptimizer.stats(), **{ + "wait_time_ms": round(1000 * self.wait_timer.mean, 3), + "apply_time_ms": round(1000 * self.apply_timer.mean, 3), + "dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3), + } + ) diff --git a/python/ray/rllib/optimizers/local_sync.py b/python/ray/rllib/optimizers/local_sync.py index 3f71cce4e988..98dd5e4f6ca6 100644 --- a/python/ray/rllib/optimizers/local_sync.py +++ b/python/ray/rllib/optimizers/local_sync.py @@ -34,8 +34,8 @@ def step(self): with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( - ray.get( - [e.sample.remote() for e in self.remote_evaluators])) + ray.get([e.sample.remote() for e in self.remote_evaluators]) + ) else: samples = self.local_evaluator.sample() @@ -48,10 +48,14 @@ def step(self): self.num_steps_trained += samples.count def stats(self): - return dict(PolicyOptimizer.stats(self), **{ - "sample_time_ms": round(1000 * self.sample_timer.mean, 3), - "grad_time_ms": round(1000 * self.grad_timer.mean, 3), - "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), - "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3), - "opt_samples": round(self.grad_timer.mean_units_processed, 3), - }) + return dict( + PolicyOptimizer.stats(self), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": + round(1000 * self.update_weights_timer.mean, 3), + "opt_peak_throughput": + round(self.grad_timer.mean_throughput, 3), + "opt_samples": round(self.grad_timer.mean_units_processed, 3), + } + ) diff --git a/python/ray/rllib/optimizers/local_sync_replay.py b/python/ray/rllib/optimizers/local_sync_replay.py index ac430c6a1619..8361e97d141e 100644 --- a/python/ray/rllib/optimizers/local_sync_replay.py +++ b/python/ray/rllib/optimizers/local_sync_replay.py @@ -22,10 +22,17 @@ class LocalSyncReplayOptimizer(PolicyOptimizer): term will be used for sample prioritization.""" def _init( - self, learning_starts=1000, buffer_size=10000, - prioritized_replay=True, prioritized_replay_alpha=0.6, - prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, - train_batch_size=32, sample_batch_size=4, clip_rewards=True): + self, + learning_starts=1000, + buffer_size=10000, + prioritized_replay=True, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6, + train_batch_size=32, + sample_batch_size=4, + clip_rewards=True + ): self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta @@ -42,8 +49,10 @@ def _init( # Set up replay buffer if prioritized_replay: self.replay_buffer = PrioritizedReplayBuffer( - buffer_size, alpha=prioritized_replay_alpha, - clip_rewards=clip_rewards) + buffer_size, + alpha=prioritized_replay_alpha, + clip_rewards=clip_rewards + ) else: self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards) @@ -59,15 +68,15 @@ def step(self): with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( - ray.get( - [e.sample.remote() for e in self.remote_evaluators])) + ray.get([e.sample.remote() for e in self.remote_evaluators]) + ) else: batch = self.local_evaluator.sample() for row in batch.rows(): self.replay_buffer.add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], - pack_if_needed(row["new_obs"]), - row["dones"], row["weights"]) + pack_if_needed(row["new_obs"]), row["dones"], row["weights"] + ) if len(self.replay_buffer) >= self.replay_starts: self._optimize() @@ -77,39 +86,51 @@ def step(self): def _optimize(self): with self.replay_timer: if isinstance(self.replay_buffer, PrioritizedReplayBuffer): - (obses_t, actions, rewards, obses_tp1, - dones, weights, batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, - beta=self.prioritized_replay_beta) + ( + obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes + ) = self.replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta + ) else: (obses_t, actions, rewards, obses_tp1, - dones) = self.replay_buffer.sample( - self.train_batch_size) + dones) = self.replay_buffer.sample(self.train_batch_size) weights = np.ones_like(rewards) - batch_indexes = - np.ones_like(rewards) + batch_indexes = -np.ones_like(rewards) samples = SampleBatch({ - "obs": obses_t, "actions": actions, "rewards": rewards, - "new_obs": obses_tp1, "dones": dones, "weights": weights, - "batch_indexes": batch_indexes}) + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) with self.grad_timer: info = self.local_evaluator.compute_apply(samples) if isinstance(self.replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] new_priorities = ( - np.abs(td_error) + self.prioritized_replay_eps) + np.abs(td_error) + self.prioritized_replay_eps + ) self.replay_buffer.update_priorities( - samples["batch_indexes"], new_priorities) + samples["batch_indexes"], new_priorities + ) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count def stats(self): - return dict(PolicyOptimizer.stats(self), **{ - "sample_time_ms": round(1000 * self.sample_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "grad_time_ms": round(1000 * self.grad_timer.mean, 3), - "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), - "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3), - "opt_samples": round(self.grad_timer.mean_units_processed, 3), - }) + return dict( + PolicyOptimizer.stats(self), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": + round(1000 * self.update_weights_timer.mean, 3), + "opt_peak_throughput": + round(self.grad_timer.mean_throughput, 3), + "opt_samples": round(self.grad_timer.mean_units_processed, 3), + } + ) diff --git a/python/ray/rllib/optimizers/multi_gpu.py b/python/ray/rllib/optimizers/multi_gpu.py index f9d3f4a85b5f..c59b8fceb5c9 100644 --- a/python/ray/rllib/optimizers/multi_gpu.py +++ b/python/ray/rllib/optimizers/multi_gpu.py @@ -53,12 +53,10 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10): tf.get_variable_scope().reuse_variables() self.par_opt = LocalSyncParallelOptimizer( - tf.train.AdamOptimizer(self.sgd_stepsize), - self.devices, - [ph for _, ph in self.loss_inputs], - self.per_device_batch_size, - lambda *ph: self.local_evaluator.build_tf_loss(ph), - os.getcwd()) + tf.train.AdamOptimizer(self.sgd_stepsize), self.devices, + [ph for _, ph in self.loss_inputs], self.per_device_batch_size, + lambda *ph: self.local_evaluator.build_tf_loss(ph), os.getcwd() + ) self.sess = self.local_evaluator.sess self.sess.run(tf.global_variables_initializer()) @@ -73,8 +71,8 @@ def step(self): with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( - ray.get( - [e.sample.remote() for e in self.remote_evaluators])) + ray.get([e.sample.remote() for e in self.remote_evaluators]) + ) else: samples = self.local_evaluator.sample() assert isinstance(samples, SampleBatch) @@ -82,29 +80,35 @@ def step(self): with self.load_timer: tuples_per_device = self.par_opt.load_data( self.local_evaluator.sess, - samples.columns([key for key, _ in self.loss_inputs])) + samples.columns([key for key, _ in self.loss_inputs]) + ) with self.grad_timer: for i in range(self.num_sgd_iter): batch_index = 0 num_batches = ( - int(tuples_per_device) // int(self.per_device_batch_size)) + int(tuples_per_device) // int(self.per_device_batch_size) + ) permutation = np.random.permutation(num_batches) while batch_index < num_batches: # TODO(ekl) support ppo's debugging features, e.g. # printing the current loss and tracing self.par_opt.optimize( self.sess, - permutation[batch_index] * self.per_device_batch_size) + permutation[batch_index] * self.per_device_batch_size + ) batch_index += 1 self.num_steps_sampled += samples.count self.num_steps_trained += samples.count def stats(self): - return dict(PolicyOptimizer.stats(), **{ - "sample_time_ms": round(1000 * self.sample_timer.mean, 3), - "load_time_ms": round(1000 * self.load_timer.mean, 3), - "grad_time_ms": round(1000 * self.grad_timer.mean, 3), - "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), - }) + return dict( + PolicyOptimizer.stats(), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "load_time_ms": round(1000 * self.load_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": + round(1000 * self.update_weights_timer.mean, 3), + } + ) diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 1ff6bff3f05f..532a545b866b 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -8,7 +8,6 @@ from tensorflow.python.client import timeline import tensorflow as tf - # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" @@ -48,9 +47,16 @@ class LocalSyncParallelOptimizer(object): grad_norm_clipping: None or int stdev to clip grad norms by """ - def __init__(self, optimizer, devices, input_placeholders, - per_device_batch_size, build_loss, logdir, - grad_norm_clipping=None): + def __init__( + self, + optimizer, + devices, + input_placeholders, + per_device_batch_size, + build_loss, + logdir, + grad_norm_clipping=None + ): self.optimizer = optimizer self.devices = devices self.batch_size = per_device_batch_size * len(devices) @@ -69,12 +75,12 @@ def __init__(self, optimizer, devices, input_placeholders, # Split on the CPU in case the data doesn't fit in GPU memory. with tf.device("/cpu:0"): data_splits = zip( - *[tf.split(ph, len(devices)) for ph in input_placeholders]) + *[tf.split(ph, len(devices)) for ph in input_placeholders] + ) self._towers = [] for device, device_placeholders in zip(self.devices, data_splits): - self._towers.append(self._setup_device(device, - device_placeholders)) + self._towers.append(self._setup_device(device, device_placeholders)) avg = average_gradients([t.grads for t in self._towers]) if grad_norm_clipping: @@ -115,15 +121,15 @@ def load_data(self, sess, inputs, full_trace=False): run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) run_metadata = tf.RunMetadata() - sess.run( - [t.init_op for t in self._towers], - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) + sess.run([t.init_op for t in self._towers], + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) if full_trace: trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-load.json"), - "w") + trace_file = open( + os.path.join(self.logdir, "timeline-load.json"), "w" + ) trace_file.write(trace.generate_chrome_trace_format()) tuples_per_device = truncated_len / len(self.devices) @@ -135,8 +141,14 @@ def load_data(self, sess, inputs, full_trace=False): assert tuples_per_device % self.per_device_batch_size == 0 return tuples_per_device - def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={}, - file_writer=None): + def optimize( + self, + sess, + batch_index, + extra_ops=[], + extra_feed_dict={}, + file_writer=None + ): """Run a single step of SGD. Runs a SGD step over a slice of the preloaded batch with size given by @@ -168,19 +180,20 @@ def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={}, feed_dict = {self._batch_index: batch_index} feed_dict.update(extra_feed_dict) - outs = sess.run( - [self._train_op] + extra_ops, - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) + outs = sess.run([self._train_op] + extra_ops, + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) if file_writer: trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), - "w") + trace_file = open( + os.path.join(self.logdir, "timeline-sgd.json"), "w" + ) trace_file.write(trace.generate_chrome_trace_format()) file_writer.add_run_metadata( - run_metadata, "sgd_train_{}".format(batch_index)) + run_metadata, "sgd_train_{}".format(batch_index) + ) return outs[1:] @@ -197,24 +210,29 @@ def _setup_device(self, device, device_input_placeholders): device_input_slices = [] for ph in device_input_placeholders: current_batch = tf.Variable( - ph, trainable=False, validate_shape=False, - collections=[]) + ph, + trainable=False, + validate_shape=False, + collections=[] + ) device_input_batches.append(current_batch) current_slice = tf.slice( current_batch, [self._batch_index] + [0] * len(ph.shape[1:]), - ([self.per_device_batch_size] + [-1] * - len(ph.shape[1:]))) + ([self.per_device_batch_size] + + [-1] * len(ph.shape[1:])) + ) current_slice.set_shape(ph.shape) device_input_slices.append(current_slice) device_loss_obj = self.build_loss(*device_input_slices) device_grads = self.optimizer.compute_gradients( - device_loss_obj.loss, colocate_gradients_with_ops=True) + device_loss_obj.loss, colocate_gradients_with_ops=True + ) return Tower( - tf.group(*[batch.initializer - for batch in device_input_batches]), - device_grads, - device_loss_obj) + tf. + group(*[batch.initializer for batch in device_input_batches]), + device_grads, device_loss_obj + ) # Each tower is a copy of the loss graph pinned to a specific device. diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 1e31edc66ea1..2b20599d5e25 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -32,8 +32,13 @@ class PolicyOptimizer(object): @classmethod def make( - cls, evaluator_cls, evaluator_args, num_workers, optimizer_config, - evaluator_resources={"num_cpus": None}): + cls, + evaluator_cls, + evaluator_args, + num_workers, + optimizer_config, + evaluator_resources={"num_cpus": None} + ): """Create evaluators and an optimizer instance using those evaluators. Args: @@ -48,8 +53,8 @@ def make( local_evaluator = evaluator_cls(*evaluator_args) remote_cls = ray.remote(**evaluator_resources)(evaluator_cls) remote_evaluators = [ - remote_cls.remote(*evaluator_args) - for _ in range(num_workers)] + remote_cls.remote(*evaluator_args) for _ in range(num_workers) + ] return cls(optimizer_config, local_evaluator, remote_evaluators) def __init__(self, config, local_evaluator, remote_evaluators): diff --git a/python/ray/rllib/optimizers/replay_buffer.py b/python/ray/rllib/optimizers/replay_buffer.py index d38014ba26ee..8d7cbc41a6ac 100644 --- a/python/ray/rllib/optimizers/replay_buffer.py +++ b/python/ray/rllib/optimizers/replay_buffer.py @@ -65,8 +65,10 @@ def _encode_sample(self, idxes): obses_tp1.append(np.array(unpack(obs_tp1), copy=False)) dones.append(done) self._hit_count[i] += 1 - return (np.array(obses_t), np.array(actions), np.array(rewards), - np.array(obses_tp1), np.array(dones)) + return ( + np.array(obses_t), np.array(actions), np.array(rewards), + np.array(obses_tp1), np.array(dones) + ) def sample(self, batch_size): """Sample a batch of experiences. @@ -90,8 +92,10 @@ def sample(self, batch_size): done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ - idxes = [random.randint(0, len(self._storage) - 1) - for _ in range(batch_size)] + idxes = [ + random.randint(0, + len(self._storage) - 1) for _ in range(batch_size) + ] self._num_sampled += batch_size return self._encode_sample(idxes) @@ -142,19 +146,18 @@ def add(self, obs_t, action, reward, obs_tp1, done, weight): reward = np.sign(reward) idx = self._next_idx - super(PrioritizedReplayBuffer, self).add( - obs_t, action, reward, obs_tp1, done, weight) + super(PrioritizedReplayBuffer, + self).add(obs_t, action, reward, obs_tp1, done, weight) if weight is None: weight = self._max_priority - self._it_sum[idx] = weight ** self._alpha - self._it_min[idx] = weight ** self._alpha + self._it_sum[idx] = weight**self._alpha + self._it_min[idx] = weight**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? - mass = random.random() * self._it_sum.sum(0, - len(self._storage) - 1) + mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @@ -202,11 +205,11 @@ def sample(self, batch_size, beta): weights = [] p_min = self._it_min.min() / self._it_sum.sum() - max_weight = (p_min * len(self._storage)) ** (-beta) + max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() - weight = (p_sample * len(self._storage)) ** (-beta) + weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) @@ -231,10 +234,10 @@ def update_priorities(self, idxes, priorities): for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) - delta = priority ** self._alpha - self._it_sum[idx] + delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) - self._it_sum[idx] = priority ** self._alpha - self._it_min[idx] = priority ** self._alpha + self._it_sum[idx] = priority**self._alpha + self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) diff --git a/python/ray/rllib/optimizers/segment_tree.py b/python/ray/rllib/optimizers/segment_tree.py index b412a89bd349..ea5e81983aa1 100644 --- a/python/ray/rllib/optimizers/segment_tree.py +++ b/python/ray/rllib/optimizers/segment_tree.py @@ -48,13 +48,15 @@ def _reduce_helper(self, start, end, node, node_start, node_end): return self._reduce_helper(start, end, 2 * node, node_start, mid) else: if mid + 1 <= start: - return self._reduce_helper(start, end, 2 * node + 1, mid + 1, - node_end) + return self._reduce_helper( + start, end, 2 * node + 1, mid + 1, node_end + ) else: return self._operation( self._reduce_helper(start, mid, 2 * node, node_start, mid), - self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, - node_end) + self._reduce_helper( + mid + 1, end, 2 * node + 1, mid + 1, node_end + ) ) def reduce(self, start=0, end=None): @@ -90,8 +92,8 @@ def __setitem__(self, idx, val): idx //= 2 while idx >= 1: self._value[idx] = self._operation( - self._value[2 * idx], - self._value[2 * idx + 1]) + self._value[2 * idx], self._value[2 * idx + 1] + ) idx //= 2 def __getitem__(self, idx): @@ -102,9 +104,8 @@ def __getitem__(self, idx): class SumSegmentTree(SegmentTree): def __init__(self, capacity): super(SumSegmentTree, self).__init__( - capacity=capacity, - operation=operator.add, - neutral_element=0.0) + capacity=capacity, operation=operator.add, neutral_element=0.0 + ) def sum(self, start=0, end=None): """Returns arr[start] + ... + arr[end]""" @@ -142,9 +143,8 @@ def find_prefixsum_idx(self, prefixsum): class MinSegmentTree(SegmentTree): def __init__(self, capacity): super(MinSegmentTree, self).__init__( - capacity=capacity, - operation=min, - neutral_element=float('inf')) + capacity=capacity, operation=min, neutral_element=float('inf') + ) def min(self, start=0, end=None): """Returns min(arr[start], ..., arr[end])""" diff --git a/python/ray/rllib/pg/pg.py b/python/ray/rllib/pg/pg.py index c3726f89f504..81203a2cf2fb 100644 --- a/python/ray/rllib/pg/pg.py +++ b/python/ray/rllib/pg/pg.py @@ -11,7 +11,6 @@ from ray.tune.result import TrainingResult from ray.tune.trial import Resources - DEFAULT_CONFIG = { # Number of workers (excluding master) "num_workers": 4, @@ -26,14 +25,15 @@ # Arguments to pass to the rllib optimizer "optimizer": {}, # Model parameters - "model": {"fcnet_hiddens": [128, 128]}, + "model": { + "fcnet_hiddens": [128, 128] + }, # Arguments to pass to the env creator "env_config": {}, } class PGAgent(Agent): - """Simple policy gradient agent. This is an example agent to show how to implement algorithms in RLlib. @@ -53,15 +53,18 @@ def _init(self): evaluator_cls=PGEvaluator, evaluator_args=[self.registry, self.env_creator, self.config], num_workers=self.config["num_workers"], - optimizer_config=self.config["optimizer"]) + optimizer_config=self.config["optimizer"] + ) def _train(self): self.optimizer.step() episode_rewards = [] episode_lengths = [] - metric_lists = [a.get_completed_rollout_metrics.remote() - for a in self.optimizer.remote_evaluators] + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.optimizer.remote_evaluators + ] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) @@ -74,7 +77,8 @@ def _train(self): episode_reward_mean=avg_reward, episode_len_mean=avg_length, timesteps_this_iter=timesteps, - info={}) + info={} + ) return result diff --git a/python/ray/rllib/pg/pg_evaluator.py b/python/ray/rllib/pg/pg_evaluator.py index 1f217ba02855..5164c99994d2 100644 --- a/python/ray/rllib/pg/pg_evaluator.py +++ b/python/ray/rllib/pg/pg_evaluator.py @@ -15,20 +15,26 @@ class PGEvaluator(PolicyEvaluator): def __init__(self, registry, env_creator, config): self.env = ModelCatalog.get_preprocessor_as_wrapper( - registry, env_creator(config["env_config"]), config["model"]) + registry, env_creator(config["env_config"]), config["model"] + ) self.config = config - self.policy = PGPolicy(registry, self.env.observation_space, - self.env.action_space, config) + self.policy = PGPolicy( + registry, self.env.observation_space, self.env.action_space, config + ) self.sampler = SyncSampler( - self.env, self.policy, NoFilter(), - config["batch_size"], horizon=config["horizon"]) + self.env, + self.policy, + NoFilter(), + config["batch_size"], + horizon=config["horizon"] + ) def sample(self): rollout = self.sampler.get_data() samples = process_rollout( - rollout, NoFilter(), - gamma=self.config["gamma"], use_gae=False) + rollout, NoFilter(), gamma=self.config["gamma"], use_gae=False + ) return samples def get_completed_rollout_metrics(self): diff --git a/python/ray/rllib/pg/policy.py b/python/ray/rllib/pg/policy.py index cc53eebcbd84..7ab0c3487600 100644 --- a/python/ray/rllib/pg/policy.py +++ b/python/ray/rllib/pg/policy.py @@ -24,16 +24,18 @@ def __init__(self, registry, ob_space, ac_space, config): self.initialize() def _setup_graph(self, ob_space, ac_space): - self.x = tf.placeholder(tf.float32, shape=[None]+list(ob_space.shape)) + self.x = tf.placeholder(tf.float32, shape=[None] + list(ob_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) self.model = ModelCatalog.get_model( - self.registry, self.x, self.logit_dim, - options=self.config["model"]) + self.registry, self.x, self.logit_dim, options=self.config["model"] + ) self.action_logits = self.model.outputs # logit for each action self.dist = dist_class(self.action_logits) self.sample = self.dist.sample() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) + self.var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name + ) def _setup_loss(self, action_space): self.ac = ModelCatalog.get_action_placeholder(action_space) @@ -53,7 +55,8 @@ def _setup_gradients(self): def initialize(self): self.sess = tf.Session() self.variables = ray.experimental.TensorFlowVariables( - self.loss, self.sess) + self.loss, self.sess + ) self.sess.run(tf.global_variables_initializer()) def compute_gradients(self, samples): diff --git a/python/ray/rllib/ppo/loss.py b/python/ray/rllib/ppo/loss.py index 3f69ff711692..88d84bac3268 100644 --- a/python/ray/rllib/ppo/loss.py +++ b/python/ray/rllib/ppo/loss.py @@ -13,17 +13,18 @@ class ProximalPolicyLoss(object): is_recurrent = False def __init__( - self, observation_space, action_space, - observations, value_targets, advantages, actions, - prev_logits, prev_vf_preds, logit_dim, - kl_coeff, distribution_class, config, sess, registry): + self, observation_space, action_space, observations, value_targets, + advantages, actions, prev_logits, prev_vf_preds, logit_dim, kl_coeff, + distribution_class, config, sess, registry + ): self.prev_dist = distribution_class(prev_logits) # Saved so that we can compute actions given different observations self.observations = observations self.curr_logits = ModelCatalog.get_model( - registry, observations, logit_dim, config["model"]).outputs + registry, observations, logit_dim, config["model"] + ).outputs self.curr_dist = distribution_class(self.curr_logits) self.sampler = self.curr_dist.sample() @@ -35,19 +36,22 @@ def __init__( vf_config["free_log_std"] = False with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model( - registry, observations, 1, vf_config).outputs + registry, observations, 1, vf_config + ).outputs self.value_function = tf.reshape(self.value_function, [-1]) # Make loss functions. - self.ratio = tf.exp(self.curr_dist.logp(actions) - - self.prev_dist.logp(actions)) + self.ratio = tf.exp( + self.curr_dist.logp(actions) - self.prev_dist.logp(actions) + ) self.kl = self.prev_dist.kl(self.curr_dist) self.mean_kl = tf.reduce_mean(self.kl) self.entropy = self.curr_dist.entropy() self.mean_entropy = tf.reduce_mean(self.entropy) self.surr1 = self.ratio * advantages - self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"], - 1 + config["clip_param"]) * advantages + self.surr2 = tf.clip_by_value( + self.ratio, 1 - config["clip_param"], 1 + config["clip_param"] + ) * advantages self.surr = tf.minimum(self.surr1, self.surr2) self.mean_policy_loss = tf.reduce_mean(-self.surr) @@ -57,35 +61,40 @@ def __init__( # scales superlinearly with the length of the rollout) self.vf_loss1 = tf.square(self.value_function - value_targets) vf_clipped = prev_vf_preds + tf.clip_by_value( - self.value_function - prev_vf_preds, - -config["clip_param"], config["clip_param"]) + self.value_function - prev_vf_preds, -config["clip_param"], + config["clip_param"] + ) self.vf_loss2 = tf.square(vf_clipped - value_targets) self.vf_loss = tf.minimum(self.vf_loss1, self.vf_loss2) self.mean_vf_loss = tf.reduce_mean(self.vf_loss) self.loss = tf.reduce_mean( -self.surr + kl_coeff * self.kl + config["vf_loss_coeff"] * self.vf_loss - - config["entropy_coeff"] * self.entropy) + config["entropy_coeff"] * self.entropy + ) else: self.mean_vf_loss = tf.constant(0.0) self.loss = tf.reduce_mean( - -self.surr + - kl_coeff * self.kl - - config["entropy_coeff"] * self.entropy) + -self.surr + kl_coeff * self.kl - + config["entropy_coeff"] * self.entropy + ) self.sess = sess if config["use_gae"]: self.policy_results = [ - self.sampler, self.curr_logits, self.value_function] + self.sampler, self.curr_logits, self.value_function + ] else: self.policy_results = [ - self.sampler, self.curr_logits, tf.constant("NA")] + self.sampler, self.curr_logits, + tf.constant("NA") + ] def compute(self, observation): action, logprobs, vf = self.sess.run( - self.policy_results, - feed_dict={self.observations: [observation]}) + self.policy_results, feed_dict={self.observations: [observation]} + ) return action[0], {"vf_preds": vf[0], "logprobs": logprobs[0]} def loss(self): diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 8f550d318f43..9ca0a88fe6d5 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -18,7 +18,6 @@ from ray.rllib.ppo.ppo_evaluator import PPOEvaluator from ray.rllib.ppo.rollout import collect_samples - DEFAULT_CONFIG = { # Discount factor of the MDP "gamma": 0.995, @@ -39,7 +38,9 @@ # as a command line argument. "devices": ["/cpu:%d" % i for i in range(4)], "tf_session_args": { - "device_count": {"CPU": 4}, + "device_count": { + "CPU": 4 + }, "log_device_placement": False, "allow_soft_placement": True, "intra_op_parallelism_threads": 1, @@ -58,7 +59,9 @@ # Target value for KL divergence "kl_target": 0.01, # Config params to pass to the model - "model": {"free_log_std": False}, + "model": { + "free_log_std": False + }, # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", # If >1, adds frameskip @@ -102,25 +105,31 @@ def default_resource_request(cls, config): cpu=1, gpu=len([d for d in cf["devices"] if "gpu" in d.lower()]), extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"] + ) def _init(self): self.global_step = 0 self.kl_coeff = self.config["kl_coeff"] self.local_evaluator = PPOEvaluator( - self.registry, self.env_creator, self.config, self.logdir, False) + self.registry, self.env_creator, self.config, self.logdir, False + ) RemotePPOEvaluator = ray.remote( num_cpus=self.config["num_cpus_per_worker"], - num_gpus=self.config["num_gpus_per_worker"])(PPOEvaluator) + num_gpus=self.config["num_gpus_per_worker"] + )( + PPOEvaluator + ) self.remote_evaluators = [ RemotePPOEvaluator.remote( - self.registry, self.env_creator, self.config, self.logdir, - True) - for _ in range(self.config["num_workers"])] + self.registry, self.env_creator, self.config, self.logdir, True + ) for _ in range(self.config["num_workers"]) + ] self.start_time = time.time() if self.config["write_logs"]: self.file_writer = tf.summary.FileWriter( - self.logdir, self.local_evaluator.sess.graph) + self.logdir, self.local_evaluator.sess.graph + ) else: self.file_writer = None self.saver = tf.train.Saver(max_to_keep=None) @@ -130,13 +139,16 @@ def _train(self): config = self.config model = self.local_evaluator - if (config["num_workers"] * config["min_steps_per_task"] > - config["timesteps_per_batch"]): + if ( + config["num_workers"] * config["min_steps_per_task"] > + config["timesteps_per_batch"] + ): print( "WARNING: num_workers * min_steps_per_task > " "timesteps_per_batch. This means that the output of some " "tasks will be wasted. Consider decreasing " - "min_steps_per_task or increasing timesteps_per_batch.") + "min_steps_per_task or increasing timesteps_per_batch." + ) print("===> iteration", self.iteration) @@ -153,15 +165,19 @@ def standardized(value): samples.data["advantages"] = standardized(samples["advantages"]) rollouts_end = time.time() - print("Computing policy (iterations=" + str(config["num_sgd_iter"]) + - ", stepsize=" + str(config["sgd_stepsize"]) + "):") + print( + "Computing policy (iterations=" + str(config["num_sgd_iter"]) + + ", stepsize=" + str(config["sgd_stepsize"]) + "):" + ) names = [ - "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"] + "iter", "total loss", "policy loss", "vf loss", "kl", "entropy" + ] print(("{:>15}" * len(names)).format(*names)) samples.shuffle() shuffle_end = time.time() tuples_per_device = model.load_data( - samples, self.iteration == 0 and config["full_trace_data_load"]) + samples, self.iteration == 0 and config["full_trace_data_load"] + ) load_end = time.time() rollouts_time = rollouts_end - iter_start shuffle_time = shuffle_end - rollouts_end @@ -171,7 +187,8 @@ def standardized(value): sgd_start = time.time() batch_index = 0 num_batches = ( - int(tuples_per_device) // int(model.per_device_batch_size)) + int(tuples_per_device) // int(model.per_device_batch_size) + ) loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], [] permutation = np.random.permutation(num_batches) # Prepare to drop into the debugger @@ -179,8 +196,9 @@ def standardized(value): model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess) while batch_index < num_batches: full_trace = ( - i == 0 and self.iteration == 0 and - batch_index == config["full_trace_nth_sgd_batch"]) + i == 0 and self.iteration == 0 + and batch_index == config["full_trace_nth_sgd_batch"] + ) batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \ batch_entropy = model.run_sgd_minibatch( permutation[batch_index] * model.per_device_batch_size, @@ -200,24 +218,31 @@ def standardized(value): sgd_end = time.time() print( "{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format( - i, loss, policy_loss, vf_loss, kl, entropy)) + i, loss, policy_loss, vf_loss, kl, entropy + ) + ) values = [] if i == config["num_sgd_iter"] - 1: metric_prefix = "ppo/sgd/final_iter/" - values.append(tf.Summary.Value( - tag=metric_prefix + "kl_coeff", - simple_value=self.kl_coeff)) + values.append( + tf.Summary.Value( + tag=metric_prefix + "kl_coeff", + simple_value=self.kl_coeff + ) + ) values.extend([ tf.Summary.Value( tag=metric_prefix + "mean_entropy", - simple_value=entropy), + simple_value=entropy + ), tf.Summary.Value( - tag=metric_prefix + "mean_loss", - simple_value=loss), + tag=metric_prefix + "mean_loss", simple_value=loss + ), tf.Summary.Value( - tag=metric_prefix + "mean_kl", - simple_value=kl)]) + tag=metric_prefix + "mean_kl", simple_value=kl + ) + ]) if self.file_writer: sgd_stats = tf.Summary(value=values) self.file_writer.add_summary(sgd_stats, self.global_step) @@ -239,7 +264,8 @@ def standardized(value): } FilterManager.synchronize( - self.local_evaluator.filters, self.remote_evaluators) + self.local_evaluator.filters, self.remote_evaluators + ) res = self._fetch_metrics_from_remote_evaluators() res = res._replace(info=info) return res @@ -247,22 +273,27 @@ def standardized(value): def _fetch_metrics_from_remote_evaluators(self): episode_rewards = [] episode_lengths = [] - metric_lists = [a.get_completed_rollout_metrics.remote() - for a in self.remote_evaluators] + metric_lists = [ + a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators + ] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) avg_reward = ( - np.mean(episode_rewards) if episode_rewards else float('nan')) + np.mean(episode_rewards) if episode_rewards else float('nan') + ) avg_length = ( - np.mean(episode_lengths) if episode_lengths else float('nan')) + np.mean(episode_lengths) if episode_lengths else float('nan') + ) timesteps = np.sum(episode_lengths) if episode_lengths else 0 result = TrainingResult( episode_reward_mean=avg_reward, episode_len_mean=avg_length, - timesteps_this_iter=timesteps) + timesteps_this_iter=timesteps + ) return result @@ -275,14 +306,13 @@ def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, os.path.join(checkpoint_dir, "checkpoint"), - global_step=self.iteration) - agent_state = ray.get( - [a.save.remote() for a in self.remote_evaluators]) + global_step=self.iteration + ) + agent_state = ray.get([a.save.remote() for a in self.remote_evaluators]) extra_data = [ - self.local_evaluator.save(), - self.global_step, - self.kl_coeff, - agent_state] + self.local_evaluator.save(), self.global_step, self.kl_coeff, + agent_state + ] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path @@ -294,9 +324,9 @@ def _restore(self, checkpoint_path): self.kl_coeff = extra_data[2] ray.get([ a.restore.remote(o) - for (a, o) in zip(self.remote_evaluators, extra_data[3])]) + for (a, o) in zip(self.remote_evaluators, extra_data[3]) + ]) def compute_action(self, observation): - observation = self.local_evaluator.obs_filter( - observation, update=False) + observation = self.local_evaluator.obs_filter(observation, update=False) return self.local_evaluator.common_policy.compute(observation)[0] diff --git a/python/ray/rllib/ppo/ppo_evaluator.py b/python/ray/rllib/ppo/ppo_evaluator.py index 434feb094d7e..0f1c18ef3afc 100644 --- a/python/ray/rllib/ppo/ppo_evaluator.py +++ b/python/ray/rllib/ppo/ppo_evaluator.py @@ -42,7 +42,8 @@ def __init__(self, registry, env_creator, config, logdir, is_remote): self.config = config self.logdir = logdir self.env = ModelCatalog.get_preprocessor_as_wrapper( - registry, env_creator(config["env_config"]), config["model"]) + registry, env_creator(config["env_config"]), config["model"] + ) if is_remote: config_proto = tf.ConfigProto() else: @@ -51,112 +52,126 @@ def __init__(self, registry, env_creator, config, logdir, is_remote): if config["tf_debug_inf_or_nan"] and not is_remote: self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) self.sess.add_tensor_filter( - "has_inf_or_nan", tf_debug.has_inf_or_nan) + "has_inf_or_nan", tf_debug.has_inf_or_nan + ) # Defines the training inputs: # The coefficient of the KL penalty. - self.kl_coeff = tf.placeholder( - name="newkl", shape=(), dtype=tf.float32) + self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32) # The input observations. self.observations = tf.placeholder( - tf.float32, shape=(None,) + self.env.observation_space.shape) + tf.float32, shape=(None, ) + self.env.observation_space.shape + ) # Targets of the value function. - self.value_targets = tf.placeholder(tf.float32, shape=(None,)) + self.value_targets = tf.placeholder(tf.float32, shape=(None, )) # Advantage values in the policy gradient estimator. - self.advantages = tf.placeholder(tf.float32, shape=(None,)) + self.advantages = tf.placeholder(tf.float32, shape=(None, )) action_space = self.env.action_space self.actions = ModelCatalog.get_action_placeholder(action_space) self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space) + action_space + ) # Log probabilities from the policy before the policy update. self.prev_logits = tf.placeholder( - tf.float32, shape=(None, self.logit_dim)) + tf.float32, shape=(None, self.logit_dim) + ) # Value function predictions before the policy update. - self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,)) + self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None, )) if is_remote: self.batch_size = config["rollout_batchsize"] self.per_device_batch_size = config["rollout_batchsize"] else: - self.batch_size = int( - config["sgd_batchsize"] / len(devices)) * len(devices) + self.batch_size = int(config["sgd_batchsize"] / len(devices) + ) * len(devices) assert self.batch_size % len(devices) == 0 self.per_device_batch_size = int(self.batch_size / len(devices)) def build_loss(obs, vtargets, advs, acts, plog, pvf_preds): return ProximalPolicyLoss( - self.env.observation_space, self.env.action_space, - obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim, - self.kl_coeff, self.distribution_class, self.config, - self.sess, self.registry) + self.env.observation_space, self.env.action_space, obs, + vtargets, advs, acts, plog, pvf_preds, self.logit_dim, + self.kl_coeff, self.distribution_class, self.config, self.sess, + self.registry + ) self.par_opt = LocalSyncParallelOptimizer( - tf.train.AdamOptimizer(self.config["sgd_stepsize"]), - self.devices, - [self.observations, self.value_targets, self.advantages, - self.actions, self.prev_logits, self.prev_vf_preds], - self.per_device_batch_size, - build_loss, - self.logdir) + tf.train.AdamOptimizer(self.config["sgd_stepsize"]), self.devices, + [ + self.observations, self.value_targets, self.advantages, + self.actions, self.prev_logits, self.prev_vf_preds + ], self.per_device_batch_size, build_loss, self.logdir + ) # Metric ops with tf.name_scope("test_outputs"): policies = self.par_opt.get_device_losses() self.mean_loss = tf.reduce_mean( - tf.stack(values=[ - policy.loss for policy in policies]), 0) + tf.stack(values=[policy.loss for policy in policies]), 0 + ) self.mean_policy_loss = tf.reduce_mean( - tf.stack(values=[ - policy.mean_policy_loss for policy in policies]), 0) + tf.stack( + values=[policy.mean_policy_loss for policy in policies] + ), + 0 + ) self.mean_vf_loss = tf.reduce_mean( - tf.stack(values=[ - policy.mean_vf_loss for policy in policies]), 0) + tf.stack(values=[policy.mean_vf_loss for policy in policies]), + 0 + ) self.mean_kl = tf.reduce_mean( - tf.stack(values=[ - policy.mean_kl for policy in policies]), 0) + tf.stack(values=[policy.mean_kl for policy in policies]), 0 + ) self.mean_entropy = tf.reduce_mean( - tf.stack(values=[ - policy.mean_entropy for policy in policies]), 0) + tf.stack(values=[policy.mean_entropy for policy in policies]), + 0 + ) # References to the model weights self.common_policy = self.par_opt.get_common_loss() self.variables = ray.experimental.TensorFlowVariables( - self.common_policy.loss, self.sess) + self.common_policy.loss, self.sess + ) self.obs_filter = get_filter( - config["observation_filter"], self.env.observation_space.shape) + config["observation_filter"], self.env.observation_space.shape + ) self.rew_filter = MeanStdFilter((), clip=5.0) - self.filters = {"obs_filter": self.obs_filter, - "rew_filter": self.rew_filter} + self.filters = { + "obs_filter": self.obs_filter, + "rew_filter": self.rew_filter + } self.sampler = SyncSampler( self.env, self.common_policy, self.obs_filter, - self.config["horizon"], self.config["horizon"]) + self.config["horizon"], self.config["horizon"] + ) self.sess.run(tf.global_variables_initializer()) def load_data(self, trajectories, full_trace): use_gae = self.config["use_gae"] dummy = np.zeros_like(trajectories["advantages"]) return self.par_opt.load_data( - self.sess, - [trajectories["obs"], - trajectories["value_targets"] if use_gae else dummy, - trajectories["advantages"], - trajectories["actions"], - trajectories["logprobs"], - trajectories["vf_preds"] if use_gae else dummy], - full_trace=full_trace) - - def run_sgd_minibatch( - self, batch_index, kl_coeff, full_trace, file_writer): + self.sess, [ + trajectories["obs"], trajectories["value_targets"] if use_gae + else dummy, trajectories["advantages"], trajectories["actions"], + trajectories["logprobs"], trajectories["vf_preds"] + if use_gae else dummy + ], + full_trace=full_trace + ) + + def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer): return self.par_opt.optimize( self.sess, batch_index, extra_ops=[ self.mean_loss, self.mean_policy_loss, self.mean_vf_loss, - self.mean_kl, self.mean_entropy], + self.mean_kl, self.mean_entropy + ], extra_feed_dict={self.kl_coeff: kl_coeff}, - file_writer=file_writer if full_trace else None) + file_writer=file_writer if full_trace else None + ) def compute_gradients(self, samples): raise NotImplementedError @@ -191,8 +206,12 @@ def sample(self): while num_steps_so_far < self.config["min_steps_per_task"]: rollout = self.sampler.get_data() samples = process_rollout( - rollout, self.rew_filter, self.config["gamma"], - self.config["lambda"], use_gae=self.config["use_gae"]) + rollout, + self.rew_filter, + self.config["gamma"], + self.config["lambda"], + use_gae=self.config["use_gae"] + ) num_steps_so_far += samples.count all_samples.append(samples) return SampleBatch.concat_samples(all_samples) diff --git a/python/ray/rllib/ppo/test/test.py b/python/ray/rllib/ppo/test/test.py index 6ab59af9348c..f9bcdc276461 100644 --- a/python/ray/rllib/ppo/test/test.py +++ b/python/ray/rllib/ppo/test/test.py @@ -13,7 +13,6 @@ # TODO(ekl): move to rllib/models dir class DistributionsTest(unittest.TestCase): - def testCategorical(self): num_samples = 100000 logits = tf.placeholder(tf.float32, shape=(None, 10)) @@ -32,10 +31,11 @@ def testCategorical(self): class UtilsTest(unittest.TestCase): - def testFlatten(self): - d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]), - "a": np.array([[[5], [-5]], [[6], [-6]]])} + d = { + "s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]), + "a": np.array([[[5], [-5]], [[6], [-6]]]) + } flat = flatten(d.copy(), start=0, stop=2) assert_allclose(d["s"][0][0][:], flat["s"][0][:]) assert_allclose(d["s"][0][1][:], flat["s"][1][:]) diff --git a/python/ray/rllib/ppo/utils.py b/python/ray/rllib/ppo/utils.py index 5e8ac5a3ab44..e97dce5cf12e 100644 --- a/python/ray/rllib/ppo/utils.py +++ b/python/ray/rllib/ppo/utils.py @@ -16,7 +16,7 @@ def flatten(weights, start=0, stop=2): stop: The ending index. """ for key, val in weights.items(): - new_shape = val.shape[0:start] + (-1,) + val.shape[stop:] + new_shape = val.shape[0:start] + (-1, ) + val.shape[stop:] weights[key] = val.reshape(new_shape) return weights diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 64174866aa19..d13a137eca99 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -25,30 +25,44 @@ parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="Roll out a reinforcement learning agent " - "given a checkpoint.", epilog=EXAMPLE_USAGE) + "given a checkpoint.", + epilog=EXAMPLE_USAGE +) parser.add_argument( - "checkpoint", type=str, help="Checkpoint from which to roll out.") + "checkpoint", type=str, help="Checkpoint from which to roll out." +) required_named = parser.add_argument_group("required named arguments") required_named.add_argument( - "--run", type=str, required=True, + "--run", + type=str, + required=True, help="The algorithm or model to train. This may refer to the name " - "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a " - "user-defined trainable function or class registered in the " - "tune registry.") + "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a " + "user-defined trainable function or class registered in the " + "tune registry." +) required_named.add_argument( - "--env", type=str, help="The gym environment to use.") + "--env", type=str, help="The gym environment to use." +) parser.add_argument( - "--no-render", default=False, action="store_const", const=True, - help="Surpress rendering of the environment.") + "--no-render", + default=False, + action="store_const", + const=True, + help="Surpress rendering of the environment." +) parser.add_argument( - "--steps", default=None, help="Number of steps to roll out.") + "--steps", default=None, help="Number of steps to roll out." +) +parser.add_argument("--out", default=None, help="Output filename.") parser.add_argument( - "--out", default=None, help="Output filename.") -parser.add_argument( - "--config", default="{}", type=json.loads, + "--config", + default="{}", + type=json.loads, help="Algorithm-specific configuration (e.g. env, hyperparams). " - "Surpresses loading of configuration from checkpoint.") + "Surpresses loading of configuration from checkpoint." +) if __name__ == "__main__": args = parser.parse_args() @@ -76,8 +90,9 @@ env = gym.make(args.env) env = wrap_dqn(get_registry(), env, args.config.get("model", {})) else: - env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(), - gym.make(args.env)) + env = ModelCatalog.get_preprocessor_as_wrapper( + get_registry(), gym.make(args.env) + ) if args.out is not None: rollouts = [] steps = 0 diff --git a/python/ray/rllib/test/mock_evaluator.py b/python/ray/rllib/test/mock_evaluator.py index 4762bb877d5b..c4e36997878c 100644 --- a/python/ray/rllib/test/mock_evaluator.py +++ b/python/ray/rllib/test/mock_evaluator.py @@ -15,16 +15,18 @@ def __init__(self, sample_count=10): self._sample_count = sample_count self.obs_filter = MeanStdFilter(()) self.rew_filter = MeanStdFilter(()) - self.filters = {"obs_filter": self.obs_filter, - "rew_filter": self.rew_filter} + self.filters = { + "obs_filter": self.obs_filter, + "rew_filter": self.rew_filter + } def sample(self): samples_dict = {"observations": [], "rewards": []} for i in range(self._sample_count): samples_dict["observations"].append( - self.obs_filter(np.random.randn())) - samples_dict["rewards"].append( - self.rew_filter(np.random.randn())) + self.obs_filter(np.random.randn()) + ) + samples_dict["rewards"].append(self.rew_filter(np.random.randn())) return SampleBatch(samples_dict) def compute_gradients(self, samples): diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index c5e503b717ee..2d306922c80a 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -10,7 +10,8 @@ from ray.rllib.models import ModelCatalog from ray.rllib.models.model import Model from ray.rllib.models.preprocessors import ( - NoPreprocessor, OneHotPreprocessor, Preprocessor) + NoPreprocessor, OneHotPreprocessor, Preprocessor +) from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork @@ -34,11 +35,13 @@ def tearDown(self): def testGymPreprocessors(self): p1 = ModelCatalog.get_preprocessor( - get_registry(), gym.make("CartPole-v0")) + get_registry(), gym.make("CartPole-v0") + ) self.assertEqual(type(p1), NoPreprocessor) p2 = ModelCatalog.get_preprocessor( - get_registry(), gym.make("FrozenLake-v0")) + get_registry(), gym.make("FrozenLake-v0") + ) self.assertEqual(type(p2), OneHotPreprocessor) def testTuplePreprocessor(self): @@ -46,14 +49,17 @@ def testTuplePreprocessor(self): class TupleEnv(object): def __init__(self): - self.observation_space = Tuple( - [Discrete(5), Box(0, 1, shape=(3,), dtype=np.float32)]) - p1 = ModelCatalog.get_preprocessor( - get_registry(), TupleEnv()) - self.assertEqual(p1.shape, (8,)) + self.observation_space = Tuple([ + Discrete(5), + Box(0, 1, shape=(3, ), dtype=np.float32) + ]) + + p1 = ModelCatalog.get_preprocessor(get_registry(), TupleEnv()) + self.assertEqual(p1.shape, (8, )) self.assertEqual( list(p1.transform((0, [1, 2, 3]))), - [float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]]) + [float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]] + ) def testCustomPreprocessor(self): ray.init() @@ -61,10 +67,12 @@ def testCustomPreprocessor(self): ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2) env = gym.make("CartPole-v0") p1 = ModelCatalog.get_preprocessor( - get_registry(), env, {"custom_preprocessor": "foo"}) + get_registry(), env, {"custom_preprocessor": "foo"} + ) self.assertEqual(str(type(p1)), str(CustomPreprocessor)) p2 = ModelCatalog.get_preprocessor( - get_registry(), env, {"custom_preprocessor": "bar"}) + get_registry(), env, {"custom_preprocessor": "bar"} + ) self.assertEqual(str(type(p2)), str(CustomPreprocessor2)) p3 = ModelCatalog.get_preprocessor(get_registry(), env) self.assertEqual(type(p3), NoPreprocessor) @@ -74,19 +82,22 @@ def testDefaultModels(self): with tf.variable_scope("test1"): p1 = ModelCatalog.get_model( - get_registry(), np.zeros((10, 3), dtype=np.float32), 5) + get_registry(), np.zeros((10, 3), dtype=np.float32), 5 + ) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model( - get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5) + get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5 + ) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): ray.init() ModelCatalog.register_custom_model("foo", CustomModel) p1 = ModelCatalog.get_model( - get_registry(), 1, 5, {"custom_model": "foo"}) + get_registry(), 1, 5, {"custom_model": "foo"} + ) self.assertEqual(str(type(p1)), str(CustomModel)) diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index 9e583c877bb9..6a7f73dd0772 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -20,11 +20,21 @@ def get_mean_action(alg, obs): ray.init() CONFIGS = { - "ES": {"episodes_per_batch": 10, "timesteps_per_batch": 100}, + "ES": { + "episodes_per_batch": 10, + "timesteps_per_batch": 100 + }, "DQN": {}, - "DDPG": {"noise_scale": 0.0}, - "PPO": {"num_sgd_iter": 5, "timesteps_per_batch": 1000}, - "A3C": {"use_lstm": False}, + "DDPG": { + "noise_scale": 0.0 + }, + "PPO": { + "num_sgd_iter": 5, + "timesteps_per_batch": 1000 + }, + "A3C": { + "use_lstm": False + }, } diff --git a/python/ray/rllib/test/test_evaluators.py b/python/ray/rllib/test/test_evaluators.py index 29c054a0d418..c9644747ec63 100644 --- a/python/ray/rllib/test/test_evaluators.py +++ b/python/ray/rllib/test/test_evaluators.py @@ -31,7 +31,6 @@ def testNStep(self): class A3CEvaluatorTest(unittest.TestCase): - def setUp(self): ray.init(num_cpus=1) config = DEFAULT_CONFIG.copy() @@ -44,7 +43,8 @@ def setUp(self): get_registry(), lambda config: gym.make("CartPole-v0"), config, - logdir=self._temp_dir) + logdir=self._temp_dir + ) def tearDown(self): ray.worker.cleanup() diff --git a/python/ray/rllib/test/test_filters.py b/python/ray/rllib/test/test_filters.py index 1147c1768c81..eac7f5cf29f9 100644 --- a/python/ray/rllib/test/test_filters.py +++ b/python/ray/rllib/test/test_filters.py @@ -13,7 +13,7 @@ class RunningStatTest(unittest.TestCase): def testRunningStat(self): - for shp in ((), (3,), (3, 4)): + for shp in ((), (3, ), (3, 4)): li = [] rs = RunningStat(shp) for _ in range(5): @@ -22,12 +22,14 @@ def testRunningStat(self): li.append(val) m = np.mean(li, axis=0) self.assertTrue(np.allclose(rs.mean, m)) - v = (np.square(m) if (len(li) == 1) - else np.var(li, ddof=1, axis=0)) + v = ( + np.square(m) + if (len(li) == 1) else np.var(li, ddof=1, axis=0) + ) self.assertTrue(np.allclose(rs.var, v)) def testCombiningStat(self): - for shape in [(), (3,), (3, 4)]: + for shape in [(), (3, ), (3, 4)]: li = [] rs1 = RunningStat(shape) rs2 = RunningStat(shape) @@ -48,7 +50,7 @@ def testCombiningStat(self): class MSFTest(unittest.TestCase): def testBasic(self): - for shape in [(), (3,), (3, 4, 4)]: + for shape in [(), (3, ), (3, 4, 4)]: filt = MeanStdFilter(shape) for i in range(5): filt(np.ones(shape)) @@ -93,8 +95,11 @@ def testSynchronize(self): remote_e = RemoteEvaluator.remote(sample_count=10) remote_e.sample.remote() - FilterManager.synchronize( - {"obs_filter": filt1, "rew_filter": filt1.copy()}, [remote_e]) + FilterManager.synchronize({ + "obs_filter": filt1, + "rew_filter": filt1.copy() + }, + [remote_e]) filters = ray.get(remote_e.get_filters.remote()) obs_f = filters["obs_filter"] diff --git a/python/ray/rllib/test/test_optimizers.py b/python/ray/rllib/test/test_optimizers.py index cfb606101db9..496935499cbe 100644 --- a/python/ray/rllib/test/test_optimizers.py +++ b/python/ray/rllib/test/test_optimizers.py @@ -12,7 +12,6 @@ class AsyncOptimizerTest(unittest.TestCase): - def tearDown(self): ray.worker.cleanup() @@ -21,8 +20,9 @@ def testBasic(self): local = _MockEvaluator() remotes = ray.remote(_MockEvaluator) remote_evaluators = [remotes.remote() for i in range(5)] - test_optimizer = AsyncOptimizer( - {"grads_per_step": 10}, local, remote_evaluators) + test_optimizer = AsyncOptimizer({ + "grads_per_step": 10 + }, local, remote_evaluators) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0)) diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 2e41c85a0233..67c05881f50d 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -12,28 +12,39 @@ from ray.tune.registry import register_env ACTION_SPACES_TO_TEST = { - "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5,), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5,), dtype=np.float32), - Box(0.0, 1.0, (5,), dtype=np.float32)]), + "discrete": + Discrete(5), + "vector": + Box(0.0, 1.0, (5, ), dtype=np.float32), + "simple_tuple": + Tuple([ + Box(0.0, 1.0, (5, ), dtype=np.float32), + Box(0.0, 1.0, (5, ), dtype=np.float32) + ]), "implicit_tuple": [ - Box(0.0, 1.0, (5,), dtype=np.float32), - Box(0.0, 1.0, (5,), dtype=np.float32)], + Box(0.0, 1.0, (5, ), dtype=np.float32), + Box(0.0, 1.0, (5, ), dtype=np.float32) + ], } OBSERVATION_SPACES_TO_TEST = { - "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5,), dtype=np.float32), - "image": Box(0.0, 1.0, (80, 80, 1), dtype=np.float32), - "atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32), - "atari_ram": Box(0.0, 1.0, (128,), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5,), dtype=np.float32), - Box(0.0, 1.0, (5,), dtype=np.float32)]), - "mixed_tuple": Tuple([ - Discrete(10), - Box(0.0, 1.0, (5,), dtype=np.float32)]), + "discrete": + Discrete(5), + "vector": + Box(0.0, 1.0, (5, ), dtype=np.float32), + "image": + Box(0.0, 1.0, (80, 80, 1), dtype=np.float32), + "atari": + Box(0.0, 1.0, (210, 160, 3), dtype=np.float32), + "atari_ram": + Box(0.0, 1.0, (128, ), dtype=np.float32), + "simple_tuple": + Tuple([ + Box(0.0, 1.0, (5, ), dtype=np.float32), + Box(0.0, 1.0, (5, ), dtype=np.float32) + ]), + "mixed_tuple": + Tuple([Discrete(10), Box(0.0, 1.0, (5, ), dtype=np.float32)]), } # (alg, action_space, obs_space) @@ -85,8 +96,7 @@ def check_support(alg, config, stats): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing", alg, action_space, obs_space, "===") stub_env = make_stub_env(action_space, obs_space) - register_env( - "stub_env", lambda c: stub_env()) + register_env("stub_env", lambda c: stub_env()) stat = "ok" a = None try: @@ -117,23 +127,32 @@ def testAll(self): check_support("DDPG", {"timesteps_per_iteration": 1}, stats) check_support("DQN", {"timesteps_per_iteration": 1}, stats) check_support( - "A3C", {"num_workers": 1, "optimizer": {"grads_per_step": 1}}, - stats) + "A3C", { + "num_workers": 1, + "optimizer": { + "grads_per_step": 1 + } + }, stats + ) check_support( - "PPO", - {"num_workers": 1, "num_sgd_iter": 1, "timesteps_per_batch": 1, - "devices": ["/cpu:0"], "min_steps_per_task": 1, - "sgd_batchsize": 1}, - stats) + "PPO", { + "num_workers": 1, + "num_sgd_iter": 1, + "timesteps_per_batch": 1, + "devices": ["/cpu:0"], + "min_steps_per_task": 1, + "sgd_batchsize": 1 + }, stats + ) check_support( - "ES", - {"num_workers": 1, "noise_size": 10000000, - "episodes_per_batch": 1, "timesteps_per_batch": 1}, - stats) - check_support( - "PG", - {"num_workers": 1, "optimizer": {}}, - stats) + "ES", { + "num_workers": 1, + "noise_size": 10000000, + "episodes_per_batch": 1, + "timesteps_per_batch": 1 + }, stats + ) + check_support("PG", {"num_workers": 1, "optimizer": {}}, stats) num_unexpected_errors = 0 num_unexpected_success = 0 for (alg, a_name, o_name), stat in sorted(stats.items()): @@ -144,8 +163,8 @@ def testAll(self): if (alg, a_name, o_name) not in KNOWN_FAILURES: num_unexpected_errors += 1 print( - alg, "action_space", a_name, "obs_space", o_name, - "result", stat) + alg, "action_space", a_name, "obs_space", o_name, "result", stat + ) self.assertEqual(num_unexpected_errors, 0) self.assertEqual(num_unexpected_success, 0) diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 41d7771f0dd7..f0e6f44ccd1f 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -12,7 +12,6 @@ from ray.tune.config_parser import make_parser, resources_to_json from ray.tune.tune import _make_scheduler, run_experiments - EXAMPLE_USAGE = """ Training example: ./train.py --run DQN --env CartPole-v0 @@ -23,38 +22,57 @@ Note that -f overrides all other trial-specific command-line options. """ - parser = make_parser( formatter_class=argparse.RawDescriptionHelpFormatter, description="Train a reinforcement learning agent.", - epilog=EXAMPLE_USAGE) + epilog=EXAMPLE_USAGE +) # See also the base parser definition in ray/tune/config_parser.py parser.add_argument( - "--redis-address", default=None, type=str, - help="The Redis address of the cluster.") + "--redis-address", + default=None, + type=str, + help="The Redis address of the cluster." +) parser.add_argument( - "--ray-num-cpus", default=None, type=int, - help="--num-cpus to pass to Ray. This only has an affect in local mode.") + "--ray-num-cpus", + default=None, + type=int, + help="--num-cpus to pass to Ray. This only has an affect in local mode." +) parser.add_argument( - "--ray-num-gpus", default=None, type=int, - help="--num-gpus to pass to Ray. This only has an affect in local mode.") + "--ray-num-gpus", + default=None, + type=int, + help="--num-gpus to pass to Ray. This only has an affect in local mode." +) parser.add_argument( - "--experiment-name", default="default", type=str, - help="Name of the subdirectory under `local_dir` to put results in.") + "--experiment-name", + default="default", + type=str, + help="Name of the subdirectory under `local_dir` to put results in." +) parser.add_argument( - "--env", default=None, type=str, help="The gym environment to use.") + "--env", default=None, type=str, help="The gym environment to use." +) parser.add_argument( - "--queue-trials", action='store_true', + "--queue-trials", + action='store_true', help=( "Whether to queue trials when the cluster does not currently have " "enough resources to launch one. This should be set to True when " - "running on an autoscaling cluster to enable automatic scale-up.")) + "running on an autoscaling cluster to enable automatic scale-up." + ) +) parser.add_argument( - "-f", "--config-file", default=None, type=str, + "-f", + "--config-file", + default=None, + type=str, help="If specified, use config options from this file. Note that this " - "overrides any trial-specific options set via flags above.") - + "overrides any trial-specific options set via flags above." +) if __name__ == "__main__": args = parser.parse_args(sys.argv[1:]) @@ -87,7 +105,11 @@ ray.init( redis_address=args.redis_address, - num_cpus=args.ray_num_cpus, num_gpus=args.ray_num_gpus) + num_cpus=args.ray_num_cpus, + num_gpus=args.ray_num_gpus + ) run_experiments( - experiments, scheduler=_make_scheduler(args), - queue_trials=args.queue_trials) + experiments, + scheduler=_make_scheduler(args), + queue_trials=args.queue_trials + ) diff --git a/python/ray/rllib/tuned_examples/run_regression_tests.py b/python/ray/rllib/tuned_examples/run_regression_tests.py index 3bb7d52248d3..934235fe91c4 100755 --- a/python/ray/rllib/tuned_examples/run_regression_tests.py +++ b/python/ray/rllib/tuned_examples/run_regression_tests.py @@ -8,7 +8,6 @@ import ray from ray.tune import run_experiments - if __name__ == '__main__': experiments = {} @@ -24,10 +23,11 @@ num_failures = 0 for t in trials: - if (t.last_result.episode_reward_mean < - t.stopping_criterion["episode_reward_mean"]): + if ( + t.last_result.episode_reward_mean < + t.stopping_criterion["episode_reward_mean"] + ): num_failures += 1 if num_failures: - raise Exception( - "{} trials did not converge".format(num_failures)) + raise Exception("{} trials did not converge".format(num_failures)) diff --git a/python/ray/rllib/utils/atari_wrappers.py b/python/ray/rllib/utils/atari_wrappers.py index ac2ebb7050f5..1ae56372900c 100644 --- a/python/ray/rllib/utils/atari_wrappers.py +++ b/python/ray/rllib/utils/atari_wrappers.py @@ -24,8 +24,7 @@ def reset(self, **kwargs): if self.override_num_noops is not None: noops = self.override_num_noops else: - noops = self.unwrapped.np_random.randint( - 1, self.noop_max + 1) + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) assert noops > 0 obs = None for _ in range(noops): @@ -116,8 +115,8 @@ def __init__(self, env, skip=4): """Return only every `skip`-th frame""" gym.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) - self._obs_buffer = np.zeros( - (2,)+env.observation_space.shape, dtype=np.uint8) + self._obs_buffer = np.zeros((2, ) + env.observation_space.shape, + dtype=np.uint8) self._skip = skip def step(self, action): @@ -150,12 +149,14 @@ def __init__(self, env, dim): self.width = dim # in rllib we use 80 self.height = dim self.observation_space = spaces.Box( - low=0, high=255, shape=(self.height, self.width, 1)) + low=0, high=255, shape=(self.height, self.width, 1) + ) def observation(self, frame): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize( - frame, (self.width, self.height), interpolation=cv2.INTER_AREA) + frame, (self.width, self.height), interpolation=cv2.INTER_AREA + ) return frame[:, :, None] @@ -167,7 +168,8 @@ def __init__(self, env, k): self.frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = spaces.Box( - low=0, high=255, shape=(shp[0], shp[1], shp[2] * k)) + low=0, high=255, shape=(shp[0], shp[1], shp[2] * k) + ) def reset(self): ob = self.env.reset() diff --git a/python/ray/rllib/utils/compression.py b/python/ray/rllib/utils/compression.py index dee8d875df3c..3032a133fafc 100644 --- a/python/ray/rllib/utils/compression.py +++ b/python/ray/rllib/utils/compression.py @@ -14,7 +14,8 @@ print( "WARNING: lz4 not available, disabling sample compression. " "This will significantly impact RLlib performance. " - "To install lz4, run `pip install lz4`.") + "To install lz4, run `pip install lz4`." + ) LZ4_ENABLED = False diff --git a/python/ray/rllib/utils/filter.py b/python/ray/rllib/utils/filter.py index 6e60b4e5f947..17cc67ba6834 100644 --- a/python/ray/rllib/utils/filter.py +++ b/python/ray/rllib/utils/filter.py @@ -59,7 +59,6 @@ def as_serializable(self): # http://www.johndcook.com/blog/standard_deviation/ class RunningStat(object): - def __init__(self, shape=None): self._n = 0 self._M = np.zeros(shape) @@ -75,8 +74,9 @@ def copy(self): def push(self, x): x = np.asarray(x) # Unvectorized update of the running statistics. - assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}" - .format(x.shape, self._M.shape)) + assert x.shape == self._M.shape, ( + "x.shape = {}, self.shape = {}".format(x.shape, self._M.shape) + ) n1 = self._n self._n += 1 if self._n == 1: @@ -103,7 +103,8 @@ def update(self, other): def __repr__(self): return '(n={}, mean_mean={}, mean_std={})'.format( - self.n, np.mean(self.mean), np.mean(self.std)) + self.n, np.mean(self.mean), np.mean(self.std) + ) @property def n(self): @@ -227,8 +228,8 @@ def __call__(self, x, update=True): def __repr__(self): return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format( - self.shape, self.demean, self.destd, - self.clip, self.rs, self.buffer) + self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer + ) class ConcurrentMeanStdFilter(MeanStdFilter): @@ -242,6 +243,7 @@ def lock_wrap(func): def wrapper(*args, **kwargs): with self._lock: return func(*args, **kwargs) + return wrapper self.__getattribute__ = lock_wrap(self.__getattribute__) @@ -260,8 +262,8 @@ def copy(self): def __repr__(self): return 'ConcurrentMeanStdFilter({}, {}, {}, {}, {}, {})'.format( - self.shape, self.demean, self.destd, - self.clip, self.rs, self.buffer) + self.shape, self.demean, self.destd, self.clip, self.rs, self.buffer + ) def get_filter(filter_config, shape): @@ -273,5 +275,4 @@ def get_filter(filter_config, shape): elif filter_config == "NoFilter": return NoFilter() else: - raise Exception("Unknown observation_filter: " + - str(filter_config)) + raise Exception("Unknown observation_filter: " + str(filter_config)) diff --git a/python/ray/rllib/utils/filter_manager.py b/python/ray/rllib/utils/filter_manager.py index 98b0471e95f8..e9ad5be02d28 100644 --- a/python/ray/rllib/utils/filter_manager.py +++ b/python/ray/rllib/utils/filter_manager.py @@ -20,8 +20,9 @@ def synchronize(local_filters, remotes): local_filters (dict): Filters to be synchronized. remotes (list): Remote evaluators with filters. """ - remote_filters = ray.get( - [r.get_filters.remote(flush_after=True) for r in remotes]) + remote_filters = ray.get([ + r.get_filters.remote(flush_after=True) for r in remotes + ]) for rf in remote_filters: for k in local_filters: local_filters[k].apply_changes(rf[k], with_buffer=False) diff --git a/python/ray/rllib/utils/process_rollout.py b/python/ray/rllib/utils/process_rollout.py index 2232135780d2..103638afd6be 100644 --- a/python/ray/rllib/utils/process_rollout.py +++ b/python/ray/rllib/utils/process_rollout.py @@ -33,7 +33,8 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True): if use_gae: assert "vf_preds" in rollout.data, "Values not found!" vpred_t = np.stack( - rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze() + rollout.data["vf_preds"] + [np.array(rollout.last_r)] + ).squeeze() delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1] # This formula for the advantage comes # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 @@ -41,7 +42,8 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True): traj["value_targets"] = traj["advantages"] + traj["vf_preds"] else: rewards_plus_v = np.stack( - rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze() + rollout.data["rewards"] + [np.array(rollout.last_r)] + ).squeeze() traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] for i in range(traj["advantages"].shape[0]): diff --git a/python/ray/rllib/utils/reshaper.py b/python/ray/rllib/utils/reshaper.py index c0687b488036..3b6f74e5b3f9 100644 --- a/python/ray/rllib/utils/reshaper.py +++ b/python/ray/rllib/utils/reshaper.py @@ -7,6 +7,7 @@ class Reshaper(object): This class keeps track of where in the flattened observation space we should be slicing and what the new shapes should be """ + def __init__(self, env_space): self.shapes = [] self.slice_positions = [] @@ -24,8 +25,9 @@ def __init__(self, env_space): if len(self.slice_positions) == 0: self.slice_positions.append(np.product(arr_shape)) else: - self.slice_positions.append(np.product(arr_shape) + - self.slice_positions[-1]) + self.slice_positions.append( + np.product(arr_shape) + self.slice_positions[-1] + ) else: self.shapes.append(np.asarray(env_space.shape)) self.slice_positions.append(np.product(env_space.shape)) @@ -38,11 +40,14 @@ def get_slice_lengths(self): def split_tensor(self, tensor, axis=-1): # FIXME (ev) This won't work for mixed action distributions like # one agent Gaussian one agent discrete - slice_rescale = int(tensor.shape.as_list()[axis] / - int(np.sum(self.get_slice_lengths()))) - return tf.split(tensor, slice_rescale*self.get_slice_lengths(), - axis=axis) + slice_rescale = int( + tensor.shape.as_list()[axis] / + int(np.sum(self.get_slice_lengths())) + ) + return tf.split( + tensor, slice_rescale * self.get_slice_lengths(), axis=axis + ) def split_number(self, number): slice_rescale = int(number / int(np.sum(self.get_slice_lengths()))) - return slice_rescale*self.get_slice_lengths() + return slice_rescale * self.get_slice_lengths() diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index c522806d6699..aac3ff1f80f5 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -58,7 +58,8 @@ def is_terminal(self): CompletedRollout = namedtuple( - "CompletedRollout", ["episode_length", "episode_reward"]) + "CompletedRollout", ["episode_length", "episode_reward"] +) class SyncSampler(object): @@ -71,8 +72,7 @@ class SyncSampler(object): thread.""" async = False - def __init__(self, env, policy, obs_filter, - num_local_steps, horizon=None): + def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None): self.num_local_steps = num_local_steps self.horizon = horizon self.env = env @@ -80,7 +80,8 @@ def __init__(self, env, policy, obs_filter, self._obs_filter = obs_filter self.rollout_provider = _env_runner( self.env, self.policy, self.num_local_steps, self.horizon, - self._obs_filter) + self._obs_filter + ) self.metrics_queue = queue.Queue() def get_data(self): @@ -108,10 +109,10 @@ class AsyncSampler(threading.Thread): accumulate and the gradient can be calculated on up to 5 batches.""" async = True - def __init__(self, env, policy, obs_filter, - num_local_steps, horizon=None): - assert getattr(obs_filter, "is_concurrent", False), ( - "Observation Filter must support concurrent updates.") + def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None): + assert getattr( + obs_filter, "is_concurrent", False + ), ("Observation Filter must support concurrent updates.") threading.Thread.__init__(self) self.queue = queue.Queue(5) self.metrics_queue = queue.Queue() @@ -133,8 +134,9 @@ def run(self): def _run(self): rollout_provider = _env_runner( - self.env, self.policy, self.num_local_steps, - self.horizon, self._obs_filter) + self.env, self.policy, self.num_local_steps, self.horizon, + self._obs_filter + ) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -232,13 +234,15 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter): action = np.concatenate(action, axis=0).flatten() # Collect the experience. - rollout.add(obs=last_observation, - actions=action, - rewards=reward, - dones=terminal, - features=last_features, - new_obs=observation, - **pi_info) + rollout.add( + obs=last_observation, + actions=action, + rewards=reward, + dones=terminal, + features=last_features, + new_obs=observation, + **pi_info + ) last_observation = observation last_features = features @@ -247,8 +251,10 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter): terminal_end = True yield CompletedRollout(length, rewards) - if (length >= horizon or - not env.metadata.get("semantics.autoreset")): + if ( + length >= horizon + or not env.metadata.get("semantics.autoreset") + ): last_observation = obs_filter(env.reset()) if hasattr(policy, "get_initial_features"): last_features = policy.get_initial_features() diff --git a/python/ray/rllib/utils/window_stat.py b/python/ray/rllib/utils/window_stat.py index ed1d99c46c87..d20c6476c7ce 100644 --- a/python/ray/rllib/utils/window_stat.py +++ b/python/ray/rllib/utils/window_stat.py @@ -23,7 +23,8 @@ def stats(self): quantiles = [] else: quantiles = np.percentile( - self.items[:self.count], [0, 10, 50, 90, 100]).tolist() + self.items[:self.count], [0, 10, 50, 90, 100] + ).tolist() return { self.name + "_count": int(self.count), self.name + "_mean": float(np.mean(self.items[:self.count])), diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d06d4606890c..ddf4a2dab2ab 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -7,8 +7,9 @@ import subprocess import ray.services as services -from ray.autoscaler.commands import (create_or_update_cluster, - teardown_cluster, get_head_node_ip) +from ray.autoscaler.commands import ( + create_or_update_cluster, teardown_cluster, get_head_node_ip +) def check_no_existing_redis_clients(node_ip_address, redis_client): @@ -31,8 +32,10 @@ def check_no_existing_redis_clients(node_ip_address, redis_client): continue if info[b"node_ip_address"].decode("ascii") == node_ip_address: - raise Exception("This Redis instance is already connected to " - "clients with this IP address.") + raise Exception( + "This Redis instance is already connected to " + "clients with this IP address." + ) @click.group() @@ -45,112 +48,138 @@ def cli(): "--node-ip-address", required=False, type=str, - help="the IP address of this node") + help="the IP address of this node" +) @click.option( "--redis-address", required=False, type=str, - help="the address to use for connecting to Redis") + help="the address to use for connecting to Redis" +) @click.option( "--redis-port", required=False, type=str, - help="the port to use for starting Redis") + help="the port to use for starting Redis" +) @click.option( "--num-redis-shards", required=False, type=int, - help=("the number of additional Redis shards to use in " - "addition to the primary Redis shard")) + help=( + "the number of additional Redis shards to use in " + "addition to the primary Redis shard" + ) +) @click.option( "--redis-max-clients", required=False, type=int, - help=("If provided, attempt to configure Redis with this " - "maximum number of clients.")) + help=( + "If provided, attempt to configure Redis with this " + "maximum number of clients." + ) +) @click.option( "--redis-shard-ports", required=False, type=str, help="the port to use for the Redis shards other than the " - "primary Redis shard") + "primary Redis shard" +) @click.option( "--object-manager-port", required=False, type=int, - help="the port to use for starting the object manager") + help="the port to use for starting the object manager" +) @click.option( "--object-store-memory", required=False, type=int, help="the maximum amount of memory (in bytes) to allow the " - "object store to use") + "object store to use" +) @click.option( "--num-workers", required=False, type=int, - help=("The initial number of workers to start on this node, " - "note that the local scheduler may start additional " - "workers. If you wish to control the total number of " - "concurent tasks, then use --resources instead and " - "specify the CPU field.")) + help=( + "The initial number of workers to start on this node, " + "note that the local scheduler may start additional " + "workers. If you wish to control the total number of " + "concurent tasks, then use --resources instead and " + "specify the CPU field." + ) +) @click.option( "--num-cpus", required=False, type=int, - help="the number of CPUs on this node") + help="the number of CPUs on this node" +) @click.option( "--num-gpus", required=False, type=int, - help="the number of GPUs on this node") + help="the number of GPUs on this node" +) @click.option( "--resources", required=False, default="{}", type=str, help="a JSON serialized dictionary mapping resource name to " - "resource quantity") + "resource quantity" +) @click.option( "--head", is_flag=True, default=False, - help="provide this argument for the head node") + help="provide this argument for the head node" +) @click.option( "--no-ui", is_flag=True, default=False, - help="provide this argument if the UI should not be started") + help="provide this argument if the UI should not be started" +) @click.option( "--block", is_flag=True, default=False, - help="provide this argument to block forever in this command") + help="provide this argument to block forever in this command" +) @click.option( "--plasma-directory", required=False, type=str, - help="object store directory for memory mapped files") + help="object store directory for memory mapped files" +) @click.option( "--huge-pages", is_flag=True, default=False, - help="enable support for huge pages in the object store") + help="enable support for huge pages in the object store" +) @click.option( "--autoscaling-config", required=False, type=str, - help="the file that contains the autoscaling config") + help="the file that contains the autoscaling config" +) @click.option( "--use-raylet", is_flag=True, default=False, - help="use the raylet code path, this is not supported yet") -def start(node_ip_address, redis_address, redis_port, num_redis_shards, - redis_max_clients, redis_shard_ports, object_manager_port, - object_store_memory, num_workers, num_cpus, num_gpus, resources, - head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, - use_raylet): + help="use the raylet code path, this is not supported yet" +) +def start( + node_ip_address, redis_address, redis_port, num_redis_shards, + redis_max_clients, redis_shard_ports, object_manager_port, + object_store_memory, num_workers, num_cpus, num_gpus, resources, head, + no_ui, block, plasma_directory, huge_pages, autoscaling_config, use_raylet +): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -160,10 +189,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, try: resources = json.loads(resources) except Exception as e: - raise Exception("Unable to parse the --resources argument using " - "json.loads. Try using a format like\n\n" - " --resources='{\"CustomResource1\": 3, " - "\"CustomReseource2\": 2}'") + raise Exception( + "Unable to parse the --resources argument using " + "json.loads. Try using a format like\n\n" + " --resources='{\"CustomResource1\": 3, " + "\"CustomReseource2\": 2}'" + ) assert "CPU" not in resources, "Use the --num-cpus argument." assert "GPU" not in resources, "Use the --num-gpus argument." @@ -182,16 +213,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, num_redis_shards = len(redis_shard_ports) # Check that the arguments match. if len(redis_shard_ports) != num_redis_shards: - raise Exception("If --redis-shard-ports is provided, it must " - "have the form '6380,6381,6382', and the " - "number of ports provided must equal " - "--num-redis-shards (which is 1 if not " - "provided)") + raise Exception( + "If --redis-shard-ports is provided, it must " + "have the form '6380,6381,6382', and the " + "number of ports provided must equal " + "--num-redis-shards (which is 1 if not " + "provided)" + ) if redis_address is not None: - raise Exception("If --head is passed in, a Redis server will be " - "started, so a Redis address should not be " - "provided.") + raise Exception( + "If --head is passed in, a Redis server will be " + "started, so a Redis address should not be " + "provided." + ) # Get the node IP address if one is not provided. if node_ip_address is None: @@ -222,40 +257,56 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) + use_raylet=use_raylet + ) print(address_info) - print("\nStarted Ray on this node. You can add additional nodes to " - "the cluster by calling\n\n" - " ray start --redis-address {}\n\n" - "from the node you wish to add. You can connect a driver to the " - "cluster from Python by running\n\n" - " import ray\n" - " ray.init(redis_address=\"{}\")\n\n" - "If you have trouble connecting from a different machine, check " - "that your firewall is configured properly. If you wish to " - "terminate the processes that have been started, run\n\n" - " ray stop".format(address_info["redis_address"], - address_info["redis_address"])) + print( + "\nStarted Ray on this node. You can add additional nodes to " + "the cluster by calling\n\n" + " ray start --redis-address {}\n\n" + "from the node you wish to add. You can connect a driver to the " + "cluster from Python by running\n\n" + " import ray\n" + " ray.init(redis_address=\"{}\")\n\n" + "If you have trouble connecting from a different machine, check " + "that your firewall is configured properly. If you wish to " + "terminate the processes that have been started, run\n\n" + " ray stop".format( + address_info["redis_address"], address_info["redis_address"] + ) + ) else: # Start Ray on a non-head node. if redis_port is not None: - raise Exception("If --head is not passed in, --redis-port is not " - "allowed") + raise Exception( + "If --head is not passed in, --redis-port is not " + "allowed" + ) if redis_shard_ports is not None: - raise Exception("If --head is not passed in, --redis-shard-ports " - "is not allowed") + raise Exception( + "If --head is not passed in, --redis-shard-ports " + "is not allowed" + ) if redis_address is None: - raise Exception("If --head is not passed in, --redis-address must " - "be provided.") + raise Exception( + "If --head is not passed in, --redis-address must " + "be provided." + ) if num_redis_shards is not None: - raise Exception("If --head is not passed in, --num-redis-shards " - "must not be provided.") + raise Exception( + "If --head is not passed in, --num-redis-shards " + "must not be provided." + ) if redis_max_clients is not None: - raise Exception("If --head is not passed in, --redis-max-clients " - "must not be provided.") + raise Exception( + "If --head is not passed in, --redis-max-clients " + "must not be provided." + ) if no_ui: - raise Exception("If --head is not passed in, the --no-ui flag is " - "not relevant.") + raise Exception( + "If --head is not passed in, the --no-ui flag is " + "not relevant." + ) redis_ip_address, redis_port = redis_address.split(":") # Wait for the Redis server to be started. And throw an exception if we @@ -289,11 +340,14 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + use_raylet=use_raylet + ) print(address_info) - print("\nStarted Ray on this node. If you wish to terminate the " - "processes that have been started, run\n\n" - " ray stop") + print( + "\nStarted Ray on this node. If you wish to terminate the " + "processes that have been started, run\n\n" + " ray stop" + ) if block: import time @@ -303,54 +357,50 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, @click.command() def stop(): - subprocess.call( - [ - "killall global_scheduler plasma_store plasma_manager " - "local_scheduler raylet raylet_monitor" - ], - shell=True) + subprocess.call([ + "killall global_scheduler plasma_store plasma_manager " + "local_scheduler raylet raylet_monitor" + ], + shell=True) # Find the PID of the monitor process and kill it. - subprocess.call( - [ - "kill $(ps aux | grep monitor.py | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null" - ], - shell=True) + subprocess.call([ + "kill $(ps aux | grep monitor.py | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null" + ], + shell=True) # Find the PID of the Redis process and kill it. - subprocess.call( - [ - "kill $(ps aux | grep redis-server | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null" - ], - shell=True) + subprocess.call([ + "kill $(ps aux | grep redis-server | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null" + ], + shell=True) # Find the PIDs of the worker processes and kill them. - subprocess.call( - [ - "kill -9 $(ps aux | grep default_worker.py | " - "grep -v grep | awk '{ print $2 }') 2> /dev/null" - ], - shell=True) + subprocess.call([ + "kill -9 $(ps aux | grep default_worker.py | " + "grep -v grep | awk '{ print $2 }') 2> /dev/null" + ], + shell=True) # Find the PID of the Ray log monitor process and kill it. - subprocess.call( - [ - "kill $(ps aux | grep log_monitor.py | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null" - ], - shell=True) + subprocess.call([ + "kill $(ps aux | grep log_monitor.py | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null" + ], + shell=True) # Find the PID of the jupyter process and kill it. try: from notebook.notebookapp import list_running_servers pids = [ - str(server["pid"]) for server in list_running_servers() + str(server["pid"]) + for server in list_running_servers() if "/tmp/raylogs" in server["notebook_dir"] ] - subprocess.call( - ["kill {} 2> /dev/null".format(" ".join(pids))], shell=True) + subprocess.call(["kill {} 2> /dev/null".format(" ".join(pids))], + shell=True) except ImportError: pass @@ -361,28 +411,36 @@ def stop(): "--no-restart", is_flag=True, default=False, - help=("Whether to skip restarting Ray services during the update. " - "This avoids interrupting running jobs.")) + help=( + "Whether to skip restarting Ray services during the update. " + "This avoids interrupting running jobs." + ) +) @click.option( "--min-workers", required=False, type=int, - help=("Override the configured min worker node count for the cluster.")) + help=("Override the configured min worker node count for the cluster.") +) @click.option( "--max-workers", required=False, type=int, - help=("Override the configured max worker node count for the cluster.")) + help=("Override the configured max worker node count for the cluster.") +) @click.option( "--yes", "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) -def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, - yes): - create_or_update_cluster(cluster_config_file, min_workers, max_workers, - no_restart, yes) + help=("Don't ask for confirmation.") +) +def create_or_update( + cluster_config_file, min_workers, max_workers, no_restart, yes +): + create_or_update_cluster( + cluster_config_file, min_workers, max_workers, no_restart, yes + ) @click.command() @@ -392,7 +450,8 @@ def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) + help=("Don't ask for confirmation.") +) def teardown(cluster_config_file, yes): teardown_cluster(cluster_config_file, yes) diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 0998888e8128..72b8ea83be42 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -27,32 +27,40 @@ def check_serializable(cls): # This case works. return if not hasattr(cls, "__new__"): - print("The class {} does not have a '__new__' attribute and is " - "probably an old-stye class. Please make it a new-style class " - "by inheriting from 'object'.") - raise RayNotDictionarySerializable("The class {} does not have a " - "'__new__' attribute and is " - "probably an old-style class. We " - "do not support this. Please make " - "it a new-style class by " - "inheriting from 'object'." - .format(cls)) + print( + "The class {} does not have a '__new__' attribute and is " + "probably an old-stye class. Please make it a new-style class " + "by inheriting from 'object'." + ) + raise RayNotDictionarySerializable( + "The class {} does not have a " + "'__new__' attribute and is " + "probably an old-style class. We " + "do not support this. Please make " + "it a new-style class by " + "inheriting from 'object'.".format(cls) + ) try: obj = cls.__new__(cls) except Exception: - raise RayNotDictionarySerializable("The class {} has overridden " - "'__new__', so Ray may not be able " - "to serialize it efficiently." - .format(cls)) + raise RayNotDictionarySerializable( + "The class {} has overridden " + "'__new__', so Ray may not be able " + "to serialize it efficiently.".format(cls) + ) if not hasattr(obj, "__dict__"): - raise RayNotDictionarySerializable("Objects of the class {} do not " - "have a '__dict__' attribute, so " - "Ray cannot serialize it " - "efficiently.".format(cls)) + raise RayNotDictionarySerializable( + "Objects of the class {} do not " + "have a '__dict__' attribute, so " + "Ray cannot serialize it " + "efficiently.".format(cls) + ) if hasattr(obj, "__slots__"): - raise RayNotDictionarySerializable("The class {} uses '__slots__', so " - "Ray may not be able to serialize " - "it efficiently.".format(cls)) + raise RayNotDictionarySerializable( + "The class {} uses '__slots__', so " + "Ray may not be able to serialize " + "it efficiently.".format(cls) + ) def is_named_tuple(cls): diff --git a/python/ray/services.py b/python/ray/services.py index 8cfd8aeee17f..3c5e0828bd9a 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -41,12 +41,13 @@ # important because it determines the order in which these processes will be # terminated when Ray exits, and certain orders will cause errors to be logged # to the screen. -all_processes = OrderedDict( - [(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []), - (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []), - (PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []), - (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []), - (PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], ) +all_processes = OrderedDict([ + (PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []), + (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []), + (PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []), + (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []), + (PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, []) +], ) # True if processes are run in the valgrind profiler. RUN_RAYLET_PROFILER = False @@ -57,36 +58,44 @@ # Location of the redis server and module. REDIS_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/thirdparty/redis/src/redis-server") + "core/src/common/thirdparty/redis/src/redis-server" +) REDIS_MODULE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/redis_module/libray_redis_module.so") + "core/src/common/redis_module/libray_redis_module.so" +) # Location of the credis server and modules. CREDIS_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/credis/redis/src/redis-server") + "core/src/credis/redis/src/redis-server" +) CREDIS_MASTER_MODULE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/credis/build/src/libmaster.so") + "core/src/credis/build/src/libmaster.so" +) CREDIS_MEMBER_MODULE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/credis/build/src/libmember.so") + "core/src/credis/build/src/libmember.so" +) # Location of the raylet executables. RAYLET_MONITOR_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/ray/raylet/raylet_monitor") + "core/src/ray/raylet/raylet_monitor" +) RAYLET_EXECUTABLE = os.path.join( - os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet") + os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet" +) # ObjectStoreAddress tuples contain all information necessary to connect to an # object store. The fields are: # - name: The socket name for the object store # - manager_name: The socket name for the object store manager # - manager_port: The Internet port that the object store manager listens on -ObjectStoreAddress = namedtuple("ObjectStoreAddress", - ["name", "manager_name", "manager_port"]) +ObjectStoreAddress = namedtuple( + "ObjectStoreAddress", ["name", "manager_name", "manager_port"] +) def address(ip_address, port): @@ -128,8 +137,8 @@ def kill_process(p): # The process has already terminated. return True if any([ - RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER, - RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER + RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER, + RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER ]): # Give process signal to write profiler data. os.kill(p.pid, signal.SIGINT) @@ -257,7 +266,8 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): if log_file is not None: redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port + ) # The name of the key storing the list of log filenames for this IP # address. log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) @@ -300,8 +310,11 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): while counter < num_retries: try: # Run some random command and see if it worked. - print("Waiting for redis server at {}:{} to respond...".format( - redis_ip_address, redis_port)) + print( + "Waiting for redis server at {}:{} to respond...".format( + redis_ip_address, redis_port + ) + ) redis_client.client_list() except redis.ConnectionError as e: # Wait a little bit. @@ -311,9 +324,11 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): else: break if counter == num_retries: - raise Exception("Unable to connect to Redis. If the Redis instance is " - "on a different machine, check that your firewall is " - "configured properly.") + raise Exception( + "Unable to connect to Redis. If the Redis instance is " + "on a different machine, check that your firewall is " + "configured properly." + ) def _autodetect_num_gpus(): @@ -378,26 +393,30 @@ def check_version_info(redis_client): version_info = _compute_version_info() if version_info != true_version_info: node_ip_address = ray.services.get_node_ip_address() - error_message = ("Version mismatch: The cluster was started with:\n" - " Ray: " + true_version_info[0] + "\n" - " Python: " + true_version_info[1] + "\n" - " Pyarrow: " + str(true_version_info[2]) + "\n" - "This process on node " + node_ip_address + - " was started with:" + "\n" - " Ray: " + version_info[0] + "\n" - " Python: " + version_info[1] + "\n" - " Pyarrow: " + str(version_info[2])) + error_message = ( + "Version mismatch: The cluster was started with:\n" + " Ray: " + true_version_info[0] + "\n" + " Python: " + true_version_info[1] + "\n" + " Pyarrow: " + str(true_version_info[2]) + "\n" + "This process on node " + node_ip_address + " was started with:" + + "\n" + " Ray: " + version_info[0] + "\n" + " Python: " + version_info[1] + "\n" + " Pyarrow: " + str(version_info[2]) + ) if version_info[:2] != true_version_info[:2]: raise Exception(error_message) else: print(error_message) -def start_credis(node_ip_address, - redis_address, - port=None, - redirect_output=False, - cleanup=True): +def start_credis( + node_ip_address, + redis_address, + port=None, + redirect_output=False, + cleanup=True +): """Start the credis global state store. Credis is a chain replicated reliable redis store. It consists @@ -423,9 +442,7 @@ def start_credis(node_ip_address, """ components = ["credis_master", "credis_head", "credis_tail"] - modules = [ - CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE - ] + modules = [CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE] ports = [] for i, component in enumerate(components): @@ -438,7 +455,8 @@ def start_credis(node_ip_address, stderr_file=stderr_file, cleanup=cleanup, module=modules[i], - executable=CREDIS_EXECUTABLE) + executable=CREDIS_EXECUTABLE + ) ports.append(new_port) @@ -460,14 +478,16 @@ def start_credis(node_ip_address, return credis_address -def start_redis(node_ip_address, - port=None, - redis_shard_ports=None, - num_redis_shards=1, - redis_max_clients=None, - redirect_output=False, - redirect_worker_output=False, - cleanup=True): +def start_redis( + node_ip_address, + port=None, + redis_shard_ports=None, + num_redis_shards=1, + redis_max_clients=None, + redirect_output=False, + redirect_worker_output=False, + cleanup=True +): """Start the Redis global state store. Args: @@ -497,13 +517,16 @@ def start_redis(node_ip_address, addresses for the remaining shards. """ redis_stdout_file, redis_stderr_file = new_log_files( - "redis", redirect_output) + "redis", redirect_output + ) if redis_shard_ports is None: redis_shard_ports = num_redis_shards * [None] elif len(redis_shard_ports) != num_redis_shards: - raise Exception("The number of Redis shard ports does not match the " - "number of Redis shards.") + raise Exception( + "The number of Redis shard ports does not match the " + "number of Redis shards." + ) assigned_port, _ = start_redis_instance( node_ip_address=node_ip_address, @@ -511,7 +534,8 @@ def start_redis(node_ip_address, redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) if port is not None: assert assigned_port == port port = assigned_port @@ -534,14 +558,16 @@ def start_redis(node_ip_address, redis_shards = [] for i in range(num_redis_shards): redis_stdout_file, redis_stderr_file = new_log_files( - "redis-{}".format(i), redirect_output) + "redis-{}".format(i), redirect_output + ) redis_shard_port, _ = start_redis_instance( node_ip_address=node_ip_address, port=redis_shard_ports[i], redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) if redis_shard_ports[i] is not None: assert redis_shard_port == redis_shard_ports[i] shard_address = address(node_ip_address, redis_shard_port) @@ -552,15 +578,17 @@ def start_redis(node_ip_address, return redis_address, redis_shards -def start_redis_instance(node_ip_address="127.0.0.1", - port=None, - redis_max_clients=None, - num_retries=20, - stdout_file=None, - stderr_file=None, - cleanup=True, - executable=REDIS_EXECUTABLE, - module=REDIS_MODULE): +def start_redis_instance( + node_ip_address="127.0.0.1", + port=None, + redis_max_clients=None, + num_retries=20, + stdout_file=None, + stderr_file=None, + cleanup=True, + executable=REDIS_EXECUTABLE, + module=REDIS_MODULE +): """Start a single Redis server. Args: @@ -601,13 +629,12 @@ def start_redis_instance(node_ip_address="127.0.0.1", while counter < num_retries: if counter > 0: print("Redis failed to start, retrying now.") - p = subprocess.Popen( - [ - executable, "--port", - str(port), "--loglevel", "warning", "--loadmodule", module - ], - stdout=stdout_file, - stderr=stderr_file) + p = subprocess.Popen([ + executable, "--port", + str(port), "--loglevel", "warning", "--loadmodule", module + ], + stdout=stdout_file, + stderr=stderr_file) time.sleep(0.1) # Check if Redis successfully started (or at least if it the executable # did not exit within 0.1 seconds). @@ -639,7 +666,8 @@ def start_redis_instance(node_ip_address="127.0.0.1", # We will use this to attempt to raise the maximum number of Redis # clients. current_max_clients = int( - redis_client.config_get("maxclients")["maxclients"]) + redis_client.config_get("maxclients")["maxclients"] + ) # The below command should be the same as doing ulimit -n. ulimit_n = resource.getrlimit(resource.RLIMIT_NOFILE)[0] # The quantity redis_client_buffer appears to be the required buffer @@ -648,33 +676,40 @@ def start_redis_instance(node_ip_address="127.0.0.1", # 10000 - redis_client_buffer. redis_client_buffer = 32 if current_max_clients < ulimit_n - redis_client_buffer: - redis_client.config_set("maxclients", - ulimit_n - redis_client_buffer) + redis_client.config_set( + "maxclients", ulimit_n - redis_client_buffer + ) # Increase the hard and soft limits for the redis client pubsub buffer to # 128MB. This is a hack to make it less likely for pubsub messages to be # dropped and for pubsub connections to therefore be killed. - cur_config = (redis_client.config_get("client-output-buffer-limit")[ - "client-output-buffer-limit"]) + cur_config = ( + redis_client.config_get("client-output-buffer-limit") + ["client-output-buffer-limit"] + ) cur_config_list = cur_config.split() assert len(cur_config_list) == 12 cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"] - redis_client.config_set("client-output-buffer-limit", - " ".join(cur_config_list)) + redis_client.config_set( + "client-output-buffer-limit", " ".join(cur_config_list) + ) # Put a time stamp in Redis to indicate when it was started. redis_client.set("redis_start_time", time.time()) # Record the log files in Redis. record_log_files_in_redis( address(node_ip_address, port), node_ip_address, - [stdout_file, stderr_file]) + [stdout_file, stderr_file] + ) return port, p -def start_log_monitor(redis_address, - node_ip_address, - stdout_file=None, - stderr_file=None, - cleanup=cleanup): +def start_log_monitor( + redis_address, + node_ip_address, + stdout_file=None, + stderr_file=None, + cleanup=cleanup +): """Start a log monitor process. Args: @@ -690,25 +725,29 @@ def start_log_monitor(redis_address, Python process that imported services exits. """ log_monitor_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") - p = subprocess.Popen( - [ - sys.executable, "-u", log_monitor_filepath, "--redis-address", - redis_address, "--node-ip-address", node_ip_address - ], - stdout=stdout_file, - stderr=stderr_file) + os.path.dirname(os.path.abspath(__file__)), "log_monitor.py" + ) + p = subprocess.Popen([ + sys.executable, "-u", log_monitor_filepath, "--redis-address", + redis_address, "--node-ip-address", node_ip_address + ], + stdout=stdout_file, + stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) - - -def start_global_scheduler(redis_address, - node_ip_address, - stdout_file=None, - stderr_file=None, - cleanup=True): + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) + + +def start_global_scheduler( + redis_address, + node_ip_address, + stdout_file=None, + stderr_file=None, + cleanup=True +): """Start a global scheduler process. Args: @@ -727,11 +766,14 @@ def start_global_scheduler(redis_address, redis_address, node_ip_address, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file + ) if cleanup: all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): @@ -749,7 +791,8 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): """ new_env = os.environ.copy() notebook_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb") + os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb" + ) # We copy the notebook file so that the original doesn't get modified by # the user. random_ui_id = random.randint(0, 100000) @@ -786,15 +829,20 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): env=new_env, cwd=new_notebook_directory, stdout=stdout_file, - stderr=stderr_file) + stderr=stderr_file + ) except Exception: - print("Failed to start the UI, you may need to run " - "'pip install jupyter'.") + print( + "Failed to start the UI, you may need to run " + "'pip install jupyter'." + ) else: if cleanup: all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) - webui_url = ("http://localhost:{}/notebooks/ray_ui{}.ipynb?token={}" - .format(port, random_ui_id, token)) + webui_url = ( + "http://localhost:{}/notebooks/ray_ui{}.ipynb?token={}" + .format(port, random_ui_id, token) + ) print("\n" + "=" * 70) print("View the web UI at {}".format(webui_url)) print("=" * 70 + "\n") @@ -823,11 +871,16 @@ def check_and_update_resources(resources): # Check that the number of GPUs that the local scheduler wants doesn't # excede the amount allowed by CUDA_VISIBLE_DEVICES. - if ("GPU" in resources and gpu_ids is not None - and resources["GPU"] > len(gpu_ids)): - raise Exception("Attempting to start local scheduler with {} GPUs, " - "but CUDA_VISIBLE_DEVICES contains {}.".format( - resources["GPU"], gpu_ids)) + if ( + "GPU" in resources and gpu_ids is not None + and resources["GPU"] > len(gpu_ids) + ): + raise Exception( + "Attempting to start local scheduler with {} GPUs, " + "but CUDA_VISIBLE_DEVICES contains {}.".format( + resources["GPU"], gpu_ids + ) + ) if "GPU" not in resources: # Try to automatically detect the number of GPUs. @@ -838,23 +891,27 @@ def check_and_update_resources(resources): # Check types. for _, resource_quantity in resources.items(): - assert (isinstance(resource_quantity, int) - or isinstance(resource_quantity, float)) + assert ( + isinstance(resource_quantity, int) + or isinstance(resource_quantity, float) + ) return resources -def start_local_scheduler(redis_address, - node_ip_address, - plasma_store_name, - plasma_manager_name, - worker_path, - plasma_address=None, - stdout_file=None, - stderr_file=None, - cleanup=True, - resources=None, - num_workers=0): +def start_local_scheduler( + redis_address, + node_ip_address, + plasma_store_name, + plasma_manager_name, + worker_path, + plasma_address=None, + stdout_file=None, + stderr_file=None, + cleanup=True, + resources=None, + num_workers=0 +): """Start a local scheduler process. Args: @@ -884,8 +941,10 @@ def start_local_scheduler(redis_address, """ resources = check_and_update_resources(resources) - print("Starting local scheduler with the following resources: {}." - .format(resources)) + print( + "Starting local scheduler with the following resources: {}." + .format(resources) + ) local_scheduler_name, p = ray.local_scheduler.start_local_scheduler( plasma_store_name, plasma_manager_name, @@ -897,23 +956,28 @@ def start_local_scheduler(redis_address, stdout_file=stdout_file, stderr_file=stderr_file, static_resources=resources, - num_workers=num_workers) + num_workers=num_workers + ) if cleanup: all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) return local_scheduler_name -def start_raylet(redis_address, - node_ip_address, - plasma_store_name, - worker_path, - resources=None, - num_workers=0, - stdout_file=None, - stderr_file=None, - cleanup=True): +def start_raylet( + redis_address, + node_ip_address, + plasma_store_name, + worker_path, + resources=None, + num_workers=0, + stdout_file=None, + stderr_file=None, + cleanup=True +): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -940,21 +1004,24 @@ def start_raylet(redis_address, # Format the resource argument in a form like 'CPU,1.0,GPU,0,Custom,3'. resource_argument = ",".join([ "{},{}".format(resource_name, resource_value) - for resource_name, resource_value in zip(static_resources.keys(), - static_resources.values()) + for resource_name, resource_value in + zip(static_resources.keys(), static_resources.values()) ]) gcs_ip_address, gcs_port = redis_address.split(":") raylet_name = "/tmp/raylet{}".format(random_name()) # Create the command that the Raylet will use to start workers. - start_worker_command = ("{} {} " - "--node-ip-address={} " - "--object-store-name={} " - "--raylet-name={} " - "--redis-address={}".format( - sys.executable, worker_path, node_ip_address, - plasma_store_name, raylet_name, redis_address)) + start_worker_command = ( + "{} {} " + "--node-ip-address={} " + "--object-store-name={} " + "--raylet-name={} " + "--redis-address={}".format( + sys.executable, worker_path, node_ip_address, plasma_store_name, + raylet_name, redis_address + ) + ) command = [ RAYLET_EXECUTABLE, @@ -971,24 +1038,28 @@ def start_raylet(redis_address, if cleanup: all_processes[PROCESS_TYPE_RAYLET].append(pid) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) return raylet_name -def start_objstore(node_ip_address, - redis_address, - object_manager_port=None, - store_stdout_file=None, - store_stderr_file=None, - manager_stdout_file=None, - manager_stderr_file=None, - objstore_memory=None, - cleanup=True, - plasma_directory=None, - huge_pages=False, - use_raylet=False): +def start_objstore( + node_ip_address, + redis_address, + object_manager_port=None, + store_stdout_file=None, + store_stderr_file=None, + manager_stdout_file=None, + manager_stderr_file=None, + objstore_memory=None, + cleanup=True, + plasma_directory=None, + huge_pages=False, + use_raylet=False +): """This method starts an object store process. Args: @@ -1041,12 +1112,14 @@ def start_objstore(node_ip_address, # blocks. shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail if objstore_memory > shm_avail: - print("Warning: Reducing object store memory because " - "/dev/shm has only {} bytes available. You may be " - "able to free up space by deleting files in " - "/dev/shm. If you are inside a Docker container, " - "you may need to pass an argument with the flag " - "'--shm-size' to 'docker run'.".format(shm_avail)) + print( + "Warning: Reducing object store memory because " + "/dev/shm has only {} bytes available. You may be " + "able to free up space by deleting files in " + "/dev/shm. If you are inside a Docker container, " + "you may need to pass an argument with the flag " + "'--shm-size' to 'docker run'.".format(shm_avail) + ) objstore_memory = int(shm_avail * 0.8) finally: os.close(shm_fd) @@ -1059,7 +1132,8 @@ def start_objstore(node_ip_address, stdout_file=store_stdout_file, stderr_file=store_stderr_file, plasma_directory=plasma_directory, - huge_pages=huge_pages) + huge_pages=huge_pages + ) # Start the plasma manager. if not use_raylet: if object_manager_port is not None: @@ -1072,7 +1146,8 @@ def start_objstore(node_ip_address, num_retries=1, run_profiler=RUN_PLASMA_MANAGER_PROFILER, stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) + stderr_file=manager_stderr_file + ) assert plasma_manager_port == object_manager_port else: (plasma_manager_name, p2, @@ -1082,34 +1157,42 @@ def start_objstore(node_ip_address, node_ip_address=node_ip_address, run_profiler=RUN_PLASMA_MANAGER_PROFILER, stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) + stderr_file=manager_stderr_file + ) else: plasma_manager_port = None plasma_manager_name = None if cleanup: all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) - record_log_files_in_redis(redis_address, node_ip_address, - [store_stdout_file, store_stderr_file]) + record_log_files_in_redis( + redis_address, node_ip_address, + [store_stdout_file, store_stderr_file] + ) if not use_raylet: if cleanup: all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) - record_log_files_in_redis(redis_address, node_ip_address, - [manager_stdout_file, manager_stderr_file]) - - return ObjectStoreAddress(plasma_store_name, plasma_manager_name, - plasma_manager_port) - - -def start_worker(node_ip_address, - object_store_name, - object_store_manager_name, - local_scheduler_name, - redis_address, - worker_path, - stdout_file=None, - stderr_file=None, - cleanup=True): + record_log_files_in_redis( + redis_address, node_ip_address, + [manager_stdout_file, manager_stderr_file] + ) + + return ObjectStoreAddress( + plasma_store_name, plasma_manager_name, plasma_manager_port + ) + + +def start_worker( + node_ip_address, + object_store_name, + object_store_manager_name, + local_scheduler_name, + redis_address, + worker_path, + stdout_file=None, + stderr_file=None, + cleanup=True +): """This method starts a worker process. Args: @@ -1141,16 +1224,20 @@ def start_worker(node_ip_address, p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_WORKER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) - - -def start_monitor(redis_address, - node_ip_address, - stdout_file=None, - stderr_file=None, - cleanup=True, - autoscaling_config=None): + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) + + +def start_monitor( + redis_address, + node_ip_address, + stdout_file=None, + stderr_file=None, + cleanup=True, + autoscaling_config=None +): """Run a process to monitor the other processes. Args: @@ -1168,7 +1255,8 @@ def start_monitor(redis_address, autoscaling_config: path to autoscaling config file. """ monitor_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "monitor.py") + os.path.dirname(os.path.abspath(__file__)), "monitor.py" + ) command = [ sys.executable, "-u", monitor_path, "--redis-address=" + str(redis_address) @@ -1178,14 +1266,15 @@ def start_monitor(redis_address, p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, node_ip_address, + [stdout_file, stderr_file] + ) -def start_raylet_monitor(redis_address, - stdout_file=None, - stderr_file=None, - cleanup=True): +def start_raylet_monitor( + redis_address, stdout_file=None, stderr_file=None, cleanup=True +): """Run a process to monitor the other processes. Args: @@ -1206,28 +1295,30 @@ def start_raylet_monitor(redis_address, all_processes[PROCESS_TYPE_MONITOR].append(p) -def start_ray_processes(address_info=None, - node_ip_address="127.0.0.1", - redis_port=None, - redis_shard_ports=None, - num_workers=None, - num_local_schedulers=1, - object_store_memory=None, - num_redis_shards=1, - redis_max_clients=None, - worker_path=None, - cleanup=True, - redirect_worker_output=False, - redirect_output=False, - include_global_scheduler=False, - include_log_monitor=False, - include_webui=False, - start_workers_from_local_scheduler=True, - resources=None, - plasma_directory=None, - huge_pages=False, - autoscaling_config=None, - use_raylet=False): +def start_ray_processes( + address_info=None, + node_ip_address="127.0.0.1", + redis_port=None, + redis_shard_ports=None, + num_workers=None, + num_local_schedulers=1, + object_store_memory=None, + num_redis_shards=1, + redis_max_clients=None, + worker_path=None, + cleanup=True, + redirect_worker_output=False, + redirect_output=False, + include_global_scheduler=False, + include_log_monitor=False, + include_webui=False, + start_workers_from_local_scheduler=True, + resources=None, + plasma_directory=None, + huge_pages=False, + autoscaling_config=None, + use_raylet=False +): """Helper method to start Ray processes. Args: @@ -1300,8 +1391,9 @@ def start_ray_processes(address_info=None, workers_per_local_scheduler = [] for resource_dict in resources: cpus = resource_dict.get("CPU") - workers_per_local_scheduler.append(cpus if cpus is not None else - psutil.cpu_count()) + workers_per_local_scheduler.append( + cpus if cpus is not None else psutil.cpu_count() + ) if address_info is None: address_info = {} @@ -1310,7 +1402,8 @@ def start_ray_processes(address_info=None, if worker_path is None: worker_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), - "workers/default_worker.py") + "workers/default_worker.py" + ) # Start Redis if there isn't already an instance running. TODO(rkn): We are # suppressing the output of Redis because on Linux it prints a bunch of @@ -1327,36 +1420,40 @@ def start_ray_processes(address_info=None, redis_max_clients=redis_max_clients, redirect_output=True, redirect_worker_output=redirect_worker_output, - cleanup=cleanup) + cleanup=cleanup + ) address_info["redis_address"] = redis_address if "RAY_USE_NEW_GCS" in os.environ: credis_address = start_credis( - node_ip_address, redis_address, cleanup=cleanup) + node_ip_address, redis_address, cleanup=cleanup + ) address_info["credis_address"] = credis_address time.sleep(0.1) # Start monitoring the processes. monitor_stdout_file, monitor_stderr_file = new_log_files( - "monitor", redirect_output) + "monitor", redirect_output + ) start_monitor( redis_address, node_ip_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, cleanup=cleanup, - autoscaling_config=autoscaling_config) + autoscaling_config=autoscaling_config + ) if use_raylet: start_raylet_monitor( redis_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) if redis_shards == []: # Get redis shards from primary redis instance. redis_ip_address, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) redis_shards = [shard.decode("ascii") for shard in redis_shards] address_info["redis_shards"] = redis_shards @@ -1364,24 +1461,28 @@ def start_ray_processes(address_info=None, # Start the log monitor, if necessary. if include_log_monitor: log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( - "log_monitor", redirect_output=True) + "log_monitor", redirect_output=True + ) start_log_monitor( redis_address, node_ip_address, stdout_file=log_monitor_stdout_file, stderr_file=log_monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) # Start the global scheduler, if necessary. if include_global_scheduler and not use_raylet: global_scheduler_stdout_file, global_scheduler_stderr_file = ( - new_log_files("global_scheduler", redirect_output)) + new_log_files("global_scheduler", redirect_output) + ) start_global_scheduler( redis_address, node_ip_address, stdout_file=global_scheduler_stdout_file, stderr_file=global_scheduler_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) # Initialize with existing services. if "object_store_addresses" not in address_info: @@ -1395,8 +1496,10 @@ def start_ray_processes(address_info=None, raylet_socket_names = address_info["raylet_socket_names"] # Get the ports to use for the object managers if any are provided. - object_manager_ports = (address_info["object_manager_ports"] if - "object_manager_ports" in address_info else None) + object_manager_ports = ( + address_info["object_manager_ports"] + if "object_manager_ports" in address_info else None + ) if not isinstance(object_manager_ports, list): object_manager_ports = num_local_schedulers * [object_manager_ports] assert len(object_manager_ports) == num_local_schedulers @@ -1405,9 +1508,11 @@ def start_ray_processes(address_info=None, for i in range(num_local_schedulers - len(object_store_addresses)): # Start Plasma. plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( - "plasma_store_{}".format(i), redirect_output) + "plasma_store_{}".format(i), redirect_output + ) plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( - "plasma_manager_{}".format(i), redirect_output) + "plasma_manager_{}".format(i), redirect_output + ) object_store_address = start_objstore( node_ip_address, redis_address, @@ -1420,19 +1525,20 @@ def start_ray_processes(address_info=None, cleanup=cleanup, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + use_raylet=use_raylet + ) object_store_addresses.append(object_store_address) time.sleep(0.1) if not use_raylet: # Start any local schedulers that do not yet exist. - for i in range( - len(local_scheduler_socket_names), num_local_schedulers): + for i in range(len(local_scheduler_socket_names), num_local_schedulers): # Connect the local scheduler to the object store at the same # index. object_store_address = object_store_addresses[i] - plasma_address = "{}:{}".format(node_ip_address, - object_store_address.manager_port) + plasma_address = "{}:{}".format( + node_ip_address, object_store_address.manager_port + ) # Determine how many workers this local scheduler should start. if start_workers_from_local_scheduler: num_local_scheduler_workers = workers_per_local_scheduler[i] @@ -1447,7 +1553,9 @@ def start_ray_processes(address_info=None, local_scheduler_stdout_file, local_scheduler_stderr_file = ( new_log_files( "local_scheduler_{}".format(i), - redirect_output=redirect_worker_output)) + redirect_output=redirect_worker_output + ) + ) local_scheduler_name = start_local_scheduler( redis_address, node_ip_address, @@ -1459,7 +1567,8 @@ def start_ray_processes(address_info=None, stderr_file=local_scheduler_stderr_file, cleanup=cleanup, resources=resources[i], - num_workers=num_local_scheduler_workers) + num_workers=num_local_scheduler_workers + ) local_scheduler_socket_names.append(local_scheduler_name) # Make sure that we have exactly num_local_schedulers instances of @@ -1471,7 +1580,8 @@ def start_ray_processes(address_info=None, # Start any raylets that do not exist yet. for i in range(len(raylet_socket_names), num_local_schedulers): raylet_stdout_file, raylet_stderr_file = new_log_files( - "raylet_{}".format(i), redirect_output=redirect_output) + "raylet_{}".format(i), redirect_output=redirect_output + ) address_info["raylet_socket_names"].append( start_raylet( redis_address, @@ -1482,17 +1592,21 @@ def start_ray_processes(address_info=None, num_workers=workers_per_local_scheduler[i], stdout_file=raylet_stdout_file, stderr_file=raylet_stderr_file, - cleanup=cleanup)) + cleanup=cleanup + ) + ) if not use_raylet: # Start any workers that the local scheduler has not already started. for i, num_local_scheduler_workers in enumerate( - workers_per_local_scheduler): + workers_per_local_scheduler + ): object_store_address = object_store_addresses[i] local_scheduler_name = local_scheduler_socket_names[i] for j in range(num_local_scheduler_workers): worker_stdout_file, worker_stderr_file = new_log_files( - "worker_{}_{}".format(i, j), redirect_output) + "worker_{}_{}".format(i, j), redirect_output + ) start_worker( node_ip_address, object_store_address.name, @@ -1502,7 +1616,8 @@ def start_ray_processes(address_info=None, worker_path, stdout_file=worker_stdout_file, stderr_file=worker_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) workers_per_local_scheduler[i] -= 1 # Make sure that we've started all the workers. @@ -1511,32 +1626,36 @@ def start_ray_processes(address_info=None, # Try to start the web UI. if include_webui: ui_stdout_file, ui_stderr_file = new_log_files( - "webui", redirect_output=True) + "webui", redirect_output=True + ) address_info["webui_url"] = start_ui( redis_address, stdout_file=ui_stdout_file, stderr_file=ui_stderr_file, - cleanup=cleanup) + cleanup=cleanup + ) else: address_info["webui_url"] = "" # Return the addresses of the relevant processes. return address_info -def start_ray_node(node_ip_address, - redis_address, - object_manager_ports=None, - num_workers=0, - num_local_schedulers=1, - object_store_memory=None, - worker_path=None, - cleanup=True, - redirect_worker_output=False, - redirect_output=False, - resources=None, - plasma_directory=None, - huge_pages=False, - use_raylet=False): +def start_ray_node( + node_ip_address, + redis_address, + object_manager_ports=None, + num_workers=0, + num_local_schedulers=1, + object_store_memory=None, + worker_path=None, + cleanup=True, + redirect_worker_output=False, + redirect_output=False, + resources=None, + plasma_directory=None, + huge_pages=False, + use_raylet=False +): """Start the Ray processes for a single node. This assumes that the Ray processes on some master node have already been @@ -1594,29 +1713,32 @@ def start_ray_node(node_ip_address, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) - - -def start_ray_head(address_info=None, - node_ip_address="127.0.0.1", - redis_port=None, - redis_shard_ports=None, - num_workers=0, - num_local_schedulers=1, - object_store_memory=None, - worker_path=None, - cleanup=True, - redirect_worker_output=False, - redirect_output=False, - start_workers_from_local_scheduler=True, - resources=None, - num_redis_shards=None, - redis_max_clients=None, - include_webui=True, - plasma_directory=None, - huge_pages=False, - autoscaling_config=None, - use_raylet=False): + use_raylet=use_raylet + ) + + +def start_ray_head( + address_info=None, + node_ip_address="127.0.0.1", + redis_port=None, + redis_shard_ports=None, + num_workers=0, + num_local_schedulers=1, + object_store_memory=None, + worker_path=None, + cleanup=True, + redirect_worker_output=False, + redirect_output=False, + start_workers_from_local_scheduler=True, + resources=None, + num_redis_shards=None, + redis_max_clients=None, + include_webui=True, + plasma_directory=None, + huge_pages=False, + autoscaling_config=None, + use_raylet=False +): """Start Ray in local mode. Args: @@ -1693,7 +1815,8 @@ def start_ray_head(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) + use_raylet=use_raylet + ) def try_to_create_directory(directory_path): @@ -1708,8 +1831,10 @@ def try_to_create_directory(directory_path): except OSError as e: if e.errno != os.errno.EEXIST: raise e - print("Attempted to create '{}', but the directory already " - "exists.".format(directory_path)) + print( + "Attempted to create '{}', but the directory already " + "exists.".format(directory_path) + ) # Change the log directory permissions so others can use it. This is # important when multiple people are using the same machine. os.chmod(directory_path, 0o0777) diff --git a/python/ray/signature.py b/python/ray/signature.py index c4ae60aa368f..292b9135f814 100644 --- a/python/ray/signature.py +++ b/python/ray/signature.py @@ -7,10 +7,12 @@ from ray.utils import is_cython -FunctionSignature = namedtuple("FunctionSignature", [ - "arg_names", "arg_defaults", "arg_is_positionals", "keyword_names", - "function_name" -]) +FunctionSignature = namedtuple( + "FunctionSignature", [ + "arg_names", "arg_defaults", "arg_is_positionals", "keyword_names", + "function_name" + ] +) """This class is used to represent a function signature. Attributes: @@ -61,8 +63,9 @@ def func(): for attr in attrs: setattr(func, attr, getattr(original_func, attr)) else: - raise TypeError("{0!r} is not a Python function we can process" - .format(func)) + raise TypeError( + "{0!r} is not a Python function we can process".format(func) + ) return list(funcsigs.signature(func).parameters.items()) @@ -97,8 +100,10 @@ def check_signature_supported(func, warn=False): has_keyword_arg = True if has_kwargs_param: - message = ("The function {} has a **kwargs argument, which is " - "currently not supported.".format(function_name)) + message = ( + "The function {} has a **kwargs argument, which is " + "currently not supported.".format(function_name) + ) if warn: print(message) else: @@ -106,9 +111,10 @@ def check_signature_supported(func, warn=False): # Check if the user specified a variable number of arguments and any # keyword arguments. if has_vararg_param and has_keyword_arg: - message = ("Function {} has a *args argument as well as a keyword " - "argument, which is currently not supported." - .format(function_name)) + message = ( + "Function {} has a *args argument as well as a keyword " + "argument, which is currently not supported.".format(function_name) + ) if warn: print(message) else: @@ -131,9 +137,10 @@ def extract_signature(func, ignore_first=False): if ignore_first: if len(sig_params) == 0: - raise Exception("Methods must take a 'self' argument, but the " - "method '{}' does not have one.".format( - func.__name__)) + raise Exception( + "Methods must take a 'self' argument, but the " + "method '{}' does not have one.".format(func.__name__) + ) sig_params = sig_params[1:] # Extract the names of the keyword arguments. @@ -151,8 +158,10 @@ def extract_signature(func, ignore_first=False): arg_defaults.append(parameter.default) arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL) - return FunctionSignature(arg_names, arg_defaults, arg_is_positionals, - keyword_names, func.__name__) + return FunctionSignature( + arg_names, arg_defaults, arg_is_positionals, keyword_names, + func.__name__ + ) def extend_args(function_signature, args, kwargs): @@ -184,9 +193,10 @@ def extend_args(function_signature, args, kwargs): for keyword_name in kwargs: if keyword_name not in keyword_names: - raise Exception("The name '{}' is not a valid keyword argument " - "for the function '{}'.".format( - keyword_name, function_name)) + raise Exception( + "The name '{}' is not a valid keyword argument " + "for the function '{}'.".format(keyword_name, function_name) + ) # Fill in the remaining arguments. zipped_info = list(zip(arg_names, arg_defaults, @@ -202,14 +212,20 @@ def extend_args(function_signature, args, kwargs): # the last argument and it is a *args argument in which case it # can be omitted. if not is_positional: - raise Exception("No value was provided for the argument " - "'{}' for the function '{}'.".format( - keyword_name, function_name)) - - too_many_arguments = (len(args) > len(arg_names) - and (len(arg_is_positionals) == 0 - or not arg_is_positionals[-1])) + raise Exception( + "No value was provided for the argument " + "'{}' for the function '{}'.".format( + keyword_name, function_name + ) + ) + + too_many_arguments = ( + len(args) > len(arg_names) + and (len(arg_is_positionals) == 0 or not arg_is_positionals[-1]) + ) if too_many_arguments: - raise Exception("Too many arguments were passed to the function '{}'" - .format(function_name)) + raise Exception( + "Too many arguments were passed to the function '{}'" + .format(function_name) + ) return args diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index 177a72519ec8..eb264954b54b 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -47,15 +47,17 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20): return if num_ready_nodes > num_nodes: # Too many nodes have joined. Something must be wrong. - raise Exception("{} nodes have joined the cluster, but we were " - "expecting {} nodes.".format( - num_ready_nodes, num_nodes)) + raise Exception( + "{} nodes have joined the cluster, but we were " + "expecting {} nodes.".format(num_ready_nodes, num_nodes) + ) time.sleep(0.1) # If we get here then we timed out. - raise Exception("Timed out while waiting for {} nodes to join. Only {} " - "nodes have joined so far.".format(num_ready_nodes, - num_nodes)) + raise Exception( + "Timed out while waiting for {} nodes to join. Only {} " + "nodes have joined so far.".format(num_ready_nodes, num_nodes) + ) def _broadcast_event(event_name, redis_address, data=None): @@ -99,8 +101,9 @@ def _wait_for_event(event_name, redis_address, extra_buffer=0): for event_info in event_infos: name, data = json.loads(event_info) if name in events: - raise Exception("The same event {} was broadcast twice." - .format(name)) + raise Exception( + "The same event {} was broadcast twice.".format(name) + ) events[name] = data if event_name in events: # Potentially sleep a little longer and then return the event data. diff --git a/python/ray/tune/async_hyperband.py b/python/ray/tune/async_hyperband.py index 3fafe0ace3d3..425fb38e3a99 100644 --- a/python/ray/tune/async_hyperband.py +++ b/python/ray/tune/async_hyperband.py @@ -35,13 +35,15 @@ class AsyncHyperBandScheduler(FIFOScheduler): halving rate, specified by the reduction factor. """ - def __init__(self, - time_attr='training_iteration', - reward_attr='episode_reward_mean', - max_t=100, - grace_period=10, - reduction_factor=3, - brackets=3): + def __init__( + self, + time_attr='training_iteration', + reward_attr='episode_reward_mean', + max_t=100, + grace_period=10, + reduction_factor=3, + brackets=3 + ): assert max_t > 0, "Max (time_attr) not valid!" assert max_t >= grace_period, "grace_period must be <= max_t!" assert grace_period > 0, "grace_period must be positive!" @@ -76,16 +78,20 @@ def on_trial_result(self, trial_runner, trial, result): action = TrialScheduler.STOP else: bracket = self._trial_info[trial.trial_id] - action = bracket.on_result(trial, getattr(result, self._time_attr), - getattr(result, self._reward_attr)) + action = bracket.on_result( + trial, getattr(result, self._time_attr), + getattr(result, self._reward_attr) + ) if action == TrialScheduler.STOP: self._num_stopped += 1 return action def on_trial_complete(self, trial_runner, trial, result): bracket = self._trial_info[trial.trial_id] - bracket.on_result(trial, getattr(result, self._time_attr), - getattr(result, self._reward_attr)) + bracket.on_result( + trial, getattr(result, self._time_attr), + getattr(result, self._reward_attr) + ) del self._trial_info[trial.trial_id] def on_trial_remove(self, trial_runner, trial): @@ -133,8 +139,10 @@ def on_result(self, trial, cur_iter, cur_rew): if cutoff is not None and cur_rew < cutoff: action = TrialScheduler.STOP if cur_rew is None: - print("Reward attribute is None! Consider" - " reporting using a different field.") + print( + "Reward attribute is None! Consider" + " reporting using a different field." + ) else: recorded[trial.trial_id] = cur_rew break @@ -150,7 +158,8 @@ def debug_str(self): if __name__ == '__main__': sched = AsyncHyperBandScheduler( - grace_period=1, max_t=10, reduction_factor=2) + grace_period=1, max_t=10, reduction_factor=2 + ) print(sched.debug_string()) bracket = sched._brackets[0] print(bracket.cutoff({str(i): i for i in range(20)})) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 509d77c00473..7391fdc683b7 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -19,14 +19,18 @@ def json_to_resources(data): if k in ["driver_cpu_limit", "driver_gpu_limit"]: raise TuneError( "The field `{}` is no longer supported. Use `extra_cpu` " - "or `extra_gpu` instead.".format(k)) + "or `extra_gpu` instead.".format(k) + ) if k not in Resources._fields: raise TuneError( "Unknown resource type {}, must be one of {}".format( - k, Resources._fields)) + k, Resources._fields + ) + ) return Resources( data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0), - data.get("extra_gpu", 0)) + data.get("extra_gpu", 0) + ) def resources_to_json(resources): @@ -57,7 +61,8 @@ def make_parser(**kwargs): help="The algorithm or model to train. This may refer to the name " "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a " "user-defined trainable function or class registered in the " - "tune registry.") + "tune registry." + ) parser.add_argument( "--stop", default="{}", @@ -65,13 +70,15 @@ def make_parser(**kwargs): help="The stopping criteria, specified in JSON. The keys may be any " "field in TrainingResult, e.g. " "'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop " - "after 600 seconds or 100k timesteps, whichever is reached first.") + "after 600 seconds or 100k timesteps, whichever is reached first." + ) parser.add_argument( "--config", default="{}", type=json.loads, help="Algorithm-specific configuration (e.g. env, hyperparams), " - "specified in JSON.") + "specified in JSON." + ) parser.add_argument( "--trial-resources", default=None, @@ -79,52 +86,61 @@ def make_parser(**kwargs): help="Override the machine resources to allocate per trial, e.g. " "'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned " "unless you specify them here. For RLlib, you probably want to " - "leave this alone and use RLlib configs to control parallelism.") + "leave this alone and use RLlib configs to control parallelism." + ) parser.add_argument( "--repeat", default=1, type=int, - help="Number of times to repeat each trial.") + help="Number of times to repeat each trial." + ) parser.add_argument( "--local-dir", default=DEFAULT_RESULTS_DIR, type=str, - help="Local dir to save training results to. Defaults to '{}'.".format( - DEFAULT_RESULTS_DIR)) + help="Local dir to save training results to. Defaults to '{}'.". + format(DEFAULT_RESULTS_DIR) + ) parser.add_argument( "--upload-dir", default="", type=str, - help="Optional URI to sync training results to (e.g. s3://bucket).") + help="Optional URI to sync training results to (e.g. s3://bucket)." + ) parser.add_argument( "--checkpoint-freq", default=0, type=int, help="How many training iterations between checkpoints. " - "A value of 0 (default) disables checkpointing.") + "A value of 0 (default) disables checkpointing." + ) parser.add_argument( "--max-failures", default=3, type=int, help="Try to recover a trial from its last checkpoint at least this " - "many times. Only applies if checkpointing is enabled.") + "many times. Only applies if checkpointing is enabled." + ) parser.add_argument( "--scheduler", default="FIFO", type=str, help="FIFO (default), MedianStopping, AsyncHyperBand, " - "HyperBand, or HyperOpt.") + "HyperBand, or HyperOpt." + ) parser.add_argument( "--scheduler-config", default="{}", type=json.loads, - help="Config options to pass to the scheduler.") + help="Config options to pass to the scheduler." + ) # Note: this currently only makes sense when running a single trial parser.add_argument( "--restore", default=None, type=str, - help="If specified, restore from this checkpoint.") + help="If specified, restore from this checkpoint." + ) return parser diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py index a37755935edc..5313dd1d8104 100644 --- a/python/ray/tune/examples/async_hyperband_example.py +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -52,7 +52,8 @@ def _restore(self, checkpoint_path): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) args, _ = parser.parse_known_args() ray.init() @@ -63,24 +64,24 @@ def _restore(self, checkpoint_path): time_attr="timesteps_total", reward_attr="episode_reward_mean", grace_period=5, - max_t=100) - - run_experiments( - { - "asynchyperband_test": { - "run": "my_class", - "stop": { - "training_iteration": 1 if args.smoke_test else 99999 - }, - "repeat": 20, - "trial_resources": { - "cpu": 1, - "gpu": 0 - }, - "config": { - "width": lambda spec: 10 + int(90 * random.random()), - "height": lambda spec: int(100 * random.random()), - }, - } - }, - scheduler=ahb) + max_t=100 + ) + + run_experiments({ + "asynchyperband_test": { + "run": "my_class", + "stop": { + "training_iteration": 1 if args.smoke_test else 99999 + }, + "repeat": 20, + "trial_resources": { + "cpu": 1, + "gpu": 0 + }, + "config": { + "width": lambda spec: 10 + int(90 * random.random()), + "height": lambda spec: int(100 * random.random()), + }, + } + }, + scheduler=ahb) diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index 65410e80b320..de8a3d7f6a64 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -52,7 +52,8 @@ def _restore(self, checkpoint_path): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) args, _ = parser.parse_known_args() ray.init() @@ -61,7 +62,8 @@ def _restore(self, checkpoint_path): hyperband = HyperBandScheduler( time_attr="timesteps_total", reward_attr="episode_reward_mean", - max_t=100) + max_t=100 + ) exp = Experiment( name="hyperband_test", @@ -71,6 +73,7 @@ def _restore(self, checkpoint_path): config={ "width": lambda spec: 10 + int(90 * random.random()), "height": lambda spec: int(100 * random.random()) - }) + } + ) run_experiments(exp, scheduler=hyperband) diff --git a/python/ray/tune/examples/hyperopt_example.py b/python/ray/tune/examples/hyperopt_example.py index 6104d39ec908..070807dca95b 100644 --- a/python/ray/tune/examples/hyperopt_example.py +++ b/python/ray/tune/examples/hyperopt_example.py @@ -12,8 +12,9 @@ def easy_objective(config, reporter): time.sleep(0.2) reporter( timesteps_total=1, - episode_reward_mean=-( - (config["height"] - 14)**2 + abs(config["width"] - 3))) + episode_reward_mean=-((config["height"] - 14)**2 + + abs(config["width"] - 3)) + ) time.sleep(0.2) @@ -23,7 +24,8 @@ def easy_objective(config, reporter): parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) args, _ = parser.parse_known_args() ray.init(redirect_output=True) diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index e63a5e5421fc..03a5462fca4d 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -29,7 +29,8 @@ def _train(self): # Reward increase is parabolic as a function of factor_2, with a # maxima around factor_1=10.0. self.current_value += max( - 0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0)) + 0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0) + ) # Flat increase by factor_2 self.current_value += random.gauss(self.config["factor_2"], 1.0) @@ -37,7 +38,8 @@ def _train(self): # Here we use `episode_reward_mean`, but you can also report other # objectives such as loss or accuracy (see tune/result.py). return TrainingResult( - episode_reward_mean=self.current_value, timesteps_this_iter=1) + episode_reward_mean=self.current_value, timesteps_this_iter=1 + ) def _save(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "checkpoint") @@ -46,7 +48,8 @@ def _save(self, checkpoint_dir): json.dumps({ "timestep": self.timestep, "value": self.current_value - })) + }) + ) return path def _restore(self, checkpoint_path): @@ -61,7 +64,8 @@ def _restore(self, checkpoint_path): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) args, _ = parser.parse_known_args() ray.init() @@ -75,22 +79,22 @@ def _restore(self, checkpoint_path): "factor_1": lambda: random.uniform(0.0, 20.0), # Allow perturbations within this set of categorical values. "factor_2": [1, 2], - }) + } + ) # Try to find the best factor 1 and factor 2 - run_experiments( - { - "pbt_test": { - "run": "my_class", - "stop": { - "training_iteration": 2 if args.smoke_test else 99999 - }, - "repeat": 10, - "config": { - "factor_1": 4.0, - "factor_2": 1.0, - }, - } - }, - scheduler=pbt, - verbose=False) + run_experiments({ + "pbt_test": { + "run": "my_class", + "stop": { + "training_iteration": 2 if args.smoke_test else 99999 + }, + "repeat": 10, + "config": { + "factor_1": 4.0, + "factor_2": 1.0, + }, + } + }, + scheduler=pbt, + verbose=False) diff --git a/python/ray/tune/examples/pbt_ppo_example.py b/python/ray/tune/examples/pbt_ppo_example.py index 9914999868e8..ee4e7e0f4897 100755 --- a/python/ray/tune/examples/pbt_ppo_example.py +++ b/python/ray/tune/examples/pbt_ppo_example.py @@ -42,7 +42,8 @@ def explore(config): "sgd_batchsize": lambda: random.randint(128, 16384), "timesteps_per_batch": lambda: random.randint(2000, 160000), }, - custom_explore_fn=explore) + custom_explore_fn=explore + ) ray.init() run_experiments( @@ -77,4 +78,5 @@ def explore(config): }, }, }, - scheduler=pbt) + scheduler=pbt + ) diff --git a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py index fc02ad8120bc..ecbf01c90f4d 100755 --- a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py +++ b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py @@ -57,14 +57,20 @@ def _build_model(self, input_shape): strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = Convolution2D( filters=64, kernel_size=3, strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) y = Convolution2D( @@ -73,14 +79,20 @@ def _build_model(self, input_shape): strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = Convolution2D( filters=128, kernel_size=3, strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) y = Convolution2D( @@ -89,20 +101,29 @@ def _build_model(self, input_shape): strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = Convolution2D( filters=256, kernel_size=3, strides=1, padding="same", activation="relu", - kernel_initializer="he_normal")(y) + kernel_initializer="he_normal" + )( + y + ) y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y) y = Flatten()(y) y = Dropout(self.config["dropout"])(y) y = Dense( - units=10, activation="softmax", kernel_initializer="he_normal")(y) + units=10, activation="softmax", kernel_initializer="he_normal" + )( + y + ) model = Model(inputs=x, outputs=y, name="model1") return model @@ -116,7 +137,8 @@ def _setup(self): model.compile( loss="categorical_crossentropy", optimizer=opt, - metrics=["accuracy"]) + metrics=["accuracy"] + ) self.model = model def _train(self): @@ -148,12 +170,14 @@ def _train(self): aug_gen.fit(x_train) gen = aug_gen.flow( - x_train, y_train, batch_size=self.config["batch_size"]) + x_train, y_train, batch_size=self.config["batch_size"] + ) self.model.fit_generator( generator=gen, steps_per_epoch=50000 // self.config["batch_size"], epochs=self.config["epochs"], - validation_data=None) + validation_data=None + ) # loss, accuracy _, accuracy = self.model.evaluate(x_test, y_test, verbose=0) @@ -177,7 +201,8 @@ def _stop(self): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) args, _ = parser.parse_known_args() register_trainable("train_cifar10", Cifar10Model) @@ -213,6 +238,7 @@ def _stop(self): perturbation_interval=10, hyperparam_mutations={ "dropout": lambda _: np.random.uniform(0, 1), - }) + } + ) run_experiments({"pbt_cifar10": train_spec}, scheduler=pbt) diff --git a/python/ray/tune/examples/tune_mnist_async_hyperband.py b/python/ray/tune/examples/tune_mnist_async_hyperband.py index 6e37fc234230..bcf56facef3d 100755 --- a/python/ray/tune/examples/tune_mnist_async_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_async_hyperband.py @@ -115,7 +115,8 @@ def conv2d(x, W): def max_pool_2x2(x): """max_pool_2x2 downsamples a feature map by 2X.""" return tf.nn.max_pool( - x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') + x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME' + ) def weight_variable(shape): @@ -150,7 +151,8 @@ def main(_): with tf.name_scope('loss'): cross_entropy = tf.nn.softmax_cross_entropy_with_logits( - labels=y_, logits=y_conv) + labels=y_, logits=y_conv + ) cross_entropy = tf.reduce_mean(cross_entropy) with tf.name_scope('adam_optimizer'): @@ -171,29 +173,38 @@ def main(_): for i in range(20000): batch = mnist.train.next_batch(50) if i % 10 == 0: - train_accuracy = accuracy.eval(feed_dict={ - x: batch[0], - y_: batch[1], - keep_prob: 1.0 - }) + train_accuracy = accuracy.eval( + feed_dict={ + x: batch[0], + y_: batch[1], + keep_prob: 1.0 + } + ) # !!! Report status to ray.tune !!! if status_reporter: status_reporter( - timesteps_total=i, mean_accuracy=train_accuracy) + timesteps_total=i, mean_accuracy=train_accuracy + ) print('step %d, training accuracy %g' % (i, train_accuracy)) - train_step.run(feed_dict={ - x: batch[0], - y_: batch[1], - keep_prob: 0.5 - }) - - print('test accuracy %g' % accuracy.eval(feed_dict={ - x: mnist.test.images, - y_: mnist.test.labels, - keep_prob: 1.0 - })) + train_step.run( + feed_dict={ + x: batch[0], + y_: batch[1], + keep_prob: 0.5 + } + ) + + print( + 'test accuracy %g' % accuracy.eval( + feed_dict={ + x: mnist.test.images, + y_: mnist.test.labels, + keep_prob: 1.0 + } + ) + ) # !!! Entrypoint for ray.tune !!! @@ -206,7 +217,8 @@ def train(config={'activation': 'relu'}, reporter=None): '--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') + help='Directory for storing input data' + ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) @@ -215,7 +227,8 @@ def train(config={'activation': 'relu'}, reporter=None): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - '--smoke-test', action='store_true', help='Finish quickly for testing') + '--smoke-test', action='store_true', help='Finish quickly for testing' + ) args, _ = parser.parse_known_args() register_trainable('train_mnist', train) @@ -238,12 +251,11 @@ def train(config={'activation': 'relu'}, reporter=None): ray.init() from ray.tune.async_hyperband import AsyncHyperBandScheduler - run_experiments( - { - 'tune_mnist_test': mnist_spec - }, - scheduler=AsyncHyperBandScheduler( - time_attr="timesteps_total", - reward_attr="mean_accuracy", - max_t=600, - )) + run_experiments({ + 'tune_mnist_test': mnist_spec + }, + scheduler=AsyncHyperBandScheduler( + time_attr="timesteps_total", + reward_attr="mean_accuracy", + max_t=600, + )) diff --git a/python/ray/tune/examples/tune_mnist_ray.py b/python/ray/tune/examples/tune_mnist_ray.py index 176bffbc0052..950cedd6418a 100755 --- a/python/ray/tune/examples/tune_mnist_ray.py +++ b/python/ray/tune/examples/tune_mnist_ray.py @@ -115,7 +115,8 @@ def conv2d(x, W): def max_pool_2x2(x): """max_pool_2x2 downsamples a feature map by 2X.""" return tf.nn.max_pool( - x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') + x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME' + ) def weight_variable(shape): @@ -150,7 +151,8 @@ def main(_): with tf.name_scope('loss'): cross_entropy = tf.nn.softmax_cross_entropy_with_logits( - labels=y_, logits=y_conv) + labels=y_, logits=y_conv + ) cross_entropy = tf.reduce_mean(cross_entropy) with tf.name_scope('adam_optimizer'): @@ -171,29 +173,38 @@ def main(_): for i in range(20000): batch = mnist.train.next_batch(50) if i % 10 == 0: - train_accuracy = accuracy.eval(feed_dict={ - x: batch[0], - y_: batch[1], - keep_prob: 1.0 - }) + train_accuracy = accuracy.eval( + feed_dict={ + x: batch[0], + y_: batch[1], + keep_prob: 1.0 + } + ) # !!! Report status to ray.tune !!! if status_reporter: status_reporter( - timesteps_total=i, mean_accuracy=train_accuracy) + timesteps_total=i, mean_accuracy=train_accuracy + ) print('step %d, training accuracy %g' % (i, train_accuracy)) - train_step.run(feed_dict={ - x: batch[0], - y_: batch[1], - keep_prob: 0.5 - }) - - print('test accuracy %g' % accuracy.eval(feed_dict={ - x: mnist.test.images, - y_: mnist.test.labels, - keep_prob: 1.0 - })) + train_step.run( + feed_dict={ + x: batch[0], + y_: batch[1], + keep_prob: 0.5 + } + ) + + print( + 'test accuracy %g' % accuracy.eval( + feed_dict={ + x: mnist.test.images, + y_: mnist.test.labels, + keep_prob: 1.0 + } + ) + ) # !!! Entrypoint for ray.tune !!! @@ -206,7 +217,8 @@ def train(config={'activation': 'relu'}, reporter=None): '--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') + help='Directory for storing input data' + ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) @@ -215,7 +227,8 @@ def train(config={'activation': 'relu'}, reporter=None): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - '--smoke-test', action='store_true', help='Finish quickly for testing') + '--smoke-test', action='store_true', help='Finish quickly for testing' + ) args, _ = parser.parse_known_args() register_trainable('train_mnist', train) diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index c320f8e5c666..807d1874b085 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -110,7 +110,8 @@ def conv2d(x, W): def max_pool_2x2(x): """max_pool_2x2 downsamples a feature map by 2X.""" return tf.nn.max_pool( - x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') + x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME' + ) def weight_variable(shape): @@ -137,7 +138,8 @@ def _setup(self): for _ in range(10): try: self.mnist = input_data.read_data_sets( - "/tmp/mnist_ray_demo", one_hot=True) + "/tmp/mnist_ray_demo", one_hot=True + ) break except Exception as e: print("Error loading data, retrying", e) @@ -155,18 +157,20 @@ def _setup(self): with tf.name_scope('loss'): cross_entropy = tf.nn.softmax_cross_entropy_with_logits( - labels=self.y_, logits=y_conv) + labels=self.y_, logits=y_conv + ) cross_entropy = tf.reduce_mean(cross_entropy) with tf.name_scope('adam_optimizer'): - train_step = tf.train.AdamOptimizer( - self.config['learning_rate']).minimize(cross_entropy) + train_step = tf.train.AdamOptimizer(self.config['learning_rate'] + ).minimize(cross_entropy) self.train_step = train_step with tf.name_scope('accuracy'): correct_prediction = tf.equal( - tf.argmax(y_conv, 1), tf.argmax(self.y_, 1)) + tf.argmax(y_conv, 1), tf.argmax(self.y_, 1) + ) correct_prediction = tf.cast(correct_prediction, tf.float32) self.accuracy = tf.reduce_mean(correct_prediction) @@ -184,7 +188,8 @@ def _train(self): self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5 - }) + } + ) batch = self.mnist.train.next_batch(50) train_accuracy = self.sess.run( @@ -193,15 +198,18 @@ def _train(self): self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0 - }) + } + ) self.iterations += 1 return TrainingResult( - timesteps_this_iter=10, mean_accuracy=train_accuracy) + timesteps_this_iter=10, mean_accuracy=train_accuracy + ) def _save(self, checkpoint_dir): return self.saver.save( - self.sess, checkpoint_dir + "/save", global_step=self.iterations) + self.sess, checkpoint_dir + "/save", global_step=self.iterations + ) def _restore(self, path): return self.saver.restore(self.sess, path) @@ -211,7 +219,8 @@ def _restore(self, path): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - '--smoke-test', action='store_true', help='Finish quickly for testing') + '--smoke-test', action='store_true', help='Finish quickly for testing' + ) args, _ = parser.parse_known_args() register_trainable("my_class", TrainMNIST) @@ -234,6 +243,7 @@ def _restore(self, path): ray.init() hyperband = HyperBandScheduler( - time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100) + time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100 + ) run_experiments({'mnist_hyperband_test': mnist_spec}, scheduler=hyperband) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 09102349440b..b9912e7b9278 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -36,17 +36,19 @@ class Experiment(object): checkpointing is enabled. Defaults to 3. """ - def __init__(self, - name, - run, - stop=None, - config=None, - trial_resources=None, - repeat=1, - local_dir=None, - upload_dir="", - checkpoint_freq=0, - max_failures=3): + def __init__( + self, + name, + run, + stop=None, + config=None, + trial_resources=None, + repeat=1, + local_dir=None, + upload_dir="", + checkpoint_freq=0, + max_failures=3 + ): spec = { "run": run, "stop": stop or {}, diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index b4fff276b371..ba4e0fe727c9 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -92,8 +92,9 @@ def _setup(self): for k in self._default_config: if k in scrubbed_config: del scrubbed_config[k] - self._runner = _RunnerThread(entrypoint, scrubbed_config, - self._status_reporter) + self._runner = _RunnerThread( + entrypoint, scrubbed_config, self._status_reporter + ) self._start_time = time.time() self._last_reported_timestep = 0 self._runner.start() @@ -105,8 +106,11 @@ def _trainable_func(self): def _train(self): time.sleep( - self.config.get("script_min_iter_time_s", - self._default_config["script_min_iter_time_s"])) + self.config.get( + "script_min_iter_time_s", + self._default_config["script_min_iter_time_s"] + ) + ) result = self._status_reporter._get_and_clear_status() while result is None: _serve_get_pin_requests() @@ -117,7 +121,9 @@ def _train(self): result = result._replace( timesteps_this_iter=( - result.timesteps_total - self._last_reported_timestep)) + result.timesteps_total - self._last_reported_timestep + ) + ) self._last_reported_timestep = result.timesteps_total return result diff --git a/python/ray/tune/hpo_scheduler.py b/python/ray/tune/hpo_scheduler.py index 017f89ba2363..0e18722829b5 100644 --- a/python/ray/tune/hpo_scheduler.py +++ b/python/ray/tune/hpo_scheduler.py @@ -102,8 +102,10 @@ def _trial_generator(self): self._hpopt_trials.refresh() # Get new suggestion from - new_trials = self.algo(new_ids, self.domain, self._hpopt_trials, - self.rstate.randint(2**31 - 1)) + new_trials = self.algo( + new_ids, self.domain, self._hpopt_trials, + self.rstate.randint(2**31 - 1) + ) self._hpopt_trials.insert_trial_docs(new_trials) self._hpopt_trials.refresh() new_trial = new_trials[0] @@ -129,7 +131,8 @@ def _trial_generator(self): checkpoint_freq=self.args.checkpoint_freq, restore_path=self.args.restore, upload_dir=self.args.upload_dir, - max_failures=self.args.max_failures) + max_failures=self.args.max_failures + ) self._tune_to_hp[trial] = new_trial_id self._num_trials_left -= 1 diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index bbc4fa0077bc..a19a6f42ae00 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -66,10 +66,12 @@ class HyperBandScheduler(FIFOScheduler): mentioned in the original HyperBand paper. """ - def __init__(self, - time_attr='training_iteration', - reward_attr='episode_reward_mean', - max_t=81): + def __init__( + self, + time_attr='training_iteration', + reward_attr='episode_reward_mean', + max_t=81 + ): assert max_t > 0, "Max (time_attr) not valid!" FIFOScheduler.__init__(self) self._eta = 3 @@ -116,9 +118,10 @@ def on_trial_add(self, trial_runner, trial): cur_bracket = None else: retry = False - cur_bracket = Bracket(self._time_attr, self._get_n0(s), - self._get_r0(s), self._max_t_attr, - self._eta, s) + cur_bracket = Bracket( + self._time_attr, self._get_n0(s), self._get_r0(s), + self._max_t_attr, self._eta, s + ) cur_band.append(cur_bracket) self._state["bracket"] = cur_bracket @@ -218,10 +221,13 @@ def choose_trial_to_run(self, trial_runner): for hyperband in self._hyperbands: for bracket in sorted( - hyperband, key=lambda b: b.completion_percentage()): + hyperband, key=lambda b: b.completion_percentage() + ): for trial in bracket.current_trials(): - if (trial.status == Trial.PENDING - and trial_runner.has_resources(trial.resources)): + if ( + trial.status == Trial.PENDING + and trial_runner.has_resources(trial.resources) + ): return trial return None @@ -244,7 +250,8 @@ def debug_string(self): """ out = "Using HyperBand: " out += "num_stopped={} total_brackets={}".format( - self._num_stopped, sum(len(band) for band in self._hyperbands)) + self._num_stopped, sum(len(band) for band in self._hyperbands) + ) for i, band in enumerate(self._hyperbands): out += "\nRound #{}:".format(i) for bracket in band: @@ -290,7 +297,8 @@ def cur_iter_done(self): TODO(rliaw): also check that `t.iterations == self._r`""" return all( self._get_result_time(result) >= self._cumul_r - for result in self._live_trials.values()) + for result in self._live_trials.values() + ) def finished(self): return self._halves == 0 and self.cur_iter_done() @@ -324,7 +332,8 @@ def successive_halving(self, reward_attr): self._cumul_r += self._r sorted_trials = sorted( self._live_trials, - key=lambda t: getattr(self._live_trials[t], reward_attr)) + key=lambda t: getattr(self._live_trials[t], reward_attr) + ) good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n] return good, bad @@ -391,11 +400,12 @@ def _calculate_total_work(self, n, r, s): def __repr__(self): status = ", ".join([ - "Max Size (n)={}".format(self._n), "Milestone (r)={}".format( - self._cumul_r), "completed={:.1%}".format( - self.completion_percentage()) + "Max Size (n)={}".format(self._n), + "Milestone (r)={}".format(self._cumul_r), + "completed={:.1%}".format(self.completion_percentage()) ]) counts = collections.Counter([t.status for t in self._all_trials]) trial_statuses = ", ".join( - sorted(["{}: {}".format(k, v) for k, v in counts.items()])) + sorted(["{}: {}".format(k, v) for k, v in counts.items()]) + ) return "Bracket({}): {{{}}} ".format(status, trial_statuses) diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index d47d0e4e21c0..942f38e0fa34 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -68,8 +68,11 @@ def sync_if_needed(self): def sync_now(self, force=False): self.last_sync_time = time.time() if not self.worker_ip: - print("Worker ip unknown, skipping log sync for {}".format( - self.local_dir)) + print( + "Worker ip unknown, skipping log sync for {}".format( + self.local_dir + ) + ) return if self.worker_ip == self.local_ip: @@ -78,21 +81,28 @@ def sync_now(self, force=False): ssh_key = get_ssh_key() ssh_user = get_ssh_user() if ssh_key is None or ssh_user is None: - print("Error: log sync requires cluster to be setup with " - "`ray create_or_update`.") + print( + "Error: log sync requires cluster to be setup with " + "`ray create_or_update`." + ) return if not distutils.spawn.find_executable("rsync"): print("Error: log sync requires rsync to be installed.") return worker_to_local_sync_cmd = (( """rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format( - ssh_key, ssh_user, self.worker_ip, - pipes.quote(self.local_dir), pipes.quote(self.local_dir))) + """-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""" + ).format( + ssh_key, ssh_user, self.worker_ip, pipes.quote(self.local_dir), + pipes.quote(self.local_dir) + )) if self.remote_dir: - local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format( - pipes.quote(self.local_dir), pipes.quote(self.remote_dir))) + local_to_remote_sync_cmd = ( + "aws s3 sync '{}' '{}'".format( + pipes.quote(self.local_dir), pipes.quote(self.remote_dir) + ) + ) else: local_to_remote_sync_cmd = None diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index bb708f99ae9a..0c848ed872d1 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -112,7 +112,9 @@ def to_tf_values(result, path): if type(value) in [int, float]: values.append( tf.Summary.Value( - tag="/".join(path + [attr]), simple_value=value)) + tag="/".join(path + [attr]), simple_value=value + ) + ) elif type(value) is dict: values.extend(to_tf_values(value, path + [attr])) return values @@ -125,7 +127,7 @@ def _init(self): def on_result(self, result): tmp = result._asdict() for k in [ - "config", "pid", "timestamp", "time_total_s", "timesteps_total" + "config", "pid", "timestamp", "time_total_s", "timesteps_total" ]: del tmp[k] # not useful to tf log these values = to_tf_values(tmp, ["ray", "tune"]) @@ -167,7 +169,8 @@ def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str): _iterencode = json.encoder._make_iterencode( None, self.default, _encoder, self.indent, floatstr, self.key_separator, self.item_separator, self.sort_keys, - self.skipkeys, _one_shot) + self.skipkeys, _one_shot + ) return _iterencode(o, 0) def default(self, value): diff --git a/python/ray/tune/median_stopping_rule.py b/python/ray/tune/median_stopping_rule.py index 5d75e62b05a6..5b1d7edbe640 100644 --- a/python/ray/tune/median_stopping_rule.py +++ b/python/ray/tune/median_stopping_rule.py @@ -32,13 +32,15 @@ class MedianStoppingRule(FIFOScheduler): time a trial reports. Defaults to True. """ - def __init__(self, - time_attr="time_total_s", - reward_attr="episode_reward_mean", - grace_period=60.0, - min_samples_required=3, - hard_stop=True, - verbose=True): + def __init__( + self, + time_attr="time_total_s", + reward_attr="episode_reward_mean", + grace_period=60.0, + min_samples_required=3, + hard_stop=True, + verbose=True + ): FIFOScheduler.__init__(self) self._stopped_trials = set() self._completed_trials = set() @@ -67,8 +69,11 @@ def on_trial_result(self, trial_runner, trial, result): median_result = self._get_median_result(time) best_result = self._best_result(trial) if self._verbose: - print("Trial {} best res={} vs median res={} at t={}".format( - trial, best_result, median_result, time)) + print( + "Trial {} best res={} vs median res={} at t={}".format( + trial, best_result, median_result, time + ) + ) if best_result < median_result and time > self._grace_period: if self._verbose: print("MedianStoppingRule: early stopping {}".format(trial)) @@ -91,7 +96,8 @@ def on_trial_remove(self, trial_runner, trial): def debug_string(self): return "Using MedianStoppingRule: num_stopped={}.".format( - len(self._stopped_trials)) + len(self._stopped_trials) + ) def _get_median_result(self, time): scores = [] @@ -107,7 +113,8 @@ def _running_result(self, trial, t_max=float('inf')): # TODO(ekl) we could do interpolation to be more precise, but for now # assume len(results) is large and the time diffs are roughly equal return np.mean([ - getattr(r, self._reward_attr) for r in results + getattr(r, self._reward_attr) + for r in results if getattr(r, self._time_attr) <= t_max ]) diff --git a/python/ray/tune/pbt.py b/python/ray/tune/pbt.py index ec1fe6225ecc..fd589267c61f 100644 --- a/python/ray/tune/pbt.py +++ b/python/ray/tune/pbt.py @@ -26,8 +26,9 @@ def __init__(self, trial): self.last_perturbation_time = 0 def __repr__(self): - return str((self.last_score, self.last_checkpoint, - self.last_perturbation_time)) + return str(( + self.last_score, self.last_checkpoint, self.last_perturbation_time + )) def explore(config, mutations, resample_probability, custom_explore_fn): @@ -51,11 +52,13 @@ def explore(config, mutations, resample_probability, custom_explore_fn): elif random.random() > 0.5: new_config[key] = distribution[max( 0, - distribution.index(config[key]) - 1)] + distribution.index(config[key]) - 1 + )] else: new_config[key] = distribution[min( len(distribution) - 1, - distribution.index(config[key]) + 1)] + distribution.index(config[key]) + 1 + )] else: if random.random() < resample_probability: new_config[key] = distribution() @@ -69,8 +72,7 @@ def explore(config, mutations, resample_probability, custom_explore_fn): new_config = custom_explore_fn(new_config) assert new_config is not None, \ "Custom explore fn failed to return new config" - print("[explore] perturbed config from {} -> {}".format( - config, new_config)) + print("[explore] perturbed config from {} -> {}".format(config, new_config)) return new_config @@ -147,17 +149,20 @@ class PopulationBasedTraining(FIFOScheduler): >>> run_experiments({...}, scheduler=pbt) """ - def __init__(self, - time_attr="time_total_s", - reward_attr="episode_reward_mean", - perturbation_interval=60.0, - hyperparam_mutations={}, - resample_probability=0.25, - custom_explore_fn=None): + def __init__( + self, + time_attr="time_total_s", + reward_attr="episode_reward_mean", + perturbation_interval=60.0, + hyperparam_mutations={}, + resample_probability=0.25, + custom_explore_fn=None + ): if not hyperparam_mutations and not custom_explore_fn: raise TuneError( "You must specify at least one of `hyperparam_mutations` or " - "`custom_explore_fn` to use PBT.") + "`custom_explore_fn` to use PBT." + ) FIFOScheduler.__init__(self) self._reward_attr = reward_attr self._time_attr = time_attr @@ -211,19 +216,24 @@ def _exploit(self, trial, trial_to_clone): if not new_state.last_checkpoint: print("[pbt] warn: no checkpoint for trial, skip exploit", trial) return - new_config = explore(trial_to_clone.config, self._hyperparam_mutations, - self._resample_probability, - self._custom_explore_fn) - print("[exploit] transferring weights from trial " - "{} (score {}) -> {} (score {})".format( - trial_to_clone, new_state.last_score, trial, - trial_state.last_score)) + new_config = explore( + trial_to_clone.config, self._hyperparam_mutations, + self._resample_probability, self._custom_explore_fn + ) + print( + "[exploit] transferring weights from trial " + "{} (score {}) -> {} (score {})".format( + trial_to_clone, new_state.last_score, trial, + trial_state.last_score + ) + ) # TODO(ekl) restarting the trial is expensive. We should implement a # lighter way reset() method that can alter the trial config. trial.stop(stop_logger=False) trial.config = new_config trial.experiment_tag = make_experiment_tag( - trial_state.orig_tag, new_config, self._hyperparam_mutations) + trial_state.orig_tag, new_config, self._hyperparam_mutations + ) trial.start(new_state.last_checkpoint) self._num_perturbations += 1 # Transfer over the last perturbation time as well @@ -243,8 +253,10 @@ def _quantiles(self): if len(trials) <= 1: return [], [] else: - return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))], - trials[int(math.floor(-len(trials) * PBT_QUANTILE)):]) + return ( + trials[:int(math.ceil(len(trials) * PBT_QUANTILE))], + trials[int(math.floor(-len(trials) * PBT_QUANTILE)):] + ) def choose_trial_to_run(self, trial_runner): """Ensures all trials get fair share of time (as defined by time_attr). @@ -259,7 +271,8 @@ def choose_trial_to_run(self, trial_runner): trial_runner.has_resources(trial.resources): candidates.append(trial) candidates.sort( - key=lambda trial: self._trial_state[trial].last_perturbation_time) + key=lambda trial: self._trial_state[trial].last_perturbation_time + ) return candidates[0] if candidates else None def reset_stats(self): @@ -276,4 +289,5 @@ def last_scores(self, trials): def debug_string(self): return "PopulationBasedTraining: {} checkpoints, {} perturbs".format( - self._num_checkpoints, self._num_perturbations) + self._num_checkpoints, self._num_perturbations + ) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index f17267eaadd4..858fcf1cffe0 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -33,8 +33,9 @@ def register_trainable(name, trainable): if isinstance(trainable, FunctionType): trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): - raise TypeError("Second argument must be convertable to Trainable", - trainable) + raise TypeError( + "Second argument must be convertable to Trainable", trainable + ) _default_registry.register(TRAINABLE_CLASS, name, trainable) @@ -84,8 +85,11 @@ def __init__(self, objs=None): def register(self, category, key, value): if category not in KNOWN_CATEGORIES: from ray.tune import TuneError - raise TuneError("Unknown category {} not among {}".format( - category, KNOWN_CATEGORIES)) + raise TuneError( + "Unknown category {} not among {}".format( + category, KNOWN_CATEGORIES + ) + ) self._all_objects[(category, key)] = value def contains(self, category, key): diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 74ea2bcb9838..5082e07a003a 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -88,6 +88,7 @@ # (Auto=filled) The current hyperparameter configuration. "config", - ]) + ] +) TrainingResult.__new__.__defaults__ = (None, ) * len(TrainingResult._fields) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index d51f9ec6f988..5133f5fd2bd9 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -88,17 +88,16 @@ def _train(self): register_trainable("B", B) def f(cpus, gpus, queue_trials): - return run_experiments( - { - "foo": { - "run": "B", - "config": { - "cpu": cpus, - "gpu": gpus, - }, - } - }, - queue_trials=queue_trials)[0] + return run_experiments({ + "foo": { + "run": "B", + "config": { + "cpu": cpus, + "gpu": gpus, + }, + } + }, + queue_trials=queue_trials)[0] # Should all succeed self.assertEqual(f(0, 0, False).status, Trial.TERMINATED) @@ -359,13 +358,15 @@ def train(config, reporter): reporter(timesteps_total=i) register_trainable("f1", train) - exp1 = Experiment(**{ - "name": "foo", - "run": "f1", - "config": { - "script_min_iter_time_s": 0 + exp1 = Experiment( + **{ + "name": "foo", + "run": "f1", + "config": { + "script_min_iter_time_s": 0 + } } - }) + ) [trial] = run_experiments(exp1) self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result.timesteps_total, 99) @@ -376,20 +377,24 @@ def train(config, reporter): reporter(timesteps_total=i) register_trainable("f1", train) - exp1 = Experiment(**{ - "name": "foo", - "run": "f1", - "config": { - "script_min_iter_time_s": 0 + exp1 = Experiment( + **{ + "name": "foo", + "run": "f1", + "config": { + "script_min_iter_time_s": 0 + } } - }) - exp2 = Experiment(**{ - "name": "bar", - "run": "f1", - "config": { - "script_min_iter_time_s": 0 + ) + exp2 = Experiment( + **{ + "name": "bar", + "run": "f1", + "config": { + "script_min_iter_time_s": 0 + } } - }) + ) trials = run_experiments([exp1, exp2]) for trial in trials: self.assertEqual(trial.status, Trial.TERMINATED) @@ -421,8 +426,9 @@ def testParseToTrials(self): self.assertEqual(trials[0].trainable_name, "PPO") self.assertEqual(trials[0].experiment_tag, "0") self.assertEqual(trials[0].max_failures, 5) - self.assertEqual(trials[0].local_dir, - os.path.join(DEFAULT_RESULTS_DIR, "tune-pong")) + self.assertEqual( + trials[0].local_dir, os.path.join(DEFAULT_RESULTS_DIR, "tune-pong") + ) self.assertEqual(trials[1].experiment_tag, "1") def testEval(self): @@ -528,7 +534,8 @@ def testRecursiveDep(self): "config": { "foo": lambda spec: spec.config.foo, }, - })) + }) + ) except RecursiveDependencyError as e: assert "`foo` recursively depends on" in str(e), e else: diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index b008af3c7d6a..4cf8d451b2a6 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -21,7 +21,8 @@ def result(t, rew): return TrainingResult( - time_total_s=t, episode_reward_mean=rew, training_iteration=int(t)) + time_total_s=t, episode_reward_mean=rew, training_iteration=int(t) + ) class EarlyStoppingSuite(unittest.TestCase): @@ -38,11 +39,13 @@ def basicSetup(self, rule): for i in range(10): self.assertEqual( rule.on_trial_result(None, t1, result(i, i * 100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) for i in range(5): self.assertEqual( rule.on_trial_result(None, t2, result(i, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) return t1, t2 def testMedianStoppingConstantPerf(self): @@ -51,24 +54,27 @@ def testMedianStoppingConstantPerf(self): rule.on_trial_complete(None, t1, result(10, 1000)) self.assertEqual( rule.on_trial_result(None, t2, result(5, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( rule.on_trial_result(None, t2, result(6, 0)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( - rule.on_trial_result(None, t2, result(10, 450)), - TrialScheduler.STOP) + rule.on_trial_result(None, t2, result(10, 450)), TrialScheduler.STOP + ) def testMedianStoppingOnCompleteOnly(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) self.assertEqual( rule.on_trial_result(None, t2, result(100, 0)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) rule.on_trial_complete(None, t1, result(10, 1000)) self.assertEqual( - rule.on_trial_result(None, t2, result(101, 0)), - TrialScheduler.STOP) + rule.on_trial_result(None, t2, result(101, 0)), TrialScheduler.STOP + ) def testMedianStoppingGracePeriod(self): rule = MedianStoppingRule(grace_period=2.5, min_samples_required=1) @@ -78,12 +84,15 @@ def testMedianStoppingGracePeriod(self): t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 10)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( rule.on_trial_result(None, t3, result(2, 10)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( - rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP) + rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP + ) def testMedianStoppingMinSamples(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) @@ -92,10 +101,12 @@ def testMedianStoppingMinSamples(self): t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(3, 10)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) rule.on_trial_complete(None, t2, result(10, 1000)) self.assertEqual( - rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP) + rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP + ) def testMedianStoppingUsesMedian(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) @@ -105,24 +116,27 @@ def testMedianStoppingUsesMedian(self): t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 260)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( - rule.on_trial_result(None, t3, result(2, 260)), - TrialScheduler.STOP) + rule.on_trial_result(None, t3, result(2, 260)), TrialScheduler.STOP + ) def testMedianStoppingSoftStop(self): rule = MedianStoppingRule( - grace_period=0, min_samples_required=1, hard_stop=False) + grace_period=0, min_samples_required=1, hard_stop=False + ) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 260)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( - rule.on_trial_result(None, t3, result(2, 260)), - TrialScheduler.PAUSE) + rule.on_trial_result(None, t3, result(2, 260)), TrialScheduler.PAUSE + ) def testAlternateMetrics(self): def result2(t, rew): @@ -132,24 +146,29 @@ def result2(t, rew): grace_period=0, min_samples_required=1, time_attr='training_iteration', - reward_attr='neg_mean_loss') + reward_attr='neg_mean_loss' + ) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 for i in range(10): self.assertEqual( rule.on_trial_result(None, t1, result2(i, i * 100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) for i in range(5): self.assertEqual( rule.on_trial_result(None, t2, result2(i, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) rule.on_trial_complete(None, t1, result2(10, 1000)) self.assertEqual( rule.on_trial_result(None, t2, result2(5, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( rule.on_trial_result(None, t2, result2(6, 0)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) class _MockTrialRunner(): @@ -314,8 +333,9 @@ def testSuccessiveHalving(self): # Provides results from 0 to 8 in order, keeping last one running for i, trl in enumerate(trials): - action = sched.on_trial_result(mock_runner, trl, - result(cur_units, i)) + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i) + ) if i < current_length - 1: self.assertEqual(action, TrialScheduler.PAUSE) mock_runner.process_action(trl, action) @@ -337,8 +357,9 @@ def testHalvingStop(self): # # Provides result in reverse order, killing the last one cur_units = stats[str(1)]["r"] for i, trl in reversed(list(enumerate(big_bracket.current_trials()))): - action = sched.on_trial_result(mock_runner, trl, - result(cur_units, i)) + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i) + ) mock_runner.process_action(trl, action) self.assertEqual(action, TrialScheduler.STOP) @@ -354,8 +375,9 @@ def testStopsLastOne(self): # # Provides result in reverse order, killing the last one cur_units = stats[str(0)]["r"] for i, trl in enumerate(big_bracket.current_trials()): - action = sched.on_trial_result(mock_runner, trl, - result(cur_units, i)) + action = sched.on_trial_result( + mock_runner, trl, result(cur_units, i) + ) mock_runner.process_action(trl, action) self.assertEqual(action, TrialScheduler.STOP) @@ -370,12 +392,18 @@ def testTrialErrored(self): mock_runner._launch_trial(t) sched.on_trial_error(mock_runner, t3) - self.assertEqual(TrialScheduler.PAUSE, - sched.on_trial_result(mock_runner, t1, - result(stats[str(1)]["r"], 10))) - self.assertEqual(TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t2, - result(stats[str(1)]["r"], 10))) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result( + mock_runner, t1, result(stats[str(1)]["r"], 10) + ) + ) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, t2, result(stats[str(1)]["r"], 10) + ) + ) def testTrialErrored2(self): """Check successive halving happened even when last trial failed""" @@ -385,14 +413,16 @@ def testTrialErrored2(self): trials = sched._state["bracket"].current_trials() for t in trials[:-1]: mock_runner._launch_trial(t) - sched.on_trial_result(mock_runner, t, result( - stats[str(1)]["r"], 10)) + sched.on_trial_result( + mock_runner, t, result(stats[str(1)]["r"], 10) + ) mock_runner._launch_trial(trials[-1]) sched.on_trial_error(mock_runner, trials[-1]) self.assertEqual( len(sched._state["bracket"].current_trials()), - self.downscale(stats[str(1)]["n"], sched)) + self.downscale(stats[str(1)]["n"], sched) + ) def testTrialEndedEarly(self): """Check successive halving happened even when one trial failed""" @@ -405,12 +435,18 @@ def testTrialEndedEarly(self): mock_runner._launch_trial(t) sched.on_trial_complete(mock_runner, t3, result(1, 12)) - self.assertEqual(TrialScheduler.PAUSE, - sched.on_trial_result(mock_runner, t1, - result(stats[str(1)]["r"], 10))) - self.assertEqual(TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t2, - result(stats[str(1)]["r"], 10))) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result( + mock_runner, t1, result(stats[str(1)]["r"], 10) + ) + ) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result( + mock_runner, t2, result(stats[str(1)]["r"], 10) + ) + ) def testTrialEndedEarly2(self): """Check successive halving happened even when last trial failed""" @@ -420,14 +456,16 @@ def testTrialEndedEarly2(self): trials = sched._state["bracket"].current_trials() for t in trials[:-1]: mock_runner._launch_trial(t) - sched.on_trial_result(mock_runner, t, result( - stats[str(1)]["r"], 10)) + sched.on_trial_result( + mock_runner, t, result(stats[str(1)]["r"], 10) + ) mock_runner._launch_trial(trials[-1]) sched.on_trial_complete(mock_runner, trials[-1], result(100, 12)) self.assertEqual( len(sched._state["bracket"].current_trials()), - self.downscale(stats[str(1)]["n"], sched)) + self.downscale(stats[str(1)]["n"], sched) + ) def testAddAfterHalving(self): stats = self.default_statistics() @@ -440,8 +478,9 @@ def testAddAfterHalving(self): mock_runner._launch_trial(t) for i, t in enumerate(bracket_trials): - action = sched.on_trial_result(mock_runner, t, result( - init_units, i)) + action = sched.on_trial_result( + mock_runner, t, result(init_units, i) + ) self.assertEqual(action, TrialScheduler.CONTINUE) t = Trial("__fake") sched.on_trial_add(None, t) @@ -449,13 +488,15 @@ def testAddAfterHalving(self): self.assertEqual(len(sched._state["bracket"].current_trials()), 2) # Make sure that newly added trial gets fair computation (not just 1) - self.assertEqual(TrialScheduler.CONTINUE, - sched.on_trial_result(mock_runner, t, - result(init_units, 12))) + self.assertEqual( + TrialScheduler.CONTINUE, + sched.on_trial_result(mock_runner, t, result(init_units, 12)) + ) new_units = init_units + int(init_units * sched._eta) - self.assertEqual(TrialScheduler.PAUSE, - sched.on_trial_result(mock_runner, t, - result(new_units, 12))) + self.assertEqual( + TrialScheduler.PAUSE, + sched.on_trial_result(mock_runner, t, result(new_units, 12)) + ) def testAlternateMetrics(self): """Checking that alternate metrics will pass.""" @@ -464,7 +505,8 @@ def result2(t, rew): return TrainingResult(time_total_s=t, neg_mean_loss=rew) sched = HyperBandScheduler( - time_attr='time_total_s', reward_attr='neg_mean_loss') + time_attr='time_total_s', reward_attr='neg_mean_loss' + ) stats = self.default_statistics() for i in range(stats["max_trials"]): @@ -570,7 +612,8 @@ def basicSetup(self, resample_prob=0.0, explore=None): "float_factor": lambda: 100.0, "int_factor": lambda: 10, }, - custom_explore_fn=explore) + custom_explore_fn=explore + ) runner = _MockTrialRunner(pbt) for i in range(5): trial = _MockTrial( @@ -579,12 +622,14 @@ def basicSetup(self, resample_prob=0.0, explore=None): "float_factor": 2.0, "const_factor": 3, "int_factor": 10 - }) + } + ) runner.add_trial(trial) trial.status = Trial.RUNNING self.assertEqual( pbt.on_trial_result(runner, trial, result(10, 50 * i)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) pbt.reset_stats() return pbt, runner @@ -596,26 +641,30 @@ def testCheckpointsMostPromisingTrials(self): self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) self.assertEqual( pbt.on_trial_result(runner, trials[0], result(15, 200)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) self.assertEqual(pbt._num_checkpoints, 0) # checkpoint: both past interval and upper quantile self.assertEqual( pbt.on_trial_result(runner, trials[0], result(20, 200)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200]) self.assertEqual(pbt._num_checkpoints, 1) self.assertEqual( pbt.on_trial_result(runner, trials[1], result(30, 201)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200]) self.assertEqual(pbt._num_checkpoints, 2) # not upper quantile any more self.assertEqual( pbt.on_trial_result(runner, trials[4], result(30, 199)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt._num_checkpoints, 2) self.assertEqual(pbt._num_perturbations, 0) @@ -626,7 +675,8 @@ def testPerturbsLowPerformingTrials(self): # no perturbation: haven't hit next perturbation interval self.assertEqual( pbt.on_trial_result(runner, trials[0], result(15, -100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) self.assertTrue("@perturbed" not in trials[0].experiment_tag) self.assertEqual(pbt._num_perturbations, 0) @@ -634,7 +684,8 @@ def testPerturbsLowPerformingTrials(self): # perturb since it's lower quantile self.assertEqual( pbt.on_trial_result(runner, trials[0], result(20, -100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200]) self.assertTrue("@perturbed" in trials[0].experiment_tag) self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) @@ -643,7 +694,8 @@ def testPerturbsLowPerformingTrials(self): # also perturbed self.assertEqual( pbt.on_trial_result(runner, trials[2], result(20, 40)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200]) self.assertEqual(pbt._num_perturbations, 2) self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) @@ -654,7 +706,8 @@ def testPerturbWithoutResample(self): trials = runner.get_trials() self.assertEqual( pbt.on_trial_result(runner, trials[0], result(20, -100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) self.assertIn(trials[0].config["id_factor"], [100]) self.assertIn(trials[0].config["float_factor"], [2.4, 1.6]) @@ -668,7 +721,8 @@ def testPerturbWithResample(self): trials = runner.get_trials() self.assertEqual( pbt.on_trial_result(runner, trials[0], result(20, -100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) self.assertEqual(trials[0].config["id_factor"], 100) self.assertEqual(trials[0].config["float_factor"], 100.0) @@ -726,7 +780,8 @@ def testYieldsTimeToOtherTrials(self): self.assertEqual( pbt.on_trial_result(runner, trials[1], result(20, 1000)), - TrialScheduler.PAUSE) + TrialScheduler.PAUSE + ) self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200]) self.assertEqual(pbt.choose_trial_to_run(runner), trials[0]) @@ -767,7 +822,8 @@ def explore(new_config): trials = runner.get_trials() self.assertEqual( pbt.on_trial_result(runner, trials[0], result(20, -100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual(trials[0].config["id_factor"], 42) self.assertEqual(trials[0].config["float_factor"], 43) @@ -788,11 +844,13 @@ def basicSetup(self, scheduler): for i in range(10): self.assertEqual( scheduler.on_trial_result(None, t1, result(i, i * 100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) for i in range(5): self.assertEqual( scheduler.on_trial_result(None, t2, result(i, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) return t1, t2 def testAsyncHBOnComplete(self): @@ -803,11 +861,13 @@ def testAsyncHBOnComplete(self): scheduler.on_trial_complete(None, t3, result(10, 1000)) self.assertEqual( scheduler.on_trial_result(None, t2, result(101, 0)), - TrialScheduler.STOP) + TrialScheduler.STOP + ) def testAsyncHBGracePeriod(self): scheduler = AsyncHyperBandScheduler( - grace_period=2.5, reduction_factor=3, brackets=1) + grace_period=2.5, reduction_factor=3, brackets=1 + ) t1, t2 = self.basicSetup(scheduler) scheduler.on_trial_complete(None, t1, result(10, 1000)) scheduler.on_trial_complete(None, t2, result(10, 1000)) @@ -815,13 +875,16 @@ def testAsyncHBGracePeriod(self): scheduler.on_trial_add(None, t3) self.assertEqual( scheduler.on_trial_result(None, t3, result(1, 10)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( scheduler.on_trial_result(None, t3, result(2, 10)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( scheduler.on_trial_result(None, t3, result(3, 10)), - TrialScheduler.STOP) + TrialScheduler.STOP + ) def testAsyncHBAllCompletes(self): scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10) @@ -832,11 +895,13 @@ def testAsyncHBAllCompletes(self): for t in trials: self.assertEqual( scheduler.on_trial_result(None, t, result(10, -2)), - TrialScheduler.STOP) + TrialScheduler.STOP + ) def testAsyncHBUsesPercentile(self): scheduler = AsyncHyperBandScheduler( - grace_period=1, max_t=10, reduction_factor=2, brackets=1) + grace_period=1, max_t=10, reduction_factor=2, brackets=1 + ) t1, t2 = self.basicSetup(scheduler) scheduler.on_trial_complete(None, t1, result(10, 1000)) scheduler.on_trial_complete(None, t2, result(10, 1000)) @@ -844,10 +909,12 @@ def testAsyncHBUsesPercentile(self): scheduler.on_trial_add(None, t3) self.assertEqual( scheduler.on_trial_result(None, t3, result(1, 260)), - TrialScheduler.STOP) + TrialScheduler.STOP + ) self.assertEqual( scheduler.on_trial_result(None, t3, result(2, 260)), - TrialScheduler.STOP) + TrialScheduler.STOP + ) def testAlternateMetrics(self): def result2(t, rew): @@ -857,7 +924,8 @@ def result2(t, rew): grace_period=1, time_attr='training_iteration', reward_attr='neg_mean_loss', - brackets=1) + brackets=1 + ) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 scheduler.on_trial_add(None, t1) @@ -865,18 +933,22 @@ def result2(t, rew): for i in range(10): self.assertEqual( scheduler.on_trial_result(None, t1, result2(i, i * 100)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) for i in range(5): self.assertEqual( scheduler.on_trial_result(None, t2, result2(i, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) scheduler.on_trial_complete(None, t1, result2(10, 1000)) self.assertEqual( scheduler.on_trial_result(None, t2, result2(5, 450)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) self.assertEqual( scheduler.on_trial_result(None, t2, result2(6, 0)), - TrialScheduler.CONTINUE) + TrialScheduler.CONTINUE + ) if __name__ == "__main__": diff --git a/python/ray/tune/test/tune_server_test.py b/python/ray/tune/test/tune_server_test.py index f80e90bc481c..d25873c0ded0 100644 --- a/python/ray/tune/test/tune_server_test.py +++ b/python/ray/tune/test/tune_server_test.py @@ -89,7 +89,8 @@ def testStopTrial(self): runner.step() all_trials = client.get_all_trials()["trials"] self.assertEqual( - len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1) + len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1 + ) tid = [t for t in all_trials if t["status"] == Trial.RUNNING][0]["id"] client.stop_trial(tid) @@ -97,7 +98,8 @@ def testStopTrial(self): all_trials = client.get_all_trials()["trials"] self.assertEqual( - len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0) + len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0 + ) if __name__ == "__main__": diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 03976300e27e..a34ae588ce47 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -81,7 +81,8 @@ def __init__(self, config=None, registry=None, logger_creator=None): if not os.path.exists(DEFAULT_RESULTS_DIR): os.makedirs(DEFAULT_RESULTS_DIR) self.logdir = tempfile.mkdtemp( - prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) + prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR + ) self._result_logger = UnifiedLogger(self.config, self.logdir, None) self._iteration = 0 @@ -120,7 +121,8 @@ def train(self): if not self._initialize_ok: raise ValueError( - "Trainable initialization failed, see previous errors") + "Trainable initialization failed, see previous errors" + ) start = time.time() result = self._train() @@ -131,8 +133,9 @@ def train(self): time_this_iter = time.time() - start if result.timesteps_this_iter is None: - raise TuneError("Must specify timesteps_this_iter in result", - result) + raise TuneError( + "Must specify timesteps_this_iter in result", result + ) self._time_total += time_this_iter self._timesteps_total += result.timesteps_this_iter @@ -156,7 +159,8 @@ def train(self): pid=os.getpid(), hostname=os.uname()[1], node_ip=self._local_ip, - config=self.config) + config=self.config + ) self._result_logger.on_result(result) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 9d12e768ce8d..2f1cf9326270 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -28,7 +28,8 @@ def date_str(): class Resources( - namedtuple("Resources", ["cpu", "gpu", "extra_cpu", "extra_gpu"])): + namedtuple("Resources", ["cpu", "gpu", "extra_cpu", "extra_gpu"]) +): """Ray resources required to schedule a trial. Attributes: @@ -42,12 +43,13 @@ class Resources( __slots__ = () def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0): - return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu, - extra_gpu) + return super(Resources, + cls).__new__(cls, cpu, gpu, extra_cpu, extra_gpu) def summary_string(self): - return "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu, - self.gpu + self.extra_gpu) + return "{} CPUs, {} GPUs".format( + self.cpu + self.extra_cpu, self.gpu + self.extra_gpu + ) def cpu_total(self): return self.cpu + self.extra_cpu @@ -58,7 +60,8 @@ def gpu_total(self): def has_trainable(trainable_name): return ray.tune.registry._default_registry.contains( - ray.tune.registry.TRAINABLE_CLASS, trainable_name) + ray.tune.registry.TRAINABLE_CLASS, trainable_name + ) class Trial(object): @@ -77,17 +80,19 @@ class Trial(object): TERMINATED = "TERMINATED" ERROR = "ERROR" - def __init__(self, - trainable_name, - config=None, - local_dir=DEFAULT_RESULTS_DIR, - experiment_tag="", - resources=None, - stopping_criterion=None, - checkpoint_freq=0, - restore_path=None, - upload_dir=None, - max_failures=0): + def __init__( + self, + trainable_name, + config=None, + local_dir=DEFAULT_RESULTS_DIR, + experiment_tag="", + resources=None, + stopping_criterion=None, + checkpoint_freq=0, + restore_path=None, + upload_dir=None, + max_failures=0 + ): """Initialize a new trial. The args here take the same meaning as the command line flags defined @@ -105,7 +110,9 @@ def __init__(self, if k not in TrainingResult._fields: raise TuneError( "Stopping condition key `{}` must be one of {}".format( - k, TrainingResult._fields)) + k, TrainingResult._fields + ) + ) # Trial config self.trainable_name = trainable_name @@ -114,7 +121,8 @@ def __init__(self, self.experiment_tag = experiment_tag self.resources = ( resources - or self._get_trainable_cls().default_resource_request(self.config)) + or self._get_trainable_cls().default_resource_request(self.config) + ) self.stopping_criterion = stopping_criterion or {} self.checkpoint_freq = checkpoint_freq self.upload_dir = upload_dir @@ -174,8 +182,9 @@ def stop(self, error=False, error_msg=None, stop_logger=True): try: if error_msg and self.logdir: self.num_failures += 1 - error_file = os.path.join(self.logdir, "error_{}.txt".format( - date_str())) + error_file = os.path.join( + self.logdir, "error_{}.txt".format(date_str()) + ) with open(error_file, "w") as f: f.write(error_msg) self.error_file = error_file @@ -184,10 +193,11 @@ def stop(self, error=False, error_msg=None, stop_logger=True): stop_tasks.append(self.runner.stop.remote()) stop_tasks.append( self.runner.__ray_terminate__.remote( - self.runner._ray_actor_id.id())) + self.runner._ray_actor_id.id() + ) + ) # TODO(ekl) seems like wait hangs when killing actors - _, unfinished = ray.wait( - stop_tasks, num_returns=2, timeout=250) + _, unfinished = ray.wait(stop_tasks, num_returns=2, timeout=250) except Exception: print("Error stopping runner:", traceback.format_exc()) self.status = Trial.ERROR @@ -261,30 +271,40 @@ def location_string(hostname, pid): return '{} pid={}'.format(hostname, pid) pieces = [ - '{} [{}]'.format(self._status_string(), - location_string(self.last_result.hostname, - self.last_result.pid)), - '{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format( - int(self.last_result.timesteps_total)) + '{} [{}]'.format( + self._status_string(), + location_string( + self.last_result.hostname, self.last_result.pid + ) + ), '{} s'.format(int(self.last_result.time_total_s)), + '{} ts'.format(int(self.last_result.timesteps_total)) ] if self.last_result.episode_reward_mean is not None: - pieces.append('{} rew'.format( - format(self.last_result.episode_reward_mean, '.3g'))) + pieces.append( + '{} rew'.format( + format(self.last_result.episode_reward_mean, '.3g') + ) + ) if self.last_result.mean_loss is not None: - pieces.append('{} loss'.format( - format(self.last_result.mean_loss, '.3g'))) + pieces.append( + '{} loss'.format(format(self.last_result.mean_loss, '.3g')) + ) if self.last_result.mean_accuracy is not None: - pieces.append('{} acc'.format( - format(self.last_result.mean_accuracy, '.3g'))) + pieces.append( + '{} acc'.format(format(self.last_result.mean_accuracy, '.3g')) + ) return ', '.join(pieces) def _status_string(self): - return "{}{}".format(self.status, ", {} failures: {}".format( - self.num_failures, self.error_file) if self.error_file else "") + return "{}{}".format( + self.status, + ", {} failures: {}".format(self.num_failures, self.error_file) + if self.error_file else "" + ) def has_checkpoint(self): return self._checkpoint_path is not None or \ @@ -342,8 +362,9 @@ def restore_from_obj(self, obj): def update_last_result(self, result, terminate=False): if terminate: result = result._replace(done=True) - if self.verbose and (terminate or time.time() - self.last_debug > - DEBUG_PRINT_INTERVAL): + if self.verbose and ( + terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL + ): print("TrainingResult for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() @@ -353,17 +374,22 @@ def update_last_result(self, result, terminate=False): def _setup_runner(self): self.status = Trial.RUNNING cls = ray.remote( - num_cpus=self.resources.cpu, - num_gpus=self.resources.gpu)(self._get_trainable_cls()) + num_cpus=self.resources.cpu, num_gpus=self.resources.gpu + )( + self._get_trainable_cls() + ) if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) self.logdir = tempfile.mkdtemp( prefix="{}_{}".format( - str(self)[:MAX_LEN_IDENTIFIER], date_str()), - dir=self.local_dir) - self.result_logger = UnifiedLogger(self.config, self.logdir, - self.upload_dir) + str(self)[:MAX_LEN_IDENTIFIER], date_str() + ), + dir=self.local_dir + ) + self.result_logger = UnifiedLogger( + self.config, self.logdir, self.upload_dir + ) remote_logdir = self.logdir def logger_creator(config): @@ -378,11 +404,13 @@ def logger_creator(config): self.runner = cls.remote( config=self.config, registry=ray.tune.registry.get_registry(), - logger_creator=logger_creator) + logger_creator=logger_creator + ) def _get_trainable_cls(self): return ray.tune.registry.get_registry().get( - ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) + ray.tune.registry.TRAINABLE_CLASS, self.trainable_name + ) def set_verbose(self, verbose): self.verbose = verbose @@ -396,8 +424,7 @@ def __repr__(self): def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.""" if "env" in self.config: - identifier = "{}_{}".format(self.trainable_name, - self.config["env"]) + identifier = "{}_{}".format(self.trainable_name, self.config["env"]) else: identifier = self.trainable_name if self.experiment_tag: diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 56474c119495..90b9fcc392bd 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -38,12 +38,14 @@ class TrialRunner(object): misleading benchmark results. """ - def __init__(self, - scheduler=None, - launch_web_server=False, - server_port=TuneServer.DEFAULT_PORT, - verbose=True, - queue_trials=False): + def __init__( + self, + scheduler=None, + launch_web_server=False, + server_port=TuneServer.DEFAULT_PORT, + verbose=True, + queue_trials=False + ): """Initializes a new TrialRunner. Args: @@ -68,7 +70,8 @@ def __init__(self, # For debugging, it may be useful to halt trials after some time has # elapsed. TODO(ekl) consider exposing this in the API. self._global_time_limit = float( - os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf'))) + os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf')) + ) self._total_time = 0 self._server = None if launch_web_server: @@ -81,8 +84,11 @@ def is_finished(self): """Returns whether all trials have finished running.""" if self._total_time > self._global_time_limit: - print("Exceeded global time limit {} / {}".format( - self._total_time, self._global_time_limit)) + print( + "Exceeded global time limit {} / {}".format( + self._total_time, self._global_time_limit + ) + ) return True for t in self._trials: @@ -105,21 +111,25 @@ def step(self): for trial in self._trials: if trial.status == Trial.PENDING: if not self.has_resources(trial.resources): - raise TuneError( - ("Insufficient cluster resources to launch trial: " - "trial requested {} but the cluster only has {} " - "available. Pass `queue_trials=True` in " - "ray.tune.run_experiments() or on the command " - "line to queue trials until the cluster scales " - "up. {}").format( - trial.resources.summary_string(), - self._avail_resources.summary_string(), - trial._get_trainable_cls().resource_help( - trial.config))) + raise TuneError(( + "Insufficient cluster resources to launch trial: " + "trial requested {} but the cluster only has {} " + "available. Pass `queue_trials=True` in " + "ray.tune.run_experiments() or on the command " + "line to queue trials until the cluster scales " + "up. {}" + ).format( + trial.resources.summary_string(), + self._avail_resources.summary_string(), + trial._get_trainable_cls().resource_help( + trial.config + ) + )) elif trial.status == Trial.PAUSED: raise TuneError( "There are paused trials, but no more pending " - "trials with sufficient resources.") + "trials with sufficient resources." + ) raise TuneError("Called step when all trials finished?") if self._server: @@ -181,7 +191,8 @@ def debug_string(self, max_debug=MAX_DEBUG_TRIALS): messages.append(" - {}:\t{}".format(t, t.progress_string())) if len(trials) > limit: messages.append( - " ... {} more not shown".format(len(trials) - limit)) + " ... {} more not shown".format(len(trials) - limit) + ) return "\n".join(messages) + "\n" def _debug_messages(self): @@ -191,7 +202,9 @@ def _debug_messages(self): messages.append( "Resources requested: {}/{} CPUs, {}/{} GPUs".format( self._committed_resources.cpu, self._avail_resources.cpu, - self._committed_resources.gpu, self._avail_resources.gpu)) + self._committed_resources.gpu, self._avail_resources.gpu + ) + ) return messages def has_resources(self, resources): @@ -200,8 +213,10 @@ def has_resources(self, resources): cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu - have_space = (resources.cpu_total() <= cpu_avail - and resources.gpu_total() <= gpu_avail) + have_space = ( + resources.cpu_total() <= cpu_avail + and resources.gpu_total() <= gpu_avail + ) if have_space: return True @@ -209,16 +224,18 @@ def has_resources(self, resources): can_overcommit = self._queue_trials if ((resources.cpu_total() > 0 and cpu_avail <= 0) - or (resources.gpu_total() > 0 and gpu_avail <= 0)): + or (resources.gpu_total() > 0 and gpu_avail <= 0)): can_overcommit = False # requested resource is already saturated if can_overcommit: - print("WARNING:tune:allowing trial to start even though the " - "cluster does not have enough free resources. Trial actors " - "may appear to hang until enough resources are added to the " - "cluster (e.g., via autoscaling). You can disable this " - "behavior by specifying `queue_trials=False` in " - "ray.tune.run_experiments().") + print( + "WARNING:tune:allowing trial to start even though the " + "cluster does not have enough free resources. Trial actors " + "may appear to hang until enough resources are added to the " + "cluster (e.g., via autoscaling). You can disable this " + "behavior by specifying `queue_trials=False` in " + "ray.tune.run_experiments()." + ) return True return False @@ -261,9 +278,11 @@ def _process_events(self): decision = TrialScheduler.STOP else: decision = self._scheduler_alg.on_trial_result( - self, trial, result) + self, trial, result + ) trial.update_last_result( - result, terminate=(decision == TrialScheduler.STOP)) + result, terminate=(decision == TrialScheduler.STOP) + ) if decision == TrialScheduler.CONTINUE: if trial.should_checkpoint(): @@ -275,8 +294,7 @@ def _process_events(self): elif decision == TrialScheduler.STOP: self._stop_trial(trial) else: - assert False, "Invalid scheduling decision: {}".format( - decision) + assert False, "Invalid scheduling decision: {}".format(decision) except Exception: error_msg = traceback.format_exc() print("Error processing event:", error_msg) @@ -303,12 +321,14 @@ def _try_recover(self, trial, error_msg): def _commit_resources(self, resources): self._committed_resources = Resources( self._committed_resources.cpu + resources.cpu_total(), - self._committed_resources.gpu + resources.gpu_total()) + self._committed_resources.gpu + resources.gpu_total() + ) def _return_resources(self, resources): self._committed_resources = Resources( self._committed_resources.cpu - resources.cpu_total(), - self._committed_resources.gpu - resources.gpu_total()) + self._committed_resources.gpu - resources.gpu_total() + ) assert self._committed_resources.cpu >= 0 assert self._committed_resources.gpu >= 0 @@ -336,9 +356,8 @@ def stop_trial(self, trial): self._scheduler_alg.on_trial_remove(self, trial) elif trial.status is Trial.RUNNING: # NOTE: There should only be one... - result_id = [ - rid for rid, t in self._running.items() if t is trial - ][0] + result_id = [rid for rid, t in self._running.items() + if t is trial][0] self._running.pop(result_id) try: result = ray.get(result_id) diff --git a/python/ray/tune/trial_scheduler.py b/python/ray/tune/trial_scheduler.py index e8529878632e..0d17fcdd5d72 100644 --- a/python/ray/tune/trial_scheduler.py +++ b/python/ray/tune/trial_scheduler.py @@ -99,12 +99,16 @@ def on_trial_remove(self, trial_runner, trial): def choose_trial_to_run(self, trial_runner): for trial in trial_runner.get_trials(): - if (trial.status == Trial.PENDING - and trial_runner.has_resources(trial.resources)): + if ( + trial.status == Trial.PENDING + and trial_runner.has_resources(trial.resources) + ): return trial for trial in trial_runner.get_trials(): - if (trial.status == Trial.PAUSED - and trial_runner.has_resources(trial.resources)): + if ( + trial.status == Trial.PAUSED + and trial_runner.has_resources(trial.resources) + ): return trial return None diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index b81cdc9b67a3..3fdb8ca64ade 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -29,16 +29,21 @@ def _make_scheduler(args): if args.scheduler in _SCHEDULERS: return _SCHEDULERS[args.scheduler](**args.scheduler_config) else: - raise TuneError("Unknown scheduler: {}, should be one of {}".format( - args.scheduler, _SCHEDULERS.keys())) - - -def run_experiments(experiments, - scheduler=None, - with_server=False, - server_port=TuneServer.DEFAULT_PORT, - verbose=True, - queue_trials=False): + raise TuneError( + "Unknown scheduler: {}, should be one of {}".format( + args.scheduler, _SCHEDULERS.keys() + ) + ) + + +def run_experiments( + experiments, + scheduler=None, + with_server=False, + server_port=TuneServer.DEFAULT_PORT, + verbose=True, + queue_trials=False +): """Tunes experiments. Args: @@ -64,7 +69,8 @@ def run_experiments(experiments, launch_web_server=with_server, server_port=server_port, verbose=verbose, - queue_trials=queue_trials) + queue_trials=queue_trials + ) exp_list = experiments if isinstance(experiments, Experiment): exp_list = [experiments] @@ -74,8 +80,10 @@ def run_experiments(experiments, for name, spec in experiments.items() ] - if (type(exp_list) is list - and all(isinstance(exp, Experiment) for exp in exp_list)): + if ( + type(exp_list) is list + and all(isinstance(exp, Experiment) for exp in exp_list) + ): for experiment in exp_list: scheduler.add_experiment(experiment, runner) else: diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 45718916acbb..ed1a14155f33 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -24,8 +24,10 @@ def pin_in_object_store(obj): obj_id = ray.put(_to_pinnable(obj)) _pinned_objects.append(ray.get(obj_id)) - return "{}{}".format(PINNED_OBJECT_PREFIX, - base64.b64encode(obj_id.id()).decode("utf-8")) + return "{}{}".format( + PINNED_OBJECT_PREFIX, + base64.b64encode(obj_id.id()).decode("utf-8") + ) def get_pinned_object(pinned_id): @@ -41,7 +43,9 @@ def get_pinned_object(pinned_id): return _from_pinnable( ray.get( - ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):])))) + ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):])) + ) + ) def _serve_get_pin_requests(): diff --git a/python/ray/tune/variant_generator.py b/python/ray/tune/variant_generator.py index 0d8957c411c1..c6f94c35c790 100644 --- a/python/ray/tune/variant_generator.py +++ b/python/ray/tune/variant_generator.py @@ -67,7 +67,8 @@ def generate_trials(unresolved_spec, output_path=''): checkpoint_freq=args.checkpoint_freq, restore_path=spec.get("restore"), upload_dir=args.upload_dir, - max_failures=args.max_failures) + max_failures=args.max_failures + ) def generate_variants(unresolved_spec): @@ -167,12 +168,15 @@ def _generate_variants(spec): for path, value in grid_vars: resolved_vars[path] = _get_value(spec, path) for k, v in resolved.items(): - if (k in resolved_vars and v != resolved_vars[k] - and _is_resolved(resolved_vars[k])): + if ( + k in resolved_vars and v != resolved_vars[k] + and _is_resolved(resolved_vars[k]) + ): raise ValueError( "The variable `{}` could not be unambiguously " "resolved to a single value. Consider simplifying " - "your variable dependencies.".format(k)) + "your variable dependencies.".format(k) + ) resolved_vars[k] = v yield resolved_vars, spec @@ -255,8 +259,9 @@ def _try_resolve(v): grid_values = v["grid_search"] if not isinstance(grid_values, list): raise TuneError( - "Grid search expected list of values, got: {}".format( - grid_values)) + "Grid search expected list of values, got: {}". + format(grid_values) + ) return False, grid_values return True, v @@ -288,7 +293,8 @@ def __getattribute__(self, item): value = dict.__getattribute__(self, item) if not _is_resolved(value): raise RecursiveDependencyError( - "`{}` recursively depends on {}".format(item, value)) + "`{}` recursively depends on {}".format(item, value) + ) elif isinstance(value, dict): return _UnresolvedAccessGuard(value) else: diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py index 4e9255b23a5e..95a9129a3322 100644 --- a/python/ray/tune/web_server.py +++ b/python/ray/tune/web_server.py @@ -19,8 +19,10 @@ import requests # `requests` is not part of stdlib. except ImportError: requests = None - print("Couldn't import `requests` library. Be sure to install it on" - " the client side.") + print( + "Couldn't import `requests` library. Be sure to install it on" + " the client side." + ) class TuneClient(object): diff --git a/python/ray/utils.py b/python/ray/utils.py index 0ef47daf971f..f31db4918b89 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -46,11 +46,9 @@ def format_error_message(exception_message, task_exception=False): return "\n".join(lines) -def push_error_to_driver(redis_client, - error_type, - message, - driver_id=None, - data=None): +def push_error_to_driver( + redis_client, error_type, message, driver_id=None, data=None +): """Push an error message to the driver to be printed in the background. Args: @@ -67,11 +65,14 @@ def push_error_to_driver(redis_client, driver_id = DRIVER_ID_LENGTH * b"\x00" error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) + redis_client.hmset( + error_key, + { + "type": error_type, + "message": message, + "data": data + } + ) redis_client.rpush("ErrorKeys", error_key) @@ -140,7 +141,8 @@ def hex_to_binary(hex_identifier): FunctionProperties = collections.namedtuple( - "FunctionProperties", ["num_return_vals", "resources", "max_calls"]) + "FunctionProperties", ["num_return_vals", "resources", "max_calls"] +) """FunctionProperties: A named tuple storing remote functions information.""" diff --git a/python/ray/worker.py b/python/ray/worker.py index a12f93a541b1..92d99aac7e35 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -28,8 +28,9 @@ import ray.signature as signature import ray.local_scheduler import ray.plasma -from ray.utils import (FunctionProperties, random_string, binary_to_hex, - is_cython) +from ray.utils import ( + FunctionProperties, random_string, binary_to_hex, is_cython +) # Import flatbuffer bindings. from ray.core.generated.ClientTableData import ClientTableData @@ -107,8 +108,10 @@ class RayTaskError(Exception): def __init__(self, function_name, exception, traceback_str): """Initialize a RayTaskError.""" self.function_name = function_name - if (isinstance(exception, RayGetError) - or isinstance(exception, RayGetArgumentError)): + if ( + isinstance(exception, RayGetError) + or isinstance(exception, RayGetArgumentError) + ): self.exception = exception else: self.exception = None @@ -118,14 +121,20 @@ def __str__(self): """Format a RayTaskError as a string.""" if self.traceback_str is None: # This path is taken if getting the task arguments failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.exception)) + return ( + "Remote function {}{}{} failed with:\n\n{}".format( + colorama.Fore.RED, self.function_name, colorama.Fore.RESET, + self.exception + ) + ) else: # This path is taken if the task execution failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.traceback_str)) + return ( + "Remote function {}{}{} failed with:\n\n{}".format( + colorama.Fore.RED, self.function_name, colorama.Fore.RESET, + self.traceback_str + ) + ) class RayGetError(Exception): @@ -144,11 +153,13 @@ def __init__(self, objectid, task_error): def __str__(self): """Format a RayGetError as a string.""" - return ("Could not get objectid {}. It was created by remote function " - "{}{}{} which failed with:\n\n{}".format( - self.objectid, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) + return ( + "Could not get objectid {}. It was created by remote function " + "{}{}{} which failed with:\n\n{}".format( + self.objectid, colorama.Fore.RED, self.task_error.function_name, + colorama.Fore.RESET, self.task_error + ) + ) class RayGetArgumentError(Exception): @@ -173,13 +184,16 @@ def __init__(self, function_name, argument_index, objectid, task_error): def __str__(self): """Format a RayGetArgumentError as a string.""" - return ("Failed to get objectid {} as argument {} for remote function " - "{}{}{}. It was created by remote function {}{}{} which " - "failed with:\n{}".format( - self.objectid, self.argument_index, colorama.Fore.RED, - self.function_name, colorama.Fore.RESET, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) + return ( + "Failed to get objectid {} as argument {} for remote function " + "{}{}{}. It was created by remote function {}{}{} which " + "failed with:\n{}".format( + self.objectid, self.argument_index, colorama.Fore.RED, + self.function_name, colorama.Fore.RESET, colorama.Fore.RED, + self.task_error.function_name, colorama.Fore.RESET, + self.task_error + ) + ) class Worker(object): @@ -290,52 +304,63 @@ def store_and_register(self, object_id, value, depth=100): counter = 0 while True: if counter == depth: - raise Exception("Ray exceeded the maximum number of classes " - "that it will recursively serialize when " - "attempting to serialize an object of " - "type {}.".format(type(value))) + raise Exception( + "Ray exceeded the maximum number of classes " + "that it will recursively serialize when " + "attempting to serialize an object of " + "type {}.".format(type(value)) + ) counter += 1 try: self.plasma_client.put( value, object_id=pyarrow.plasma.ObjectID(object_id.id()), memcopy_threads=self.memcopy_threads, - serialization_context=self.serialization_context) + serialization_context=self.serialization_context + ) break except pyarrow.SerializationCallbackError as e: try: register_custom_serializer( - type(e.example_object), use_dict=True) - warning_message = ("WARNING: Serializing objects of type " - "{} by expanding them as dictionaries " - "of their fields. This behavior may " - "be incorrect in some cases.".format( - type(e.example_object))) + type(e.example_object), use_dict=True + ) + warning_message = ( + "WARNING: Serializing objects of type " + "{} by expanding them as dictionaries " + "of their fields. This behavior may " + "be incorrect in some cases.".format( + type(e.example_object) + ) + ) print(warning_message) - except (serialization.RayNotDictionarySerializable, - serialization.CloudPickleError, - pickle.pickle.PicklingError, Exception): + except ( + serialization.RayNotDictionarySerializable, + serialization.CloudPickleError, pickle.pickle.PicklingError, + Exception + ): # We also handle generic exceptions here because # cloudpickle can fail with many different types of errors. try: register_custom_serializer( - type(e.example_object), use_pickle=True) - warning_message = ("WARNING: Falling back to " - "serializing objects of type {} by " - "using pickle. This may be " - "inefficient.".format( - type(e.example_object))) + type(e.example_object), use_pickle=True + ) + warning_message = ( + "WARNING: Falling back to " + "serializing objects of type {} by " + "using pickle. This may be " + "inefficient.".format(type(e.example_object)) + ) print(warning_message) except serialization.CloudPickleError: register_custom_serializer( - type(e.example_object), - use_pickle=True, - local=True) - warning_message = ("WARNING: Pickling the class {} " - "failed, so we are using pickle " - "and only registering the class " - "locally.".format( - type(e.example_object))) + type(e.example_object), use_pickle=True, local=True + ) + warning_message = ( + "WARNING: Pickling the class {} " + "failed, so we are using pickle " + "and only registering the class " + "locally.".format(type(e.example_object)) + ) print(warning_message) def put_object(self, object_id, value): @@ -357,16 +382,20 @@ def put_object(self, object_id, value): """ # Make sure that the value is not an object ID. if isinstance(value, ray.local_scheduler.ObjectID): - raise Exception("Calling 'put' on an ObjectID is not allowed " - "(similarly, returning an ObjectID from a remote " - "function is not allowed). If you really want to " - "do this, you can wrap the ObjectID in a list and " - "call 'put' on it (or return it).") + raise Exception( + "Calling 'put' on an ObjectID is not allowed " + "(similarly, returning an ObjectID from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ObjectID in a list and " + "call 'put' on it (or return it)." + ) if isinstance(value, ray.actor.ActorHandleParent): - raise Exception("Calling 'put' on an actor handle is currently " - "not allowed (similarly, returning an actor " - "handle from a remote function is not allowed).") + raise Exception( + "Calling 'put' on an actor handle is currently " + "not allowed (similarly, returning an actor " + "handle from a remote function is not allowed)." + ) # Serialize and put the object in the object store. try: @@ -377,8 +406,10 @@ def put_object(self, object_id, value): # and make sure that the objects are in fact the same. We also # should return an error code to the caller instead of printing a # message. - print("The object with ID {} already exists in the object store." - .format(object_id)) + print( + "The object with ID {} already exists in the object store." + .format(object_id) + ) def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): start_time = time.time() @@ -391,12 +422,14 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): # long time, if the store is blocked, it can block the manager # as well as a consequence. results = [] - for i in range(0, len(object_ids), - ray._config.worker_get_request_size()): + for i in range( + 0, len(object_ids), ray._config.worker_get_request_size() + ): results += self.plasma_client.get( - object_ids[i:( - i + ray._config.worker_get_request_size())], - timeout, self.serialization_context) + object_ids[i: + (i + ray._config.worker_get_request_size())], + timeout, self.serialization_context + ) return results except pyarrow.lib.ArrowInvalid as e: # TODO(ekl): the local scheduler could include relevant @@ -404,7 +437,8 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): invalid_error = RayTaskError( "", None, "Invalid return value: likely worker died or was killed " - "while executing the task.") + "while executing the task." + ) return [invalid_error] * len(object_ids) except pyarrow.DeserializationCallbackError as e: # Wait a little bit for the import thread to import the class. @@ -417,17 +451,20 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): self.lock.acquire() if time.time() - start_time > error_timeout: - warning_message = ("This worker or driver is waiting to " - "receive a class definition so that it " - "can deserialize an object from the " - "object store. This may be fine, or it " - "may be a bug.") + warning_message = ( + "This worker or driver is waiting to " + "receive a class definition so that it " + "can deserialize an object from the " + "object store. This may be fine, or it " + "may be a bug." + ) if not warning_sent: ray.utils.push_error_to_driver( self.redis_client, "wait_for_class", warning_message, - driver_id=self.task_driver_id.id()) + driver_id=self.task_driver_id.id() + ) warning_sent = True def get_object(self, object_ids): @@ -444,19 +481,25 @@ def get_object(self, object_ids): # Make sure that the values are object IDs. for object_id in object_ids: if not isinstance(object_id, ray.local_scheduler.ObjectID): - raise Exception("Attempting to call `get` on the value {}, " - "which is not an ObjectID.".format(object_id)) + raise Exception( + "Attempting to call `get` on the value {}, " + "which is not an ObjectID.".format(object_id) + ) # Do an initial fetch for remote objects. We divide the fetch into # smaller fetches so as to not block the manager for a prolonged period # of time in a single call. plain_object_ids = [ plasma.ObjectID(object_id.id()) for object_id in object_ids ] - for i in range(0, len(object_ids), - ray._config.worker_fetch_request_size()): + for i in range( + 0, len(object_ids), ray._config.worker_fetch_request_size() + ): if not self.use_raylet: - self.plasma_client.fetch(plain_object_ids[i:( - i + ray._config.worker_fetch_request_size())]) + self.plasma_client.fetch( + plain_object_ids[i:( + i + ray._config.worker_fetch_request_size() + )] + ) else: print("plasma_client.fetch has not been implemented yet") @@ -478,13 +521,17 @@ def get_object(self, object_ids): # in case they were evicted since the last fetch. We divide the # fetch into smaller fetches so as to not block the manager for a # prolonged period of time in a single call. - object_ids_to_fetch = list( - map(plasma.ObjectID, unready_ids.keys())) - for i in range(0, len(object_ids_to_fetch), - ray._config.worker_fetch_request_size()): + object_ids_to_fetch = list(map(plasma.ObjectID, unready_ids.keys())) + for i in range( + 0, len(object_ids_to_fetch), + ray._config.worker_fetch_request_size() + ): if not self.use_raylet: - self.plasma_client.fetch(object_ids_to_fetch[i:( - i + ray._config.worker_fetch_request_size())]) + self.plasma_client.fetch( + object_ids_to_fetch[i:( + i + ray._config.worker_fetch_request_size() + )] + ) else: print("plasma_client.fetch has not been implemented yet") results = self.retrieve_and_deserialize( @@ -492,7 +539,8 @@ def get_object(self, object_ids): max([ ray._config.get_timeout_milliseconds(), int(0.01 * len(unready_ids)) - ])) + ]) + ) # Remove any entries for objects we received during this iteration # so we don't retrieve the same object twice. for i, val in enumerate(results): @@ -510,20 +558,22 @@ def get_object(self, object_ids): assert len(final_results) == len(object_ids) return final_results - def submit_task(self, - function_id, - args, - actor_id=None, - actor_handle_id=None, - actor_counter=0, - is_actor_checkpoint_method=False, - actor_creation_id=None, - actor_creation_dummy_object_id=None, - execution_dependencies=None, - num_return_vals=None, - num_cpus=None, - num_gpus=None, - resources=None): + def submit_task( + self, + function_id, + args, + actor_id=None, + actor_handle_id=None, + actor_counter=0, + is_actor_checkpoint_method=False, + actor_creation_id=None, + actor_creation_dummy_object_id=None, + execution_dependencies=None, + num_return_vals=None, + num_cpus=None, + num_gpus=None, + resources=None + ): """Submit a remote task to the scheduler. Tell the scheduler to schedule the execution of the function with ID @@ -560,7 +610,8 @@ def submit_task(self, assert actor_handle_id is None actor_id = ray.local_scheduler.ObjectID(NIL_ACTOR_ID) actor_handle_id = ray.local_scheduler.ObjectID( - NIL_ACTOR_HANDLE_ID) + NIL_ACTOR_HANDLE_ID + ) else: assert actor_handle_id is not None @@ -569,7 +620,8 @@ def submit_task(self, if actor_creation_dummy_object_id is None: actor_creation_dummy_object_id = ( - ray.local_scheduler.ObjectID(NIL_ID)) + ray.local_scheduler.ObjectID(NIL_ID) + ) # Put large or complex arguments that are passed by value in the # object store first. @@ -579,7 +631,8 @@ def submit_task(self, args_for_local_scheduler.append(arg) elif isinstance(arg, ray.actor.ActorHandleParent): args_for_local_scheduler.append( - put(ray.actor.wrap_actor_handle(arg))) + put(ray.actor.wrap_actor_handle(arg)) + ) elif ray.local_scheduler.check_simple_value(arg): args_for_local_scheduler.append(arg) else: @@ -590,8 +643,9 @@ def submit_task(self, execution_dependencies = [] # Look up the various function properties. - function_properties = self.function_properties[ - self.task_driver_id.id()][function_id.id()] + function_properties = self.function_properties[self.task_driver_id. + id()][function_id.id() + ] if num_return_vals is None: num_return_vals = function_properties.num_return_vals @@ -601,20 +655,23 @@ def submit_task(self, else: resources = {} if resources is None else resources if "CPU" in resources or "GPU" in resources: - raise ValueError("The resources dictionary must not " - "contain the keys 'CPU' or 'GPU'") + raise ValueError( + "The resources dictionary must not " + "contain the keys 'CPU' or 'GPU'" + ) resources["CPU"] = num_cpus resources["GPU"] = num_gpus # Submit the task to local scheduler. task = ray.local_scheduler.Task( self.task_driver_id, - ray.local_scheduler.ObjectID( - function_id.id()), args_for_local_scheduler, + ray.local_scheduler.ObjectID(function_id.id() + ), args_for_local_scheduler, num_return_vals, self.current_task_id, self.task_index, actor_creation_id, actor_creation_dummy_object_id, actor_id, actor_handle_id, actor_counter, is_actor_checkpoint_method, - execution_dependencies, resources, self.use_raylet) + execution_dependencies, resources, self.use_raylet + ) # Increment the worker's task index to track how many tasks have # been submitted by the current task so far. self.task_index += 1 @@ -660,11 +717,13 @@ def run_function_on_all_workers(self, function): return # Run the function on all workers. self.redis_client.hmset( - key, { + key, + { "driver_id": self.task_driver_id.id(), "function_id": function_to_run_id, "function": pickled_function - }) + } + ) self.redis_client.rpush("Exports", key) # TODO(rkn): If the worker fails after it calls setnx and before it # successfully completes the hmset and rpush, then the program will @@ -694,23 +753,29 @@ def _wait_for_function(self, function_id, driver_id, timeout=10): warning_sent = False while True: with self.lock: - if (self.actor_id == NIL_ACTOR_ID - and (function_id.id() in self.functions[driver_id])): + if ( + self.actor_id == NIL_ACTOR_ID + and (function_id.id() in self.functions[driver_id]) + ): break elif self.actor_id != NIL_ACTOR_ID and ( - self.actor_id in self.actors): + self.actor_id in self.actors + ): break if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart " - "Ray.") + warning_message = ( + "This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray." + ) if not warning_sent: ray.utils.push_error_to_driver( self.redis_client, "wait_for_function", warning_message, - driver_id=driver_id) + driver_id=driver_id + ) warning_sent = True time.sleep(0.001) @@ -803,22 +868,25 @@ def _process_task(self, task): if task.actor_id().id() != NIL_ACTOR_ID: dummy_return_id = return_object_ids.pop() function_name, function_executor = ( - self.functions[self.task_driver_id.id()][function_id.id()]) + self.functions[self.task_driver_id.id()][function_id.id()] + ) # Get task arguments from the object store. try: with log_span("ray:task:get_arguments", worker=self): arguments = self._get_arguments_for_execution( - function_name, args) + function_name, args + ) except (RayGetError, RayGetArgumentError) as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, None) + self._handle_process_task_failure( + function_id, return_object_ids, e, None + ) return except Exception as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, - ray.utils.format_error_message( - traceback.format_exc())) + self._handle_process_task_failure( + function_id, return_object_ids, e, + ray.utils.format_error_message(traceback.format_exc()) + ) return # Execute the task. @@ -829,15 +897,18 @@ def _process_task(self, task): else: outputs = function_executor( dummy_return_id, self.actors[task.actor_id().id()], - *arguments) + *arguments + ) except Exception as e: # Determine whether the exception occured during a task, not an # actor method. task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( - traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, return_object_ids, - e, traceback_str) + traceback.format_exc(), task_exception=task_exception + ) + self._handle_process_task_failure( + function_id, return_object_ids, e, traceback_str + ) return # Store the outputs in the local object store. @@ -851,15 +922,16 @@ def _process_task(self, task): outputs = (outputs, ) self._store_outputs_in_objstore(return_object_ids, outputs) except Exception as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, - ray.utils.format_error_message( - traceback.format_exc())) - - def _handle_process_task_failure(self, function_id, return_object_ids, - error, backtrace): - function_name, _ = self.functions[self.task_driver_id.id()][ - function_id.id()] + self._handle_process_task_failure( + function_id, return_object_ids, e, + ray.utils.format_error_message(traceback.format_exc()) + ) + + def _handle_process_task_failure( + self, function_id, return_object_ids, error, backtrace + ): + function_name, _ = self.functions[self.task_driver_id.id() + ][function_id.id()] failure_object = RayTaskError(function_name, error, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) @@ -874,7 +946,8 @@ def _handle_process_task_failure(self, function_id, return_object_ids, data={ "function_id": function_id.id(), "function_name": function_name - }) + } + ) def _become_actor(self, task): """Turn this worker into an actor. @@ -910,8 +983,10 @@ def _wait_for_and_process_task(self, task): # TODO(rkn): It would be preferable for actor creation tasks to share # more of the code path with regular task execution. - if (task.actor_creation_id() != - ray.local_scheduler.ObjectID(NIL_ACTOR_ID)): + if ( + task.actor_creation_id() != + ray.local_scheduler.ObjectID(NIL_ACTOR_ID) + ): self._become_actor(task) return @@ -931,7 +1006,8 @@ def _wait_for_and_process_task(self, task): log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self) function_name, _ = ( - self.functions[task.driver_id().id()][function_id.id()]) + self.functions[task.driver_id().id()][function_id.id()] + ) contents = { "function_name": function_name, "task_id": task.task_id().hex(), @@ -944,12 +1020,13 @@ def _wait_for_and_process_task(self, task): flush_log() # Increase the task execution counter. - (self.num_task_executions[task.driver_id().id()][function_id.id()] - ) += 1 + (self.num_task_executions[task.driver_id().id()][function_id.id()]) += 1 - reached_max_executions = (self.num_task_executions[task.driver_id().id( - )][function_id.id()] == self.function_properties[task.driver_id().id()] - [function_id.id()].max_calls) + reached_max_executions = ( + self.num_task_executions[task.driver_id().id()][function_id.id()] == + self.function_properties[task.driver_id().id()][function_id.id() + ].max_calls + ) if reached_max_executions: ray.worker.global_worker.local_scheduler_client.disconnect() os._exit(0) @@ -995,8 +1072,10 @@ def get_gpu_ids(): A list of GPU IDs. """ if _mode() == PYTHON_MODE: - raise Exception("ray.get_gpu_ids() currently does not work in PYTHON " - "MODE.") + raise Exception( + "ray.get_gpu_ids() currently does not work in PYTHON " + "MODE." + ) assigned_ids = global_worker.local_scheduler_client.gpu_ids() # If the user had already set CUDA_VISIBLE_DEVICES, then respect that (in @@ -1032,8 +1111,10 @@ def get_webui_url(): The URL of the web UI as a string. """ if _mode() == PYTHON_MODE: - raise Exception("ray.get_webui_url() currently does not work in " - "PYTHON MODE.") + raise Exception( + "ray.get_webui_url() currently does not work in " + "PYTHON MODE." + ) return _webui_url_helper(global_worker.redis_client) @@ -1059,10 +1140,11 @@ def check_main_thread(): than the main thread. """ if threading.current_thread().getName() != "MainThread": - raise Exception("The Ray methods are not thread safe and must be " - "called from the main thread. This method was called " - "from thread {}." - .format(threading.current_thread().getName())) + raise Exception( + "The Ray methods are not thread safe and must be " + "called from the main thread. This method was called " + "from thread {}.".format(threading.current_thread().getName()) + ) def check_connected(worker=global_worker): @@ -1072,9 +1154,11 @@ def check_connected(worker=global_worker): Exception: An exception is raised if the worker is not connected. """ if not worker.connected: - raise RayConnectionError("This command cannot be called before Ray " - "has been started. You can start Ray with " - "'ray.init()'.") + raise RayConnectionError( + "This command cannot be called before Ray " + "has been started. You can start Ray with " + "'ray.init()'." + ) def print_failed_task(task_status): @@ -1084,28 +1168,35 @@ def print_failed_task(task_status): task_status (Dict): A dictionary containing the name, operationid, and error message for a failed task. """ - print(""" + print( + """ Error: Task failed Function Name: {} Task ID: {} Error Message: \n{} - """.format(task_status["function_name"], task_status["operationid"], - task_status["error_message"])) + """.format( + task_status["function_name"], task_status["operationid"], + task_status["error_message"] + ) + ) def error_applies_to_driver(error_key, worker=global_worker): """Return True if the error is for this driver and false otherwise.""" # TODO(rkn): Should probably check that this is only called on a driver. # Check that the error key is formatted as in push_error_to_driver. - assert len(error_key) == (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + - ERROR_ID_LENGTH), error_key + assert len(error_key) == ( + len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + ERROR_ID_LENGTH + ), error_key # If the driver ID in the error message is a sequence of all zeros, then # the message is intended for all drivers. generic_driver_id = DRIVER_ID_LENGTH * b"\x00" - driver_id = error_key[len(ERROR_KEY_PREFIX):( - len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH)] - return (driver_id == worker.task_driver_id.id() - or driver_id == generic_driver_id) + driver_id = error_key[len(ERROR_KEY_PREFIX): + (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH)] + return ( + driver_id == worker.task_driver_id.id() + or driver_id == generic_driver_id + ) def error_info(worker=global_worker): @@ -1146,7 +1237,8 @@ def objectid_custom_deserializer(serialized_obj): "ray.ObjectID", pickle=False, custom_serializer=objectid_custom_serializer, - custom_deserializer=objectid_custom_deserializer) + custom_deserializer=objectid_custom_deserializer + ) if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # These should only be called on the driver because @@ -1164,24 +1256,25 @@ def objectid_custom_deserializer(serialized_obj): # Tell Ray to serialize FunctionSignatures as dictionaries. This is # used when passing around actor handles. register_custom_serializer( - ray.signature.FunctionSignature, use_dict=True) + ray.signature.FunctionSignature, use_dict=True + ) -def get_address_info_from_redis_helper(redis_address, - node_ip_address, - use_raylet=False): +def get_address_info_from_redis_helper( + redis_address, node_ip_address, use_raylet=False +): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as # Redis) must have run "CONFIG SET protected-mode no". redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) + host=redis_ip_address, port=int(redis_port) + ) if not use_raylet: # The client table prefix must be kept in sync with the file # "src/common/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" - client_keys = redis_client.keys( - "{}*".format(REDIS_CLIENT_TABLE_PREFIX)) + client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) # Filter to live clients on the same node and do some basic checking. plasma_managers = [] local_schedulers = [] @@ -1198,9 +1291,12 @@ def get_address_info_from_redis_helper(redis_address, assert b"node_ip_address" in info assert b"client_type" in info client_node_ip_address = info[b"node_ip_address"].decode("ascii") - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): + if ( + client_node_ip_address == node_ip_address or ( + client_node_ip_address == "127.0.0.1" + and redis_ip_address == ray.services.get_node_ip_address() + ) + ): if info[b"client_type"].decode("ascii") == "plasma_manager": plasma_managers.append(info) elif info[b"client_type"].decode("ascii") == "local_scheduler": @@ -1217,9 +1313,11 @@ def get_address_info_from_redis_helper(redis_address, object_store_addresses.append( services.ObjectStoreAddress( name=manager[b"store_socket_name"].decode("ascii"), - manager_name=manager[b"manager_socket_name"].decode( - "ascii"), - manager_port=port)) + manager_name=manager[b"manager_socket_name"] + .decode("ascii"), + manager_port=port + ) + ) scheduler_names = [ scheduler[b"local_scheduler_socket_name"].decode("ascii") for scheduler in local_schedulers @@ -1242,20 +1340,22 @@ def get_address_info_from_redis_helper(redis_address, clients = redis_client.zrange(client_key, 0, -1) raylets = [] for client_message in clients: - client = ClientTableData.GetRootAsClientTableData( - client_message, 0) - client_node_ip_address = client.NodeManagerAddress().decode( - "ascii") - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): + client = ClientTableData.GetRootAsClientTableData(client_message, 0) + client_node_ip_address = client.NodeManagerAddress().decode("ascii") + if ( + client_node_ip_address == node_ip_address or ( + client_node_ip_address == "127.0.0.1" + and redis_ip_address == ray.services.get_node_ip_address() + ) + ): raylets.append(client) object_store_addresses = [ services.ObjectStoreAddress( name=raylet.ObjectStoreSocketName().decode("ascii"), manager_name=None, - manager_port=None) for raylet in raylets + manager_port=None + ) for raylet in raylets ] raylet_socket_names = [ raylet.RayletSocketName().decode("ascii") for raylet in raylets @@ -1270,29 +1370,32 @@ def get_address_info_from_redis_helper(redis_address, } -def get_address_info_from_redis(redis_address, - node_ip_address, - num_retries=5, - use_raylet=False): +def get_address_info_from_redis( + redis_address, node_ip_address, num_retries=5, use_raylet=False +): counter = 0 while True: try: return get_address_info_from_redis_helper( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, node_ip_address, use_raylet=use_raylet + ) except Exception as e: if counter == num_retries: raise # Some of the information may not be in Redis yet, so wait a little # bit. - print("Some processes that the driver needs to connect to have " - "not registered with Redis, so retrying. Have you run " - "'ray start' on this node?") + print( + "Some processes that the driver needs to connect to have " + "not registered with Redis, so retrying. Have you run " + "'ray start' on this node?" + ) time.sleep(1) counter += 1 -def _normalize_resource_arguments(num_cpus, num_gpus, resources, - num_local_schedulers): +def _normalize_resource_arguments( + num_cpus, num_gpus, resources, num_local_schedulers +): """Stick the CPU and GPU arguments into the resources dictionary. This also checks that the arguments are well-formed. @@ -1329,25 +1432,27 @@ def _normalize_resource_arguments(num_cpus, num_gpus, resources, return new_resources -def _init(address_info=None, - start_ray_local=False, - object_id_seed=None, - num_workers=None, - num_local_schedulers=None, - object_store_memory=None, - driver_mode=SCRIPT_MODE, - redirect_worker_output=False, - redirect_output=True, - start_workers_from_local_scheduler=True, - num_cpus=None, - num_gpus=None, - resources=None, - num_redis_shards=None, - redis_max_clients=None, - plasma_directory=None, - huge_pages=False, - include_webui=True, - use_raylet=False): +def _init( + address_info=None, + start_ray_local=False, + object_id_seed=None, + num_workers=None, + num_local_schedulers=None, + object_store_memory=None, + driver_mode=SCRIPT_MODE, + redirect_worker_output=False, + redirect_output=True, + start_workers_from_local_scheduler=True, + num_cpus=None, + num_gpus=None, + resources=None, + num_redis_shards=None, + redis_max_clients=None, + plasma_directory=None, + huge_pages=False, + include_webui=True, + use_raylet=False +): """Helper method to connect to an existing Ray cluster or start a new one. This method handles two cases. Either a Ray cluster already exists and we @@ -1414,8 +1519,10 @@ def _init(address_info=None, """ check_main_thread() if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: - raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, " - "ray.PYTHON_MODE, ray.SILENT_MODE].") + raise Exception( + "Driver_mode must be in [ray.SCRIPT_MODE, " + "ray.PYTHON_MODE, ray.SILENT_MODE]." + ) # Get addresses of existing services. if address_info is None: @@ -1449,7 +1556,8 @@ def _init(address_info=None, # Stick the CPU and GPU resources into the resource dictionary. resources = _normalize_resource_arguments( - num_cpus, num_gpus, resources, num_local_schedulers) + num_cpus, num_gpus, resources, num_local_schedulers + ) # Start the scheduler, object store, and some workers. These will be # killed by the call to cleanup(), which happens when the Python script @@ -1463,51 +1571,74 @@ def _init(address_info=None, redirect_worker_output=redirect_worker_output, redirect_output=redirect_output, start_workers_from_local_scheduler=( - start_workers_from_local_scheduler), + start_workers_from_local_scheduler + ), resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, - use_raylet=use_raylet) + use_raylet=use_raylet + ) else: if redis_address is None: - raise Exception("When connecting to an existing cluster, " - "redis_address must be provided.") + raise Exception( + "When connecting to an existing cluster, " + "redis_address must be provided." + ) if num_workers is not None: - raise Exception("When connecting to an existing cluster, " - "num_workers must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "num_workers must not be provided." + ) if num_local_schedulers is not None: - raise Exception("When connecting to an existing cluster, " - "num_local_schedulers must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "num_local_schedulers must not be provided." + ) if num_cpus is not None or num_gpus is not None: - raise Exception("When connecting to an existing cluster, num_cpus " - "and num_gpus must not be provided.") + raise Exception( + "When connecting to an existing cluster, num_cpus " + "and num_gpus must not be provided." + ) if resources is not None: - raise Exception("When connecting to an existing cluster, " - "resources must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "resources must not be provided." + ) if num_redis_shards is not None: - raise Exception("When connecting to an existing cluster, " - "num_redis_shards must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "num_redis_shards must not be provided." + ) if redis_max_clients is not None: - raise Exception("When connecting to an existing cluster, " - "redis_max_clients must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "redis_max_clients must not be provided." + ) if object_store_memory is not None: - raise Exception("When connecting to an existing cluster, " - "object_store_memory must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "object_store_memory must not be provided." + ) if plasma_directory is not None: - raise Exception("When connecting to an existing cluster, " - "plasma_directory must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "plasma_directory must not be provided." + ) if huge_pages: - raise Exception("When connecting to an existing cluster, " - "huge_pages must not be provided.") + raise Exception( + "When connecting to an existing cluster, " + "huge_pages must not be provided." + ) # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. address_info = get_address_info_from_redis( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, node_ip_address, use_raylet=use_raylet + ) # Connect this driver to Redis, the object store, and the local scheduler. # Choose the first object store and local scheduler if there are multiple. @@ -1525,39 +1656,45 @@ def _init(address_info=None, } if not use_raylet: driver_address_info["manager_socket_name"] = ( - address_info["object_store_addresses"][0].manager_name) + address_info["object_store_addresses"][0].manager_name + ) driver_address_info["local_scheduler_socket_name"] = ( - address_info["local_scheduler_socket_names"][0]) + address_info["local_scheduler_socket_names"][0] + ) else: driver_address_info["raylet_socket_name"] = ( - address_info["raylet_socket_names"][0]) + address_info["raylet_socket_names"][0] + ) connect( driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, - use_raylet=use_raylet) + use_raylet=use_raylet + ) return address_info -def init(redis_address=None, - node_ip_address=None, - object_id_seed=None, - num_workers=None, - driver_mode=SCRIPT_MODE, - redirect_worker_output=False, - redirect_output=True, - num_cpus=None, - num_gpus=None, - resources=None, - num_custom_resource=None, - num_redis_shards=None, - redis_max_clients=None, - plasma_directory=None, - huge_pages=False, - include_webui=True, - object_store_memory=None, - use_raylet=False): +def init( + redis_address=None, + node_ip_address=None, + object_id_seed=None, + num_workers=None, + driver_mode=SCRIPT_MODE, + redirect_worker_output=False, + redirect_output=True, + num_cpus=None, + num_gpus=None, + resources=None, + num_custom_resource=None, + num_redis_shards=None, + redis_max_clients=None, + plasma_directory=None, + huge_pages=False, + include_webui=True, + object_store_memory=None, + use_raylet=False +): """Connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -1635,7 +1772,8 @@ def init(redis_address=None, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, - use_raylet=use_raylet) + use_raylet=use_raylet + ) def cleanup(worker=global_worker): @@ -1655,8 +1793,9 @@ def cleanup(worker=global_worker): if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # If this is a driver, push the finish time to Redis and clean up any # other services that were started with the driver. - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, - {"end_time": time.time()}) + worker.redis_client.hmset( + b"Drivers:" + worker.worker_id, {"end_time": time.time()} + ) services.cleanup() else: # If this is not a driver, make sure there are no orphan processes, @@ -1681,8 +1820,10 @@ def custom_excepthook(type, value, tb): # If this is a driver, push the exception to redis. if global_worker.mode in [SCRIPT_MODE, SILENT_MODE]: error_message = "".join(traceback.format_tb(tb)) - global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id, - {"exception": error_message}) + global_worker.redis_client.hmset( + b"Drivers:" + global_worker.worker_id, + {"exception": error_message} + ) # Call the normal excepthook. normal_excepthook(type, value, tb) @@ -1722,8 +1863,8 @@ def print_error_messages(worker): error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) for error_key in error_keys: if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget( - error_key, "message").decode("ascii") + error_message = worker.redis_client.hget(error_key, "message" + ).decode("ascii") print(error_message) print(helpful_message) num_errors_received += 1 @@ -1732,10 +1873,12 @@ def print_error_messages(worker): for msg in worker.error_message_pubsub_client.listen(): with worker.lock: for error_key in worker.redis_client.lrange( - "ErrorKeys", num_errors_received, -1): + "ErrorKeys", num_errors_received, -1 + ): if error_applies_to_driver(error_key, worker=worker): error_message = worker.redis_client.hget( - error_key, "message").decode("ascii") + error_key, "message" + ).decode("ascii") print(error_message) print(helpful_message) num_errors_received += 1 @@ -1747,18 +1890,22 @@ def print_error_messages(worker): def fetch_and_register_remote_function(key, worker=global_worker): """Import a remote function.""" - (driver_id, function_id_str, function_name, serialized_function, - num_return_vals, module, resources, - max_calls) = worker.redis_client.hmget(key, [ - "driver_id", "function_id", "name", "function", "num_return_vals", - "module", "resources", "max_calls" - ]) + ( + driver_id, function_id_str, function_name, serialized_function, + num_return_vals, module, resources, max_calls + ) = worker.redis_client.hmget( + key, [ + "driver_id", "function_id", "name", "function", "num_return_vals", + "module", "resources", "max_calls" + ] + ) function_id = ray.local_scheduler.ObjectID(function_id_str) function_name = function_name.decode("ascii") function_properties = FunctionProperties( num_return_vals=int(num_return_vals), resources=json.loads(resources.decode("ascii")), - max_calls=int(max_calls)) + max_calls=int(max_calls) + ) module = module.decode("ascii") # This is a placeholder in case the function can't be unpickled. This will @@ -1767,10 +1914,11 @@ def f(): raise Exception("This function was not imported properly.") remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f()) - worker.functions[driver_id][function_id.id()] = (function_name, - remote_f_placeholder) + worker.functions[driver_id][function_id.id() + ] = (function_name, remote_f_placeholder) worker.function_properties[driver_id][function_id.id()] = ( - function_properties) + function_properties + ) worker.num_task_executions[driver_id][function_id.id()] = 0 try: @@ -1788,24 +1936,30 @@ def f(): data={ "function_id": function_id.id(), "function_name": function_name - }) + } + ) else: # TODO(rkn): Why is the below line necessary? function.__module__ = module - worker.functions[driver_id][function_id.id()] = ( - function_name, remote(function_id=function_id)(function)) + worker.functions[driver_id][ + function_id.id() + ] = (function_name, remote(function_id=function_id)(function)) # Add the function to the function table. - worker.redis_client.rpush(b"FunctionTable:" + function_id.id(), - worker.worker_id) + worker.redis_client.rpush( + b"FunctionTable:" + function_id.id(), worker.worker_id + ) def fetch_and_execute_function_to_run(key, worker=global_worker): """Run on arbitrary function on the worker.""" driver_id, serialized_function = worker.redis_client.hmget( - key, ["driver_id", "function"]) + key, ["driver_id", "function"] + ) - if (worker.mode in [SCRIPT_MODE, SILENT_MODE] - and driver_id != worker.task_driver_id.id()): + if ( + worker.mode in [SCRIPT_MODE, SILENT_MODE] + and driver_id != worker.task_driver_id.id() + ): # This export was from a different driver and there's no need for this # driver to import it. return @@ -1820,14 +1974,16 @@ def fetch_and_execute_function_to_run(key, worker=global_worker): # traceback and notify the scheduler of the failure. traceback_str = traceback.format_exc() # Log the error message. - name = function.__name__ if ("function" in locals() - and hasattr(function, "__name__")) else "" + name = function.__name__ if ( + "function" in locals() and hasattr(function, "__name__") + ) else "" ray.utils.push_error_to_driver( worker.redis_client, "function_to_run", traceback_str, driver_id=driver_id, - data={"name": name}) + data={"name": name} + ) def import_thread(worker, mode): @@ -1881,24 +2037,29 @@ def import_thread(worker, mode): if mode != WORKER_MODE: if key.startswith(b"FunctionsToRun"): with log_span( - "ray:import_function_to_run", - worker=worker): + "ray:import_function_to_run", worker=worker + ): fetch_and_execute_function_to_run( - key, worker=worker) + key, worker=worker + ) # Continue because FunctionsToRun are the only things # that the driver should import. continue if key.startswith(b"RemoteFunction"): with log_span( - "ray:import_remote_function", worker=worker): + "ray:import_remote_function", worker=worker + ): fetch_and_register_remote_function( - key, worker=worker) + key, worker=worker + ) elif key.startswith(b"FunctionsToRun"): with log_span( - "ray:import_function_to_run", worker=worker): + "ray:import_function_to_run", worker=worker + ): fetch_and_execute_function_to_run( - key, worker=worker) + key, worker=worker + ) elif key.startswith(b"ActorClass"): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this @@ -1915,11 +2076,13 @@ def import_thread(worker, mode): pass -def connect(info, - object_id_seed=None, - mode=WORKER_MODE, - worker=global_worker, - use_raylet=False): +def connect( + info, + object_id_seed=None, + mode=WORKER_MODE, + worker=global_worker, + use_raylet=False +): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -1963,7 +2126,8 @@ def connect(info, # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) + host=redis_ip_address, port=int(redis_port) + ) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -1978,7 +2142,8 @@ def connect(info, worker.redis_client, "version_mismatch", traceback_str, - driver_id=None) + driver_id=None + ) worker.lock = threading.Lock() @@ -1987,19 +2152,23 @@ def connect(info, if mode == WORKER_MODE: # This key is set in services.py when Redis is started. redirect_worker_output_val = worker.redis_client.get("RedirectOutput") - if (redirect_worker_output_val is not None - and int(redirect_worker_output_val) == 1): + if ( + redirect_worker_output_val is not None + and int(redirect_worker_output_val) == 1 + ): redirect_worker_output = 1 else: redirect_worker_output = 0 if redirect_worker_output: log_stdout_file, log_stderr_file = services.new_log_files( - "worker", True) + "worker", True + ) sys.stdout = log_stdout_file sys.stderr = log_stderr_file services.record_log_files_in_redis( info["redis_address"], info["node_ip_address"], - [log_stdout_file, log_stderr_file]) + [log_stdout_file, log_stderr_file] + ) # Create an object for interfacing with the global state. global_state._initialize_global_state(redis_ip_address, int(redis_port)) @@ -2018,8 +2187,9 @@ def connect(info, "local_scheduler_socket": info.get("local_scheduler_socket_name"), "raylet_socket": info.get("raylet_socket_name") } - driver_info["name"] = (main.__file__ if hasattr(main, "__file__") else - "INTERACTIVE MODE") + driver_info["name"] = ( + main.__file__ if hasattr(main, "__file__") else "INTERACTIVE MODE" + ) worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) if not worker.redis_client.exists("webui"): worker.redis_client.hmset("webui", {"url": info["webui_url"]}) @@ -2042,11 +2212,11 @@ def connect(info, # Create an object store client. if not worker.use_raylet: - worker.plasma_client = plasma.connect(info["store_socket_name"], - info["manager_socket_name"], 64) + worker.plasma_client = plasma.connect( + info["store_socket_name"], info["manager_socket_name"], 64 + ) else: - worker.plasma_client = plasma.connect(info["store_socket_name"], "", - 64) + worker.plasma_client = plasma.connect(info["store_socket_name"], "", 64) if not worker.use_raylet: local_scheduler_socket = info["local_scheduler_socket_name"] @@ -2054,7 +2224,8 @@ def connect(info, local_scheduler_socket = info["raylet_socket_name"] worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( - local_scheduler_socket, worker.worker_id, is_worker) + local_scheduler_socket, worker.worker_id, is_worker + ) # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -2070,7 +2241,8 @@ def connect(info, # Try to use true randomness. np.random.seed(None) worker.current_task_id = ray.local_scheduler.ObjectID( - np.random.bytes(20)) + np.random.bytes(20) + ) # When tasks are executed on remote workers in the context of multiple # drivers, the task driver ID is used to keep track of which driver is # responsible for the task so that error messages will be propagated to @@ -2098,13 +2270,15 @@ def connect(info, ray.local_scheduler.ObjectID(NIL_ACTOR_ID), ray.local_scheduler.ObjectID(NIL_ACTOR_ID), ray.local_scheduler.ObjectID(NIL_ACTOR_ID), nil_actor_counter, - False, [], {"CPU": 0}, worker.use_raylet) + False, [], {"CPU": 0}, worker.use_raylet + ) global_state._execute_command( driver_task.task_id(), "RAY.TASK_TABLE_ADD", driver_task.task_id().id(), TASK_STATUS_RUNNING, NIL_LOCAL_SCHEDULER_ID, driver_task.execution_dependencies_string(), 0, - ray.local_scheduler.task_to_string(driver_task)) + ray.local_scheduler.task_to_string(driver_task) + ) # Set the driver's current task ID to the task ID assigned to the # driver task. worker.current_task_id = driver_task.task_id() @@ -2143,9 +2317,11 @@ def connect(info, script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) current_directory = os.path.abspath(os.path.curdir) worker.run_function_on_all_workers( - lambda worker_info: sys.path.insert(1, script_directory)) + lambda worker_info: sys.path.insert(1, script_directory) + ) worker.run_function_on_all_workers( - lambda worker_info: sys.path.insert(1, current_directory)) + lambda worker_info: sys.path.insert(1, current_directory) + ) # TODO(rkn): Here we first export functions to run, then remote # functions. The order matters. For example, one of the functions to # run may set the Python path, which is needed to import a module used @@ -2161,15 +2337,19 @@ def connect(info, # Export cached remote functions to the workers. for cached_type, info in worker.cached_remote_functions_and_actors: if cached_type == "remote_function": - (function_id, func_name, func, func_invoker, - function_properties) = info - export_remote_function(function_id, func_name, func, - func_invoker, function_properties, - worker) + ( + function_id, func_name, func, func_invoker, + function_properties + ) = info + export_remote_function( + function_id, func_name, func, func_invoker, + function_properties, worker + ) elif cached_type == "actor": (key, actor_class_info) = info - ray.actor.publish_actor_class_to_key(key, actor_class_info, - worker) + ray.actor.publish_actor_class_to_key( + key, actor_class_info, worker + ) else: assert False, "This code should be unreachable." worker.cached_functions_to_run = None @@ -2227,17 +2407,20 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): print( "WARNING: Could not produce a deterministic class ID for class " "{}".format(cls), - file=sys.stderr) + file=sys.stderr + ) return hashlib.sha1(new_class_id).digest() -def register_custom_serializer(cls, - use_pickle=False, - use_dict=False, - serializer=None, - deserializer=None, - local=False, - worker=global_worker): +def register_custom_serializer( + cls, + use_pickle=False, + use_dict=False, + serializer=None, + deserializer=None, + local=False, + worker=global_worker +): """Enable serialization and deserialization for a particular class. This method runs the register_class function defined below on every worker, @@ -2265,12 +2448,14 @@ def register_custom_serializer(cls, """ assert (serializer is None) == (deserializer is None), ( "The serializer/deserializer arguments must both be provided or " - "both not be provided.") + "both not be provided." + ) use_custom_serializer = (serializer is not None) assert use_custom_serializer + use_pickle + use_dict == 1, ( "Exactly one of use_pickle, use_dict, or serializer/deserializer must " - "be specified.") + "be specified." + ) if use_dict: # Raise an exception if cls cannot be serialized efficiently by Ray. @@ -2289,8 +2474,10 @@ def register_custom_serializer(cls, # may be different on different workers. class_id = _try_to_compute_deterministic_class_id(cls) except Exception as e: - raise serialization.CloudPickleError("Failed to pickle class " - "'{}'".format(cls)) + raise serialization.CloudPickleError( + "Failed to pickle class " + "'{}'".format(cls) + ) else: # In this case, the class ID only needs to be meaningful on this worker # and not across workers. @@ -2307,7 +2494,8 @@ def register_class_for_serialization(worker_info): class_id, pickle=use_pickle, custom_serializer=serializer, - custom_deserializer=deserializer) + custom_deserializer=deserializer + ) if not local: worker.run_function_on_all_workers(register_class_for_serialization) @@ -2333,26 +2521,32 @@ def __init__(self, event_type, contents=None, worker=global_worker): def __enter__(self): """Log the beginning of a span event.""" - log(event_type=self.event_type, + log( + event_type=self.event_type, contents=self.contents, kind=LOG_SPAN_START, - worker=self.worker) + worker=self.worker + ) def __exit__(self, type, value, tb): """Log the end of a span event. Log any exception that occurred.""" if type is None: - log(event_type=self.event_type, + log( + event_type=self.event_type, kind=LOG_SPAN_END, - worker=self.worker) + worker=self.worker + ) else: - log(event_type=self.event_type, + log( + event_type=self.event_type, contents={ "type": str(type), "value": value, "traceback": traceback.format_exc() }, kind=LOG_SPAN_END, - worker=self.worker) + worker=self.worker + ) def log_span(event_type, contents=None, worker=global_worker): @@ -2396,8 +2590,9 @@ def flush_log(worker=global_worker): event_log_key = b"event_log:" + worker.worker_id event_log_value = json.dumps(worker.events) if not worker.use_raylet: - worker.local_scheduler_client.log_event(event_log_key, event_log_value, - time.time()) + worker.local_scheduler_client.log_event( + event_log_key, event_log_value, time.time() + ) worker.events = [] @@ -2458,7 +2653,8 @@ def put(value, worker=global_worker): # In PYTHON_MODE, ray.put is the identity operation. return value object_id = worker.local_scheduler_client.compute_put_id( - worker.current_task_id, worker.put_index, worker.use_raylet) + worker.current_task_id, worker.put_index, worker.use_raylet + ) worker.put_object(object_id, value) worker.put_index += 1 return object_id @@ -2493,18 +2689,23 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): if isinstance(object_ids, ray.local_scheduler.ObjectID): raise TypeError( - "wait() expected a list of ObjectID, got a single ObjectID") + "wait() expected a list of ObjectID, got a single ObjectID" + ) if not isinstance(object_ids, list): - raise TypeError("wait() expected a list of ObjectID, got {}".format( - type(object_ids))) + raise TypeError( + "wait() expected a list of ObjectID, got {}".format( + type(object_ids) + ) + ) if worker.mode != PYTHON_MODE: for object_id in object_ids: if not isinstance(object_id, ray.local_scheduler.ObjectID): - raise TypeError("wait() expected a list of ObjectID, " - "got list containing {}".format( - type(object_id))) + raise TypeError( + "wait() expected a list of ObjectID, " + "got list containing {}".format(type(object_id)) + ) check_connected(worker) with log_span("ray:wait", worker=worker): @@ -2526,7 +2727,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): ] timeout = timeout if timeout is not None else 2**30 ready_ids, remaining_ids = worker.plasma_client.wait( - object_id_strs, timeout, num_returns) + object_id_strs, timeout, num_returns + ) ready_ids = [ ray.local_scheduler.ObjectID(object_id.binary()) for object_id in ready_ids @@ -2560,19 +2762,23 @@ def _mode(worker=global_worker): return worker.mode -def export_remote_function(function_id, - func_name, - func, - func_invoker, - function_properties, - worker=global_worker): +def export_remote_function( + function_id, + func_name, + func, + func_invoker, + function_properties, + worker=global_worker +): check_main_thread() if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: - raise Exception("export_remote_function can only be called on a " - "driver.") + raise Exception( + "export_remote_function can only be called on a " + "driver." + ) - worker.function_properties[worker.task_driver_id.id()][ - function_id.id()] = function_properties + worker.function_properties[worker.task_driver_id.id()][function_id.id() + ] = function_properties task_driver_id = worker.task_driver_id key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id() @@ -2592,7 +2798,8 @@ def export_remote_function(function_id, del func.__globals__[func.__name__] worker.redis_client.hmset( - key, { + key, + { "driver_id": worker.task_driver_id.id(), "function_id": function_id.id(), "name": func_name, @@ -2601,7 +2808,8 @@ def export_remote_function(function_id, "num_return_vals": function_properties.num_return_vals, "resources": json.dumps(function_properties.resources), "max_calls": function_properties.max_calls - }) + } + ) worker.redis_client.rpush("Exports", key) @@ -2661,27 +2869,35 @@ def remote(*args, **kwargs): """ worker = global_worker - def make_remote_decorator(num_return_vals, - num_cpus, - num_gpus, - resources, - max_calls, - checkpoint_interval, - func_id=None): + def make_remote_decorator( + num_return_vals, + num_cpus, + num_gpus, + resources, + max_calls, + checkpoint_interval, + func_id=None + ): def remote_decorator(func_or_class): if inspect.isfunction(func_or_class) or is_cython(func_or_class): # Set the remote function default resources. - resources["CPU"] = (DEFAULT_REMOTE_FUNCTION_CPUS - if num_cpus is None else num_cpus) - resources["GPU"] = (DEFAULT_REMOTE_FUNCTION_GPUS - if num_gpus is None else num_gpus) + resources["CPU"] = ( + DEFAULT_REMOTE_FUNCTION_CPUS + if num_cpus is None else num_cpus + ) + resources["GPU"] = ( + DEFAULT_REMOTE_FUNCTION_GPUS + if num_gpus is None else num_gpus + ) function_properties = FunctionProperties( num_return_vals=num_return_vals, resources=resources, - max_calls=max_calls) - return remote_function_decorator(func_or_class, - function_properties) + max_calls=max_calls + ) + return remote_function_decorator( + func_or_class, function_properties + ) if inspect.isclass(func_or_class): # Set the actor default resources. if num_cpus is None and num_gpus is None and resources == {}: @@ -2695,18 +2911,24 @@ def remote_decorator(func_or_class): # associated with methods. resources["CPU"] = ( DEFAULT_ACTOR_CREATION_CPUS_SPECIFIED_CASE - if num_cpus is None else num_cpus) + if num_cpus is None else num_cpus + ) resources["GPU"] = ( DEFAULT_ACTOR_CREATION_GPUS_SPECIFIED_CASE - if num_gpus is None else num_gpus) + if num_gpus is None else num_gpus + ) actor_method_cpus = ( - DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE) - - return worker.make_actor(func_or_class, resources, - checkpoint_interval, - actor_method_cpus) - raise Exception("The @ray.remote decorator must be applied to " - "either a function or to a class.") + DEFAULT_ACTOR_METHOD_CPUS_SPECIFIED_CASE + ) + + return worker.make_actor( + func_or_class, resources, checkpoint_interval, + actor_method_cpus + ) + raise Exception( + "The @ray.remote decorator must be applied to " + "either a function or to a class." + ) def remote_function_decorator(func, function_properties): func_name = "{}.{}".format(func.__module__, func.__name__) @@ -2719,12 +2941,14 @@ def func_call(*args, **kwargs): """This runs immediately when a remote function is called.""" return _submit(args=args, kwargs=kwargs) - def _submit(args=None, - kwargs=None, - num_return_vals=None, - num_cpus=None, - num_gpus=None, - resources=None): + def _submit( + args=None, + kwargs=None, + num_return_vals=None, + num_cpus=None, + num_gpus=None, + resources=None + ): """An experimental alternate way to submit remote functions.""" check_connected() check_main_thread() @@ -2744,7 +2968,8 @@ def _submit(args=None, num_return_vals=num_return_vals, num_cpus=num_cpus, num_gpus=num_gpus, - resources=resources) + resources=resources + ) if len(object_ids) == 1: return object_ids[0] elif len(object_ids) > 1: @@ -2757,9 +2982,11 @@ def func_executor(arguments): def func_invoker(*args, **kwargs): """This is used to invoke the function.""" - raise Exception("Remote functions cannot be called directly. " - "Instead of running '{}()', try '{}.remote()'." - .format(func_name, func_name)) + raise Exception( + "Remote functions cannot be called directly. " + "Instead of running '{}()', try '{}.remote()'." + .format(func_name, func_name) + ) func_invoker.remote = func_call func_invoker._submit = _submit @@ -2777,12 +3004,18 @@ def func_invoker(*args, **kwargs): # Everything ready - export the function if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - export_remote_function(function_id, func_name, func, - func_invoker, function_properties) + export_remote_function( + function_id, func_name, func, func_invoker, + function_properties + ) elif worker.mode is None: - worker.cached_remote_functions_and_actors.append( - ("remote_function", (function_id, func_name, func, - func_invoker, function_properties))) + worker.cached_remote_functions_and_actors.append(( + "remote_function", + ( + function_id, func_name, func, func_invoker, + function_properties + ) + )) return func_invoker return remote_decorator @@ -2792,40 +3025,49 @@ def func_invoker(*args, **kwargs): num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None resources = kwargs.get("resources", {}) if not isinstance(resources, dict): - raise Exception("The 'resources' keyword argument must be a " - "dictionary, but received type {}.".format( - type(resources))) + raise Exception( + "The 'resources' keyword argument must be a " + "dictionary, but received type {}.".format(type(resources)) + ) assert "CPU" not in resources, "Use the 'num_cpus' argument." assert "GPU" not in resources, "Use the 'num_gpus' argument." # Handle other arguments. - num_return_vals = (kwargs["num_return_vals"] - if "num_return_vals" in kwargs else 1) + num_return_vals = ( + kwargs["num_return_vals"] if "num_return_vals" in kwargs else 1 + ) max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0 - checkpoint_interval = (kwargs["checkpoint_interval"] - if "checkpoint_interval" in kwargs else -1) + checkpoint_interval = ( + kwargs["checkpoint_interval"] if "checkpoint_interval" in kwargs else -1 + ) if _mode() == WORKER_MODE: if "function_id" in kwargs: function_id = kwargs["function_id"] - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - resources, max_calls, - checkpoint_interval, function_id) + return make_remote_decorator( + num_return_vals, num_cpus, num_gpus, resources, max_calls, + checkpoint_interval, function_id + ) if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): # This is the case where the decorator is just @ray.remote. - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - resources, max_calls, - checkpoint_interval)(args[0]) + return make_remote_decorator( + num_return_vals, num_cpus, num_gpus, resources, max_calls, + checkpoint_interval + )( + args[0] + ) else: # This is the case where the decorator is something like # @ray.remote(num_return_vals=2). - error_string = ("The @ray.remote decorator must be applied either " - "with no arguments and no parentheses, for example " - "'@ray.remote', or it must be applied using some of " - "the arguments 'num_return_vals', 'resources', " - "or 'max_calls', like " - "'@ray.remote(num_return_vals=2, " - "resources={\"GPU\": 1})'.") + error_string = ( + "The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_return_vals', 'resources', " + "or 'max_calls', like " + "'@ray.remote(num_return_vals=2, " + "resources={\"GPU\": 1})'." + ) assert len(args) == 0 and len(kwargs) > 0, error_string for key in kwargs: assert key in [ @@ -2833,5 +3075,7 @@ def func_invoker(*args, **kwargs): "max_calls", "checkpoint_interval" ], error_string assert "function_id" not in kwargs - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - resources, max_calls, checkpoint_interval) + return make_remote_decorator( + num_return_vals, num_cpus, num_gpus, resources, max_calls, + checkpoint_interval + ) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 3e761a9d4c77..6f15f852dea4 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -10,34 +10,41 @@ parser = argparse.ArgumentParser( description=("Parse addresses for the worker " - "to connect to.")) + "to connect to.") +) parser.add_argument( "--node-ip-address", required=True, type=str, - help="the ip address of the worker's node") + help="the ip address of the worker's node" +) parser.add_argument( "--redis-address", required=True, type=str, - help="the address to use for Redis") + help="the address to use for Redis" +) parser.add_argument( "--object-store-name", required=True, type=str, - help="the object store's name") + help="the object store's name" +) parser.add_argument( "--object-store-manager-name", required=False, type=str, - help="the object store manager's name") + help="the object store manager's name" +) parser.add_argument( "--local-scheduler-name", required=False, type=str, - help="the local scheduler's name") + help="the local scheduler's name" +) parser.add_argument( - "--raylet-name", required=False, type=str, help="the raylet's name") + "--raylet-name", required=False, type=str, help="the raylet's name" +) if __name__ == "__main__": args = parser.parse_args() @@ -52,7 +59,8 @@ } ray.worker.connect( - info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None)) + info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None) + ) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker @@ -72,7 +80,8 @@ # Create a Redis client. redis_client = ray.services.create_redis_client(args.redis_address) ray.utils.push_error_to_driver( - redis_client, "worker_crash", traceback_str, driver_id=None) + redis_client, "worker_crash", traceback_str, driver_id=None + ) # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to diff --git a/python/setup.py b/python/setup.py index 85e639c50c92..aada4bf21b31 100644 --- a/python/setup.py +++ b/python/setup.py @@ -70,8 +70,8 @@ def run(self): pyarrow_files = [ os.path.join("ray/pyarrow_files/pyarrow", filename) for filename in os.listdir("./ray/pyarrow_files/pyarrow") - if not os.path.isdir( - os.path.join("ray/pyarrow_files/pyarrow", filename)) + if not os.path. + isdir(os.path.join("ray/pyarrow_files/pyarrow", filename)) ] files_to_include = ray_files + pyarrow_files @@ -83,15 +83,18 @@ def run(self): for filename in os.listdir(generated_python_directory): if filename[-3:] == ".py": self.move_file( - os.path.join(generated_python_directory, filename)) + os.path.join(generated_python_directory, filename) + ) # Try to copy over the optional files. for filename in optional_ray_files: try: self.move_file(filename) except Exception as e: - print("Failed to copy optional file {}. This is ok." - .format(filename)) + print( + "Failed to copy optional file {}. This is ok." + .format(filename) + ) def move_file(self, filename): # TODO(rkn): This feels very brittle. It may not handle all cases. See @@ -138,4 +141,5 @@ def has_ext_modules(self): entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]}, include_package_data=True, zip_safe=False, - license="Apache 2.0") + license="Apache 2.0" +) diff --git a/test/actor_test.py b/test/actor_test.py index 7e040185b9fc..5708427de565 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -39,21 +39,24 @@ def get_values(self, arg0, arg1=2, arg2="b"): actor = Actor.remote(1, 2, "c") self.assertEqual( - ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, "cd")) + ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, "cd") + ) actor = Actor.remote(1, arg2="c") self.assertEqual( - ray.get(actor.get_values.remote(0, arg2="d")), (1, 3, "cd")) + ray.get(actor.get_values.remote(0, arg2="d")), (1, 3, "cd") + ) self.assertEqual( - ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), - (1, 1, "cd")) + ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), (1, 1, "cd") + ) actor = Actor.remote(1, arg2="c", arg1=2) self.assertEqual( - ray.get(actor.get_values.remote(0, arg2="d")), (1, 4, "cd")) + ray.get(actor.get_values.remote(0, arg2="d")), (1, 4, "cd") + ) self.assertEqual( - ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), - (1, 2, "cd")) + ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), (1, 2, "cd") + ) # Make sure we get an exception if the constructor is called # incorrectly. @@ -85,18 +88,19 @@ def get_values(self, arg0, arg1=2, *args): self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, (), ())) actor = Actor.remote(1, 2) - self.assertEqual( - ray.get(actor.get_values.remote(2, 3)), (3, 5, (), ())) + self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, (), ())) actor = Actor.remote(1, 2, "c") self.assertEqual( - ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, ("c", ), - ("d", ))) + ray.get(actor.get_values.remote(2, 3, "d")), + (3, 5, ("c", ), ("d", )) + ) actor = Actor.remote(1, 2, "a", "b", "c", "d") self.assertEqual( ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)), - (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) + (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4)) + ) @ray.remote class Actor(object): @@ -429,7 +433,8 @@ def reset(self): for i in range(num_actors): self.assertEqual( result_values[(num_increases * i):(num_increases * (i + 1))], - list(range(i + 1, num_increases + i + 1))) + list(range(i + 1, num_increases + i + 1)) + ) # Reset the actor values. [actor.reset.remote() for actor in actors] @@ -442,7 +447,8 @@ def reset(self): for j in range(num_increases): self.assertEqual( result_values[(num_actors * j):(num_actors * (j + 1))], - num_actors * [j + 1]) + num_actors * [j + 1] + ) class ActorNesting(unittest.TestCase): @@ -496,7 +502,8 @@ def h(self, object_ids): self.assertEqual(ray.get(actor.g.remote()), list(range(1, 6))) self.assertEqual( ray.get(actor.h.remote([f.remote(i) for i in range(5)])), - list(range(1, 6))) + list(range(1, 6)) + ) def testDefineActorWithinActor(self): # Make sure we can use remote funtions within actors. @@ -569,7 +576,8 @@ def get_value(self): self.assertEqual(ray.get(f.remote(3, 1)), [3]) self.assertEqual( ray.get([f.remote(i, 20) for i in range(10)]), - [20 * [i] for i in range(10)]) + [20 * [i] for i in range(10)] + ) def testUseActorWithinRemoteFunction(self): # Make sure we can create and use actors within remote funtions. @@ -701,7 +709,8 @@ def testActorLoadBalancing(self): ray.worker._init( start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers) + num_local_schedulers=num_local_schedulers + ) @ray.remote class Actor1(object): @@ -720,13 +729,16 @@ def get_location(self): attempts = 0 while attempts < num_attempts: actors = [Actor1.remote() for _ in range(num_actors)] - locations = ray.get( - [actor.get_location.remote() for actor in actors]) + locations = ray.get([ + actor.get_location.remote() for actor in actors + ]) names = set(locations) counts = [locations.count(name) for name in names] print("Counts are {}.".format(counts)) - if (len(names) == num_local_schedulers - and all([count >= minimum_count for count in counts])): + if ( + len(names) == num_local_schedulers + and all([count >= minimum_count for count in counts]) + ): break attempts += 1 self.assertLess(attempts, num_attempts) @@ -744,7 +756,8 @@ def tearDown(self): ray.worker.cleanup() @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Crashing with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Crashing with new GCS API." + ) def testActorGPUs(self): num_local_schedulers = 3 num_gpus_per_scheduler = 4 @@ -753,7 +766,8 @@ def testActorGPUs(self): num_workers=0, num_local_schedulers=num_local_schedulers, num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]), - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]) + ) @ray.remote(num_gpus=1) class Actor1(object): @@ -764,7 +778,8 @@ def get_location_and_ids(self): assert ray.get_gpu_ids() == self.gpu_ids return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) # Create one actor per GPU. actors = [ @@ -772,8 +787,9 @@ def get_location_and_ids(self): for _ in range(num_local_schedulers * num_gpus_per_scheduler) ] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get( - [actor.get_location_and_ids.remote() for actor in actors]) + locations_and_ids = ray.get([ + actor.get_location_and_ids.remote() for actor in actors + ]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), num_local_schedulers) location_actor_combinations = [] @@ -781,7 +797,8 @@ def get_location_and_ids(self): for gpu_id in range(num_gpus_per_scheduler): location_actor_combinations.append((node_name, (gpu_id, ))) self.assertEqual( - set(locations_and_ids), set(location_actor_combinations)) + set(locations_and_ids), set(location_actor_combinations) + ) # Creating a new actor should fail because all of the GPUs are being # used. @@ -797,7 +814,8 @@ def testActorMultipleGPUs(self): num_workers=0, num_local_schedulers=num_local_schedulers, num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]), - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]) + ) @ray.remote(num_gpus=2) class Actor1(object): @@ -808,13 +826,15 @@ def get_location_and_ids(self): assert ray.get_gpu_ids() == self.gpu_ids return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) # Create some actors. actors1 = [Actor1.remote() for _ in range(num_local_schedulers * 2)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get( - [actor.get_location_and_ids.remote() for actor in actors1]) + locations_and_ids = ray.get([ + actor.get_location_and_ids.remote() for actor in actors1 + ]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), num_local_schedulers) @@ -840,16 +860,19 @@ def __init__(self): def get_location_and_ids(self): return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) # Create some actors. actors2 = [Actor2.remote() for _ in range(num_local_schedulers)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get( - [actor.get_location_and_ids.remote() for actor in actors2]) + locations_and_ids = ray.get([ + actor.get_location_and_ids.remote() for actor in actors2 + ]) self.assertEqual( node_names, - set([location for location, gpu_id in locations_and_ids])) + set([location for location, gpu_id in locations_and_ids]) + ) for location, gpu_ids in locations_and_ids: gpus_in_use[location].extend(gpu_ids) for node_name in node_names: @@ -870,7 +893,8 @@ def testActorDifferentNumbersOfGPUs(self): num_workers=0, num_local_schedulers=3, num_cpus=[10, 10, 10], - num_gpus=[0, 5, 10]) + num_gpus=[0, 5, 10] + ) @ray.remote(num_gpus=1) class Actor1(object): @@ -880,13 +904,15 @@ def __init__(self): def get_location_and_ids(self): return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) # Create some actors. actors = [Actor1.remote() for _ in range(0 + 5 + 10)] # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get( - [actor.get_location_and_ids.remote() for actor in actors]) + locations_and_ids = ray.get([ + actor.get_location_and_ids.remote() for actor in actors + ]) node_names = set([location for location, gpu_id in locations_and_ids]) self.assertEqual(len(node_names), 2) for node_name in node_names: @@ -897,7 +923,8 @@ def get_location_and_ids(self): self.assertIn(len(node_gpu_ids), [5, 10]) self.assertEqual( set(node_gpu_ids), - set([(i, ) for i in range(len(node_gpu_ids))])) + set([(i, ) for i in range(len(node_gpu_ids))]) + ) # Creating a new actor should fail because all of the GPUs are being # used. @@ -914,7 +941,8 @@ def testActorMultipleGPUsFromMultipleTasks(self): num_local_schedulers=num_local_schedulers, redirect_output=True, num_cpus=(num_local_schedulers * [10 * num_gpus_per_scheduler]), - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]) + ) @ray.remote def create_actors(n): @@ -924,8 +952,9 @@ def __init__(self): self.gpu_ids = ray.get_gpu_ids() def get_location_and_ids(self): - return ((ray.worker.global_worker.plasma_client. - store_socket_name), tuple(self.gpu_ids)) + return (( + ray.worker.global_worker.plasma_client.store_socket_name + ), tuple(self.gpu_ids)) # Create n actors. for _ in range(n): @@ -944,7 +973,8 @@ def __init__(self): def get_location_and_ids(self): return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) # All the GPUs should be used up now. a = Actor.remote() @@ -960,7 +990,8 @@ def testActorsAndTasksWithGPUs(self): num_workers=0, num_local_schedulers=num_local_schedulers, num_cpus=num_gpus_per_scheduler, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler]) + ) def check_intervals_non_overlapping(list_of_intervals): for i in range(len(list_of_intervals)): @@ -973,10 +1004,13 @@ def check_intervals_non_overlapping(list_of_intervals): self.assertLess(second_interval[0], second_interval[1]) intervals_nonoverlapping = ( first_interval[1] <= second_interval[0] - or second_interval[1] <= first_interval[0]) + or second_interval[1] <= first_interval[0] + ) assert intervals_nonoverlapping, ( "Intervals {} and {} are overlapping.".format( - first_interval, second_interval)) + first_interval, second_interval + ) + ) @ray.remote(num_gpus=1) def f1(): @@ -986,8 +1020,11 @@ def f1(): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 1 assert gpu_ids[0] in range(num_gpus_per_scheduler) - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(gpu_ids), [t1, t2]) + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(gpu_ids), + [t1, t2] + ) @ray.remote(num_gpus=2) def f2(): @@ -998,8 +1035,11 @@ def f2(): assert len(gpu_ids) == 2 assert gpu_ids[0] in range(num_gpus_per_scheduler) assert gpu_ids[1] in range(num_gpus_per_scheduler) - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(gpu_ids), [t1, t2]) + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(gpu_ids), + [t1, t2] + ) @ray.remote(num_gpus=1) class Actor1(object): @@ -1012,19 +1052,20 @@ def get_location_and_ids(self): assert ray.get_gpu_ids() == self.gpu_ids return ( ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) + tuple(self.gpu_ids) + ) def locations_to_intervals_for_many_tasks(): # Launch a bunch of GPU tasks. locations_ids_and_intervals = ray.get([ - f1.remote() for _ in range( - 5 * num_local_schedulers * num_gpus_per_scheduler) + f1.remote() for _ in + range(5 * num_local_schedulers * num_gpus_per_scheduler) ] + [ - f2.remote() for _ in range( - 5 * num_local_schedulers * num_gpus_per_scheduler) + f2.remote() for _ in + range(5 * num_local_schedulers * num_gpus_per_scheduler) ] + [ - f1.remote() for _ in range( - 5 * num_local_schedulers * num_gpus_per_scheduler) + f1.remote() for _ in + range(5 * num_local_schedulers * num_gpus_per_scheduler) ]) locations_to_intervals = collections.defaultdict(lambda: []) @@ -1038,7 +1079,8 @@ def locations_to_intervals_for_many_tasks(): # Make sure that all GPUs were used. self.assertEqual( len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler) + num_local_schedulers * num_gpus_per_scheduler + ) # For each GPU, verify that the set of tasks that used this specific # GPU did not overlap in time. for locations in locations_to_intervals: @@ -1057,7 +1099,8 @@ def locations_to_intervals_for_many_tasks(): # Make sure that all but one of the GPUs were used. self.assertEqual( len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler - 1) + num_local_schedulers * num_gpus_per_scheduler - 1 + ) # For each GPU, verify that the set of tasks that used this specific # GPU did not overlap in time. for locations in locations_to_intervals: @@ -1067,15 +1110,17 @@ def locations_to_intervals_for_many_tasks(): # Create several more actors that use GPUs. actors = [Actor1.remote() for _ in range(3)] - actor_locations = ray.get( - [actor.get_location_and_ids.remote() for actor in actors]) + actor_locations = ray.get([ + actor.get_location_and_ids.remote() for actor in actors + ]) # Run a bunch of GPU tasks. locations_to_intervals = locations_to_intervals_for_many_tasks() # Make sure that all but 11 of the GPUs were used. self.assertEqual( len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler - 1 - 3) + num_local_schedulers * num_gpus_per_scheduler - 1 - 3 + ) # For each GPU, verify that the set of tasks that used this specific # GPU did not overlap in time. for locations in locations_to_intervals: @@ -1087,8 +1132,8 @@ def locations_to_intervals_for_many_tasks(): # Create more actors to fill up all the GPUs. more_actors = [ - Actor1.remote() for _ in range( - num_local_schedulers * num_gpus_per_scheduler - 1 - 3) + Actor1.remote() for _ in + range(num_local_schedulers * num_gpus_per_scheduler - 1 - 3) ] # Wait for the actors to finish being created. ray.get([actor.get_location_and_ids.remote() for actor in more_actors]) @@ -1227,13 +1272,15 @@ def tearDown(self): ray.worker.cleanup() @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testLocalSchedulerDying(self): ray.worker._init( start_ray_local=True, num_local_schedulers=2, num_workers=0, - redirect_output=True) + redirect_output=True + ) @ray.remote class Counter(object): @@ -1261,8 +1308,8 @@ def inc(self): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1272,7 +1319,8 @@ def inc(self): self.assertEqual(results, list(range(1, 1 + len(results)))) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testManyLocalSchedulersDying(self): # This test can be made more stressful by increasing the numbers below. # The total number of actors created will be @@ -1285,7 +1333,8 @@ def testManyLocalSchedulersDying(self): start_ray_local=True, num_local_schedulers=num_local_schedulers, num_workers=0, - redirect_output=True) + redirect_output=True + ) @ray.remote class SlowCounter(object): @@ -1311,8 +1360,9 @@ def inc(self, duration): # a local scheduler, and run some more methods. for i in range(num_local_schedulers - 1): # Create some actors. - actors.extend( - [SlowCounter.remote() for _ in range(num_actors_at_a_time)]) + actors.extend([ + SlowCounter.remote() for _ in range(num_actors_at_a_time) + ]) # Run some methods. for j in range(len(actors)): actor = actors[j] @@ -1322,8 +1372,9 @@ def inc(self, duration): # exit of the corresponding local scheduler. Don't kill the first # local scheduler since that is the one that the driver is # connected to. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][i + 1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][i + + 1] process.kill() process.wait() @@ -1336,19 +1387,23 @@ def inc(self, duration): # Get the results and check that they have the correct values. for _, result_id_list in result_ids.items(): self.assertEqual( - ray.get(result_id_list), list( - range(1, - len(result_id_list) + 1))) - - def setup_counter_actor(self, - test_checkpoint=False, - save_exception=False, - resume_exception=False): + ray.get(result_id_list), + list(range(1, + len(result_id_list) + 1)) + ) + + def setup_counter_actor( + self, + test_checkpoint=False, + save_exception=False, + resume_exception=False + ): ray.worker._init( start_ray_local=True, num_local_schedulers=2, num_workers=0, - redirect_output=True) + redirect_output=True + ) # Only set the checkpoint interval if we're testing with checkpointing. checkpoint_interval = -1 @@ -1406,15 +1461,16 @@ def __ray_restore__(self, checkpoint): return actor, ids @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testCheckpointing(self): actor, ids = self.setup_counter_actor(test_checkpoint=True) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1431,7 +1487,8 @@ def testCheckpointing(self): self.assertLess(num_inc_calls, x) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testRemoteCheckpoint(self): actor, ids = self.setup_counter_actor(test_checkpoint=True) @@ -1439,8 +1496,8 @@ def testRemoteCheckpoint(self): ray.get(actor.__ray_checkpoint__.remote()) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1457,15 +1514,16 @@ def testRemoteCheckpoint(self): self.assertEqual(x, 101) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testLostCheckpoint(self): actor, ids = self.setup_counter_actor(test_checkpoint=True) # Wait for the first fraction of tasks to finish running. ray.get(ids[len(ids) // 10]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1483,16 +1541,18 @@ def testLostCheckpoint(self): self.assertLess(5, num_inc_calls) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testCheckpointException(self): actor, ids = self.setup_counter_actor( - test_checkpoint=True, save_exception=True) + test_checkpoint=True, save_exception=True + ) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1512,16 +1572,18 @@ def testCheckpointException(self): self.assertEqual(error[b"type"], b"checkpoint") @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testCheckpointResumeException(self): actor, ids = self.setup_counter_actor( - test_checkpoint=True, resume_exception=True) + test_checkpoint=True, resume_exception=True + ) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1565,8 +1627,8 @@ def fork_many_incs(counter, num_incs): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1578,7 +1640,8 @@ def fork_many_incs(counter, num_incs): self.assertEqual(x, count + 1) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testRemoteCheckpointDistributedHandle(self): counter, ids = self.setup_counter_actor(test_checkpoint=True) @@ -1603,8 +1666,8 @@ def fork_many_incs(counter, num_incs): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1644,8 +1707,8 @@ def fork_many_incs(counter, num_incs): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1657,12 +1720,14 @@ def fork_many_incs(counter, num_incs): self.assertEqual(x, count + 1) def _testNondeterministicReconstruction( - self, num_forks, num_items_per_fork, num_forks_to_wait): + self, num_forks, num_items_per_fork, num_forks_to_wait + ): ray.worker._init( start_ray_local=True, num_local_schedulers=2, num_workers=0, - redirect_output=True) + redirect_output=True + ) # Make a shared queue. @ray.remote @@ -1703,8 +1768,10 @@ def enqueue(queue, items): enqueue_tasks = [] for fork in range(num_forks): enqueue_tasks.append( - enqueue.remote(actor, - [(fork, i) for i in range(num_items_per_fork)])) + enqueue.remote( + actor, [(fork, i) for i in range(num_items_per_fork)] + ) + ) # Wait for the forks to complete their tasks. enqueue_tasks = ray.get(enqueue_tasks) enqueue_tasks = [fork_ids[0] for fork_ids in enqueue_tasks] @@ -1715,8 +1782,8 @@ def enqueue(queue, items): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] + process = ray.services.all_processes[ray.services. + PROCESS_TYPE_PLASMA_STORE][1] process.kill() process.wait() @@ -1725,20 +1792,24 @@ def enqueue(queue, items): reconstructed_queue = ray.get(actor.read.remote()) # Make sure the final queue has all items from all forks. self.assertEqual( - len(reconstructed_queue), num_forks * num_items_per_fork) + len(reconstructed_queue), num_forks * num_items_per_fork + ) # Make sure that the prefix of the final queue matches the queue from # the initial execution. self.assertEqual(queue, reconstructed_queue[:len(queue)]) @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "Currently doesn't work with the new GCS.") + "Currently doesn't work with the new GCS." + ) def testNondeterministicReconstruction(self): self._testNondeterministicReconstruction(10, 100, 10) - @unittest.skip("Nondeterministic reconstruction currently not supported " - "when there are concurrent forks that didn't finish " - "initial execution.") + @unittest.skip( + "Nondeterministic reconstruction currently not supported " + "when there are concurrent forks that didn't finish " + "initial execution." + ) def testNondeterministicReconstructionConcurrentForks(self): self._testNondeterministicReconstruction(10, 100, 1) @@ -1799,8 +1870,10 @@ def fork(queue, key, num_items): filtered_items = [item[1] for item in items if item[0] == i] self.assertEqual(filtered_items, list(range(num_items_per_fork))) - @unittest.skip("Garbage collection for distributed actor handles not " - "implemented.") + @unittest.skip( + "Garbage collection for distributed actor handles not " + "implemented." + ) def testGarbageCollection(self): queue = self.setup_queue_actor() @@ -1871,7 +1944,8 @@ def method(self): actor2s = [Actor2.remote() for _ in range(2)] results = [a.method.remote() for a in actor2s] ready_ids, remaining_ids = ray.wait( - results, num_returns=len(results), timeout=1000) + results, num_returns=len(results), timeout=1000 + ) self.assertEqual(len(ready_ids), 1) def testCustomLabelPlacement(self): @@ -1883,7 +1957,8 @@ def testCustomLabelPlacement(self): "CustomResource1": 2 }, { "CustomResource2": 2 - }]) + }] + ) @ray.remote(resources={"CustomResource1": 1}) class ResourceActor1(object): @@ -1912,7 +1987,8 @@ def testCreatingMoreActorsThanResources(self): num_workers=0, num_cpus=10, num_gpus=2, - resources={"CustomResource1": 1}) + resources={"CustomResource1": 1} + ) @ray.remote(num_gpus=1) class ResourceActor1(object): diff --git a/test/array_test.py b/test/array_test.py index 567ac81152c6..a7c5ac192380 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -21,7 +21,7 @@ def tearDown(self): def testMethods(self): for module in [ - ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg + ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg ]: reload(module) ray.init() @@ -58,28 +58,31 @@ def tearDown(self): def testAssemble(self): for module in [ - ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg + ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg ]: reload(module) ray.init() a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) - x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], - np.array([[a], [b]])) - assert_equal(x.assemble(), - np.vstack([ - np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), - np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE]) - ])) + x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], + [b]])) + assert_equal( + x.assemble(), + np.vstack([ + np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), + np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE]) + ]) + ) def testMethods(self): for module in [ - ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg + ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg ]: reload(module) ray.worker._init( - start_ray_local=True, num_local_schedulers=2, num_cpus=[10, 10]) + start_ray_local=True, num_local_schedulers=2, num_cpus=[10, 10] + ) x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) @@ -90,7 +93,8 @@ def testMethods(self): x = da.random.normal.remote([11, 25, 49]) y = da.copy.remote(x) assert_equal( - ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y))) + ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(y)) + ) x = da.eye.remote(25, dtype_name="float") assert_equal(ray.get(da.assemble.remote(x)), np.eye(25)) @@ -99,13 +103,15 @@ def testMethods(self): y = da.triu.remote(x) assert_equal( ray.get(da.assemble.remote(y)), - np.triu(ray.get(da.assemble.remote(x)))) + np.triu(ray.get(da.assemble.remote(x))) + ) x = da.random.normal.remote([25, 49]) y = da.tril.remote(x) assert_equal( ray.get(da.assemble.remote(y)), - np.tril(ray.get(da.assemble.remote(x)))) + np.tril(ray.get(da.assemble.remote(x))) + ) x = da.random.normal.remote([25, 49]) y = da.random.normal.remote([49, 18]) @@ -122,7 +128,8 @@ def testMethods(self): z = da.add.remote(x, y) assert_almost_equal( ray.get(da.assemble.remote(z)), - ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y))) + ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y)) + ) # test subtract x = da.random.normal.remote([33, 40]) @@ -130,13 +137,15 @@ def testMethods(self): z = da.subtract.remote(x, y) assert_almost_equal( ray.get(da.assemble.remote(z)), - ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y))) + ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y)) + ) # test transpose x = da.random.normal.remote([234, 432]) y = da.transpose.remote(x) assert_equal( - ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y))) + ray.get(da.assemble.remote(x)).T, ray.get(da.assemble.remote(y)) + ) # test numpy_to_dist x = da.random.normal.remote([23, 45]) @@ -144,11 +153,13 @@ def testMethods(self): z = da.numpy_to_dist.remote(y) w = da.assemble.remote(z) assert_equal( - ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z))) + ray.get(da.assemble.remote(x)), ray.get(da.assemble.remote(z)) + ) assert_equal(ray.get(y), ray.get(w)) # test da.tsqr - for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], + for shape in [[123, da.BLOCK_SIZE], + [7, da.BLOCK_SIZE], [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]: x = da.random.normal.remote(shape) @@ -164,8 +175,10 @@ def testMethods(self): # test da.linalg.modified_lu def test_modified_lu(d1, d2): - print("testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " + - str(d2)) + print( + "testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " + + str(d2) + ) assert d1 >= d2 m = ra.random.normal.remote([d1, d2]) q, r = ra.linalg.qr.remote(m) @@ -185,14 +198,15 @@ def test_modified_lu(d1, d2): # Check that l is lower triangular. assert_equal(np.tril(l_val), l_val) - for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20, - 10)]: + for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20, 10)]: test_modified_lu(d1, d2) # test dist_tsqr_hr def test_dist_tsqr_hr(d1, d2): - print("testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + - str(d2)) + print( + "testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + + str(d2) + ) a = da.random.normal.remote([d1, d2]) y, t, y_top, r = da.linalg.tsqr_hr.remote(a) a_val = ray.get(da.assemble.remote(a)) @@ -208,7 +222,8 @@ def test_dist_tsqr_hr(d1, d2): # Check that a = (I - y * t * y_top.T) * r. assert_almost_equal(np.dot(q, r_val), a_val) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + for d1, d2 in [(123, da.BLOCK_SIZE), + (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]: test_dist_tsqr_hr(d1, d2) @@ -227,7 +242,8 @@ def test_dist_qr(d1, d2): assert_equal(r_val, np.triu(r_val)) assert_almost_equal(a_val, np.dot(q_val, r_val)) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + for d1, d2 in [(123, da.BLOCK_SIZE), + (7, da.BLOCK_SIZE), (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), (13, 21), (34, 35), (8, 7)]: test_dist_qr(d1, d2) diff --git a/test/autoscaler_test.py b/test/autoscaler_test.py index c3b84ca968f1..0436585fc93f 100644 --- a/test/autoscaler_test.py +++ b/test/autoscaler_test.py @@ -55,7 +55,8 @@ def nodes(self, tag_filters): if self.throw: raise Exception("oops") return [ - n.node_id for n in self.mock_nodes.values() + n.node_id + for n in self.mock_nodes.values() if n.matches(tag_filters) and n.state != "terminated" ] @@ -238,7 +239,8 @@ def testScaleUp(self): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) self.assertEqual(len(self.provider.nodes({})), 0) autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -253,7 +255,8 @@ def testTerminateOutdatedNodesGracefully(self): self.provider = MockProvider() self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "Worker"}, 10) autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) self.assertEqual(len(self.provider.nodes({})), 10) # Gradually scales down to meet target size, never going too low @@ -273,7 +276,8 @@ def testDynamicScaling(self): LoadMetrics(), max_concurrent_launches=5, max_failures=0, - update_interval_s=0) + update_interval_s=0 + ) self.assertEqual(len(self.provider.nodes({})), 0) autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -302,7 +306,8 @@ def testUpdateThrottling(self): LoadMetrics(), max_concurrent_launches=5, max_failures=0, - update_interval_s=10) + update_interval_s=10 + ) autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) new_config = SMALL_CLUSTER.copy() @@ -315,7 +320,8 @@ def testLaunchConfigChange(self): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -338,7 +344,8 @@ def testIgnoresCorruptedConfig(self): LoadMetrics(), max_concurrent_launches=10, max_failures=0, - update_interval_s=0) + update_interval_s=0 + ) autoscaler.update() # Write a corrupted config @@ -360,7 +367,8 @@ def testMaxFailures(self): self.provider = MockProvider() self.provider.throw = True autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=2, update_interval_s=0) + config_path, LoadMetrics(), max_failures=2, update_interval_s=0 + ) autoscaler.update() autoscaler.update() self.assertRaises(Exception, autoscaler.update) @@ -370,14 +378,16 @@ def testAbortOnCreationFailures(self): self.provider = MockProvider() self.provider.fail_creates = True autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) self.assertRaises(AssertionError, autoscaler.update) def testLaunchNewNodeOnOutOfBandTerminate(self): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) autoscaler.update() autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -398,16 +408,16 @@ def testConfiguresNewNodes(self): process_runner=runner, verbose_updates=True, node_updater_cls=NodeUpdaterThread, - update_interval_s=0) + update_interval_s=0 + ) autoscaler.update() autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) for node in self.provider.mock_nodes.values(): node.state = "running" - assert len( - self.provider.nodes({ - TAG_RAY_NODE_STATUS: "Uninitialized" - })) == 2 + assert len(self.provider.nodes({ + TAG_RAY_NODE_STATUS: "Uninitialized" + })) == 2 autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( @@ -424,16 +434,16 @@ def testReportsConfigFailures(self): process_runner=runner, verbose_updates=True, node_updater_cls=NodeUpdaterThread, - update_interval_s=0) + update_interval_s=0 + ) autoscaler.update() autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) for node in self.provider.mock_nodes.values(): node.state = "running" - assert len( - self.provider.nodes({ - TAG_RAY_NODE_STATUS: "Uninitialized" - })) == 2 + assert len(self.provider.nodes({ + TAG_RAY_NODE_STATUS: "Uninitialized" + })) == 2 autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( @@ -450,7 +460,8 @@ def testConfiguresOutdatedNodes(self): process_runner=runner, verbose_updates=True, node_updater_cls=NodeUpdaterThread, - update_interval_s=0) + update_interval_s=0 + ) autoscaler.update() autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -477,7 +488,8 @@ def testScaleUpBasedOnLoad(self): self.provider = MockProvider() lm = LoadMetrics() autoscaler = StandardAutoscaler( - config_path, lm, max_failures=0, update_interval_s=0) + config_path, lm, max_failures=0, update_interval_s=0 + ) self.assertEqual(len(self.provider.nodes({})), 0) autoscaler.update() self.assertEqual(len(self.provider.nodes({})), 2) @@ -521,7 +533,8 @@ def testRecoverUnhealthyWorkers(self): process_runner=runner, verbose_updates=True, node_updater_cls=NodeUpdaterThread, - update_interval_s=0) + update_interval_s=0 + ) autoscaler.update() for node in self.provider.mock_nodes.values(): node.state = "running" @@ -544,7 +557,8 @@ def testExternalNodeScaler(self): } config_path = self.write_config(config) autoscaler = StandardAutoscaler( - config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + config_path, LoadMetrics(), max_failures=0, update_interval_s=0 + ) self.assertIsInstance(autoscaler.provider, NodeProvider) def testExternalNodeScalerWrongImport(self): diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 81bc41cbfbcc..fbf8a68e26c0 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -28,28 +28,34 @@ def f(): driver_mode=ray.SILENT_MODE, start_workers_from_local_scheduler=False, start_ray_local=True, - redirect_output=True) + redirect_output=True + ) # Have the worker wait in a get call. f.remote() # Kill the worker. time.sleep(1) - (ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0] - .terminate()) + ( + ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0] + .terminate() + ) time.sleep(0.1) # Seal the object so the store attempts to notify the worker that the # get has been fulfilled. ray.worker.global_worker.plasma_client.create( - pa.plasma.ObjectID(obj_id), 100) + pa.plasma.ObjectID(obj_id), 100 + ) ray.worker.global_worker.plasma_client.seal(pa.plasma.ObjectID(obj_id)) time.sleep(0.1) # Make sure that nothing has died. self.assertTrue( ray.services.all_processes_alive( - exclude=[ray.services.PROCESS_TYPE_WORKER])) + exclude=[ray.services.PROCESS_TYPE_WORKER] + ) + ) # This test checks that when a worker dies in the middle of a wait, the # plasma store and manager will not die. @@ -65,28 +71,34 @@ def f(): driver_mode=ray.SILENT_MODE, start_workers_from_local_scheduler=False, start_ray_local=True, - redirect_output=True) + redirect_output=True + ) # Have the worker wait in a get call. f.remote() # Kill the worker. time.sleep(1) - (ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0] - .terminate()) + ( + ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0] + .terminate() + ) time.sleep(0.1) # Seal the object so the store attempts to notify the worker that the # get has been fulfilled. ray.worker.global_worker.plasma_client.create( - pa.plasma.ObjectID(obj_id), 100) + pa.plasma.ObjectID(obj_id), 100 + ) ray.worker.global_worker.plasma_client.seal(pa.plasma.ObjectID(obj_id)) time.sleep(0.1) # Make sure that nothing has died. self.assertTrue( ray.services.all_processes_alive( - exclude=[ray.services.PROCESS_TYPE_WORKER])) + exclude=[ray.services.PROCESS_TYPE_WORKER] + ) + ) def _testWorkerFailed(self, num_local_schedulers): @ray.remote @@ -101,7 +113,8 @@ def f(x): start_workers_from_local_scheduler=False, start_ray_local=True, num_cpus=[num_initial_workers] * num_local_schedulers, - redirect_output=True) + redirect_output=True + ) # Submit more tasks than there are workers so that all workers and # cores are utilized. object_ids = [ @@ -113,7 +126,8 @@ def f(x): time.sleep(0.1) # Kill the workers as the tasks execute. for worker in ( - ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]): + ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER] + ): worker.terminate() time.sleep(0.1) # Make sure that we can still get the objects after the executing tasks @@ -142,7 +156,8 @@ def f(x, j): num_local_schedulers=num_local_schedulers, start_ray_local=True, num_cpus=[num_workers_per_scheduler] * num_local_schedulers, - redirect_output=True) + redirect_output=True + ) # Submit more tasks than there are workers so that all workers and # cores are utilized. @@ -170,7 +185,8 @@ def f(x, j): # died. results = ray.get(object_ids) expected_results = 4 * list( - range(num_workers_per_scheduler * num_local_schedulers)) + range(num_workers_per_scheduler * num_local_schedulers) + ) self.assertEqual(results, expected_results) def check_components_alive(self, component_type, check_component_alive): @@ -181,56 +197,72 @@ def check_components_alive(self, component_type, check_component_alive): if check_component_alive: self.assertTrue(component.poll() is None) else: - print("waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") + print( + "waiting for " + component_type + " with PID " + + str(component.pid) + "to terminate" + ) component.wait() - print("done waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") + print( + "done waiting for " + component_type + " with PID " + + str(component.pid) + "to terminate" + ) self.assertTrue(not component.poll() is None) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testLocalSchedulerFailed(self): # Kill all local schedulers on worker nodes. self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER) # The plasma stores and plasma managers should still be alive on the # worker nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, - True) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, - True) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_STORE, True + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_MANAGER, True + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False + ) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testPlasmaManagerFailed(self): # Kill all plasma managers on worker nodes. self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER) # The plasma stores should still be alive (but unreachable) on the # worker nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, - True) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, - False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_STORE, True + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_MANAGER, False + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False + ) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testPlasmaStoreFailed(self): # Kill all plasma stores on worker nodes. self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE) # No processes should be left alive on the worker nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, - False) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, - False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_STORE, False + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_PLASMA_MANAGER, False + ) + self.check_components_alive( + ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False + ) def testDriverLivesSequential(self): ray.worker.init(redirect_output=True) diff --git a/test/credis_test.py b/test/credis_test.py index 86d50b9a6f31..105992292983 100644 --- a/test/credis_test.py +++ b/test/credis_test.py @@ -9,8 +9,10 @@ import ray -@unittest.skipIf(not os.environ.get('RAY_USE_NEW_GCS', False), - "Tests functionality of the new GCS.") +@unittest.skipIf( + not os.environ.get('RAY_USE_NEW_GCS', False), + "Tests functionality of the new GCS." +) class CredisTest(unittest.TestCase): def setUp(self): self.config = ray.init(num_workers=0) @@ -21,8 +23,7 @@ def tearDown(self): def test_credis_started(self): assert "credis_address" in self.config credis_address, credis_port = self.config["credis_address"].split(":") - credis_client = redis.StrictRedis( - host=credis_address, port=credis_port) + credis_client = redis.StrictRedis(host=credis_address, port=credis_port) assert credis_client.ping() is True redis_client = ray.worker.global_state.redis_client diff --git a/test/failure_test.py b/test/failure_test.py index 560bc020506d..31fa48e045e7 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -42,8 +42,9 @@ def testFailedTask(self): wait_for_errors(b"task", 2) self.assertEqual(len(relevant_errors(b"task")), 2) for task in relevant_errors(b"task"): - self.assertIn(b"Test function 1 intentionally failed.", - task.get(b"message")) + self.assertIn( + b"Test function 1 intentionally failed.", task.get(b"message") + ) x = test_functions.throw_exception_fct2.remote() try: @@ -113,10 +114,14 @@ def f(worker): wait_for_errors(b"function_to_run", 2) # Check that the error message is in the task info. self.assertEqual(len(ray.error_info()), 2) - self.assertIn(b"Function to run failed.", - ray.error_info()[0][b"message"]) - self.assertIn(b"Function to run failed.", - ray.error_info()[1][b"message"]) + self.assertIn( + b"Function to run failed.", + ray.error_info()[0][b"message"] + ) + self.assertIn( + b"Function to run failed.", + ray.error_info()[1][b"message"] + ) def testFailImportingActor(self): ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) @@ -160,7 +165,8 @@ def get_val(self): wait_for_errors(b"task", 1) self.assertIn( b"failed to be imported, and so cannot execute this method", - ray.error_info()[1][b"message"]) + ray.error_info()[1][b"message"] + ) # Check that if we try to get the function it throws an exception and # does not hang. @@ -171,7 +177,8 @@ def get_val(self): wait_for_errors(b"task", 2) self.assertIn( b"failed to be imported, and so cannot execute this method", - ray.error_info()[2][b"message"]) + ray.error_info()[2][b"message"] + ) f.close() @@ -205,15 +212,19 @@ def fail_method(self): # Make sure that we get errors from a failed constructor. wait_for_errors(b"task", 1) self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_message1, - ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn( + error_message1, + ray.error_info()[0][b"message"].decode("ascii") + ) # Make sure that we get errors from a failed method. a.fail_method.remote() wait_for_errors(b"task", 2) self.assertEqual(len(ray.error_info()), 2) - self.assertIn(error_message2, - ray.error_info()[1][b"message"].decode("ascii")) + self.assertIn( + error_message2, + ray.error_info()[1][b"message"].decode("ascii") + ) def testIncorrectMethodCalls(self): ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) @@ -285,8 +296,10 @@ def f(): wait_for_errors(b"worker_died", 1) self.assertEqual(len(ray.error_info()), 1) - self.assertIn("died or was killed while executing the task", - ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn( + "died or was killed while executing the task", + ray.error_info()[0][b"message"].decode("ascii") + ) def testActorWorkerDying(self): ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) @@ -353,7 +366,8 @@ def testPutError1(self): ray.worker._init( start_ray_local=True, driver_mode=ray.SILENT_MODE, - object_store_memory=store_size) + object_store_memory=store_size + ) num_objects = 3 object_size = 4 * 10**5 @@ -372,9 +386,9 @@ def put_arg_task(): # on the one before it. The result of the first task should get # evicted. args = [] - arg = single_dependency.remote(0, - np.zeros( - object_size, dtype=np.uint8)) + arg = single_dependency.remote( + 0, np.zeros(object_size, dtype=np.uint8) + ) for i in range(num_objects): arg = single_dependency.remote(i, arg) args.append(arg) @@ -401,7 +415,8 @@ def testPutError2(self): ray.worker._init( start_ray_local=True, driver_mode=ray.SILENT_MODE, - object_store_memory=store_size) + object_store_memory=store_size + ) num_objects = 3 object_size = 4 * 10**5 diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index b757d942d048..fb2d5d5d67b5 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -92,14 +92,12 @@ def _get_container_ip(self, container_id): Returns: The IP address of the container. """ - proc = subprocess.Popen( - [ - "docker", "inspect", - "--format={{.NetworkSettings.Networks.bridge" - ".IPAddress}}", container_id - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + proc = subprocess.Popen([ + "docker", "inspect", "--format={{.NetworkSettings.Networks.bridge" + ".IPAddress}}", container_id + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) p = re.compile("([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})") m = p.match(stdout_data) @@ -108,16 +106,18 @@ def _get_container_ip(self, container_id): else: return m.group(1) - def _start_head_node(self, docker_image, mem_size, shm_size, - num_redis_shards, num_cpus, num_gpus, - development_mode): + def _start_head_node( + self, docker_image, mem_size, shm_size, num_redis_shards, num_cpus, + num_gpus, development_mode + ): """Start the Ray head node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] volume_arg = ([ "-v", "{}:{}".format( os.path.dirname(os.path.realpath(__file__)), - "/ray/test/jenkins_tests") + "/ray/test/jenkins_tests" + ) ] if development_mode else []) command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [ @@ -130,7 +130,8 @@ def _start_head_node(self, docker_image, mem_size, shm_size, print("Starting head node with command:{}".format(command)) proc = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout_data, _ = wait_for_output(proc) container_id = self._get_container_id(stdout_data) if container_id is None: @@ -138,15 +139,18 @@ def _start_head_node(self, docker_image, mem_size, shm_size, self.head_container_id = container_id self.head_container_ip = self._get_container_ip(container_id) - def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus, - num_gpus, development_mode): + def _start_worker_node( + self, docker_image, mem_size, shm_size, num_cpus, num_gpus, + development_mode + ): """Start a Ray worker node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] volume_arg = ([ "-v", "{}:{}".format( os.path.dirname(os.path.realpath(__file__)), - "/ray/test/jenkins_tests") + "/ray/test/jenkins_tests" + ) ] if development_mode else []) command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + [ "--shm-size=" + shm_size, docker_image, "ray", "start", "--block", @@ -155,22 +159,25 @@ def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus, ]) print("Starting worker node with command:{}".format(command)) proc = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout_data, _ = wait_for_output(proc) container_id = self._get_container_id(stdout_data) if container_id is None: raise RuntimeError("Failed to find container id") self.worker_container_ids.append(container_id) - def start_ray(self, - docker_image=None, - mem_size=None, - shm_size=None, - num_nodes=None, - num_redis_shards=1, - num_cpus=None, - num_gpus=None, - development_mode=None): + def start_ray( + self, + docker_image=None, + mem_size=None, + shm_size=None, + num_nodes=None, + num_redis_shards=1, + num_cpus=None, + num_gpus=None, + development_mode=None + ): """Start a Ray cluster within docker. This starts one docker container running the head node and @@ -198,42 +205,44 @@ def start_ray(self, assert len(num_gpus) == num_nodes # Launch the head node. - self._start_head_node(docker_image, mem_size, shm_size, - num_redis_shards, num_cpus[0], num_gpus[0], - development_mode) + self._start_head_node( + docker_image, mem_size, shm_size, num_redis_shards, num_cpus[0], + num_gpus[0], development_mode + ) # Start the worker nodes. for i in range(num_nodes - 1): - self._start_worker_node(docker_image, mem_size, shm_size, - num_cpus[1 + i], num_gpus[1 + i], - development_mode) + self._start_worker_node( + docker_image, mem_size, shm_size, num_cpus[1 + i], + num_gpus[1 + i], development_mode + ) def _stop_node(self, container_id): """Stop a node in the Ray cluster.""" - proc = subprocess.Popen( - ["docker", "kill", container_id], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + proc = subprocess.Popen(["docker", "kill", container_id], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) stopped_container_id = self._get_container_id(stdout_data) if not container_id == stopped_container_id: - raise Exception("Failed to stop container {}." - .format(container_id)) + raise Exception("Failed to stop container {}.".format(container_id)) - proc = subprocess.Popen( - ["docker", "rm", "-f", container_id], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + proc = subprocess.Popen(["docker", "rm", "-f", container_id], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) removed_container_id = self._get_container_id(stdout_data) if not container_id == removed_container_id: - raise Exception("Failed to remove container {}." - .format(container_id)) + raise Exception( + "Failed to remove container {}.".format(container_id) + ) print( - "stop_node", { + "stop_node", + { "container_id": container_id, "is_head": container_id == self.head_container_id - }) + } + ) def stop_ray(self): """Stop the Ray cluster.""" @@ -252,11 +261,13 @@ def stop_ray(self): return success - def run_test(self, - test_script, - num_drivers, - driver_locations=None, - timeout_seconds=600): + def run_test( + self, + test_script, + num_drivers, + driver_locations=None, + timeout_seconds=600 + ): """Run a test script. Run a test using the Ray cluster. @@ -277,8 +288,8 @@ def run_test(self, Raises: Exception: An exception is raised if the timeout expires. """ - all_container_ids = ( - [self.head_container_id] + self.worker_container_ids) + all_container_ids = ([self.head_container_id] + + self.worker_container_ids) if driver_locations is None: driver_locations = [ np.random.randint(0, len(all_container_ids)) @@ -288,8 +299,9 @@ def run_test(self, # Define a signal handler and set an alarm to go off in # timeout_seconds. def handler(signum, frame): - raise RuntimeError("This test timed out after {} seconds." - .format(timeout_seconds)) + raise RuntimeError( + "This test timed out after {} seconds.".format(timeout_seconds) + ) signal.signal(signal.SIGALRM, handler) signal.alarm(timeout_seconds) @@ -300,14 +312,16 @@ def handler(signum, frame): # Get the container ID to run the ith driver in. container_id = all_container_ids[driver_locations[i]] command = [ - "docker", "exec", container_id, "/bin/bash", "-c", - ("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python " - "{}".format(self.head_container_ip, i, test_script)) + "docker", "exec", container_id, "/bin/bash", "-c", ( + "RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python " + "{}".format(self.head_container_ip, i, test_script) + ) ] print("Starting driver with command {}.".format(test_script)) # Start the driver. p = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) driver_processes.append(p) # Wait for the drivers to finish. @@ -331,44 +345,59 @@ def handler(signum, frame): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Run multinode tests in Docker.") + description="Run multinode tests in Docker." + ) parser.add_argument( - "--docker-image", default="ray-project/deploy", help="docker image") + "--docker-image", default="ray-project/deploy", help="docker image" + ) parser.add_argument("--mem-size", help="memory size") parser.add_argument("--shm-size", default="1G", help="shared memory size") parser.add_argument( "--num-nodes", default=1, type=int, - help="number of nodes to use in the cluster") + help="number of nodes to use in the cluster" + ) parser.add_argument( "--num-redis-shards", default=1, type=int, help=("the number of Redis shards to start on the " - "head node")) + "head node") + ) parser.add_argument( "--num-cpus", type=str, - help=("a comma separated list of values representing " - "the number of CPUs to start each node with")) + help=( + "a comma separated list of values representing " + "the number of CPUs to start each node with" + ) + ) parser.add_argument( "--num-gpus", type=str, - help=("a comma separated list of values representing " - "the number of GPUs to start each node with")) + help=( + "a comma separated list of values representing " + "the number of GPUs to start each node with" + ) + ) parser.add_argument( - "--num-drivers", default=1, type=int, help="number of drivers to run") + "--num-drivers", default=1, type=int, help="number of drivers to run" + ) parser.add_argument( "--driver-locations", type=str, - help=("a comma separated list of indices of the " - "containers to run the drivers in")) + help=( + "a comma separated list of indices of the " + "containers to run the drivers in" + ) + ) parser.add_argument("--test-script", required=True, help="test script") parser.add_argument( "--development-mode", action="store_true", - help="use local copies of the test scripts") + help="use local copies of the test scripts" + ) args = parser.parse_args() # Parse the number of CPUs and GPUs to use for each worker. @@ -379,8 +408,10 @@ def handler(signum, frame): if args.num_gpus is not None else num_nodes * [0]) # Parse the driver locations. - driver_locations = (None if args.driver_locations is None else - [int(i) for i in args.driver_locations.split(",")]) + driver_locations = ( + None if args.driver_locations is None else + [int(i) for i in args.driver_locations.split(",")] + ) d = DockerRunner() d.start_ray( @@ -391,12 +422,14 @@ def handler(signum, frame): num_redis_shards=args.num_redis_shards, num_cpus=num_cpus, num_gpus=num_gpus, - development_mode=args.development_mode) + development_mode=args.development_mode + ) try: run_results = d.run_test( args.test_script, args.num_drivers, - driver_locations=driver_locations) + driver_locations=driver_locations + ) finally: successfully_stopped = d.stop_ray() diff --git a/test/jenkins_tests/multi_node_tests/many_drivers_test.py b/test/jenkins_tests/multi_node_tests/many_drivers_test.py index d00e84a58c0f..b329ee5e08b3 100644 --- a/test/jenkins_tests/multi_node_tests/many_drivers_test.py +++ b/test/jenkins_tests/multi_node_tests/many_drivers_test.py @@ -6,8 +6,9 @@ import time import ray -from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event, - _wait_for_event) +from ray.test.test_utils import ( + _wait_for_nodes_to_join, _broadcast_event, _wait_for_event +) # This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for # a total of 61 GPUs. It should be run with a large number of drivers (e.g., diff --git a/test/jenkins_tests/multi_node_tests/remove_driver_test.py b/test/jenkins_tests/multi_node_tests/remove_driver_test.py index d18afed6e463..5dbd478dd48a 100644 --- a/test/jenkins_tests/multi_node_tests/remove_driver_test.py +++ b/test/jenkins_tests/multi_node_tests/remove_driver_test.py @@ -6,8 +6,10 @@ import time import ray -from ray.test.test_utils import (_wait_for_nodes_to_join, _broadcast_event, - _wait_for_event, wait_for_pid_to_exit) +from ray.test.test_utils import ( + _wait_for_nodes_to_join, _broadcast_event, _wait_for_event, + wait_for_pid_to_exit +) # This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a # total of 10 GPUs. It should be run with 7 drivers. Drivers 2 through 6 must @@ -29,7 +31,8 @@ def long_running_task(driver_index, task_index, redis_address): _broadcast_event( remote_function_event_name(driver_index, task_index), redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) + data=(ray.services.get_node_ip_address(), os.getpid()) + ) # Loop forever. while True: time.sleep(100) @@ -44,7 +47,8 @@ def __init__(self, driver_index, actor_index, redis_address): _broadcast_event( actor_event_name(driver_index, actor_index), redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) + data=(ray.services.get_node_ip_address(), os.getpid()) + ) assert len(ray.get_gpu_ids()) == 0 def check_ids(self): @@ -62,7 +66,8 @@ def __init__(self, driver_index, actor_index, redis_address): _broadcast_event( actor_event_name(driver_index, actor_index), redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) + data=(ray.services.get_node_ip_address(), os.getpid()) + ) assert len(ray.get_gpu_ids()) == 1 def check_ids(self): @@ -80,7 +85,8 @@ def __init__(self, driver_index, actor_index, redis_address): _broadcast_event( actor_event_name(driver_index, actor_index), redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) + data=(ray.services.get_node_ip_address(), os.getpid()) + ) assert len(ray.get_gpu_ids()) == 2 def check_ids(self): @@ -190,15 +196,15 @@ def cleanup_driver(redis_address, driver_index): _wait_for_event("DRIVER_0_DONE", redis_address) _wait_for_event("DRIVER_1_DONE", redis_address) - def try_to_create_actor(actor_class, driver_index, actor_index, - timeout=20): + def try_to_create_actor(actor_class, driver_index, actor_index, timeout=20): # Try to create an actor, but allow failures while we wait for the # monitor to release the resources for the removed drivers. start_time = time.time() while time.time() - start_time < timeout: try: - actor = actor_class.remote(driver_index, actor_index, - redis_address) + actor = actor_class.remote( + driver_index, actor_index, redis_address + ) except Exception as e: time.sleep(0.1) else: @@ -212,12 +218,14 @@ def try_to_create_actor(actor_class, driver_index, actor_index, actors_two_gpus = [] for i in range(3): actors_two_gpus.append( - try_to_create_actor(Actor2, driver_index, 10 + i)) + try_to_create_actor(Actor2, driver_index, 10 + i) + ) # Create some actors that require one GPU. actors_one_gpu = [] for i in range(4): actors_one_gpu.append( - try_to_create_actor(Actor1, driver_index, 10 + 3 + i)) + try_to_create_actor(Actor1, driver_index, 10 + 3 + i) + ) removed_workers = 0 @@ -225,13 +233,15 @@ def try_to_create_actor(actor_class, driver_index, actor_index, # driver 1 have been killed. for i in range(num_long_running_tasks_per_driver): node_ip_address, pid = _wait_for_event( - remote_function_event_name(0, i), redis_address) + remote_function_event_name(0, i), redis_address + ) if node_ip_address == ray.services.get_node_ip_address(): wait_for_pid_to_exit(pid) removed_workers += 1 for i in range(num_long_running_tasks_per_driver): node_ip_address, pid = _wait_for_event( - remote_function_event_name(1, i), redis_address) + remote_function_event_name(1, i), redis_address + ) if node_ip_address == ray.services.get_node_ip_address(): wait_for_pid_to_exit(pid) removed_workers += 1 @@ -239,19 +249,22 @@ def try_to_create_actor(actor_class, driver_index, actor_index, # been killed. for i in range(10): node_ip_address, pid = _wait_for_event( - actor_event_name(0, i), redis_address) + actor_event_name(0, i), redis_address + ) if node_ip_address == ray.services.get_node_ip_address(): wait_for_pid_to_exit(pid) removed_workers += 1 for i in range(9): node_ip_address, pid = _wait_for_event( - actor_event_name(1, i), redis_address) + actor_event_name(1, i), redis_address + ) if node_ip_address == ray.services.get_node_ip_address(): wait_for_pid_to_exit(pid) removed_workers += 1 - print("{} workers/actors were removed on this node." - .format(removed_workers)) + print( + "{} workers/actors were removed on this node.".format(removed_workers) + ) # Only one of the cleanup drivers should create and use more actors. if driver_index == 2: diff --git a/test/jenkins_tests/multi_node_tests/test_0.py b/test/jenkins_tests/multi_node_tests/test_0.py index 7d8240568ba6..c2ba9e2d62fe 100644 --- a/test/jenkins_tests/multi_node_tests/test_0.py +++ b/test/jenkins_tests/multi_node_tests/test_0.py @@ -25,9 +25,7 @@ def f(): for i in range(num_attempts): ip_addresses = ray.get([f.remote() for i in range(1000)]) distinct_addresses = set(ip_addresses) - counts = [ - ip_addresses.count(address) for address in distinct_addresses - ] + counts = [ip_addresses.count(address) for address in distinct_addresses] print("Counts are {}".format(counts)) if len(counts) == 5: break diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index bd358d3d9b03..6457c3a60803 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -67,8 +67,10 @@ def testTiming(self): elapsed_times.append(end_time - start_time) elapsed_times = np.sort(elapsed_times) average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to submit a trivial function call and get the " - "result:") + print( + "Time required to submit a trivial function call and get the " + "result:" + ) print(" Average: {}".format(average_elapsed_time)) print(" 90th percentile: {}".format(elapsed_times[900])) print(" 99th percentile: {}".format(elapsed_times[990])) @@ -109,11 +111,15 @@ def testCache(self): if d > 1.5 * b: if os.getenv("TRAVIS") is None: - raise Exception("The caching test was too slow. " - "d = {}, b = {}".format(d, b)) + raise Exception( + "The caching test was too slow. " + "d = {}, b = {}".format(d, b) + ) else: - print("WARNING: The caching test was too slow. " - "d = {}, b = {}".format(d, b)) + print( + "WARNING: The caching test was too slow. " + "d = {}, b = {}".format(d, b) + ) if __name__ == "__main__": diff --git a/test/monitor_test.py b/test/monitor_test.py index 968b3a3dfea6..c5337eac5ef4 100644 --- a/test/monitor_test.py +++ b/test/monitor_test.py @@ -86,13 +86,15 @@ def f(): @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "Failing with the new GCS API.") + "Failing with the new GCS API." + ) def testCleanupOnDriverExitSingleRedisShard(self): self._testCleanupOnDriverExit(num_redis_shards=1) @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "Hanging with the new GCS API.") + "Hanging with the new GCS API." + ) def testCleanupOnDriverExitManyRedisShards(self): self._testCleanupOnDriverExit(num_redis_shards=5) self._testCleanupOnDriverExit(num_redis_shards=31) diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 97c39d89544e..91e02d5cb658 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -35,7 +35,8 @@ def setUp(self): # Get the redis address from the output. redis_substring_prefix = "redis_address=\"" redis_address_location = ( - out.find(redis_substring_prefix) + len(redis_substring_prefix)) + out.find(redis_substring_prefix) + len(redis_substring_prefix) + ) redis_address = out[redis_address_location:] self.redis_address = redis_address.split("\"")[0] @@ -69,8 +70,10 @@ def f(): # Make sure we got the error. self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, - ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn( + error_string1, + ray.error_info()[0][b"message"].decode("ascii") + ) # Start another driver and make sure that it does not receive this # error. Make the other driver throw an error, and make sure it @@ -110,8 +113,10 @@ def f(): # Make sure that the other error message doesn't show up for this # driver. self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, - ray.error_info()[0][b"message"].decode("ascii")) + self.assertIn( + error_string1, + ray.error_info()[0][b"message"].decode("ascii") + ) def testRemoteFunctionIsolation(self): # This test will run multiple remote functions with the same names in @@ -218,13 +223,15 @@ def testCallingStartRayHead(self): subprocess.Popen(["ray", "stop"]).wait() # Test starting Ray with a node IP address specified. - run_and_get_output( - ["ray", "start", "--head", "--node-ip-address", "127.0.0.1"]) + run_and_get_output([ + "ray", "start", "--head", "--node-ip-address", "127.0.0.1" + ]) subprocess.Popen(["ray", "stop"]).wait() # Test starting Ray with an object manager port specified. - run_and_get_output( - ["ray", "start", "--head", "--object-manager-port", "12345"]) + run_and_get_output([ + "ray", "start", "--head", "--object-manager-port", "12345" + ]) subprocess.Popen(["ray", "stop"]).wait() # Test starting Ray with the number of CPUs specified. @@ -236,17 +243,17 @@ def testCallingStartRayHead(self): subprocess.Popen(["ray", "stop"]).wait() # Test starting Ray with the max redis clients specified. - run_and_get_output( - ["ray", "start", "--head", "--redis-max-clients", "100"]) + run_and_get_output([ + "ray", "start", "--head", "--redis-max-clients", "100" + ]) subprocess.Popen(["ray", "stop"]).wait() # Test starting Ray with all arguments specified. run_and_get_output([ "ray", "start", "--head", "--num-workers", "20", "--redis-port", "6379", "--redis-shard-ports", "6380,6381,6382", - "--object-manager-port", "12345", "--num-cpus", "100", - "--num-gpus", "0", "--redis-max-clients", "100", "--resources", - "{\"Custom\": 1}" + "--object-manager-port", "12345", "--num-cpus", "100", "--num-gpus", + "0", "--redis-max-clients", "100", "--resources", "{\"Custom\": 1}" ]) subprocess.Popen(["ray", "stop"]).wait() diff --git a/test/runtest.py b/test/runtest.py index a44543a21294..916e05e98dbd 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -19,8 +19,10 @@ def assert_equal(obj1, obj2): - module_numpy = (type(obj1).__module__ == np.__name__ - or type(obj2).__module__ == np.__name__) + module_numpy = ( + type(obj1).__module__ == np.__name__ + or type(obj2).__module__ == np.__name__ + ) if module_numpy: empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or (hasattr(obj2, "shape") and obj2.shape == ())) @@ -28,17 +30,20 @@ def assert_equal(obj1, obj2): # This is a special case because currently np.testing.assert_equal # fails because we do not properly handle different numerical # types. - assert obj1 == obj2, ("Objects {} and {} are " - "different.".format(obj1, obj2)) + assert obj1 == obj2, ( + "Objects {} and {} are " + "different.".format(obj1, obj2) + ) else: np.testing.assert_equal(obj1, obj2) elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): special_keys = ["_pytype_"] - assert (set(list(obj1.__dict__.keys()) + special_keys) == set( - list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " - "and {} are " - "different.".format( - obj1, obj2)) + assert ( + set(list(obj1.__dict__.keys()) + + special_keys) == set(list(obj2.__dict__.keys()) + special_keys) + ), ("Objects {} " + "and {} are " + "different.".format(obj1, obj2)) for key in obj1.__dict__.keys(): if key not in special_keys: assert_equal(obj1.__dict__[key], obj2.__dict__[key]) @@ -47,27 +52,33 @@ def assert_equal(obj1, obj2): for key in obj1.keys(): assert_equal(obj1[key], obj2[key]) elif type(obj1) is list or type(obj2) is list: - assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " - "different lengths.".format( - obj1, obj2)) + assert len(obj1) == len(obj2), ( + "Objects {} and {} are lists with " + "different lengths.".format(obj1, obj2) + ) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif type(obj1) is tuple or type(obj2) is tuple: - assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " - "different lengths.".format( - obj1, obj2)) + assert len(obj1) == len(obj2), ( + "Objects {} and {} are tuples with " + "different lengths.".format(obj1, obj2) + ) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) - elif (ray.serialization.is_named_tuple(type(obj1)) - or ray.serialization.is_named_tuple(type(obj2))): - assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples " - "with different lengths.".format( - obj1, obj2)) + elif ( + ray.serialization.is_named_tuple(type(obj1)) + or ray.serialization.is_named_tuple(type(obj2)) + ): + assert len(obj1) == len(obj2), ( + "Objects {} and {} are named tuples " + "with different lengths.".format(obj1, obj2) + ) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) else: assert obj1 == obj2, "Objects {} and {} are different.".format( - obj1, obj2) + obj1, obj2 + ) if sys.version_info >= (3, 0): @@ -107,7 +118,8 @@ def assert_equal(obj1, obj2): # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { # (): {(): {}}}}}}}}}}}}}, ( - (((((((((), ), ), ), ), ), ), ), ), ), + (((((((((), ), ), ), ), ), ), ), ), + ), { "a": { "b": { @@ -161,8 +173,9 @@ class CustomError(Exception): Point = namedtuple("Point", ["x", "y"]) -NamedTupleExample = namedtuple("Example", - "field1, field2, field3, field4, field5") +NamedTupleExample = namedtuple( + "Example", "field1, field2, field3, field4, field5" +) CUSTOM_OBJECTS = [ Exception("Test object."), @@ -180,16 +193,17 @@ class CustomError(Exception): TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS] # The check that type(obj).__module__ != "numpy" should be unnecessary, but # otherwise this seems to fail on Mac OS X on Travis. -DICT_OBJECTS = ( - [{ +DICT_OBJECTS = ([ + { obj: obj - } for obj in PRIMITIVE_OBJECTS - if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + - [{ - 0: obj - } for obj in BASE_OBJECTS] + [{ - Foo(123): Foo(456) - }]) + } + for obj in PRIMITIVE_OBJECTS + if (obj.__hash__ is not None and type(obj).__module__ != "numpy") +] + [{ + 0: obj +} for obj in BASE_OBJECTS] + [{ + Foo(123): Foo(456) +}]) RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS @@ -300,7 +314,8 @@ def testPythonWorkers(self): ray.worker._init( num_workers=num_workers, start_workers_from_local_scheduler=False, - start_ray_local=True) + start_ray_local=True + ) @ray.remote def f(x): @@ -360,28 +375,28 @@ def custom_deserializer(serialized_obj): return serialized_obj, "string2" ray.register_custom_serializer( - Foo, - serializer=custom_serializer, - deserializer=custom_deserializer) + Foo, serializer=custom_serializer, deserializer=custom_deserializer + ) self.assertEqual( - ray.get(ray.put(Foo())), ((3, "string1", Foo.__name__), "string2")) + ray.get(ray.put(Foo())), ((3, "string1", Foo.__name__), "string2") + ) class Bar(object): def __init__(self): self.x = 3 ray.register_custom_serializer( - Bar, - serializer=custom_serializer, - deserializer=custom_deserializer) + Bar, serializer=custom_serializer, deserializer=custom_deserializer + ) @ray.remote def f(): return Bar() self.assertEqual( - ray.get(f.remote()), ((3, "string1", Bar.__name__), "string2")) + ray.get(f.remote()), ((3, "string1", Bar.__name__), "string2") + ) def testRegisterClass(self): self.init_ray(num_workers=2) @@ -705,9 +720,8 @@ def g(): assert ray.get(f._submit(args=[2], num_return_vals=2)) == [0, 1] assert ray.get(f._submit(args=[3], num_return_vals=3)) == [0, 1, 2] assert ray.get( - g._submit( - args=[], num_cpus=1, num_gpus=1, resources={"Custom": - 1})) == [0] + g._submit(args=[], num_cpus=1, num_gpus=1, resources={"Custom": 1}) + ) == [0] def testGetMultiple(self): self.init_ray() @@ -728,12 +742,7 @@ def f(delay): time.sleep(delay) return 1 - objectids = [ - f.remote(1.0), - f.remote(0.5), - f.remote(0.5), - f.remote(0.5) - ] + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] ready_ids, remaining_ids = ray.wait(objectids) self.assertEqual(len(ready_ids), 1) self.assertEqual(len(remaining_ids), 3) @@ -741,25 +750,16 @@ def f(delay): self.assertEqual(set(ready_ids), set(objectids)) self.assertEqual(remaining_ids, []) - objectids = [ - f.remote(0.5), - f.remote(0.5), - f.remote(0.5), - f.remote(0.5) - ] + objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)] start_time = time.time() ready_ids, remaining_ids = ray.wait( - objectids, timeout=1750, num_returns=4) + objectids, timeout=1750, num_returns=4 + ) self.assertLess(time.time() - start_time, 2) self.assertEqual(len(ready_ids), 3) self.assertEqual(len(remaining_ids), 1) ray.wait(objectids) - objectids = [ - f.remote(1.0), - f.remote(0.5), - f.remote(0.5), - f.remote(0.5) - ] + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] start_time = time.time() ready_ids, remaining_ids = ray.wait(objectids, timeout=5000) self.assertTrue(time.time() - start_time < 5) @@ -897,7 +897,8 @@ def events(): res = [] for key in keys: res.extend( - ray.worker.global_worker.redis_client.zrange(key, 0, -1)) + ray.worker.global_worker.redis_client.zrange(key, 0, -1) + ) return res def wait_for_num_events(num_events, timeout=10): @@ -1064,7 +1065,8 @@ def f(): num_returns = 5 object_ids = [ray.put(i) for i in range(20)] ready, remaining = ray.wait( - object_ids, num_returns=num_returns, timeout=None) + object_ids, num_returns=num_returns, timeout=None + ) assert_equal(ready, object_ids[:num_returns]) assert_equal(remaining, object_ids[num_returns:]) @@ -1114,10 +1116,12 @@ def get_worker_id(): # Attempt to wait for all of the workers to start up. while True: if len( - set( - ray.get([ - get_worker_id.remote() for _ in range(num_workers) - ]))) == num_workers: + set( + ray.get([ + get_worker_id.remote() for _ in range(num_workers) + ]) + ) + ) == num_workers: break time_buffer = 0.3 @@ -1189,10 +1193,12 @@ def get_worker_id(): # Attempt to wait for all of the workers to start up. while True: if len( - set( - ray.get([ - get_worker_id.remote() for _ in range(num_workers) - ]))) == num_workers: + set( + ray.get([ + get_worker_id.remote() for _ in range(num_workers) + ]) + ) + ) == num_workers: break @ray.remote(num_cpus=1, num_gpus=9) @@ -1238,8 +1244,11 @@ def f0(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 0 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1249,8 +1258,11 @@ def f1(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 1 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1260,8 +1272,11 @@ def f2(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 2 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1271,8 +1286,11 @@ def f3(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 3 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1282,8 +1300,11 @@ def f4(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 4 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1293,8 +1314,11 @@ def f5(): time.sleep(0.1) gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 5 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) for gpu_id in gpu_ids: assert gpu_id in range(num_gpus) return gpu_ids @@ -1310,8 +1334,10 @@ def f(): if len(set(ray.get([f.remote() for _ in range(10)]))) == 10: break if time.time() > start_time + 10: - raise Exception("Timed out while waiting for workers to start " - "up.") + raise Exception( + "Timed out while waiting for workers to start " + "up." + ) list_of_ids = ray.get([f0.remote() for _ in range(10)]) self.assertEqual(list_of_ids, 10 * [[]]) @@ -1346,16 +1372,22 @@ class Actor0(object): def __init__(self): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 0 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) # Set self.x to make sure that we got here. self.x = 1 def test(self): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 0 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) return self.x @ray.remote(num_gpus=1) @@ -1363,16 +1395,22 @@ class Actor1(object): def __init__(self): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 1 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) # Set self.x to make sure that we got here. self.x = 1 def test(self): gpu_ids = ray.get_gpu_ids() assert len(gpu_ids) == 1 - assert (os.environ["CUDA_VISIBLE_DEVICES"] == ",".join( - [str(i) for i in gpu_ids])) + assert ( + os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([ + str(i) for i in gpu_ids + ]) + ) return self.x a0 = Actor0.remote() @@ -1383,7 +1421,8 @@ def test(self): def testZeroCPUs(self): ray.worker._init( - start_ray_local=True, num_local_schedulers=2, num_cpus=[0, 2]) + start_ray_local=True, num_local_schedulers=2, num_cpus=[0, 2] + ) local_plasma = ray.worker.global_worker.plasma_client.store_socket_name @@ -1410,7 +1449,8 @@ def testMultipleLocalSchedulers(self): num_local_schedulers=3, num_workers=1, num_cpus=[100, 5, 10], - num_gpus=[0, 5, 1]) + num_gpus=[0, 5, 1] + ) # Define a bunch of remote functions that all return the socket name of # the plasma store. Since there is a one-to-one correspondence between @@ -1488,7 +1528,8 @@ def validate_names_and_results(names, results): elif name == "run_on_0_1_2": self.assertIn( result, - [store_names[0], store_names[1], store_names[2]]) + [store_names[0], store_names[1], store_names[2]] + ) elif name == "run_on_1_2": self.assertIn(result, [store_names[1], store_names[2]]) elif name == "run_on_0_2": @@ -1524,7 +1565,8 @@ def testCustomResources(self): "CustomResource": 0 }, { "CustomResource": 1 - }]) + }] + ) @ray.remote def f(): @@ -1566,7 +1608,8 @@ def testTwoCustomResources(self): }, { "CustomResource1": 3, "CustomResource2": 4 - }]) + }] + ) @ray.remote(resources={"CustomResource1": 1}) def f(): @@ -1606,8 +1649,8 @@ def k(): # Make sure that tasks with unsatisfied custom resource requirements do # not get scheduled. - ready_ids, remaining_ids = ray.wait( - [j.remote(), k.remote()], timeout=500) + ready_ids, remaining_ids = ray.wait([j.remote(), k.remote()], + timeout=500) self.assertEqual(ready_ids, []) def testManyCustomResources(self): @@ -1624,8 +1667,8 @@ def f(): remote_functions = [] for _ in range(20): num_resources = np.random.randint(0, num_custom_resources + 1) - permuted_resources = np.random.permutation( - num_custom_resources)[:num_resources] + permuted_resources = np.random.permutation(num_custom_resources + )[:num_resources] random_resources = { str(i): total_resources[str(i)] for i in permuted_resources @@ -1661,8 +1704,9 @@ def tearDown(self): def testSpecificGPUs(self): allowed_gpu_ids = [4, 5, 6] - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - [str(i) for i in allowed_gpu_ids]) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([ + str(i) for i in allowed_gpu_ids + ]) ray.init(num_gpus=3) @ray.remote(num_gpus=1) @@ -1752,22 +1796,27 @@ class SchedulingAlgorithm(unittest.TestCase): def tearDown(self): ray.worker.cleanup() - def attempt_to_load_balance(self, - remote_function, - args, - total_tasks, - num_local_schedulers, - minimum_count, - num_attempts=100): + def attempt_to_load_balance( + self, + remote_function, + args, + total_tasks, + num_local_schedulers, + minimum_count, + num_attempts=100 + ): attempts = 0 while attempts < num_attempts: - locations = ray.get( - [remote_function.remote(*args) for _ in range(total_tasks)]) + locations = ray.get([ + remote_function.remote(*args) for _ in range(total_tasks) + ]) names = set(locations) counts = [locations.count(name) for name in names] print("Counts are {}.".format(counts)) - if (len(names) == num_local_schedulers - and all([count >= minimum_count for count in counts])): + if ( + len(names) == num_local_schedulers + and all([count >= minimum_count for count in counts]) + ): break attempts += 1 self.assertLess(attempts, num_attempts) @@ -1780,7 +1829,8 @@ def testLoadBalancing(self): ray.worker._init( start_ray_local=True, num_local_schedulers=num_local_schedulers, - num_cpus=num_cpus) + num_cpus=num_cpus + ) @ray.remote def f(): @@ -1799,7 +1849,8 @@ def testLoadBalancingWithDependencies(self): ray.worker._init( start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers) + num_local_schedulers=num_local_schedulers + ) @ray.remote def f(x): @@ -1833,7 +1884,8 @@ def wait_for_num_objects(num_objects, timeout=10): @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "New GCS API doesn't have a Python API yet.") + "New GCS API doesn't have a Python API yet." +) class GlobalStateAPI(unittest.TestCase): def tearDown(self): ray.worker.cleanup() @@ -1864,28 +1916,37 @@ def testGlobalStateAPI(self): ID_SIZE = 20 driver_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.worker_id) + ray.worker.global_worker.worker_id + ) driver_task_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.current_task_id.id()) + ray.worker.global_worker.current_task_id.id() + ) # One task is put in the task table which corresponds to this driver. wait_for_num_tasks(1) task_table = ray.global_state.task_table() self.assertEqual(len(task_table), 1) self.assertEqual(driver_task_id, list(task_table.keys())[0]) - self.assertEqual(task_table[driver_task_id]["State"], - ray.experimental.state.TASK_STATUS_RUNNING) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], - driver_task_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], - ID_SIZE * "ff") + self.assertEqual( + task_table[driver_task_id]["State"], + ray.experimental.state.TASK_STATUS_RUNNING + ) + self.assertEqual( + task_table[driver_task_id]["TaskSpec"]["TaskID"], driver_task_id + ) + self.assertEqual( + task_table[driver_task_id]["TaskSpec"]["ActorID"], ID_SIZE * "ff" + ) self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], []) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["DriverID"], - driver_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], - ID_SIZE * "ff") self.assertEqual( - (task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), []) + task_table[driver_task_id]["TaskSpec"]["DriverID"], driver_id + ) + self.assertEqual( + task_table[driver_task_id]["TaskSpec"]["FunctionID"], ID_SIZE * "ff" + ) + self.assertEqual( + (task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), [] + ) client_table = ray.global_state.client_table() node_ip_address = ray.worker.global_worker.node_ip_address @@ -1925,8 +1986,9 @@ def f(*xs): self.assertEqual(function_table_entry["DriverID"], driver_id) self.assertEqual(function_table_entry["Module"], "__main__") - self.assertEqual(task_table[task_id], - ray.global_state.task_table(task_id)) + self.assertEqual( + task_table[task_id], ray.global_state.task_table(task_id) + ) # Wait for two objects, one for the x_id and one for result_id. wait_for_num_objects(2) @@ -1938,12 +2000,15 @@ def wait_for_object_table(): object_table = ray.global_state.object_table() tables_ready = ( object_table[x_id]["ManagerIDs"] is not None - and object_table[result_id]["ManagerIDs"] is not None) + and object_table[result_id]["ManagerIDs"] is not None + ) if tables_ready: return time.sleep(0.1) - raise Exception("Timed out while waiting for object table to " - "update.") + raise Exception( + "Timed out while waiting for object table to " + "update." + ) # Wait for the object table to be updated. wait_for_object_table() @@ -1952,18 +2017,23 @@ def wait_for_object_table(): self.assertEqual(object_table[x_id]["IsPut"], True) self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) - self.assertEqual(object_table[x_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + self.assertEqual( + object_table[x_id]["ManagerIDs"], [manager_client["DBClientID"]] + ) self.assertEqual(object_table[result_id]["IsPut"], False) self.assertEqual(object_table[result_id]["TaskID"], task_id) - self.assertEqual(object_table[result_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + self.assertEqual( + object_table[result_id]["ManagerIDs"], + [manager_client["DBClientID"]] + ) - self.assertEqual(object_table[x_id], - ray.global_state.object_table(x_id)) - self.assertEqual(object_table[result_id], - ray.global_state.object_table(result_id)) + self.assertEqual( + object_table[x_id], ray.global_state.object_table(x_id) + ) + self.assertEqual( + object_table[result_id], ray.global_state.object_table(result_id) + ) def testLogFileAPI(self): ray.init(redirect_worker_output=True) @@ -2009,9 +2079,11 @@ def f(): start_time = time.time() while time.time() - start_time < 10: profiles = ray.global_state.task_profiles( - 100, start=0, end=time.time()) + 100, start=0, end=time.time() + ) limited_profiles = ray.global_state.task_profiles( - 1, start=0, end=time.time()) + 1, start=0, end=time.time() + ) if len(profiles) == num_calls and len(limited_profiles) == 1: break time.sleep(0.1) @@ -2032,7 +2104,8 @@ def testWorkers(self): ray.init( redirect_worker_output=True, num_cpus=num_workers, - num_workers=num_workers) + num_workers=num_workers + ) @ray.remote def f(): @@ -2084,7 +2157,8 @@ def method(self): path = os.path.join("/tmp/ray_test_trace") task_info = ray.global_state.task_profiles( - 100, start=0, end=time.time()) + 100, start=0, end=time.time() + ) ray.global_state.dump_catapult_trace(path, task_info) # TODO(rkn): This test is not perfect because it does not verify that @@ -2118,16 +2192,19 @@ def f(): for object_info in object_table.values(): if len(object_info) != 5: tables_ready = False - if (object_info["ManagerIDs"] is None - or object_info["DataSize"] == -1 - or object_info["Hash"] == ""): + if ( + object_info["ManagerIDs"] is None + or object_info["DataSize"] == -1 + or object_info["Hash"] == "" + ): tables_ready = False if len(task_table) != 10 + 1: tables_ready = False driver_task_id = ray.utils.binary_to_hex( - ray.worker.global_worker.current_task_id.id()) + ray.worker.global_worker.current_task_id.id() + ) for info in task_table.values(): if info["State"] != ray.experimental.state.TASK_STATUS_DONE: diff --git a/test/stress_tests.py b/test/stress_tests.py index 62bf3604e72a..bdeafa643ef9 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -18,7 +18,8 @@ def testSubmittingTasks(self): start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers, - num_cpus=100) + num_cpus=100 + ) @ray.remote def f(x): @@ -47,7 +48,8 @@ def testDependencies(self): start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers, - num_cpus=100) + num_cpus=100 + ) @ray.remote def f(x): @@ -126,7 +128,8 @@ def testWait(self): start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers, - num_cpus=100) + num_cpus=100 + ) @ray.remote def f(x): @@ -169,13 +172,15 @@ def setUp(self): self.plasma_store_memory = 10**9 plasma_addresses = [] objstore_memory = ( - self.plasma_store_memory // self.num_local_schedulers) + self.plasma_store_memory // self.num_local_schedulers + ) for i in range(self.num_local_schedulers): store_stdout_file, store_stderr_file = ray.services.new_log_files( - "plasma_store_{}".format(i), True) + "plasma_store_{}".format(i), True + ) manager_stdout_file, manager_stderr_file = ( - ray.services.new_log_files("plasma_manager_{}".format(i), - True)) + ray.services.new_log_files("plasma_manager_{}".format(i), True) + ) plasma_addresses.append( ray.services.start_objstore( node_ip_address, @@ -184,7 +189,9 @@ def setUp(self): store_stdout_file=store_stdout_file, store_stderr_file=store_stderr_file, manager_stdout_file=manager_stdout_file, - manager_stderr_file=manager_stderr_file)) + manager_stderr_file=manager_stderr_file + ) + ) # Start the rest of the services in the Ray cluster. address_info = { @@ -199,7 +206,8 @@ def setUp(self): num_local_schedulers=self.num_local_schedulers, num_cpus=[1] * self.num_local_schedulers, redirect_output=True, - driver_mode=ray.SILENT_MODE) + driver_mode=ray.SILENT_MODE + ) def tearDown(self): self.assertTrue(ray.services.all_processes_alive()) @@ -211,7 +219,8 @@ def tearDown(self): if os.environ.get('RAY_USE_NEW_GCS', False): tasks = state.task_table() local_scheduler_ids = set( - task["LocalSchedulerID"] for task in tasks.values()) + task["LocalSchedulerID"] for task in tasks.values() + ) # Make sure that all nodes in the cluster were used by checking that # the set of local scheduler IDs that had a task scheduled or submitted @@ -222,14 +231,16 @@ def tearDown(self): # scheduler. if os.environ.get('RAY_USE_NEW_GCS', False): self.assertEqual( - len(local_scheduler_ids), self.num_local_schedulers + 1) + len(local_scheduler_ids), self.num_local_schedulers + 1 + ) # Clean up the Ray cluster. ray.worker.cleanup() @unittest.skipIf( os.environ.get('RAY_USE_NEW_GCS', False), - "Failing with new GCS API on Linux.") + "Failing with new GCS API on Linux." + ) def testSimple(self): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' @@ -267,7 +278,8 @@ def foo(i, size): del values @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API." + ) def testRecursive(self): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' @@ -320,7 +332,8 @@ def single_dependency(i, arg): del values @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Failing with new GCS API." + ) def testMultipleRecursive(self): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' @@ -386,7 +399,8 @@ def wait_for_errors(self, error_check): return errors @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testNondeterministicTask(self): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' @@ -446,10 +460,12 @@ def error_check(errors): errors = self.wait_for_errors(error_check) # Make sure all the errors have the correct type. self.assertTrue( - all(error[b"type"] == b"object_hash_mismatch" for error in errors)) + all(error[b"type"] == b"object_hash_mismatch" for error in errors) + ) @unittest.skipIf( - os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.") + os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API." + ) def testDriverPutErrors(self): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' @@ -485,14 +501,16 @@ def single_dependency(i, arg): # for-loop should hang on its first iteration and push an error to the # driver. ray.worker.global_worker.local_scheduler_client.reconstruct_object( - args[0].id()) + args[0].id() + ) def error_check(errors): return len(errors) > 1 errors = self.wait_for_errors(error_check) self.assertTrue( - all(error[b"type"] == b"put_reconstruction" for error in errors)) + all(error[b"type"] == b"put_reconstruction" for error in errors) + ) class ReconstructionTestsMultinode(ReconstructionTests): diff --git a/test/tensorflow_test.py b/test/tensorflow_test.py index a7518c69d1e7..b82b664256e4 100644 --- a/test/tensorflow_test.py +++ b/test/tensorflow_test.py @@ -18,8 +18,10 @@ def make_linear_network(w_name=None, b_name=None): b = tf.Variable(tf.zeros([1]), name=b_name) y = w * x_data + b # Return the loss and weight initializer. - return (tf.reduce_mean(tf.square(y - y_data)), - tf.global_variables_initializer(), x_data, y_data) + return ( + tf.reduce_mean(tf.square(y - y_data)), + tf.global_variables_initializer(), x_data, y_data + ) class LossActor(object): @@ -32,7 +34,8 @@ def __init__(self, use_loss=True): sess = tf.Session() # Additional code for setting and getting the weights. weights = ray.experimental.TensorFlowVariables( - loss if use_loss else None, sess, input_variables=var) + loss if use_loss else None, sess, input_variables=var + ) # Return all of the data needed to use the network. self.values = [weights, init, sess] sess.run(init) @@ -78,16 +81,18 @@ def __init__(self): grads = optimizer.compute_gradients(loss) train = optimizer.apply_gradients(grads) self.values = [ - loss, variables, init, sess, grads, train, [x_data, y_data] + loss, variables, init, sess, grads, train, + [x_data, y_data] ] sess.run(init) def training_step(self, weights): _, variables, _, sess, grads, _, placeholders = self.values variables.set_weights(weights) - return sess.run( - [grad[0] for grad in grads], - feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100]))) + return sess.run([grad[0] for grad in grads], + feed_dict=dict( + zip(placeholders, [[1] * 100, [2] * 100]) + )) def get_weights(self): return self.values[1].get_weights() @@ -213,7 +218,8 @@ def testNetworkDriverWorkerIndependent(self): weights2 = ray.get(net2.get_weights.remote()) new_weights2 = ray.get( - net2.set_and_get_weights.remote(net2.get_weights.remote())) + net2.set_and_get_weights.remote(net2.get_weights.remote()) + ) self.assertEqual(weights2, new_weights2) def testVariablesControlDependencies(self): @@ -244,7 +250,8 @@ def testRemoteTrainingLoss(self): loss, variables, _, sess, grads, train, placeholders = net_values before_acc = sess.run( - loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100]))) + loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])) + ) for _ in range(3): gradients_list = ray.get([ @@ -262,7 +269,8 @@ def testRemoteTrainingLoss(self): } sess.run(train, feed_dict=feed_dict) after_acc = sess.run( - loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100]))) + loss, feed_dict=dict(zip(placeholders, [[2] * 100, [4] * 100])) + ) self.assertTrue(before_acc < after_acc)