diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 669ff38a4532..f167caf787e1 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -160,11 +160,14 @@ def fetch_and_execute_function_to_run(self, key): (driver_id, serialized_function, run_on_other_drivers) = self.redis_client.hmget( key, ["driver_id", "function", "run_on_other_drivers"]) - - if (run_on_other_drivers == "False" - and self.worker.mode == ray.SCRIPT_MODE - and driver_id != self.worker.task_driver_id.id()): - return + run_on_other_drivers = ray.utils.decode(run_on_other_drivers) == "True" + + # If this is a driver, we should only import and run the function if it + # is from another driver and run_on_other_drivers is True. + if self.worker.mode == ray.SCRIPT_MODE: + if (driver_id == self.worker.task_driver_id.id() + or not run_on_other_drivers): + return try: # Deserialize the function. diff --git a/test/runtest.py b/test/runtest.py index 9301a80f7dd9..9922359f967d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -955,12 +955,19 @@ def f(worker_info): ray.worker.global_worker.run_function_on_all_workers(f) - self.init_ray() + self.init_ray(num_cpus=1) + + def get_modified_path(): + return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1] + + # run_function_on_all_workers also runs on the driver, so make sure it + # did the right thing. + assert get_modified_path() == (1, 2, 3, 4) @ray.remote def get_state(): time.sleep(1) - return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1] + return get_modified_path() res1 = get_state.remote() res2 = get_state.remote()