diff --git a/BUILD.bazel b/BUILD.bazel
index 1d5861afeff7..a9f389fe7c4e 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1783,11 +1783,6 @@ filegroup(
"python/ray/experimental/*.py",
"python/ray/util/*.py",
"python/ray/internal/*.py",
- "python/ray/projects/*.py",
- "python/ray/projects/schema.json",
- "python/ray/projects/templates/cluster_template.yaml",
- "python/ray/projects/templates/project_template.yaml",
- "python/ray/projects/templates/requirements.txt",
"python/ray/workers/default_worker.py",
]),
)
diff --git a/bazel/python.bzl b/bazel/python.bzl
index 2b6f548a24d0..523792fdc4eb 100644
--- a/bazel/python.bzl
+++ b/bazel/python.bzl
@@ -1,12 +1,22 @@
-# py_test_module_list creates a py_test target for each
+# py_test_module_list creates a py_test target for each
# Python file in `files`
def py_test_module_list(files, size, deps, extra_srcs, **kwargs):
- for file in files:
- # remove .py
- name = file[:-3]
- native.py_test(
- name=name,
- size=size,
- srcs=extra_srcs+[file],
- **kwargs
- )
+ for file in files:
+ # remove .py
+ name = file[:-3]
+ native.py_test(
+ name = name,
+ size = size,
+ srcs = extra_srcs + [file],
+ **kwargs
+ )
+
+def py_test_run_all_subdirectory(include, exclude, extra_srcs, **kwargs):
+ for file in native.glob(include = include, exclude = exclude):
+ print(file)
+ basename = file.rpartition("/")[-1]
+ native.py_test(
+ name = basename[:-3],
+ srcs = extra_srcs + [file],
+ **kwargs
+ )
diff --git a/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py b/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py
index 94eafa2de9c3..6ab4ea61cacd 100644
--- a/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py
+++ b/ci/long_running_distributed_tests/workloads/pytorch_pbt_failure.py
@@ -93,7 +93,7 @@ def optimizer_creator(model, config):
momentum=config.get("momentum", 0.9))
-ray.init(address="auto" if not args.smoke_test else None, log_to_driver=True)
+ray.init(address="auto" if not args.smoke_test else None, _log_to_driver=True)
num_training_workers = 1 if args.smoke_test else 3
executor = FailureInjectorExecutor(queue_trials=True)
diff --git a/ci/long_running_tests/workloads/serve.py b/ci/long_running_tests/workloads/serve.py
index 0f401c83621a..f3db5f7ffbde 100644
--- a/ci/long_running_tests/workloads/serve.py
+++ b/ci/long_running_tests/workloads/serve.py
@@ -38,7 +38,8 @@
@serve.accept_batch
def echo(_):
time.sleep(0.01) # Sleep for 10ms
- ray.show_in_webui(str(serve.context.batch_size), key="Current batch size")
+ ray.show_in_dashboard(
+ str(serve.context.batch_size), key="Current batch size")
return ["hi {}".format(i) for i in range(serve.context.batch_size)]
diff --git a/ci/long_running_tests/workloads/serve_failure.py b/ci/long_running_tests/workloads/serve_failure.py
index 406d35e55cfe..04577b6dc0a0 100644
--- a/ci/long_running_tests/workloads/serve_failure.py
+++ b/ci/long_running_tests/workloads/serve_failure.py
@@ -26,7 +26,7 @@
dashboard_host="0.0.0.0")
ray.init(
- address=cluster.address, dashboard_host="0.0.0.0", log_to_driver=False)
+ address=cluster.address, dashboard_host="0.0.0.0", _log_to_driver=False)
serve.init()
diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh
index e9a95289bc46..8ad938525c7c 100755
--- a/ci/travis/ci.sh
+++ b/ci/travis/ci.sh
@@ -153,7 +153,6 @@ test_python() {
-python/ray/tests:test_multiprocessing # test_connect_to_ray() fails to connect to raylet
-python/ray/tests:test_node_manager
-python/ray/tests:test_object_manager
- -python/ray/tests:test_projects
-python/ray/tests:test_ray_init # test_redis_port() seems to fail here, but pass in isolation
-python/ray/tests:test_resource_demand_scheduler
-python/ray/tests:test_stress # timeout
@@ -279,12 +278,12 @@ build_wheels() {
# caused timeouts in the past. See the "cache: false" line below.
local MOUNT_BAZEL_CACHE=(
-v "${HOME}/ray-bazel-cache":/root/ray-bazel-cache
- -e TRAVIS=true
- -e TRAVIS_PULL_REQUEST="${TRAVIS_PULL_REQUEST:-false}"
- -e encrypted_1c30b31fe1ee_key="${encrypted_1c30b31fe1ee_key-}"
- -e encrypted_1c30b31fe1ee_iv="${encrypted_1c30b31fe1ee_iv-}"
- -e TRAVIS_COMMIT="${TRAVIS_COMMIT}"
- -e CI="${CI}"
+ -e "TRAVIS=true"
+ -e "TRAVIS_PULL_REQUEST=${TRAVIS_PULL_REQUEST:-false}"
+ -e "encrypted_1c30b31fe1ee_key=${encrypted_1c30b31fe1ee_key-}"
+ -e "encrypted_1c30b31fe1ee_iv=${encrypted_1c30b31fe1ee_iv-}"
+ -e "TRAVIS_COMMIT=${TRAVIS_COMMIT}"
+ -e "CI=${CI}"
)
# This command should be kept in sync with ray/python/README-building-wheels.md,
diff --git a/dashboard/BUILD b/dashboard/BUILD
index 15ed537e68be..4a4d0894c510 100644
--- a/dashboard/BUILD
+++ b/dashboard/BUILD
@@ -1,13 +1,19 @@
+load("//bazel:python.bzl", "py_test_run_all_subdirectory")
+
# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
name = "dashboard_lib",
- srcs = glob(["**/*.py"],exclude=["tests/*"]),
+ srcs = glob(
+ ["**/*.py"],
+ exclude = ["tests/*"],
+ ),
)
-py_test(
- name = "test_dashboard",
+py_test_run_all_subdirectory(
size = "small",
- srcs = glob(["tests/*.py"]),
- deps = [":dashboard_lib"]
+ include = ["**/test*.py"],
+ exclude = ["modules/test/**"],
+ extra_srcs = [],
+ tags = ["exclusive"],
)
diff --git a/dashboard/agent.py b/dashboard/agent.py
index f0fce2f917f7..b7c379da3cb5 100644
--- a/dashboard/agent.py
+++ b/dashboard/agent.py
@@ -24,6 +24,11 @@
from ray.core.generated import agent_manager_pb2_grpc
import psutil
+try:
+ create_task = asyncio.create_task
+except AttributeError:
+ create_task = asyncio.ensure_future
+
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
@@ -36,6 +41,7 @@ def __init__(self,
redis_password=None,
temp_dir=None,
log_dir=None,
+ metrics_export_port=None,
node_manager_port=None,
object_store_name=None,
raylet_name=None):
@@ -45,6 +51,7 @@ def __init__(self,
self.redis_password = redis_password
self.temp_dir = temp_dir
self.log_dir = log_dir
+ self.metrics_export_port = metrics_export_port
self.node_manager_port = node_manager_port
self.object_store_name = object_store_name
self.raylet_name = raylet_name
@@ -69,7 +76,7 @@ def _load_modules(self):
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
- logger.info("Loaded {} modules.".format(len(modules)))
+ logger.info("Loaded %d modules.", len(modules))
return modules
async def run(self):
@@ -86,7 +93,7 @@ async def _check_parent():
dashboard_consts.
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
- check_parent_task = asyncio.create_task(_check_parent())
+ check_parent_task = create_task(_check_parent())
# Create an aioredis client for all modules.
try:
@@ -172,6 +179,11 @@ async def _check_parent():
required=True,
type=str,
help="The address to use for Redis.")
+ parser.add_argument(
+ "--metrics-export-port",
+ required=True,
+ type=int,
+ help="The port to expose metrics through Prometheus.")
parser.add_argument(
"--node-manager-port",
required=True,
@@ -277,6 +289,7 @@ async def _check_parent():
redis_password=args.redis_password,
temp_dir=temp_dir,
log_dir=log_dir,
+ metrics_export_port=args.metrics_export_port,
node_manager_port=args.node_manager_port,
object_store_name=args.object_store_name,
raylet_name=args.raylet_name)
diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py
index f8f19ef212af..6abc9293f88e 100644
--- a/dashboard/dashboard.py
+++ b/dashboard/dashboard.py
@@ -37,10 +37,10 @@ def setup_static_dir():
errno.ENOENT, "Dashboard build directory not found. If installing "
"from source, please follow the additional steps "
"required to build the dashboard"
- "(cd python/ray/{}/client "
+ f"(cd python/ray/{module_name}/client "
"&& npm install "
"&& npm ci "
- "&& npm run build)".format(module_name), build_dir)
+ "&& npm run build)", build_dir)
static_dir = os.path.join(build_dir, "static")
routes.static("/static", static_dir, follow_symlinks=True)
@@ -59,14 +59,21 @@ class Dashboard:
port(int): Port number of dashboard aiohttp server.
redis_address(str): GCS address of a Ray cluster
redis_password(str): Redis password to access GCS
+ log_dir(str): Log directory of dashboard.
"""
- def __init__(self, host, port, redis_address, redis_password=None):
+ def __init__(self,
+ host,
+ port,
+ redis_address,
+ redis_password=None,
+ log_dir=None):
self.dashboard_head = dashboard_head.DashboardHead(
http_host=host,
http_port=port,
redis_address=redis_address,
- redis_password=redis_password)
+ redis_password=redis_password,
+ log_dir=log_dir)
# Setup Dashboard Routes
build_dir = setup_static_dir()
@@ -197,7 +204,8 @@ async def run(self):
args.host,
args.port,
args.redis_address,
- redis_password=args.redis_password)
+ redis_password=args.redis_password,
+ log_dir=log_dir)
loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run())
except Exception as e:
diff --git a/dashboard/datacenter.py b/dashboard/datacenter.py
index 65e3bb449628..1ee454917671 100644
--- a/dashboard/datacenter.py
+++ b/dashboard/datacenter.py
@@ -19,7 +19,7 @@ class DataSource:
# {actor id hex(str): actor table data(dict of ActorTableData
# in gcs.proto)}
actors = Dict()
- # {ip address(str): dashboard agent grpc server port(int)}
+ # {ip address(str): dashboard agent [http port(int), grpc port(int)]}
agents = Dict()
# {ip address(str): gcs node info(dict of GcsNodeInfo in gcs.proto)}
nodes = Dict()
@@ -56,7 +56,7 @@ async def get_node_actors(cls, hostname):
node_worker_id_set.add(worker_stats["workerId"])
node_actors = {}
for actor_id, actor_table_data in DataSource.actors.items():
- if actor_table_data["workerId"] in node_worker_id_set:
+ if actor_table_data["address"]["workerId"] in node_worker_id_set:
node_actors[actor_id] = actor_table_data
return node_actors
@@ -102,6 +102,7 @@ async def get_all_node_summary(cls):
for hostname in DataSource.hostname_to_ip.keys():
node_info = await cls.get_node_info(hostname)
node_info.pop("workers", None)
+ node_info.pop("actors", None)
node_info["raylet"].pop("workersStats", None)
node_info["raylet"].pop("viewData", None)
all_nodes_summary.append(node_info)
diff --git a/dashboard/head.py b/dashboard/head.py
index 7b40ef013064..9a8f273796f3 100644
--- a/dashboard/head.py
+++ b/dashboard/head.py
@@ -28,7 +28,8 @@ def gcs_node_info_to_dict(message):
class DashboardHead:
- def __init__(self, http_host, http_port, redis_address, redis_password):
+ def __init__(self, http_host, http_port, redis_address, redis_password,
+ log_dir):
# NodeInfoGcsService
self._gcs_node_info_stub = None
self._gcs_rpc_error_counter = 0
@@ -37,6 +38,7 @@ def __init__(self, http_host, http_port, redis_address, redis_password):
self.http_port = http_port
self.redis_address = dashboard_utils.address_tuple(redis_address)
self.redis_password = redis_password
+ self.log_dir = log_dir
self.aioredis_client = None
self.aiogrpc_gcs_channel = None
self.http_session = None
@@ -74,6 +76,22 @@ async def _update_nodes(self):
while True:
try:
nodes = await self._get_nodes()
+
+ # Get correct node info by state,
+ # 1. The node is ALIVE if any ALIVE node info
+ # of the hostname exists.
+ # 2. The node is DEAD if all node info of the
+ # hostname are DEAD.
+ hostname_to_node_info = {}
+ for node in nodes:
+ hostname = node["nodeManagerAddress"]
+ assert node["state"] in ["ALIVE", "DEAD"]
+ choose = hostname_to_node_info.get(hostname)
+ if choose is not None and choose["state"] == "ALIVE":
+ continue
+ hostname_to_node_info[hostname] = node
+ nodes = hostname_to_node_info.values()
+
self._gcs_rpc_error_counter = 0
node_ips = [node["nodeManagerAddress"] for node in nodes]
node_hostnames = [
@@ -83,13 +101,11 @@ async def _update_nodes(self):
agents = dict(DataSource.agents)
for node in nodes:
node_ip = node["nodeManagerAddress"]
- if node_ip not in agents:
- key = "{}{}".format(
- dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
- node_ip)
- agent_port = await self.aioredis_client.get(key)
- if agent_port:
- agents[node_ip] = json.loads(agent_port)
+ key = "{}{}".format(
+ dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, node_ip)
+ agent_port = await self.aioredis_client.get(key)
+ if agent_port:
+ agents[node_ip] = json.loads(agent_port)
for ip in agents.keys() - set(node_ips):
agents.pop(ip, None)
@@ -99,8 +115,8 @@ async def _update_nodes(self):
dict(zip(node_hostnames, node_ips)))
DataSource.ip_to_hostname.reset(
dict(zip(node_ips, node_hostnames)))
- except aiogrpc.AioRpcError as ex:
- logger.exception(ex)
+ except aiogrpc.AioRpcError:
+ logger.exception("Got AioRpcError when updating nodes.")
self._gcs_rpc_error_counter += 1
if self._gcs_rpc_error_counter > \
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR:
@@ -109,8 +125,8 @@ async def _update_nodes(self):
self._gcs_rpc_error_counter,
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR)
sys.exit(-1)
- except Exception as ex:
- logger.exception(ex)
+ except Exception:
+ logger.exception("Error updating nodes.")
finally:
await asyncio.sleep(
dashboard_consts.UPDATE_NODES_INTERVAL_SECONDS)
@@ -126,7 +142,7 @@ def _load_modules(self):
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
- logger.info("Loaded {} modules.".format(len(modules)))
+ logger.info("Loaded %d modules.", len(modules))
return modules
async def run(self):
@@ -183,8 +199,8 @@ async def _async_notify():
co = await dashboard_utils.NotifyQueue.get()
try:
await co
- except Exception as e:
- logger.exception(e)
+ except Exception:
+ logger.exception(f"Error notifying coroutine {co}")
async def _purge_data():
"""Purge data in datacenter."""
@@ -193,8 +209,8 @@ async def _purge_data():
dashboard_consts.PURGE_DATA_INTERVAL_SECONDS)
try:
await DataOrganizer.purge()
- except Exception as e:
- logger.exception(e)
+ except Exception:
+ logger.exception("Error purging data.")
modules = self._load_modules()
diff --git a/dashboard/modules/log/__init__.py b/dashboard/modules/log/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/dashboard/modules/log/log_agent.py b/dashboard/modules/log/log_agent.py
new file mode 100644
index 000000000000..8ef31d96852b
--- /dev/null
+++ b/dashboard/modules/log/log_agent.py
@@ -0,0 +1,19 @@
+import logging
+
+import mimetypes
+
+import ray.new_dashboard.utils as dashboard_utils
+
+logger = logging.getLogger(__name__)
+routes = dashboard_utils.ClassMethodRouteTable
+
+
+class LogAgent(dashboard_utils.DashboardAgentModule):
+ def __init__(self, dashboard_agent):
+ super().__init__(dashboard_agent)
+ mimetypes.add_type("text/plain", ".err")
+ mimetypes.add_type("text/plain", ".out")
+ routes.static("/logs", self._dashboard_agent.log_dir, show_index=True)
+
+ async def run(self, server):
+ pass
diff --git a/dashboard/modules/log/log_head.py b/dashboard/modules/log/log_head.py
new file mode 100644
index 000000000000..5f1a1f6dcf0a
--- /dev/null
+++ b/dashboard/modules/log/log_head.py
@@ -0,0 +1,69 @@
+import logging
+
+import mimetypes
+
+import aiohttp.web
+import ray.new_dashboard.utils as dashboard_utils
+from ray.new_dashboard.datacenter import DataSource, GlobalSignals
+
+logger = logging.getLogger(__name__)
+routes = dashboard_utils.ClassMethodRouteTable
+
+
+class LogHead(dashboard_utils.DashboardHeadModule):
+ LOG_URL_TEMPLATE = "http://{ip}:{port}/logs"
+
+ def __init__(self, dashboard_head):
+ super().__init__(dashboard_head)
+ mimetypes.add_type("text/plain", ".err")
+ mimetypes.add_type("text/plain", ".out")
+ routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
+ GlobalSignals.node_info_fetched.append(
+ self.insert_log_url_to_node_info)
+
+ async def insert_log_url_to_node_info(self, node_info):
+ ip = node_info.get("ip")
+ if ip is None:
+ return
+ agent_port = DataSource.agents.get(ip)
+ if agent_port is None:
+ return
+ agent_http_port, _ = agent_port
+ log_url = self.LOG_URL_TEMPLATE.format(ip=ip, port=agent_http_port)
+ node_info["logUrl"] = log_url
+
+ @routes.get("/log_index")
+ async def get_log_index(self, req) -> aiohttp.web.Response:
+ url_list = []
+ for ip, ports in DataSource.agents.items():
+ url_list.append(
+ self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
+ if self._dashboard_head.ip not in DataSource.agents:
+ url_list.append(
+ self.LOG_URL_TEMPLATE.format(
+ ip=self._dashboard_head.ip,
+ port=self._dashboard_head.http_port))
+ return aiohttp.web.Response(
+ text=self._directory_as_html(url_list), content_type="text/html")
+
+ @staticmethod
+ def _directory_as_html(url_list) -> str:
+ # returns directory's index as html
+
+ index_of = "Index of logs"
+ h1 = f"
{index_of}
"
+
+ index_list = []
+ for url in sorted(url_list):
+ index_list.append(f'{url}')
+ index_list = "\n".join(index_list)
+ ul = f""
+ body = f"\n{h1}\n{ul}\n"
+
+ head_str = f"\n{index_of}\n"
+ html = f"\n{head_str}\n{body}\n"
+
+ return html
+
+ async def run(self, server):
+ pass
diff --git a/dashboard/modules/log/test_log.py b/dashboard/modules/log/test_log.py
new file mode 100644
index 000000000000..a40643f9035e
--- /dev/null
+++ b/dashboard/modules/log/test_log.py
@@ -0,0 +1,115 @@
+import os
+import sys
+import logging
+import requests
+import socket
+import time
+import traceback
+import html.parser
+import urllib.parse
+
+import pytest
+import ray
+from ray.new_dashboard.tests.conftest import * # noqa
+from ray.test_utils import (
+ format_web_url,
+ wait_until_server_available,
+)
+
+os.environ["RAY_USE_NEW_DASHBOARD"] = "1"
+
+logger = logging.getLogger(__name__)
+
+
+class LogUrlParser(html.parser.HTMLParser):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._urls = []
+
+ def handle_starttag(self, tag, attrs):
+ if tag == "a":
+ self._urls.append(dict(attrs)["href"])
+
+ def error(self, message):
+ logger.error(message)
+
+ def get_urls(self):
+ return self._urls
+
+
+def test_log(ray_start_with_dashboard):
+ @ray.remote
+ def write_log(s):
+ print(s)
+
+ test_log_text = "test_log_text"
+ ray.get(write_log.remote(test_log_text))
+ assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
+ is True)
+ webui_url = ray_start_with_dashboard["webui_url"]
+ webui_url = format_web_url(webui_url)
+
+ timeout_seconds = 10
+ start_time = time.time()
+ last_ex = None
+ while True:
+ time.sleep(1)
+ try:
+ response = requests.get(webui_url + "/log_index")
+ response.raise_for_status()
+ parser = LogUrlParser()
+ parser.feed(response.text)
+ all_nodes_log_urls = parser.get_urls()
+ assert len(all_nodes_log_urls) == 1
+
+ response = requests.get(all_nodes_log_urls[0])
+ response.raise_for_status()
+ parser = LogUrlParser()
+ parser.feed(response.text)
+
+ # Search test_log_text from all worker logs.
+ parsed_url = urllib.parse.urlparse(all_nodes_log_urls[0])
+ paths = parser.get_urls()
+ urls = []
+ for p in paths:
+ if "worker" in p:
+ urls.append(parsed_url._replace(path=p).geturl())
+
+ for u in urls:
+ response = requests.get(u)
+ response.raise_for_status()
+ if test_log_text in response.text:
+ break
+ else:
+ raise Exception(f"Can't find {test_log_text} from {urls}")
+
+ # Test range request.
+ response = requests.get(
+ webui_url + "/logs/dashboard.log",
+ headers={"Range": "bytes=43-51"})
+ response.raise_for_status()
+ assert response.text == "Dashboard"
+
+ # Test logUrl in node info.
+ response = requests.get(webui_url +
+ f"/nodes/{socket.gethostname()}")
+ response.raise_for_status()
+ node_info = response.json()
+ assert node_info["result"] is True
+ node_info = node_info["data"]["detail"]
+ assert "logUrl" in node_info
+ assert node_info["logUrl"] in all_nodes_log_urls
+ break
+ except Exception as ex:
+ last_ex = ex
+ finally:
+ if time.time() > start_time + timeout_seconds:
+ ex_stack = traceback.format_exception(
+ type(last_ex), last_ex,
+ last_ex.__traceback__) if last_ex else []
+ ex_stack = "".join(ex_stack)
+ raise Exception(f"Timed out while testing, {ex_stack}")
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py
index 4e3f7cd55d6d..f129a702e4e3 100644
--- a/dashboard/modules/reporter/reporter_agent.py
+++ b/dashboard/modules/reporter/reporter_agent.py
@@ -6,6 +6,7 @@
import socket
import subprocess
import sys
+import traceback
import aioredis
@@ -17,6 +18,7 @@
import ray.utils
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
+from ray.metrics_agent import MetricsAgent
import psutil
logger = logging.getLogger(__name__)
@@ -67,30 +69,46 @@ def __init__(self, dashboard_agent):
self._hostname = socket.gethostname()
self._workers = set()
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
+ self._metrics_agent = MetricsAgent(dashboard_agent.metrics_export_port)
async def GetProfilingStats(self, request, context):
pid = request.pid
duration = request.duration
profiling_file_path = os.path.join(ray.utils.get_ray_temp_dir(),
- "{}_profiling.txt".format(pid))
- process = subprocess.Popen(
- "sudo $(which py-spy) record -o {} -p {} -d {} -f speedscope"
- .format(profiling_file_path, pid, duration),
+ f"{pid}_profiling.txt")
+ sudo = "sudo" if ray.utils.get_user() != "root" else ""
+ process = await asyncio.create_subprocess_shell(
+ f"{sudo} $(which py-spy) record "
+ f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
- stdout, stderr = process.communicate()
+ stdout, stderr = await process.communicate()
if process.returncode != 0:
profiling_stats = ""
else:
with open(profiling_file_path, "r") as f:
profiling_stats = f.read()
return reporter_pb2.GetProfilingStatsReply(
- profiling_stats=profiling_stats, stdout=stdout, stderr=stderr)
+ profiling_stats=profiling_stats, std_out=stdout, std_err=stderr)
async def ReportMetrics(self, request, context):
- # TODO(sang): Process metrics here.
- return reporter_pb2.ReportMetricsReply()
+ # NOTE: Exceptions are not propagated properly
+ # when we don't catch them here.
+ try:
+ metrcs_description_required = (
+ self._metrics_agent.record_metrics_points(
+ request.metrics_points))
+ except Exception as e:
+ logger.error(e)
+ logger.error(traceback.format_exc())
+
+ # If metrics description is missing, we should notify cpp processes
+ # that we need them. Cpp processes will then report them to here.
+ # We need it when (1) a new metric is reported (application metric)
+ # (2) a reporter goes down and restarted (currently not implemented).
+ return reporter_pb2.ReportMetricsReply(
+ metrcs_description_required=metrcs_description_required)
@staticmethod
def _get_cpu_percent():
@@ -222,8 +240,8 @@ async def _perform_iteration(self):
await aioredis_client.publish(
"{}{}".format(reporter_consts.REPORTER_PREFIX,
self._hostname), jsonify_asdict(stats))
- except Exception as ex:
- logger.exception(ex)
+ except Exception:
+ logger.exception("Error publishing node physical stats.")
await asyncio.sleep(
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py
index fdd262a604bd..6046fd325781 100644
--- a/dashboard/modules/reporter/reporter_head.py
+++ b/dashboard/modules/reporter/reporter_head.py
@@ -1,6 +1,5 @@
import json
import logging
-import uuid
import aiohttp.web
from aioredis.pubsub import Receiver
@@ -24,58 +23,32 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._stubs = {}
- self._profiling_stats = {}
DataSource.agents.signal.append(self._update_stubs)
async def _update_stubs(self, change):
+ if change.old:
+ ip, port = change.old
+ self._stubs.pop(ip)
if change.new:
- ip, ports = next(iter(change.new.items()))
- channel = aiogrpc.insecure_channel("{}:{}".format(ip, ports[1]))
+ ip, ports = change.new
+ channel = aiogrpc.insecure_channel(f"{ip}:{ports[1]}")
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub
- if change.old:
- ip, port = next(iter(change.old.items()))
- self._stubs.pop(ip)
@routes.get("/api/launch_profiling")
async def launch_profiling(self, req) -> aiohttp.web.Response:
- node_id = req.query.get("node_id")
- pid = int(req.query.get("pid"))
- duration = int(req.query.get("duration"))
- profiling_id = str(uuid.uuid4())
- reporter_stub = self._stubs[node_id]
+ ip = req.query["ip"]
+ pid = int(req.query["pid"])
+ duration = int(req.query["duration"])
+ reporter_stub = self._stubs[ip]
reply = await reporter_stub.GetProfilingStats(
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration))
- self._profiling_stats[profiling_id] = reply
- return await dashboard_utils.rest_response(
- success=True,
- message="Profiling launched.",
- profiling_id=profiling_id)
-
- @routes.get("/api/check_profiling_status")
- async def check_profiling_status(self, req) -> aiohttp.web.Response:
- profiling_id = req.query.get("profiling_id")
- is_present = profiling_id in self._profiling_stats
- if not is_present:
- status = {"status": "pending"}
- else:
- reply = self._profiling_stats[profiling_id]
- if reply.stderr:
- status = {"status": "error", "error": reply.stderr}
- else:
- status = {"status": "finished"}
- return await dashboard_utils.rest_response(
- success=True, message="Profiling status fetched.", status=status)
-
- @routes.get("/api/get_profiling_info")
- async def get_profiling_info(self, req) -> aiohttp.web.Response:
- profiling_id = req.query.get("profiling_id")
- profiling_stats = self._profiling_stats.get(profiling_id)
- assert profiling_stats, "profiling not finished"
+ profiling_info = (json.loads(reply.profiling_stats)
+ if reply.profiling_stats else reply.std_out)
return await dashboard_utils.rest_response(
success=True,
- message="Profiling info fetched.",
- profiling_info=json.loads(profiling_stats.profiling_stats))
+ message="Profiling success.",
+ profiling_info=profiling_info)
async def run(self, server):
aioredis_client = self._dashboard_head.aioredis_client
@@ -83,12 +56,13 @@ async def run(self, server):
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
await aioredis_client.psubscribe(receiver.pattern(reporter_key))
- logger.info("Subscribed to {}".format(reporter_key))
+ logger.info(f"Subscribed to {reporter_key}")
async for sender, msg in receiver.iter():
try:
_, data = msg
data = json.loads(ray.utils.decode(data))
DataSource.node_physical_stats[data["ip"]] = data
- except Exception as ex:
- logger.exception(ex)
+ except Exception:
+ logger.exception(
+ "Error receiving node physical stats from reporter agent.")
diff --git a/dashboard/modules/reporter/test_reporter.py b/dashboard/modules/reporter/test_reporter.py
new file mode 100644
index 000000000000..0097bc465c3a
--- /dev/null
+++ b/dashboard/modules/reporter/test_reporter.py
@@ -0,0 +1,102 @@
+import os
+import sys
+import logging
+import requests
+import time
+
+import pytest
+import ray
+from ray.new_dashboard.tests.conftest import * # noqa
+from ray.test_utils import (
+ format_web_url,
+ RayTestTimeoutException,
+ wait_until_server_available,
+ wait_for_condition,
+)
+
+os.environ["RAY_USE_NEW_DASHBOARD"] = "1"
+
+logger = logging.getLogger(__name__)
+
+
+def test_profiling(shutdown_only):
+ addresses = ray.init(include_dashboard=True, num_cpus=6)
+
+ @ray.remote(num_cpus=2)
+ class Actor:
+ def getpid(self):
+ return os.getpid()
+
+ c = Actor.remote()
+ actor_pid = ray.get(c.getpid.remote())
+
+ webui_url = addresses["webui_url"]
+ assert (wait_until_server_available(webui_url) is True)
+ webui_url = format_web_url(webui_url)
+
+ start_time = time.time()
+ launch_profiling = None
+ while True:
+ # Sometimes some startup time is required
+ if time.time() - start_time > 10:
+ raise RayTestTimeoutException(
+ "Timed out while collecting profiling stats, "
+ f"launch_profiling: {launch_profiling}")
+ launch_profiling = requests.get(
+ webui_url + "/api/launch_profiling",
+ params={
+ "ip": ray.nodes()[0]["NodeManagerAddress"],
+ "pid": actor_pid,
+ "duration": 5
+ }).json()
+ if launch_profiling["result"]:
+ profiling_info = launch_profiling["data"]["profilingInfo"]
+ break
+ time.sleep(1)
+ logger.info(profiling_info)
+
+
+def test_node_physical_stats(enable_test_module, shutdown_only):
+ addresses = ray.init(include_dashboard=True, num_cpus=6)
+
+ @ray.remote(num_cpus=1)
+ class Actor:
+ def getpid(self):
+ return os.getpid()
+
+ actors = [Actor.remote() for _ in range(6)]
+ actor_pids = ray.get([actor.getpid.remote() for actor in actors])
+ actor_pids = set(actor_pids)
+
+ webui_url = addresses["webui_url"]
+ assert (wait_until_server_available(webui_url) is True)
+ webui_url = format_web_url(webui_url)
+
+ def _check_workers():
+ try:
+ resp = requests.get(webui_url +
+ "/test/dump?key=node_physical_stats")
+ resp.raise_for_status()
+ result = resp.json()
+ assert result["result"] is True
+ node_physical_stats = result["data"]["nodePhysicalStats"]
+ assert len(node_physical_stats) == 1
+ current_stats = node_physical_stats[addresses["raylet_ip_address"]]
+ # Check Actor workers
+ current_actor_pids = set()
+ for worker in current_stats["workers"]:
+ if "ray::Actor" in worker["cmdline"][0]:
+ current_actor_pids.add(worker["pid"])
+ assert current_actor_pids == actor_pids
+ # Check raylet cmdline
+ assert "raylet" in current_stats["cmdline"][0]
+ return True
+ except Exception as ex:
+ logger.info(ex)
+ return False
+
+ wait_for_condition(_check_workers, timeout=10)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/dashboard/modules/stats_collector/__init__.py b/dashboard/modules/stats_collector/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/dashboard/modules/stats_collector/stats_collector_consts.py b/dashboard/modules/stats_collector/stats_collector_consts.py
new file mode 100644
index 000000000000..a0d58d32602f
--- /dev/null
+++ b/dashboard/modules/stats_collector/stats_collector_consts.py
@@ -0,0 +1,3 @@
+NODE_STATS_UPDATE_INTERVAL_SECONDS = 1
+RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS = 1
+ACTOR_CHANNEL = "ACTOR"
diff --git a/dashboard/modules/stats_collector/stats_collector_head.py b/dashboard/modules/stats_collector/stats_collector_head.py
new file mode 100644
index 000000000000..5ff21070618b
--- /dev/null
+++ b/dashboard/modules/stats_collector/stats_collector_head.py
@@ -0,0 +1,155 @@
+import asyncio
+import logging
+
+import aiohttp.web
+from aioredis.pubsub import Receiver
+from grpc.experimental import aio as aiogrpc
+
+import ray.gcs_utils
+import ray.new_dashboard.modules.stats_collector.stats_collector_consts \
+ as stats_collector_consts
+import ray.new_dashboard.utils as dashboard_utils
+from ray.new_dashboard.utils import async_loop_forever
+from ray.core.generated import node_manager_pb2
+from ray.core.generated import node_manager_pb2_grpc
+from ray.core.generated import gcs_service_pb2
+from ray.core.generated import gcs_service_pb2_grpc
+from ray.new_dashboard.datacenter import DataSource, DataOrganizer
+from ray.utils import binary_to_hex
+
+logger = logging.getLogger(__name__)
+routes = dashboard_utils.ClassMethodRouteTable
+
+
+def node_stats_to_dict(message):
+ return dashboard_utils.message_to_dict(
+ message, {
+ "actorId", "jobId", "taskId", "parentTaskId", "sourceActorId",
+ "callerId", "rayletId", "workerId"
+ })
+
+
+def actor_table_data_to_dict(message):
+ return dashboard_utils.message_to_dict(
+ message, {
+ "actorId", "parentId", "jobId", "workerId", "rayletId",
+ "actorCreationDummyObjectId"
+ },
+ including_default_value_fields=True)
+
+
+class StatsCollector(dashboard_utils.DashboardHeadModule):
+ def __init__(self, dashboard_head):
+ super().__init__(dashboard_head)
+ self._stubs = {}
+ # JobInfoGcsServiceStub
+ self._gcs_job_info_stub = None
+ # ActorInfoGcsService
+ self._gcs_actor_info_stub = None
+ DataSource.nodes.signal.append(self._update_stubs)
+
+ async def _update_stubs(self, change):
+ if change.old:
+ ip, port = change.old
+ self._stubs.pop(ip)
+ if change.new:
+ ip, node_info = change.new
+ address = "{}:{}".format(ip, int(node_info["nodeManagerPort"]))
+ channel = aiogrpc.insecure_channel(address)
+ stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
+ self._stubs[ip] = stub
+
+ @routes.get("/nodes")
+ async def get_all_nodes(self, req) -> aiohttp.web.Response:
+ view = req.query.get("view")
+ if view == "summary":
+ all_node_summary = await DataOrganizer.get_all_node_summary()
+ return await dashboard_utils.rest_response(
+ success=True,
+ message="Node summary fetched.",
+ summary=all_node_summary)
+ elif view is not None and view.lower() == "hostNameList".lower():
+ return await dashboard_utils.rest_response(
+ success=True,
+ message="Node hostname list fetched.",
+ host_name_list=list(DataSource.hostname_to_ip.keys()))
+ else:
+ return await dashboard_utils.rest_response(
+ success=False, message=f"Unknown view {view}")
+
+ @routes.get("/nodes/{hostname}")
+ async def get_node(self, req) -> aiohttp.web.Response:
+ hostname = req.match_info.get("hostname")
+ node_info = await DataOrganizer.get_node_info(hostname)
+ return await dashboard_utils.rest_response(
+ success=True, message="Node detail fetched.", detail=node_info)
+
+ async def _update_actors(self):
+ # Subscribe actor channel.
+ aioredis_client = self._dashboard_head.aioredis_client
+ receiver = Receiver()
+
+ key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL)
+ pattern = receiver.pattern(key)
+ await aioredis_client.psubscribe(pattern)
+ logger.info("Subscribed to %s", key)
+
+ # Get all actor info.
+ while True:
+ try:
+ logger.info("Getting all actor info from GCS.")
+ request = gcs_service_pb2.GetAllActorInfoRequest()
+ reply = await self._gcs_actor_info_stub.GetAllActorInfo(
+ request, timeout=2)
+ if reply.status.code == 0:
+ result = {}
+ for actor_info in reply.actor_table_data:
+ result[binary_to_hex(actor_info.actor_id)] = \
+ actor_table_data_to_dict(actor_info)
+ DataSource.actors.reset(result)
+ logger.info("Received %d actor info from GCS.",
+ len(result))
+ break
+ else:
+ raise Exception(
+ f"Failed to GetAllActorInfo: {reply.status.message}")
+ except Exception:
+ logger.exception("Error Getting all actor info from GCS.")
+ await asyncio.sleep(stats_collector_consts.
+ RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)
+
+ # Receive actors from channel.
+ async for sender, msg in receiver.iter():
+ try:
+ _, data = msg
+ pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
+ actor_info = ray.gcs_utils.ActorTableData.FromString(
+ pubsub_message.data)
+ DataSource.actors[binary_to_hex(actor_info.actor_id)] = \
+ actor_table_data_to_dict(actor_info)
+ except Exception:
+ logger.exception("Error receiving actor info.")
+
+ @async_loop_forever(
+ stats_collector_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
+ async def _update_node_stats(self):
+ for ip, stub in self._stubs.items():
+ node_info = DataSource.nodes.get(ip)
+ if node_info["state"] != "ALIVE":
+ continue
+ try:
+ reply = await stub.GetNodeStats(
+ node_manager_pb2.GetNodeStatsRequest(), timeout=2)
+ reply_dict = node_stats_to_dict(reply)
+ DataSource.node_stats[ip] = reply_dict
+ except Exception:
+ logger.exception(f"Error updating node stats of {ip}.")
+
+ async def run(self, server):
+ gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
+ self._gcs_job_info_stub = \
+ gcs_service_pb2_grpc.JobInfoGcsServiceStub(gcs_channel)
+ self._gcs_actor_info_stub = \
+ gcs_service_pb2_grpc.ActorInfoGcsServiceStub(gcs_channel)
+
+ await asyncio.gather(self._update_node_stats(), self._update_actors())
diff --git a/dashboard/modules/stats_collector/test_stats_collector.py b/dashboard/modules/stats_collector/test_stats_collector.py
new file mode 100644
index 000000000000..72e614437d10
--- /dev/null
+++ b/dashboard/modules/stats_collector/test_stats_collector.py
@@ -0,0 +1,93 @@
+import os
+import sys
+import logging
+import requests
+import time
+import traceback
+
+import pytest
+import ray
+from ray.new_dashboard.tests.conftest import * # noqa
+from ray.test_utils import (
+ format_web_url,
+ wait_until_server_available,
+)
+
+os.environ["RAY_USE_NEW_DASHBOARD"] = "1"
+
+logger = logging.getLogger(__name__)
+
+
+def test_node_info(ray_start_with_dashboard):
+ @ray.remote
+ class Actor:
+ def getpid(self):
+ return os.getpid()
+
+ actors = [Actor.remote(), Actor.remote()]
+ actor_pids = [actor.getpid.remote() for actor in actors]
+ actor_pids = set(ray.get(actor_pids))
+
+ assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
+ is True)
+ webui_url = ray_start_with_dashboard["webui_url"]
+ webui_url = format_web_url(webui_url)
+
+ timeout_seconds = 10
+ start_time = time.time()
+ last_ex = None
+ while True:
+ time.sleep(1)
+ try:
+ response = requests.get(webui_url + "/nodes?view=hostnamelist")
+ response.raise_for_status()
+ hostname_list = response.json()
+ assert hostname_list["result"] is True, hostname_list["msg"]
+ hostname_list = hostname_list["data"]["hostNameList"]
+ assert len(hostname_list) == 1
+
+ hostname = hostname_list[0]
+ response = requests.get(webui_url + f"/nodes/{hostname}")
+ response.raise_for_status()
+ detail = response.json()
+ assert detail["result"] is True, detail["msg"]
+ detail = detail["data"]["detail"]
+ assert detail["hostname"] == hostname
+ assert detail["state"] == "ALIVE"
+ assert "raylet" in detail["cmdline"][0]
+ assert len(detail["workers"]) >= 2
+ assert len(detail["actors"]) == 2, detail["actors"]
+ assert len(detail["raylet"]["viewData"]) > 0
+
+ actor_worker_pids = set()
+ for worker in detail["workers"]:
+ if "ray::Actor" in worker["cmdline"][0]:
+ actor_worker_pids.add(worker["pid"])
+ assert actor_worker_pids == actor_pids
+
+ response = requests.get(webui_url + "/nodes?view=summary")
+ response.raise_for_status()
+ summary = response.json()
+ assert summary["result"] is True, summary["msg"]
+ assert len(summary["data"]["summary"]) == 1
+ summary = summary["data"]["summary"][0]
+ assert summary["hostname"] == hostname
+ assert summary["state"] == "ALIVE"
+ assert "raylet" in summary["cmdline"][0]
+ assert "workers" not in summary
+ assert "actors" not in summary
+ assert "viewData" not in summary["raylet"]
+ break
+ except Exception as ex:
+ last_ex = ex
+ finally:
+ if time.time() > start_time + timeout_seconds:
+ ex_stack = traceback.format_exception(
+ type(last_ex), last_ex,
+ last_ex.__traceback__) if last_ex else []
+ ex_stack = "".join(ex_stack)
+ raise Exception(f"Timed out while testing, {ex_stack}")
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/dashboard/modules/test/test_agent.py b/dashboard/modules/test/test_agent.py
index efe9c3285a77..dd4597f7dc16 100644
--- a/dashboard/modules/test/test_agent.py
+++ b/dashboard/modules/test/test_agent.py
@@ -4,12 +4,16 @@
import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.modules.test.test_utils as test_utils
+import ray.new_dashboard.modules.test.test_consts as test_consts
+from ray.ray_constants import env_bool
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
-class HeadAgent(dashboard_utils.DashboardAgentModule):
+@dashboard_utils.dashboard_module(
+ enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
+class TestAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
@@ -20,5 +24,17 @@ async def get_url(self, req) -> aiohttp.web.Response:
url)
return aiohttp.web.json_response(result)
+ @routes.head("/test/route_head")
+ async def route_head(self, req) -> aiohttp.web.Response:
+ pass
+
+ @routes.post("/test/route_post")
+ async def route_post(self, req) -> aiohttp.web.Response:
+ pass
+
+ @routes.patch("/test/route_patch")
+ async def route_patch(self, req) -> aiohttp.web.Response:
+ pass
+
async def run(self, server):
pass
diff --git a/dashboard/modules/test/test_consts.py b/dashboard/modules/test/test_consts.py
new file mode 100644
index 000000000000..3c53b928c1ef
--- /dev/null
+++ b/dashboard/modules/test/test_consts.py
@@ -0,0 +1 @@
+TEST_MODULE_ENVIRONMENT_KEY = "RAY_DASHBOARD_MODULE_TEST"
diff --git a/dashboard/modules/test/test_head.py b/dashboard/modules/test/test_head.py
index 61e7635e0e5b..699fc15458cd 100644
--- a/dashboard/modules/test/test_head.py
+++ b/dashboard/modules/test/test_head.py
@@ -4,12 +4,16 @@
import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.modules.test.test_utils as test_utils
+import ray.new_dashboard.modules.test.test_consts as test_consts
from ray.new_dashboard.datacenter import DataSource
+from ray.ray_constants import env_bool
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
+@dashboard_utils.dashboard_module(
+ enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
class TestHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
@@ -17,12 +21,28 @@ def __init__(self, dashboard_head):
DataSource.agents.signal.append(self._update_notified_agents)
async def _update_notified_agents(self, change):
- if change.new:
- ip, ports = next(iter(change.new.items()))
- self._notified_agents[ip] = ports
if change.old:
- ip, port = next(iter(change.old.items()))
+ ip, port = change.old
self._notified_agents.pop(ip)
+ if change.new:
+ ip, ports = change.new
+ self._notified_agents[ip] = ports
+
+ @routes.get("/test/route_get")
+ async def route_get(self, req) -> aiohttp.web.Response:
+ pass
+
+ @routes.put("/test/route_put")
+ async def route_put(self, req) -> aiohttp.web.Response:
+ pass
+
+ @routes.delete("/test/route_delete")
+ async def route_delete(self, req) -> aiohttp.web.Response:
+ pass
+
+ @routes.view("/test/route_view")
+ async def route_view(self, req) -> aiohttp.web.Response:
+ pass
@routes.get("/test/dump")
async def dump(self, req) -> aiohttp.web.Response:
@@ -41,7 +61,7 @@ async def dump(self, req) -> aiohttp.web.Response:
data = dict(DataSource.__dict__.get(key))
return await dashboard_utils.rest_response(
success=True,
- message="Fetch {} from datacenter success.".format(key),
+ message=f"Fetch {key} from datacenter success.",
**{key: data})
@routes.get("/test/notified_agents")
diff --git a/dashboard/tests/conftest.py b/dashboard/tests/conftest.py
index a60ce1007d3f..030e258069c5 100644
--- a/dashboard/tests/conftest.py
+++ b/dashboard/tests/conftest.py
@@ -1 +1,10 @@
+import os
+import pytest
from ray.tests.conftest import * # noqa
+
+
+@pytest.fixture
+def enable_test_module():
+ os.environ["RAY_DASHBOARD_MODULE_TEST"] = "true"
+ yield
+ os.environ.pop("RAY_DASHBOARD_MODULE_TEST", None)
diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py
index a6ebcaa49132..7983a02c174a 100644
--- a/dashboard/tests/test_dashboard.py
+++ b/dashboard/tests/test_dashboard.py
@@ -1,7 +1,11 @@
import os
+import sys
+import json
import time
import logging
+import asyncio
+import aiohttp.web
import ray
import psutil
import pytest
@@ -9,13 +13,20 @@
import requests
from ray import ray_constants
-from ray.test_utils import wait_for_condition, wait_until_server_available
+from ray.test_utils import (
+ format_web_url,
+ wait_for_condition,
+ wait_until_server_available,
+ run_string_as_driver,
+)
import ray.new_dashboard.consts as dashboard_consts
+import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.modules
os.environ["RAY_USE_NEW_DASHBOARD"] = "1"
logger = logging.getLogger(__name__)
+routes = dashboard_utils.ClassMethodRouteTable
def cleanup_test_files():
@@ -106,7 +117,9 @@ def _search_agent(processes):
logger.info("Test agent register is OK.")
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
- assert dashboard_proc.status() == psutil.STATUS_RUNNING
+ assert dashboard_proc.status() in [
+ psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING
+ ]
agent_proc = _search_agent(raylet_proc.children())
agent_pid = agent_proc.pid
@@ -130,13 +143,13 @@ def _search_agent(processes):
assert agent_ports is not None
-def test_nodes_update(ray_start_with_dashboard):
+def test_nodes_update(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
webui_url = ray_start_with_dashboard["webui_url"]
- webui_url = webui_url.replace("localhost", "http://127.0.0.1")
+ webui_url = format_web_url(webui_url)
- timeout_seconds = 20
+ timeout_seconds = 10
start_time = time.time()
while True:
time.sleep(1)
@@ -146,7 +159,7 @@ def test_nodes_update(ray_start_with_dashboard):
try:
dump_info = response.json()
except Exception as ex:
- logger.info("failed response: {}".format(response.text))
+ logger.info("failed response: %s", response.text)
raise ex
assert dump_info["result"] is True
dump_data = dump_info["data"]
@@ -162,7 +175,7 @@ def test_nodes_update(ray_start_with_dashboard):
try:
notified_agents = response.json()
except Exception as ex:
- logger.info("failed response: {}".format(response.text))
+ logger.info("failed response: %s", response.text)
raise ex
assert notified_agents["result"] is True
notified_agents = notified_agents["data"]
@@ -173,19 +186,18 @@ def test_nodes_update(ray_start_with_dashboard):
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
- raise Exception(
- "Timed out while waiting for dashboard to start.")
+ raise Exception("Timed out while testing.")
-def test_http_get(ray_start_with_dashboard):
+def test_http_get(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
webui_url = ray_start_with_dashboard["webui_url"]
- webui_url = webui_url.replace("localhost", "http://127.0.0.1")
+ webui_url = format_web_url(webui_url)
target_url = webui_url + "/test/dump"
- timeout_seconds = 20
+ timeout_seconds = 10
start_time = time.time()
while True:
time.sleep(1)
@@ -196,7 +208,7 @@ def test_http_get(ray_start_with_dashboard):
try:
dump_info = response.json()
except Exception as ex:
- logger.info("failed response: {}".format(response.text))
+ logger.info("failed response: %s", response.text)
raise ex
assert dump_info["result"] is True
dump_data = dump_info["data"]
@@ -205,13 +217,13 @@ def test_http_get(ray_start_with_dashboard):
http_port, grpc_port = ports
response = requests.get(
- "http://{}:{}/test/http_get_from_agent?url={}".format(
- ip, http_port, target_url))
+ f"http://{ip}:{http_port}"
+ f"/test/http_get_from_agent?url={target_url}")
response.raise_for_status()
try:
dump_info = response.json()
except Exception as ex:
- logger.info("failed response: {}".format(response.text))
+ logger.info("failed response: %s", response.text)
raise ex
assert dump_info["result"] is True
break
@@ -219,5 +231,135 @@ def test_http_get(ray_start_with_dashboard):
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
- raise Exception(
- "Timed out while waiting for dashboard to start.")
+ raise Exception("Timed out while testing.")
+
+
+def test_class_method_route_table(enable_test_module):
+ head_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardHeadModule)
+ agent_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardAgentModule)
+ test_head_cls = None
+ for cls in head_cls_list:
+ if cls.__name__ == "TestHead":
+ test_head_cls = cls
+ break
+ assert test_head_cls is not None
+ test_agent_cls = None
+ for cls in agent_cls_list:
+ if cls.__name__ == "TestAgent":
+ test_agent_cls = cls
+ break
+ assert test_agent_cls is not None
+
+ def _has_route(route, method, path):
+ if isinstance(route, aiohttp.web.RouteDef):
+ if route.method == method and route.path == path:
+ return True
+ return False
+
+ def _has_static(route, path, prefix):
+ if isinstance(route, aiohttp.web.StaticDef):
+ if route.path == path and route.prefix == prefix:
+ return True
+ return False
+
+ all_routes = dashboard_utils.ClassMethodRouteTable.routes()
+ assert any(_has_route(r, "HEAD", "/test/route_head") for r in all_routes)
+ assert any(_has_route(r, "GET", "/test/route_get") for r in all_routes)
+ assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
+ assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
+ assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
+ assert any(
+ _has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
+ assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
+
+ # Test bind()
+ bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
+ assert len(bound_routes) == 0
+ dashboard_utils.ClassMethodRouteTable.bind(
+ test_agent_cls.__new__(test_agent_cls))
+ bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
+ assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
+ assert all(
+ not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
+
+ # Static def should be in bound routes.
+ routes.static("/test/route_static", "/path")
+ bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
+ assert any(
+ _has_static(r, "/path", "/test/route_static") for r in bound_routes)
+
+ # Test duplicated routes should raise exception.
+ try:
+
+ @routes.get("/test/route_get")
+ def _duplicated_route(req):
+ pass
+
+ raise Exception("Duplicated routes should raise exception.")
+ except Exception as ex:
+ message = str(ex)
+ assert "/test/route_get" in message
+ assert "test_head.py" in message
+
+ # Test exception in handler
+ post_handler = None
+ for r in bound_routes:
+ if _has_route(r, "POST", "/test/route_post"):
+ post_handler = r.handler
+ break
+ assert post_handler is not None
+
+ loop = asyncio.get_event_loop()
+ r = loop.run_until_complete(post_handler())
+ assert r.status == 200
+ resp = json.loads(r.body)
+ assert resp["result"] is False
+ assert "Traceback" in resp["msg"]
+
+
+def test_async_loop_forever():
+ counter = [0]
+
+ @dashboard_utils.async_loop_forever(interval_seconds=0.1)
+ async def foo():
+ counter[0] += 1
+ raise Exception("Test exception")
+
+ loop = asyncio.get_event_loop()
+ loop.create_task(foo())
+ loop.call_later(1, loop.stop)
+ loop.run_forever()
+ assert counter[0] > 2
+
+
+def test_dashboard_module_decorator(enable_test_module):
+ head_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardHeadModule)
+ agent_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardAgentModule)
+
+ assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
+ assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
+
+ test_code = """
+import os
+import ray.new_dashboard.utils as dashboard_utils
+
+os.environ.pop("RAY_DASHBOARD_MODULE_TEST")
+head_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardHeadModule)
+agent_cls_list = dashboard_utils.get_all_modules(
+ dashboard_utils.DashboardAgentModule)
+print(head_cls_list)
+print(agent_cls_list)
+assert all(cls.__name__ != "TestHead" for cls in head_cls_list)
+assert all(cls.__name__ != "TestAgent" for cls in agent_cls_list)
+print("success")
+"""
+ run_string_as_driver(test_code)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/dashboard/utils.py b/dashboard/utils.py
index 054fe9fd4839..9fd170890b8e 100644
--- a/dashboard/utils.py
+++ b/dashboard/utils.py
@@ -13,6 +13,7 @@
import traceback
from base64 import b64decode
from collections.abc import MutableMapping, Mapping
+from collections import namedtuple
from typing import Any
import aioredis
@@ -103,10 +104,9 @@ def _register_route(cls, method, path, **kwargs):
def _wrapper(handler):
if path in cls._bind_map[method]:
bind_info = cls._bind_map[method][path]
- raise Exception("Duplicated route path: {}, "
- "previous one registered at {}:{}".format(
- path, bind_info.filename,
- bind_info.lineno))
+ raise Exception(f"Duplicated route path: {path}, "
+ f"previous one registered at "
+ f"{bind_info.filename}:{bind_info.lineno}")
bind_info = cls._BindInfo(handler.__code__.co_filename,
handler.__code__.co_firstlineno, None)
@@ -174,15 +174,28 @@ def predicate(o):
h.__func__.__route_path__].instance = instance
+def dashboard_module(enable):
+ """A decorator for dashboard module."""
+
+ def _cls_wrapper(cls):
+ cls.__ray_dashboard_module_enable__ = enable
+ return cls
+
+ return _cls_wrapper
+
+
def get_all_modules(module_type):
- logger.info("Get all modules by type: {}".format(module_type.__name__))
+ logger.info(f"Get all modules by type: {module_type.__name__}")
import ray.new_dashboard.modules
for module_loader, name, ispkg in pkgutil.walk_packages(
ray.new_dashboard.modules.__path__,
ray.new_dashboard.modules.__name__ + "."):
importlib.import_module(name)
- return module_type.__subclasses__()
+ return [
+ m for m in module_type.__subclasses__()
+ if getattr(m, "__ray_dashboard_module_enable__", True)
+ ]
def to_posix_time(dt):
@@ -314,8 +327,7 @@ def __init__(self, owner=None, old=None, new=None):
self.new = new
def __str__(self):
- return "Change(owner: {}, old: {}, new: {}".format(
- self.owner, self.old, self.new)
+ return f"Change(owner: {self.owner}, old: {self.old}, new: {self.new}"
class NotifyQueue:
@@ -338,6 +350,8 @@ class Dict(MutableMapping):
:note: Only the first level data report change.
"""
+ ChangeItem = namedtuple("DictChangeItem", ["key", "value"])
+
def __init__(self, *args, **kwargs):
self._data = dict(*args, **kwargs)
self.signal = Signal(self)
@@ -347,10 +361,14 @@ def __setitem__(self, key, value):
self._data[key] = value
if len(self.signal) and old != value:
if old is None:
- co = self.signal.send(Change(owner=self, new={key: value}))
+ co = self.signal.send(
+ Change(owner=self, new=Dict.ChangeItem(key, value)))
else:
co = self.signal.send(
- Change(owner=self, old={key: old}, new={key: value}))
+ Change(
+ owner=self,
+ old=Dict.ChangeItem(key, old),
+ new=Dict.ChangeItem(key, value)))
NotifyQueue.put(co)
def __getitem__(self, item):
@@ -359,7 +377,8 @@ def __getitem__(self, item):
def __delitem__(self, key):
old = self._data.pop(key, None)
if len(self.signal) and old is not None:
- co = self.signal.send(Change(owner=self, old={key: old}))
+ co = self.signal.send(
+ Change(owner=self, old=Dict.ChangeItem(key, old)))
NotifyQueue.put(co)
def __len__(self):
@@ -387,3 +406,19 @@ async def get_aioredis_client(redis_address, redis_password,
# Raise exception from create_redis_pool
return await aioredis.create_redis_pool(
address=redis_address, password=redis_password)
+
+
+def async_loop_forever(interval_seconds):
+ def _wrapper(coro):
+ @functools.wraps(coro)
+ async def _looper(*args, **kwargs):
+ while True:
+ try:
+ await coro(*args, **kwargs)
+ except Exception:
+ logger.exception(f"Error looping coroutine {coro}.")
+ await asyncio.sleep(interval_seconds)
+
+ return _looper
+
+ return _wrapper
diff --git a/doc/source/actors.rst b/doc/source/actors.rst
index 6d3127158ded..9601486507c8 100644
--- a/doc/source/actors.rst
+++ b/doc/source/actors.rst
@@ -54,7 +54,7 @@ Any method of the actor can return multiple object refs with the ``ray.method``
@ray.remote
class Foo(object):
- @ray.method(num_return_vals=2)
+ @ray.method(num_returns=2)
def bar(self):
return 1, 2
@@ -72,7 +72,7 @@ Resources with Actors
You can specify that an actor requires CPUs or GPUs in the decorator. While Ray has built-in support for CPUs and GPUs, Ray can also handle custom resources.
When using GPUs, Ray will automatically set the environment variable ``CUDA_VISIBLE_DEVICES`` for the actor after instantiated. The actor will have access to a list of the IDs of the GPUs
-that it is allowed to use via ``ray.get_gpu_ids(as_str=True)``. This is a list of strings,
+that it is allowed to use via ``ray.get_gpu_ids()``. This is a list of strings,
like ``[]``, or ``['1']``, or ``['2', '5', '6']``. Under some circumstances, the IDs of GPUs could be given as UUID strings instead of indices (see the `CUDA programming guide `__).
.. code-block:: python
@@ -149,10 +149,12 @@ approach should generally not be necessary as actors are automatically garbage
collected. The ``ObjectRef`` resulting from the task can be waited on to wait
for the actor to exit (calling ``ray.get()`` on it will raise a ``RayActorError``).
Note that this method of termination will wait until any previously submitted
-tasks finish executing. If you want to terminate an actor immediately, you can
-call ``ray.kill(actor_handle)``. This will cause the actor to exit immediately
-and any pending tasks to fail. Any exit handlers installed in the actor using
-``atexit`` will be called.
+tasks finish executing and then exit the process gracefully with sys.exit. If you
+want to terminate an actor forcefully, you can call ``ray.kill(actor_handle)``.
+This will call the exit syscall from within the actor, causing it to exit
+immediately and any pending tasks to fail. This will not go through the normal
+Python sys.exit teardown logic, so any exit handlers installed in the actor using
+``atexit`` will not be called.
Passing Around Actor Handles
----------------------------
diff --git a/doc/source/advanced.rst b/doc/source/advanced.rst
index fc742fb5ecdf..52ce58bd24de 100644
--- a/doc/source/advanced.rst
+++ b/doc/source/advanced.rst
@@ -47,7 +47,7 @@ And vary the number of return values for tasks (and actor methods too):
def f(n):
return list(range(n))
- id1, id2 = f.options(num_return_vals=2).remote(2)
+ id1, id2 = f.options(num_returns=2).remote(2)
assert ray.get(id1) == 0
assert ray.get(id2) == 1
diff --git a/doc/source/cluster/autoscaling.rst b/doc/source/cluster/autoscaling.rst
new file mode 100644
index 000000000000..4113aa795526
--- /dev/null
+++ b/doc/source/cluster/autoscaling.rst
@@ -0,0 +1,116 @@
+.. _ref-autoscaling:
+
+Cluster Autoscaling
+===================
+
+Basics
+------
+
+The Ray Cluster Launcher will automatically enable a load-based autoscaler. When cluster resource usage exceeds a configurable threshold (80% by default), new nodes will be launched up to the specified ``max_workers`` limit (specified in the cluster config). When nodes are idle for more than a timeout, they will be removed, down to the ``min_workers`` limit. The head node is never removed.
+
+In more detail, the autoscaler implements the following control loop:
+
+ 1. It calculates the estimated utilization of the cluster based on the most-currently-assigned resource. For example, suppose a cluster has 100/200 CPUs assigned, but 20/25 GPUs assigned, then the utilization will be considered to be max(100/200, 15/25) = 60%.
+ 2. If the estimated utilization is greater than the target (80% by default), then the autoscaler will attempt to add nodes to the cluster.
+ 3. If a node is idle for a timeout (5 minutes by default), it is removed from the cluster.
+
+The basic autoscaling config settings are as follows:
+
+.. code::
+
+ # An unique identifier for the head node and workers of this cluster.
+ cluster_name: default
+
+ # The minimum number of workers nodes to launch in addition to the head
+ # node. This number should be >= 0.
+ min_workers: 0
+
+ # The autoscaler will scale up the cluster to this target fraction of resource
+ # usage. For example, if a cluster of 10 nodes is 100% busy and
+ # target_utilization is 0.8, it would resize the cluster to 13. This fraction
+ # can be decreased to increase the aggressiveness of upscaling.
+ # The max value allowed is 1.0, which is the most conservative setting.
+ target_utilization_fraction: 0.8
+
+ # If a node is idle for this many minutes, it will be removed. A node is
+ # considered idle if there are no tasks or actors running on it.
+ idle_timeout_minutes: 5
+
+Multiple Node Type Autoscaling
+------------------------------
+
+Ray supports multiple node types in a single cluster. In this mode of operation, the scheduler will look at the queue of resource shape demands from the cluster (e.g., there might be 10 tasks queued each requesting ``{"GPU": 4, "CPU": 16}``), and tries to add the minimum set of nodes that can fulfill these resource demands. This enables precise, rapid scale up compared to looking only at resource utilization, as the autoscaler also has visiblity into the aggregate resource load.
+
+The concept of a cluster node type encompasses both the physical instance type (e.g., AWS p3.8xl GPU nodes vs m4.16xl CPU nodes), as well as other attributes (e.g., IAM role, the machine image, etc). `Custom resources `__ can be specified for each node type so that Ray is aware of the demand for specific node types at the application level (e.g., a task may request to be placed on a machine with a specific role or machine image via custom resource).
+
+Multi-node type autoscaling operates in conjunction with the basic autoscaler. You may want to configure the basic autoscaler accordingly to act conservatively (i.e., set ``target_utilization_fraction: 1.0``).
+
+An example of configuring multiple node types is as follows `(full example) `__:
+
+.. code::
+
+ # Specify the allowed node types and the resources they provide.
+ # The key is the name of the node type, which is just for debugging purposes.
+ # The node config specifies the launch config and physical instance type.
+ available_node_types:
+ cpu_4_ondemand:
+ node_config:
+ InstanceType: m4.xlarge
+ resources: {"CPU": 4}
+ max_workers: 5
+ cpu_16_spot:
+ node_config:
+ InstanceType: m4.4xlarge
+ InstanceMarketOptions:
+ MarketType: spot
+ resources: {"CPU": 16, "Custom1": 1, "is_spot": 1}
+ max_workers: 10
+ gpu_1_ondemand:
+ node_config:
+ InstanceType: p2.xlarge
+ resources: {"CPU": 4, "GPU": 1, "Custom2": 2}
+ max_workers: 4
+ worker_setup_commands:
+ - pip install tensorflow-gpu # Example command.
+ gpu_8_ondemand:
+ node_config:
+ InstanceType: p2.8xlarge
+ resources: {"CPU": 32, "GPU": 8}
+ max_workers: 2
+ worker_setup_commands:
+ - pip install tensorflow-gpu # Example command.
+
+ # Specify the node type of the head node (as configured above).
+ head_node_type: cpu_4_ondemand
+
+ # Specify the default type of the worker node (as configured above).
+ worker_default_node_type: cpu_16_spot
+
+
+The above config defines two CPU node types (``cpu_4_ondemand`` and ``cpu_16_spot``), and two GPU types (``gpu_1_ondemand`` and ``gpu_8_ondemand``). Each node type has a name (e.g., ``cpu_4_ondemand``), which has no semantic meaning and is only for debugging. Let's look at the inner fields of the ``gpu_1_ondemand`` node type:
+
+The node config tells the underlying Cloud provider how to launch a node of this type. This node config is merged with the top level node config of the YAML and can override fields (i.e., to specify the p2.xlarge instance type here):
+
+.. code::
+
+ node_config:
+ InstanceType: p2.xlarge
+
+The resources field tells the autoscaler what kinds of resources this node provides. This can include custom resources as well (e.g., "Custom2"). This field enables the autoscaler to automatically select the right kind of nodes to launch given the resource demands of the application. The resources specified here will be automatically passed to the ``ray start`` command for the node via an environment variable. For more information, see also the `resource demand scheduler `__:
+
+.. code::
+
+ resources: {"CPU": 4, "GPU": 1, "Custom2": 2}
+
+The ``max_workers`` field constrains the number of nodes of this type that can be launched:
+
+.. code::
+
+ max_workers: 4
+
+The ``worker_setup_commands`` field can be used to override the setup and initialization commands for a node type. Note that you can only override the setup for worker nodes. The head node's setup commands are always configured via the top level field in the cluster YAML:
+
+.. code::
+
+ worker_setup_commands:
+ - pip install tensorflow-gpu # Example command.
diff --git a/doc/source/cluster/launcher.rst b/doc/source/cluster/launcher.rst
index a1f58102a4e4..cde1eb30530e 100644
--- a/doc/source/cluster/launcher.rst
+++ b/doc/source/cluster/launcher.rst
@@ -213,15 +213,6 @@ This tells ``ray up`` to sync the current git branch SHA from your personal comp
2. Commit the changes with ``git commit`` and ``git push``
3. Update files on your Ray cluster with ``ray up``
-
-Autoscaling
------------
-
-The Ray Cluster Launcher will automatically enable a load-based autoscaler. When cluster resource usage exceeds a configurable threshold (80% by default), new nodes will be launched up the specified ``max_workers`` limit (in the cluster config). When nodes are idle for more than a timeout, they will be removed, down to the ``min_workers`` limit. The head node is never removed.
-
-The default idle timeout is 5 minutes, which can be set in the cluster config. This is to prevent excessive node churn which could impact performance and increase costs (in AWS / GCP there is a minimum billing charge of 1 minute per instance, after which usage is billed by the second).
-
-
Questions or Issues?
--------------------
diff --git a/doc/source/configure.rst b/doc/source/configure.rst
index 98c9f8e55f47..5b152b4e8535 100644
--- a/doc/source/configure.rst
+++ b/doc/source/configure.rst
@@ -159,7 +159,7 @@ To add authentication via the Python API, start Ray using:
.. code-block:: python
- ray.init(redis_password="password")
+ ray.init(_redis_password="password")
To add authentication via the CLI or to connect to an existing Ray instance with
password-protected Redis ports:
@@ -182,48 +182,4 @@ to localhost when the ray is started using ``ray.init``.
See the `Redis security documentation `__
for more information.
-
-Using the Object Store with Huge Pages
---------------------------------------
-
-Plasma is a high-performance shared memory object store originally developed in
-Ray and now being developed in `Apache Arrow`_. See the `relevant
-documentation`_.
-
-On Linux, it is possible to increase the write throughput of the Plasma object
-store by using huge pages. You first need to create a file system and activate
-huge pages as follows.
-
-.. code-block:: shell
-
- sudo mkdir -p /mnt/hugepages
- gid=`id -g`
- uid=`id -u`
- sudo mount -t hugetlbfs -o uid=$uid -o gid=$gid none /mnt/hugepages
- sudo bash -c "echo $gid > /proc/sys/vm/hugetlb_shm_group"
- # This typically corresponds to 20000 2MB pages (about 40GB), but this
- # depends on the platform.
- sudo bash -c "echo 20000 > /proc/sys/vm/nr_hugepages"
-
-**Note:** Once you create the huge pages, they will take up memory which will
-never be freed unless you remove the huge pages. If you run into memory issues,
-that may be the issue.
-
-You need root access to create the file system, but not for running the object
-store.
-
-You can then start Ray with huge pages on a single machine as follows.
-
-.. code-block:: python
-
- ray.init(huge_pages=True, plasma_directory="/mnt/hugepages")
-
-In the cluster case, you can do it by passing ``--huge-pages`` and
-``--plasma-directory=/mnt/hugepages`` into ``ray start`` on any machines where
-huge pages should be enabled.
-
-See the relevant `Arrow documentation for huge pages`_.
-
.. _`Apache Arrow`: https://arrow.apache.org/
-.. _`relevant documentation`: https://arrow.apache.org/docs/python/plasma.html#the-plasma-in-memory-object-store
-.. _`Arrow documentation for huge pages`: https://arrow.apache.org/docs/python/plasma.html#using-plasma-with-huge-pages
diff --git a/doc/source/fault-tolerance.rst b/doc/source/fault-tolerance.rst
index fa3abb09493c..82f77f864f4f 100644
--- a/doc/source/fault-tolerance.rst
+++ b/doc/source/fault-tolerance.rst
@@ -38,7 +38,7 @@ You can experiment with this behavior by running the following code.
# exception.
ray.get(potentially_fail.remote(0.5))
print('SUCCESS')
- except ray.exceptions.RayWorkerError:
+ except ray.exceptions.WorkerCrashedError:
print('FAILURE')
.. _actor-fault-tolerance:
@@ -172,7 +172,7 @@ Task outputs over a configurable threshold (default 100KB) may be stored in
Ray's distributed object store. Thus, a node failure can cause the loss of a
task output. If this occurs, Ray will automatically attempt to recover the
value by looking for copies of the same object on other nodes. If there are no
-other copies left, an ``UnreconstructableError`` will be raised.
+other copies left, an ``ObjectLostError`` will be raised.
When there are no copies of an object left, Ray also provides an option to
automatically recover the value by re-executing the task that created the
diff --git a/doc/source/index.rst b/doc/source/index.rst
index d2acec412e27..e25bf380dcc3 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -141,6 +141,7 @@ Academic Papers
cluster/index.rst
cluster/launcher.rst
+ cluster/autoscaling.rst
cluster/cloud.rst
cluster/deploy.rst
@@ -206,7 +207,6 @@ Academic Papers
joblib.rst
iter.rst
pandas_on_ray.rst
- projects.rst
.. toctree::
:hidden:
diff --git a/doc/source/memory-management.rst b/doc/source/memory-management.rst
index 4c0003845115..30d98ea8e2f3 100644
--- a/doc/source/memory-management.rst
+++ b/doc/source/memory-management.rst
@@ -1,7 +1,23 @@
Memory Management
=================
-This page describes how memory management works in Ray and how you can set memory quotas to ensure memory-intensive applications run predictably and reliably.
+This page describes how memory management works in Ray.
+
+Concepts
+~~~~~~~~
+
+There are several ways that Ray applications use memory:
+
+.. image:: images/memory.svg
+
+Ray system memory: this is memory used internally by Ray
+ - **Redis**: memory used for storing the list of nodes and actors present in the cluster. The amount of memory used for these purposes is typically quite small.
+ - **Raylet**: memory used by the C++ raylet process running on each node. This cannot be controlled, but is typically quite small.
+
+Application memory: this is memory used by your application
+ - **Worker heap**: memory used by your application (e.g., in Python code or TensorFlow), best measured as the *resident set size (RSS)* of your application minus its *shared memory usage (SHR)* in commands such as ``top``. The reason you need to subtract *SHR* is that object store shared memory is reported by the OS as shared with each worker. Not subtracting *SHR* will result in double counting memory usage.
+ - **Object store memory**: memory used when your application creates objects in the objects store via ``ray.put`` and when returning values from remote functions. Objects are reference counted and evicted when they fall out of scope. There is an object store server running on each node.
+ - **Object store shared memory**: memory used when your application reads objects via ``ray.get``. Note that if an object is already present on the node, this does not cause additional allocations. This allows large objects to be efficiently shared among many actors and tasks.
ObjectRef Reference Counting
----------------------------
@@ -197,92 +213,14 @@ To enable LRU eviction when the object store is full, initialize ray with the ``
ray start --lru-evict
-Memory Quotas
--------------
-
-You can set memory quotas to ensure your application runs predictably on any Ray cluster configuration. If you're not sure, you can start with a conservative default configuration like the following and see if any limits are hit.
-
-For Ray initialization on a single node, consider setting the following fields:
-
-.. code-block:: python
-
- ray.init(
- memory=2000 * 1024 * 1024,
- object_store_memory=200 * 1024 * 1024,
- driver_object_store_memory=100 * 1024 * 1024)
-
-For Ray usage on a cluster, consider setting the following fields on both the command line and in your Python script:
-
-.. tip:: 200 * 1024 * 1024 bytes is 200 MiB. Use double parentheses to evaluate math in Bash: ``$((200 * 1024 * 1024))``.
-
-.. code-block:: bash
-
- # On the head node
- ray start --head --redis-port=6379 \
- --object-store-memory=$((200 * 1024 * 1024)) \
- --memory=$((200 * 1024 * 1024)) \
- --num-cpus=1
-
- # On the worker node
- ray start --object-store-memory=$((200 * 1024 * 1024)) \
- --memory=$((200 * 1024 * 1024)) \
- --num-cpus=1 \
- --address=$RAY_HEAD_ADDRESS:6379
-
-.. code-block:: python
-
- # In your Python script connecting to Ray:
- ray.init(
- address="auto", # or ":" if not using the default port
- driver_object_store_memory=100 * 1024 * 1024
- )
-
-
-For any custom remote method or actor, you can set requirements as follows:
+Memory Aware Scheduling
+~~~~~~~~~~~~~~~~~~~~~~~
-.. code-block:: python
-
- @ray.remote(
- memory=2000 * 1024 * 1024,
- )
-
-
-Concept Overview
-~~~~~~~~~~~~~~~~
-
-There are several ways that Ray applications use memory:
-
-.. image:: images/memory.svg
-
-Ray system memory: this is memory used internally by Ray
- - **Redis**: memory used for storing task lineage and object metadata. When Redis becomes full, lineage will start to be be LRU evicted, which makes the corresponding objects ineligible for reconstruction on failure.
- - **Raylet**: memory used by the C++ raylet process running on each node. This cannot be controlled, but is usually quite small.
-
-Application memory: this is memory used by your application
- - **Worker heap**: memory used by your application (e.g., in Python code or TensorFlow), best measured as the *resident set size (RSS)* of your application minus its *shared memory usage (SHR)* in commands such as ``top``. The reason you need to subtract *SHR* is that object store shared memory is reported by the OS as shared with each worker. Not subtracting *SHR* will result in double counting memory usage.
- - **Object store memory**: memory used when your application creates objects in the objects store via ``ray.put`` and when returning values from remote functions. Objects are LRU evicted when the store is full, prioritizing objects that are no longer in scope on the driver or any worker. There is an object store server running on each node.
- - **Object store shared memory**: memory used when your application reads objects via ``ray.get``. Note that if an object is already present on the node, this does not cause additional allocations. This allows large objects to be efficiently shared among many actors and tasks.
-
-By default, Ray will cap the memory used by Redis at ``min(30% of node memory, 10GiB)``, and object store at ``min(10% of node memory, 20GiB)``, leaving half of the remaining memory on the node available for use by worker heap. You can also manually configure this by setting ``redis_max_memory=`` and ``object_store_memory=`` on Ray init.
-
-It is important to note that these default Redis and object store limits do not address the following issues:
-
-* Actor or task heap usage exceeding the remaining available memory on a node.
-
-* Heavy use of the object store by certain actors or tasks causing objects required by other tasks to be prematurely evicted.
-
-To avoid these potential sources of instability, you can set *memory quotas* to reserve memory for individual actors and tasks.
-
-Heap memory quota
-~~~~~~~~~~~~~~~~~
-
-When Ray starts, it queries the available memory on a node / container not reserved for Redis and the object store or being used by other applications. This is considered "available memory" that actors and tasks can request memory out of. You can also set ``memory=`` on Ray init to tell Ray explicitly how much memory is available.
+By default, Ray does not take into account the potential memory usage of a task or actor when scheduling. This is simply because it cannot estimate ahead of time how much memory is required. However, if you know how much memory a task or actor requires, you can specify it in the resource requirements of its ``ray.remote`` decorator to enable memory-aware scheduling:
.. important::
- Setting available memory for the node does NOT impose any limits on memory usage
- unless you specify memory resource requirements in decorators. By default, tasks
- and actors request no memory (and hence have no limit).
+ Specifying a memory requirement does NOT impose any limits on memory usage. The requirements are used for admission control during scheduling only (similar to how CPU scheduling works in Ray). It is up to the task itself to not use more memory than it requested.
To tell the Ray scheduler a task or actor requires a certain amount of available memory to run, set the ``memory`` argument. The Ray scheduler will then reserve the specified amount of available memory during scheduling, similar to how it handles CPU and GPU resources:
@@ -311,20 +249,6 @@ In the above example, the memory quota is specified statically by the decorator,
**Enforcement**: If an actor exceeds its memory quota, calls to it will throw ``RayOutOfMemoryError`` and it may be killed. Memory quota is currently enforced on a best-effort basis for actors only (but quota is taken into account during scheduling in all cases).
-Object store memory quota
-~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Use ``@ray.remote(object_store_memory=)`` to cap the amount of memory an actor can use for ``ray.put`` and method call returns. This gives the actor its own LRU queue within the object store of the given size, both protecting its objects from eviction by other actors and preventing it from using more than the specified quota. This quota protects objects from unfair eviction when certain actors are producing objects at a much higher rate than others.
-
-Ray takes this resource into account during scheduling, with the caveat that a node will always reserve ~30% of its object store for global shared use.
-
-For the driver, you can set its object store memory quota with ``driver_object_store_memory``. Setting object store quota is not supported for tasks.
-
-Object store shared memory
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Object store memory is also used to map objects returned by ``ray.get`` calls in shared memory. While an object is mapped in this way (i.e., there is a Python reference to the object), it is pinned and cannot be evicted from the object store. However, ray does not provide quota management for this kind of shared memory usage.
-
Questions or Issues?
--------------------
diff --git a/doc/source/package-ref.rst b/doc/source/package-ref.rst
index 2fa9e55f6f9b..23c827f8a841 100644
--- a/doc/source/package-ref.rst
+++ b/doc/source/package-ref.rst
@@ -74,12 +74,12 @@ ray.get_resource_ids
.. autofunction:: ray.get_resource_ids
-.. _ray-get_webui_url-ref:
+.. _ray-get_dashboard_url-ref:
-ray.get_webui_url
-~~~~~~~~~~~~~~~~~
+ray.get_dashboard_url
+~~~~~~~~~~~~~~~~~~~~~
-.. autofunction:: ray.get_webui_url
+.. autofunction:: ray.get_dashboard_url
.. _ray-shutdown-ref:
@@ -88,21 +88,6 @@ ray.shutdown
.. autofunction:: ray.shutdown
-
-.. _ray-register_custom_serializer-ref:
-
-ray.register_custom_serializer
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: ray.register_custom_serializer
-
-.. _ray-profile-ref:
-
-ray.profile
-~~~~~~~~~~~
-
-.. autofunction:: ray.profile
-
.. _ray-method-ref:
ray.method
@@ -123,13 +108,6 @@ ray.nodes
.. autofunction:: ray.nodes
-.. _ray-objects-ref:
-
-ray.objects
-~~~~~~~~~~~
-
-.. autofunction:: ray.objects
-
.. _ray-timeline-ref:
ray.timeline
@@ -137,13 +115,6 @@ ray.timeline
.. autofunction:: ray.timeline
-.. _ray-object_transfer_timeline-ref:
-
-ray.object_transfer_timeline
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autofunction:: ray.object_transfer_timeline
-
.. _ray-cluster_resources-ref:
ray.cluster_resources
@@ -221,24 +192,12 @@ The Ray Command Line API
:prog: ray stack
:show-nested:
-.. _ray-stat-doc:
-
-.. click:: ray.scripts.scripts:statistics
- :prog: ray statistics
- :show-nested:
-
.. _ray-memory-doc:
.. click:: ray.scripts.scripts:memory
:prog: ray memory
:show-nested:
-.. _ray-globalgc-doc:
-
-.. click:: ray.scripts.scripts:globalgc
- :prog: ray globalgc
- :show-nested:
-
.. _ray-timeline-doc:
.. click:: ray.scripts.scripts:timeline
diff --git a/doc/source/projects.rst b/doc/source/projects.rst
deleted file mode 100644
index be8900452840..000000000000
--- a/doc/source/projects.rst
+++ /dev/null
@@ -1,189 +0,0 @@
-Ray Projects (Experimental)
-===========================
-
-Ray projects make it easy to package a Ray application so it can be
-rerun later in the same environment. They allow for the sharing and
-reliable reuse of existing code.
-
-Quick start (CLI)
------------------
-
-.. code-block:: bash
-
- # Creates a project in the current directory. It will create a
- # project.yaml defining the code and environment and a cluster.yaml
- # describing the cluster configuration. Both will be created in the
- # ray-project subdirectory of the current directory.
- $ ray project create
-
- # Create a new session from the given project. Launch a cluster and run
- # the command, which must be specified in the project.yaml file. If no
- # command is specified, the "default" command in ray-project/project.yaml
- # will be used. Alternatively, use --shell to run a raw shell command.
- $ ray session start [arguments] [--shell]
-
- # Open a console for the given session.
- $ ray session attach
-
- # Stop the given session and terminate all of its worker nodes.
- $ ray session stop
-
-Examples
---------
-See `the readme `__
-for instructions on how to run these examples:
-
-- `Open Tacotron `__:
- A TensorFlow implementation of Google's Tacotron speech synthesis with pre-trained model (unofficial)
-- `PyTorch Transformers `__:
- A library of state-of-the-art pretrained models for Natural Language Processing (NLP)
-
-Tutorial
---------
-
-We will walk through how to use projects by executing the `streaming MapReduce example `_.
-Commands always apply to the project in the current directory.
-Let us switch into the project directory with
-
-.. code-block:: bash
-
- cd ray/doc/examples/streaming
-
-
-A session represents a running instance of a project. Let's start one with
-
-.. code-block:: bash
-
- ray session start
-
-
-The ``ray session start`` command
-will bring up a new cluster and initialize the environment of the cluster
-according to the `environment` section of the `project.yaml`, installing all
-dependencies of the project.
-
-Now we can execute a command in the session. To see a list of all available
-commands of the project, run
-
-.. code-block:: bash
-
- ray session commands
-
-
-which produces the following output:
-
-.. code-block::
-
- Active project: ray-example-streaming
-
- Command "run":
- usage: run [--num-mappers NUM_MAPPERS] [--num-reducers NUM_REDUCERS]
-
- Start the streaming example.
-
- optional arguments:
- --num-mappers NUM_MAPPERS
- Number of mapper actors used
- --num-reducers NUM_REDUCERS
- Number of reducer actors used
-
-
-As you see, in this project there is only a single ``run`` command which has arguments
-``--num-mappers`` and ``--num-reducers``. We can execute the streaming
-wordcount with the default parameters by running
-
-.. code-block:: bash
-
- ray session execute run
-
-
-You can interrupt the command with ``-c`` and attach to the running session by executing
-
-.. code-block:: bash
-
- ray session attach --tmux
-
-
-Inside the session you can for example edit the streaming applications with
-
-.. code-block:: bash
-
- cd ray-example-streaming
- emacs streaming.py
-
-
-Try for example to add the following lines after the ``for count in counts:`` loop:
-
-.. code-block:: python
-
- if "million" in wordcounts:
- print("Found the word!")
-
-
-and re-run the application from outside the session with
-
-.. code-block:: bash
-
- ray session execute run
-
-
-The session can be terminated from outside the session with
-
-.. code-block:: bash
-
- ray session stop
-
-
-Project file format (project.yaml)
-----------------------------------
-
-A project file contains everything required to run a project.
-This includes a cluster configuration, the environment and dependencies
-for the application, and the specific inputs used to run the project.
-
-Here is an example for a minimal project format:
-
-.. code-block:: yaml
-
- name: test-project
- description: "This is a simple test project"
- repo: https://github.com/ray-project/ray
-
- # Cluster to be instantiated by default when starting the project.
- cluster:
- config: ray-project/cluster.yaml
-
- # Commands/information to build the environment, once the cluster is
- # instantiated. This can include the versions of python libraries etc.
- # It can be specified as a Python requirements.txt, a conda environment,
- # a Dockerfile, or a shell script to run to set up the libraries.
- environment:
- requirements: requirements.txt
-
- # List of commands that can be executed once the cluster is instantiated
- # and the environment is set up.
- # A command can also specify a cluster that overwrites the default cluster.
- commands:
- - name: default
- command: python default.py
- help: "The command that will be executed if no command name is specified"
- - name: test
- command: python test.py --param1={{param1}} --param2={{param2}}
- help: "A test command"
- params:
- - name: "param1"
- help: "The first parameter"
- # The following line indicates possible values this parameter can take.
- choices: ["1", "2"]
- - name: "param2"
- help: "The second parameter"
-
-Project files have to adhere to the following schema:
-
-.. jsonschema:: ../../python/ray/projects/schema.json
-
-Cluster file format (cluster.yaml)
-----------------------------------
-
-This is the same as for the autoscaler, see
-:ref:`Cluster Launch page `.
diff --git a/doc/source/serialization.rst b/doc/source/serialization.rst
index 732e547976c1..e1cadf7cb60e 100644
--- a/doc/source/serialization.rst
+++ b/doc/source/serialization.rst
@@ -10,7 +10,7 @@ Since Ray processes do not share memory space, data transferred between workers
Plasma Object Store
-------------------
-Plasma is an in-memory object store that is being developed as part of `Apache Arrow`_. Ray uses Plasma to efficiently transfer objects across different processes and different nodes. All objects in Plasma object store are **immutable** and held in shared memory. This is so that they can be accessed efficiently by many workers on the same node.
+Plasma is an in-memory object store that is being developed as part of Apache Arrow. Ray uses Plasma to efficiently transfer objects across different processes and different nodes. All objects in Plasma object store are **immutable** and held in shared memory. This is so that they can be accessed efficiently by many workers on the same node.
Each node has its own object store. When data is put into the object store, it does not get automatically broadcasted to other nodes. Data remains local to the writer until requested by another task or actor on another node.
@@ -64,49 +64,6 @@ Serialization notes
- Lock objects are mostly unserializable, because copying a lock is meaningless and could cause serious concurrency problems. You may have to come up with a workaround if your object contains a lock.
-Last resort: Custom Serialization
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-If none of these options work, you can try registering a custom serializer with ``ray.register_custom_serializer`` (:ref:`docstring `):
-
-.. code-block:: python
-
- import ray
-
- ray.init()
-
- class Foo(object):
- def __init__(self, value):
- self.value = value
-
- def custom_serializer(obj):
- return obj.value
-
- def custom_deserializer(value):
- object = Foo()
- object.value = value
- return object
-
- ray.register_custom_serializer(
- Foo, serializer=custom_serializer, deserializer=custom_deserializer)
-
- object_ref = ray.put(Foo(100))
- assert ray.get(object_ref).value == 100
-
-
-If you find cases where Ray serialization doesn't work or does something unexpected, please `let us know`_ so we can fix it.
-
-.. _`let us know`: https://github.com/ray-project/ray/issues
-
-Advanced: Huge Pages
-~~~~~~~~~~~~~~~~~~~~
-
-On Linux, it is possible to increase the write throughput of the Plasma object store by using huge pages. See the `Configuration page `_ for information on how to use huge pages in Ray.
-
-
-.. _`Apache Arrow`: https://arrow.apache.org/
-
-
Known Issues
------------
diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst
index d2c8ffb3ef77..ddbf58256644 100644
--- a/doc/source/tune/user-guide.rst
+++ b/doc/source/tune/user-guide.rst
@@ -472,8 +472,7 @@ decide between the two options.
Redirecting stdout and stderr to files
--------------------------------------
The stdout and stderr streams are usually printed to the console. For remote actors,
-Ray collects these logs and prints them to the head process, as long as it
-has been initialized with ``log_to_driver=True``, which is the default.
+Ray collects these logs and prints them to the head process.
However, if you would like to collect the stream outputs in files for later
analysis or troubleshooting, Tune offers an utility parameter, ``log_to_file``,
@@ -508,17 +507,6 @@ too.
If ``log_to_file`` is set, Tune will automatically register a new logging handler
for Ray's base logger and log the output to the specified stderr output file.
-Setting ``log_to_file`` does not disable logging to the driver. If you would
-like to disable the logs showing up in the driver output (i.e. they should only
-show up in the logfiles), initialize Ray accordingly:
-
-.. code-block:: python
-
- ray.init(log_to_driver=False)
- tune.run(
- trainable,
- log_to_file=True)
-
.. _tune-debugging:
Debugging
diff --git a/doc/source/using-ray-with-gpus.rst b/doc/source/using-ray-with-gpus.rst
index 05c1b9ab9198..8e189c4ae696 100644
--- a/doc/source/using-ray-with-gpus.rst
+++ b/doc/source/using-ray-with-gpus.rst
@@ -33,7 +33,7 @@ remote decorator.
print("ray.get_gpu_ids(): {}".format(ray.get_gpu_ids()))
print("CUDA_VISIBLE_DEVICES: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
-Inside of the remote function, a call to ``ray.get_gpu_ids(as_str=True)`` will return a
+Inside of the remote function, a call to ``ray.get_gpu_ids()`` will return a
list of strings indicating which GPUs the remote function is allowed to use.
Typically, it is not necessary to call ``ray.get_gpu_ids()`` because Ray will
automatically set the ``CUDA_VISIBLE_DEVICES`` environment variable.
diff --git a/doc/source/walkthrough.rst b/doc/source/walkthrough.rst
index f73e9d619d8b..7333a77ea723 100644
--- a/doc/source/walkthrough.rst
+++ b/doc/source/walkthrough.rst
@@ -253,7 +253,7 @@ Multiple returns
.. code-block:: python
- @ray.remote(num_return_vals=3)
+ @ray.remote(num_returns=3)
def return_multiple():
return 1, 2, 3
@@ -341,7 +341,7 @@ If the current node's object store does not contain the object, the object is do
assert ray.get([ray.put(i) for i in range(3)]) == [0, 1, 2]
# You can also set a timeout to return early from a ``get`` that's blocking for too long.
- from ray.exceptions import RayTimeoutError
+ from ray.exceptions import GetTimeoutError
@ray.remote
def long_running_function()
@@ -350,7 +350,7 @@ If the current node's object store does not contain the object, the object is do
obj_ref = long_running_function.remote()
try:
ray.get(obj_ref, timeout=4)
- except RayTimeoutError:
+ except GetTimeoutError:
print("`get` timed out.")
.. group-tab:: Java
diff --git a/python/ray/__init__.py b/python/ray/__init__.py
index 6b0bdd1b25a3..251a7453672f 100644
--- a/python/ray/__init__.py
+++ b/python/ray/__init__.py
@@ -80,12 +80,11 @@
available_resources) # noqa: E402
from ray.worker import ( # noqa: F401
LOCAL_MODE, SCRIPT_MODE, WORKER_MODE, IO_WORKER_MODE, cancel, connect,
- disconnect, get, get_actor, get_gpu_ids, get_resource_ids, get_webui_url,
- init, is_initialized, put, kill, register_custom_serializer, remote,
- shutdown, show_in_webui, wait,
+ disconnect, get, get_actor, get_gpu_ids, get_resource_ids,
+ get_dashboard_url, init, is_initialized, put, kill, remote, shutdown,
+ show_in_dashboard, wait,
) # noqa: E402
import ray.internal # noqa: E402
-import ray.projects # noqa: E402
# We import ray.actor because some code is run in actor.py which initializes
# some functions in the worker.
import ray.actor # noqa: F401
@@ -113,7 +112,7 @@
"get_actor",
"get_gpu_ids",
"get_resource_ids",
- "get_webui_url",
+ "get_dashboard_url",
"init",
"internal",
"is_initialized",
@@ -127,12 +126,10 @@
"objects",
"object_transfer_timeline",
"profile",
- "projects",
"put",
- "register_custom_serializer",
"remote",
"shutdown",
- "show_in_webui",
+ "show_in_dashboard",
"timeline",
"util",
"wait",
diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx
index d5f4c1bd1ecc..654a7a538fd0 100644
--- a/python/ray/_raylet.pyx
+++ b/python/ray/_raylet.pyx
@@ -101,11 +101,11 @@ import ray.ray_constants as ray_constants
from ray import profiling
from ray.exceptions import (
RayError,
- RayletError,
+ RaySystemError,
RayTaskError,
ObjectStoreFullError,
- RayTimeoutError,
- RayCancellationError
+ GetTimeoutError,
+ TaskCancelledError
)
from ray.utils import decode
import gc
@@ -143,11 +143,11 @@ cdef int check_status(const CRayStatus& status) nogil except -1:
elif status.IsInterrupted():
raise KeyboardInterrupt()
elif status.IsTimedOut():
- raise RayTimeoutError(message)
+ raise GetTimeoutError(message)
elif status.IsNotFound():
raise ValueError(message)
else:
- raise RayletError(message)
+ raise RaySystemError(message)
cdef RayObjectsToDataMetadataPairs(
const c_vector[shared_ptr[CRayObject]] objects):
@@ -360,7 +360,7 @@ cdef execute_task(
CFiberEvent task_done_event
# Automatically restrict the GPUs available to this task.
- ray.utils.set_cuda_visible_devices(ray.get_gpu_ids(as_str=True))
+ ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
function_descriptor = CFunctionDescriptorToPython(
ray_function.GetFunctionDescriptor())
@@ -481,7 +481,7 @@ cdef execute_task(
outputs = function_executor(*args, **kwargs)
task_exception = False
except KeyboardInterrupt as e:
- raise RayCancellationError(
+ raise TaskCancelledError(
core_worker.get_current_task_id())
if c_return_ids.size() == 1:
outputs = (outputs,)
@@ -489,7 +489,7 @@ cdef execute_task(
# was exiting and was raised after the except block.
if not check_signals().ok():
task_exception = True
- raise RayCancellationError(
+ raise TaskCancelledError(
core_worker.get_current_task_id())
# Store the outputs in the object store.
with core_worker.profile_event(b"task:store_outputs"):
@@ -697,6 +697,14 @@ cdef void terminate_asyncio_thread() nogil:
core_worker.destroy_event_loop_if_exists()
+# An empty profile event context to be used when the timeline is disabled.
+cdef class EmptyProfileEvent:
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
+
cdef class CoreWorker:
def __cinit__(self, worker_type, store_socket, raylet_socket,
@@ -976,7 +984,7 @@ cdef class CoreWorker:
Language language,
FunctionDescriptor function_descriptor,
args,
- int num_return_vals,
+ int num_returns,
resources,
int max_retries,
PlacementGroupID placement_group_id,
@@ -993,7 +1001,7 @@ cdef class CoreWorker:
with self.profile_event(b"submit_task"):
prepare_resources(resources, &c_resources)
task_options = CTaskOptions(
- num_return_vals, c_resources)
+ num_returns, c_resources)
ray_function = CRayFunction(
language.lang, function_descriptor.descriptor)
prepare_args(self, language, args, &args_vector)
@@ -1103,7 +1111,7 @@ cdef class CoreWorker:
ActorID actor_id,
FunctionDescriptor function_descriptor,
args,
- int num_return_vals,
+ int num_returns,
double num_method_cpus):
cdef:
@@ -1117,7 +1125,7 @@ cdef class CoreWorker:
with self.profile_event(b"submit_task"):
if num_method_cpus > 0:
c_resources[b"CPU"] = num_method_cpus
- task_options = CTaskOptions(num_return_vals, c_resources)
+ task_options = CTaskOptions(num_returns, c_resources)
ray_function = CRayFunction(
language.lang, function_descriptor.descriptor)
prepare_args(self, language, args, &args_vector)
@@ -1172,9 +1180,12 @@ cdef class CoreWorker:
return resources_dict
def profile_event(self, c_string event_type, object extra_data=None):
- return ProfileEvent.make(
- CCoreWorkerProcess.GetCoreWorker().CreateProfileEvent(event_type),
- extra_data)
+ if RayConfig.instance().enable_timeline():
+ return ProfileEvent.make(
+ CCoreWorkerProcess.GetCoreWorker().CreateProfileEvent(
+ event_type), extra_data)
+ else:
+ return EmptyProfileEvent()
def remove_actor_handle_reference(self, ActorID actor_id):
cdef:
@@ -1209,7 +1220,7 @@ cdef class CoreWorker:
return ray.actor.ActorHandle(language, actor_id,
method_meta.decorators,
method_meta.signatures,
- method_meta.num_return_vals,
+ method_meta.num_returns,
actor_method_cpu,
actor_creation_function_descriptor,
worker.current_session_and_job)
@@ -1217,7 +1228,7 @@ cdef class CoreWorker:
return ray.actor.ActorHandle(language, actor_id,
{}, # method decorators
{}, # method signatures
- {}, # method num_return_vals
+ {}, # method num_returns
0, # actor method cpu
actor_creation_function_descriptor,
worker.current_session_and_job)
diff --git a/python/ray/actor.py b/python/ray/actor.py
index 9ae43d77b47f..7a60ec1ec903 100644
--- a/python/ray/actor.py
+++ b/python/ray/actor.py
@@ -23,7 +23,7 @@ def method(*args, **kwargs):
@ray.remote
class Foo:
- @ray.method(num_return_vals=2)
+ @ray.method(num_returns=2)
def bar(self):
return 1, 2
@@ -32,16 +32,16 @@ def bar(self):
_, _ = f.bar.remote()
Args:
- num_return_vals: The number of object refs that should be returned by
+ num_returns: The number of object refs that should be returned by
invocations of this actor method.
"""
assert len(args) == 0
assert len(kwargs) == 1
- assert "num_return_vals" in kwargs
- num_return_vals = kwargs["num_return_vals"]
+ assert "num_returns" in kwargs
+ num_returns = kwargs["num_returns"]
def annotate_method(method):
- method.__ray_num_return_vals__ = num_return_vals
+ method.__ray_num_returns__ = num_returns
return method
return annotate_method
@@ -58,7 +58,7 @@ class ActorMethod:
Attributes:
_actor: A handle to the actor.
_method_name: The name of the actor method.
- _num_return_vals: The default number of return values that the method
+ _num_returns: The default number of return values that the method
invocation should return.
_decorator: An optional decorator that should be applied to the actor
method invocation (as opposed to the actor method execution) before
@@ -72,12 +72,12 @@ class ActorMethod:
def __init__(self,
actor,
method_name,
- num_return_vals,
+ num_returns,
decorator=None,
hardref=False):
self._actor_ref = weakref.ref(actor)
self._method_name = method_name
- self._num_return_vals = num_return_vals
+ self._num_returns = num_returns
# This is a decorator that is used to wrap the function invocation (as
# opposed to the function execution). The decorator must return a
# function that takes in two arguments ("args" and "kwargs"). In most
@@ -100,9 +100,9 @@ def __call__(self, *args, **kwargs):
def remote(self, *args, **kwargs):
return self._remote(args, kwargs)
- def _remote(self, args=None, kwargs=None, num_return_vals=None):
- if num_return_vals is None:
- num_return_vals = self._num_return_vals
+ def _remote(self, args=None, kwargs=None, num_returns=None):
+ if num_returns is None:
+ num_returns = self._num_returns
def invocation(args, kwargs):
actor = self._actor_hard_ref or self._actor_ref()
@@ -112,7 +112,7 @@ def invocation(args, kwargs):
self._method_name,
args=args,
kwargs=kwargs,
- num_return_vals=num_return_vals)
+ num_returns=num_returns)
# Apply the decorator if there is one.
if self._decorator is not None:
@@ -124,7 +124,7 @@ def __getstate__(self):
return {
"actor": self._actor_ref(),
"method_name": self._method_name,
- "num_return_vals": self._num_return_vals,
+ "num_returns": self._num_returns,
"decorator": self._decorator,
}
@@ -132,7 +132,7 @@ def __setstate__(self, state):
self.__init__(
state["actor"],
state["method_name"],
- state["num_return_vals"],
+ state["num_returns"],
state["decorator"],
hardref=True)
@@ -147,7 +147,7 @@ class ActorClassMethodMetadata(object):
can be set by attaching the attribute
"__ray_invocation_decorator__" to the actor method.
signatures: The signatures of the methods.
- num_return_vals: The default number of return values for
+ num_returns: The default number of return values for
each actor method.
"""
@@ -182,7 +182,7 @@ def create(cls, modified_class, actor_creation_function_descriptor):
# arguments.
self.decorators = {}
self.signatures = {}
- self.num_return_vals = {}
+ self.num_returns = {}
for method_name, method in actor_methods:
# Whether or not this method requires binding of its first
# argument. For class and static methods, we do not want to bind
@@ -198,11 +198,10 @@ def create(cls, modified_class, actor_creation_function_descriptor):
self.signatures[method_name] = signature.extract_signature(
method, ignore_first=not is_bound)
# Set the default number of return values for this method.
- if hasattr(method, "__ray_num_return_vals__"):
- self.num_return_vals[method_name] = (
- method.__ray_num_return_vals__)
+ if hasattr(method, "__ray_num_returns__"):
+ self.num_returns[method_name] = (method.__ray_num_returns__)
else:
- self.num_return_vals[method_name] = (
+ self.num_returns[method_name] = (
ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS)
if hasattr(method, "__ray_invocation_decorator__"):
@@ -406,7 +405,6 @@ def _remote(self,
memory=None,
object_store_memory=None,
resources=None,
- is_direct_call=None,
max_concurrency=None,
max_restarts=None,
max_task_retries=None,
@@ -430,7 +428,6 @@ def _remote(self,
this actor when creating objects.
resources: The custom resources required by the actor creation
task.
- is_direct_call: Use direct actor calls.
max_concurrency: The max number of concurrent calls to allow for
this actor. This only works with direct actor calls. The max
concurrency defaults to 1 for threaded execution, and 1000 for
@@ -456,8 +453,6 @@ def _remote(self,
args = []
if kwargs is None:
kwargs = {}
- if is_direct_call is not None and not is_direct_call:
- raise ValueError("Non-direct call actors are no longer supported.")
meta = self.__ray_metadata__
actor_has_async_methods = len(
inspect.getmembers(
@@ -593,7 +588,7 @@ def _remote(self,
actor_id,
meta.method_meta.decorators,
meta.method_meta.signatures,
- meta.method_meta.num_return_vals,
+ meta.method_meta.num_returns,
actor_method_cpu,
meta.actor_creation_function_descriptor,
worker.current_session_and_job,
@@ -621,7 +616,7 @@ class ActorHandle:
invocation side, whereas a regular decorator can be used to change
the behavior on the execution side.
_ray_method_signatures: The signatures of the actor methods.
- _ray_method_num_return_vals: The default number of return values for
+ _ray_method_num_returns: The default number of return values for
each method.
_ray_actor_method_cpus: The number of CPUs required by actor methods.
_ray_original_handle: True if this is the original actor handle for a
@@ -637,7 +632,7 @@ def __init__(self,
actor_id,
method_decorators,
method_signatures,
- method_num_return_vals,
+ method_num_returns,
actor_method_cpus,
actor_creation_function_descriptor,
session_and_job,
@@ -647,7 +642,7 @@ def __init__(self,
self._ray_original_handle = original_handle
self._ray_method_decorators = method_decorators
self._ray_method_signatures = method_signatures
- self._ray_method_num_return_vals = method_num_return_vals
+ self._ray_method_num_returns = method_num_returns
self._ray_actor_method_cpus = actor_method_cpus
self._ray_session_and_job = session_and_job
self._ray_is_cross_language = language != Language.PYTHON
@@ -668,7 +663,7 @@ def __init__(self,
method = ActorMethod(
self,
method_name,
- self._ray_method_num_return_vals[method_name],
+ self._ray_method_num_returns[method_name],
decorator=self._ray_method_decorators.get(method_name))
setattr(self, method_name, method)
@@ -684,7 +679,7 @@ def _actor_method_call(self,
method_name,
args=None,
kwargs=None,
- num_return_vals=None):
+ num_returns=None):
"""Method execution stub for an actor handle.
This is the function that executes when
@@ -696,7 +691,7 @@ def _actor_method_call(self,
method_name: The name of the actor method to execute.
args: A list of arguments for the actor method.
kwargs: A dictionary of keyword arguments for the actor method.
- num_return_vals (int): The number of return values for the method.
+ num_returns (int): The number of return values for the method.
Returns:
object_refs: A list of object refs returned by the remote actor
@@ -729,7 +724,7 @@ def _actor_method_call(self,
object_refs = worker.core_worker.submit_actor_task(
self._ray_actor_language, self._ray_actor_id, function_descriptor,
- list_args, num_return_vals, self._ray_actor_method_cpus)
+ list_args, num_returns, self._ray_actor_method_cpus)
if len(object_refs) == 1:
object_refs = object_refs[0]
@@ -799,7 +794,7 @@ def _serialization_helper(self):
"actor_id": self._ray_actor_id,
"method_decorators": self._ray_method_decorators,
"method_signatures": self._ray_method_signatures,
- "method_num_return_vals": self._ray_method_num_return_vals,
+ "method_num_returns": self._ray_method_num_returns,
"actor_method_cpus": self._ray_actor_method_cpus,
"actor_creation_function_descriptor": self.
_ray_actor_creation_function_descriptor,
@@ -834,7 +829,7 @@ def _deserialization_helper(cls, state, outer_object_ref=None):
state["actor_id"],
state["method_decorators"],
state["method_signatures"],
- state["method_num_return_vals"],
+ state["method_num_returns"],
state["actor_method_cpus"],
state["actor_creation_function_descriptor"],
worker.current_session_and_job)
diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml
index 345a9ccb6b53..89f815acb79a 100644
--- a/python/ray/autoscaler/aws/example-full.yaml
+++ b/python/ray/autoscaler/aws/example-full.yaml
@@ -42,7 +42,7 @@ docker:
# usage. For example, if a cluster of 10 nodes is 100% busy and
# target_utilization is 0.8, it would resize the cluster to 13. This fraction
# can be decreased to increase the aggressiveness of upscaling.
-# This value must be less than 1.0 for scaling to happen.
+# This max value allowed is 1.0, which is the most conservative setting.
target_utilization_fraction: 0.8
# If a node is idle for this many minutes, it will be removed.
diff --git a/python/ray/autoscaler/aws/example-multi-node-type.yaml b/python/ray/autoscaler/aws/example-multi-node-type.yaml
index a6f9a311866f..ec3d5c37bda7 100644
--- a/python/ray/autoscaler/aws/example-multi-node-type.yaml
+++ b/python/ray/autoscaler/aws/example-multi-node-type.yaml
@@ -18,16 +18,11 @@ available_node_types:
InstanceType: m4.xlarge
resources: {"CPU": 4}
max_workers: 5
- cpu_4_spot:
+ cpu_16_spot:
node_config:
- InstanceType: m4.xlarge
+ InstanceType: m4.4xlarge
InstanceMarketOptions:
MarketType: spot
- resources: {"CPU": 4}
- max_workers: 20
- cpu_16_ondemand:
- node_config:
- InstanceType: m4.4xlarge
resources: {"CPU": 16, "Custom1": 1}
max_workers: 10
gpu_1_ondemand:
@@ -35,17 +30,21 @@ available_node_types:
InstanceType: p2.xlarge
resources: {"CPU": 4, "GPU": 1, "Custom2": 2}
max_workers: 4
+ worker_setup_commands:
+ - pip install tensorflow-gpu # Example command.
gpu_8_ondemand:
node_config:
InstanceType: p2.8xlarge
resources: {"CPU": 32, "GPU": 8}
max_workers: 2
+ worker_setup_commands:
+ - pip install tensorflow-gpu # Example command.
# Specify the node type of the head node (as configured above).
head_node_type: cpu_4_ondemand
# Specify the default type of the worker node (as configured above).
-worker_default_node_type: cpu_4_spot
+worker_default_node_type: cpu_16_spot
# The default settings for the head node. This will be merged with the per-node
# type configs given above.
diff --git a/python/ray/autoscaler/command_runner.py b/python/ray/autoscaler/command_runner.py
index 240317572344..c6ae246691d5 100644
--- a/python/ray/autoscaler/command_runner.py
+++ b/python/ray/autoscaler/command_runner.py
@@ -8,6 +8,7 @@
import subprocess
import sys
import time
+import json
from ray.autoscaler.docker import check_docker_running_cmd, \
check_docker_image, \
@@ -81,18 +82,9 @@ def _with_environment_variables(cmd: str,
automatically be converted to a one line yaml string.
"""
- def dict_as_one_line_yaml(d):
- items = []
- for key, val in d.items():
- item_str = "{}: {}".format(quote(str(key)), quote(str(val)))
- items.append(item_str)
-
- return "{" + ",".join(items) + "}"
-
as_strings = []
for key, val in environment_variables.items():
- if isinstance(val, dict):
- val = dict_as_one_line_yaml(val)
+ val = json.dumps(val, separators=(",", ":"))
s = "export {}={};".format(key, quote(val))
as_strings.append(s)
all_vars = "".join(as_strings)
@@ -102,7 +94,6 @@ def dict_as_one_line_yaml(d):
def _with_interactive(cmd):
force_interactive = ("true && source ~/.bashrc && "
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
-
return ["bash", "--login", "-c", "-i", quote(force_interactive + cmd)]
diff --git a/python/ray/autoscaler/resource_demand_scheduler.py b/python/ray/autoscaler/resource_demand_scheduler.py
index 908bb14eac22..f1b6baeba706 100644
--- a/python/ray/autoscaler/resource_demand_scheduler.py
+++ b/python/ray/autoscaler/resource_demand_scheduler.py
@@ -1,3 +1,12 @@
+"""Implements multi-node-type autoscaling.
+
+This file implements an autoscaling algorithm that is aware of multiple node
+types (e.g., example-multi-node-type.yaml). The Ray autoscaler will pass in
+a vector of resource shape demands, and the resource demand scheduler will
+return a list of node types that can satisfy the demands given constraints
+(i.e., reverse bin packing).
+"""
+
import copy
import numpy as np
import logging
@@ -30,18 +39,43 @@ def __init__(self, provider: NodeProvider,
self.node_types = node_types
self.max_workers = max_workers
- def debug_string(self, nodes: List[NodeID],
- pending_nodes: Dict[NodeID, int]) -> str:
+ # TODO(ekl) take into account existing utilization of node resources. We
+ # should subtract these from node resources prior to running bin packing.
+ def get_nodes_to_launch(self, nodes: List[NodeID],
+ pending_nodes: Dict[NodeType, int],
+ resource_demands: List[ResourceDict]
+ ) -> List[Tuple[NodeType, int]]:
+ """Given resource demands, return node types to add to the cluster.
+
+ This method:
+ (1) calculates the resources present in the cluster.
+ (2) calculates the unfulfilled resource bundles.
+ (3) calculates which nodes need to be launched to fulfill all
+ the bundle requests, subject to max_worker constraints.
+
+ Args:
+ nodes: List of existing nodes in the cluster.
+ pending_nodes: Summary of node types currently being launched.
+ resource_demands: Vector of resource demands from the scheduler.
+ """
+
+ if resource_demands is None:
+ logger.info("No resource demands")
+ return []
+
node_resources, node_type_counts = self.calculate_node_resources(
nodes, pending_nodes)
+ logger.info("Cluster resources: {}".format(node_resources))
+ logger.info("Node counts: {}".format(node_type_counts))
- out = "Worker node types:"
- for node_type, count in node_type_counts.items():
- out += "\n - {}: {}".format(node_type, count)
- if pending_nodes.get(node_type):
- out += " ({} pending)".format(pending_nodes[node_type])
+ unfulfilled = get_bin_pack_residual(node_resources, resource_demands)
+ logger.info("Resource demands: {}".format(resource_demands))
+ logger.info("Unfulfilled demands: {}".format(unfulfilled))
- return out
+ nodes = get_nodes_for(self.node_types, node_type_counts,
+ self.max_workers - len(nodes), unfulfilled)
+ logger.info("Node requests: {}".format(nodes))
+ return nodes
def calculate_node_resources(
self, nodes: List[NodeID], pending_nodes: Dict[NodeID, int]
@@ -73,36 +107,18 @@ def add_node(node_type):
return node_resources, node_type_counts
- def get_nodes_to_launch(self, nodes: List[NodeID],
- pending_nodes: Dict[NodeType, int],
- resource_demands: List[ResourceDict]
- ) -> List[Tuple[NodeType, int]]:
- """Get a list of node types that should be added to the cluster.
-
- This method:
- (1) calculates the resources present in the cluster.
- (2) calculates the unfulfilled resource bundles.
- (3) calculates which nodes need to be launched to fulfill all
- the bundle requests, subject to max_worker constraints.
- """
-
- if resource_demands is None:
- logger.info("No resource demands")
- return []
-
+ def debug_string(self, nodes: List[NodeID],
+ pending_nodes: Dict[NodeID, int]) -> str:
node_resources, node_type_counts = self.calculate_node_resources(
nodes, pending_nodes)
- logger.info("Cluster resources: {}".format(node_resources))
- logger.info("Node counts: {}".format(node_type_counts))
- unfulfilled = get_bin_pack_residual(node_resources, resource_demands)
- logger.info("Resource demands: {}".format(resource_demands))
- logger.info("Unfulfilled demands: {}".format(unfulfilled))
+ out = "Worker node types:"
+ for node_type, count in node_type_counts.items():
+ out += "\n - {}: {}".format(node_type, count)
+ if pending_nodes.get(node_type):
+ out += " ({} pending)".format(pending_nodes[node_type])
- nodes = get_nodes_for(self.node_types, node_type_counts,
- self.max_workers - len(nodes), unfulfilled)
- logger.info("Node requests: {}".format(nodes))
- return nodes
+ return out
def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
diff --git a/python/ray/cloudpickle/__init__.py b/python/ray/cloudpickle/__init__.py
index 91c46bb768c3..b28e91ee8184 100644
--- a/python/ray/cloudpickle/__init__.py
+++ b/python/ray/cloudpickle/__init__.py
@@ -1,3 +1,11 @@
-from ray.cloudpickle.cloudpickle_fast import * # noqa: F401, F403
+from __future__ import absolute_import
-__version__ = '1.4.1'
+
+from ray.cloudpickle.cloudpickle import * # noqa
+from ray.cloudpickle.cloudpickle_fast import CloudPickler, dumps, dump # noqa
+
+# Conform to the convention used by python serialization libraries, which
+# expose their Pickler subclass at top-level under the "Pickler" name.
+Pickler = CloudPickler
+
+__version__ = '1.6.0'
diff --git a/python/ray/cloudpickle/cloudpickle.py b/python/ray/cloudpickle/cloudpickle.py
index c639daab1c87..05d52afa0da9 100644
--- a/python/ray/cloudpickle/cloudpickle.py
+++ b/python/ray/cloudpickle/cloudpickle.py
@@ -42,29 +42,21 @@
"""
from __future__ import print_function
-import abc
import builtins
import dis
-import io
-import itertools
-import logging
import opcode
-import operator
-import pickle
import platform
-import struct
import sys
import types
import weakref
import uuid
import threading
import typing
-from enum import Enum
+import warnings
+from .compat import pickle
from typing import Generic, Union, Tuple, Callable
-from pickle import _Pickler as Pickler
from pickle import _getattribute
-from io import BytesIO
from importlib._bootstrap import _find_spec
try: # pragma: no branch
@@ -78,6 +70,17 @@
else: # pragma: no cover
ClassVar = None
+if sys.version_info >= (3, 8):
+ from types import CellType
+else:
+ def f():
+ a = 1
+
+ def g():
+ return a
+ return g
+ CellType = type(f().__closure__[0])
+
# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
@@ -163,9 +166,24 @@ def _whichmodule(obj, name):
return None
-def _is_importable_by_name(obj, name=None):
- """Determine if obj can be pickled as attribute of a file-backed module"""
- return _lookup_module_and_qualname(obj, name=name) is not None
+def _is_importable(obj, name=None):
+ """Dispatcher utility to test the importability of various constructs."""
+ if isinstance(obj, types.FunctionType):
+ return _lookup_module_and_qualname(obj, name=name) is not None
+ elif issubclass(type(obj), type):
+ return _lookup_module_and_qualname(obj, name=name) is not None
+ elif isinstance(obj, types.ModuleType):
+ # We assume that sys.modules is primarily used as a cache mechanism for
+ # the Python import machinery. Checking if a module has been added in
+ # is sys.modules therefore a cheap and simple heuristic to tell us whether
+ # we can assume that a given module could be imported by name in
+ # another Python process.
+ return obj.__name__ in sys.modules
+ else:
+ raise TypeError(
+ "cannot check importability of {} instances".format(
+ type(obj).__name__)
+ )
def _lookup_module_and_qualname(obj, name=None):
@@ -187,6 +205,8 @@ def _lookup_module_and_qualname(obj, name=None):
if module_name == "__main__":
return None
+ # Note: if module_name is in sys.modules, the corresponding module is
+ # assumed importable at unpickling time. See #357
module = sys.modules.get(module_name, None)
if module is None:
# The main reason why obj's module would not be imported is that this
@@ -196,10 +216,6 @@ def _lookup_module_and_qualname(obj, name=None):
# supported, as the standard pickle does not support it either.
return None
- # module has been added to sys.modules, but it can still be dynamic.
- if _is_dynamic(module):
- return None
-
try:
obj2, parent = _getattribute(module, name)
except AttributeError:
@@ -458,577 +474,61 @@ def _is_parametrized_type_hint(obj):
def _create_parametrized_type_hint(origin, args):
return origin[args]
+else:
+ _is_parametrized_type_hint = None
+ _create_parametrized_type_hint = None
+
+
+def parametrized_type_hint_getinitargs(obj):
+ # The distorted type check sematic for typing construct becomes:
+ # ``type(obj) is type(TypeHint)``, which means "obj is a
+ # parametrized TypeHint"
+ if type(obj) is type(Literal): # pragma: no branch
+ initargs = (Literal, obj.__values__)
+ elif type(obj) is type(Final): # pragma: no branch
+ initargs = (Final, obj.__type__)
+ elif type(obj) is type(ClassVar):
+ initargs = (ClassVar, obj.__type__)
+ elif type(obj) is type(Generic):
+ parameters = obj.__parameters__
+ if len(obj.__parameters__) > 0:
+ # in early Python 3.5, __parameters__ was sometimes
+ # preferred to __args__
+ initargs = (obj.__origin__, parameters)
-
-class CloudPickler(Pickler):
-
- dispatch = Pickler.dispatch.copy()
-
- def __init__(self, file, protocol=None):
- if protocol is None:
- protocol = DEFAULT_PROTOCOL
- Pickler.__init__(self, file, protocol=protocol)
- # map ids to dictionary. used to ensure that functions can share global env
- self.globals_ref = {}
-
- def dump(self, obj):
- self.inject_addons()
- try:
- return Pickler.dump(self, obj)
- except RuntimeError as e:
- if 'recursion' in e.args[0]:
- msg = """Could not pickle object as excessively deep recursion required."""
- raise pickle.PicklingError(msg)
- else:
- raise
-
- def save_typevar(self, obj):
- self.save_reduce(*_typevar_reduce(obj), obj=obj)
-
- dispatch[typing.TypeVar] = save_typevar
-
- def save_memoryview(self, obj):
- self.save(obj.tobytes())
-
- dispatch[memoryview] = save_memoryview
-
- def save_module(self, obj):
- """
- Save a module as an import
- """
- if _is_dynamic(obj):
- obj.__dict__.pop('__builtins__', None)
- self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)),
- obj=obj)
else:
- self.save_reduce(subimport, (obj.__name__,), obj=obj)
-
- dispatch[types.ModuleType] = save_module
-
- def save_codeobject(self, obj):
- """
- Save a code object
- """
- if hasattr(obj, "co_posonlyargcount"): # pragma: no branch
- args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
- obj.co_cellvars
- )
+ initargs = (obj.__origin__, obj.__args__)
+ elif type(obj) is type(Union):
+ if sys.version_info < (3, 5, 3): # pragma: no cover
+ initargs = (Union, obj.__union_params__)
else:
- args = (
- obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
- obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
- obj.co_names, obj.co_varnames, obj.co_filename,
- obj.co_name, obj.co_firstlineno, obj.co_lnotab,
- obj.co_freevars, obj.co_cellvars
- )
- self.save_reduce(types.CodeType, args, obj=obj)
-
- dispatch[types.CodeType] = save_codeobject
-
- def save_function(self, obj, name=None):
- """ Registered with the dispatch to handle all function types.
-
- Determines what kind of function obj is (e.g. lambda, defined at
- interactive prompt, etc) and handles the pickling appropriately.
- """
- if _is_importable_by_name(obj, name=name):
- return Pickler.save_global(self, obj, name=name)
- elif PYPY and isinstance(obj.__code__, builtin_code_type):
- return self.save_pypy_builtin_func(obj)
+ initargs = (Union, obj.__args__)
+ elif type(obj) is type(Tuple):
+ if sys.version_info < (3, 5, 3): # pragma: no cover
+ initargs = (Tuple, obj.__tuple_params__)
else:
- return self.save_function_tuple(obj)
-
- dispatch[types.FunctionType] = save_function
-
- def save_pypy_builtin_func(self, obj):
- """Save pypy equivalent of builtin functions.
-
- PyPy does not have the concept of builtin-functions. Instead,
- builtin-functions are simple function instances, but with a
- builtin-code attribute.
- Most of the time, builtin functions should be pickled by attribute. But
- PyPy has flaky support for __qualname__, so some builtin functions such
- as float.__new__ will be classified as dynamic. For this reason only,
- we created this special routine. Because builtin-functions are not
- expected to have closure or globals, there is no additional hack
- (compared the one already implemented in pickle) to protect ourselves
- from reference cycles. A simple (reconstructor, newargs, obj.__dict__)
- tuple is save_reduced.
-
- Note also that PyPy improved their support for __qualname__ in v3.6, so
- this routing should be removed when cloudpickle supports only PyPy 3.6
- and later.
- """
- rv = (types.FunctionType, (obj.__code__, {}, obj.__name__,
- obj.__defaults__, obj.__closure__),
- obj.__dict__)
- self.save_reduce(*rv, obj=obj)
-
- def _save_dynamic_enum(self, obj, clsdict):
- """Special handling for dynamic Enum subclasses
-
- Use a dedicated Enum constructor (inspired by EnumMeta.__call__) as the
- EnumMeta metaclass has complex initialization that makes the Enum
- subclasses hold references to their own instances.
- """
- members = dict((e.name, e.value) for e in obj)
-
- self.save_reduce(
- _make_skeleton_enum,
- (obj.__bases__, obj.__name__, obj.__qualname__,
- members, obj.__module__, _get_or_create_tracker_id(obj), None),
- obj=obj
- )
-
- # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class:
- # Those attributes are already handled by the metaclass.
- for attrname in ["_generate_next_value_", "_member_names_",
- "_member_map_", "_member_type_",
- "_value2member_map_"]:
- clsdict.pop(attrname, None)
- for member in members:
- clsdict.pop(member)
-
- def save_dynamic_class(self, obj):
- """Save a class that can't be stored as module global.
-
- This method is used to serialize classes that are defined inside
- functions, or that otherwise can't be serialized as attribute lookups
- from global modules.
- """
- clsdict = _extract_class_dict(obj)
- clsdict.pop('__weakref__', None)
-
- if issubclass(type(obj), abc.ABCMeta):
- # If obj is an instance of an ABCMeta subclass, dont pickle the
- # cache/negative caches populated during isinstance/issubclass
- # checks, but pickle the list of registered subclasses of obj.
- clsdict.pop('_abc_cache', None)
- clsdict.pop('_abc_negative_cache', None)
- clsdict.pop('_abc_negative_cache_version', None)
- registry = clsdict.pop('_abc_registry', None)
- if registry is None:
- # in Python3.7+, the abc caches and registered subclasses of a
- # class are bundled into the single _abc_impl attribute
- clsdict.pop('_abc_impl', None)
- (registry, _, _, _) = abc._get_dump(obj)
-
- clsdict["_abc_impl"] = [subclass_weakref()
- for subclass_weakref in registry]
- else:
- # In the above if clause, registry is a set of weakrefs -- in
- # this case, registry is a WeakSet
- clsdict["_abc_impl"] = [type_ for type_ in registry]
-
- # On PyPy, __doc__ is a readonly attribute, so we need to include it in
- # the initial skeleton class. This is safe because we know that the
- # doc can't participate in a cycle with the original class.
- type_kwargs = {'__doc__': clsdict.pop('__doc__', None)}
-
- if "__slots__" in clsdict:
- type_kwargs['__slots__'] = obj.__slots__
- # pickle string length optimization: member descriptors of obj are
- # created automatically from obj's __slots__ attribute, no need to
- # save them in obj's state
- if isinstance(obj.__slots__, str):
- clsdict.pop(obj.__slots__)
- else:
- for k in obj.__slots__:
- clsdict.pop(k, None)
-
- # If type overrides __dict__ as a property, include it in the type
- # kwargs. In Python 2, we can't set this attribute after construction.
- # XXX: can this ever happen in Python 3? If so add a test.
- __dict__ = clsdict.pop('__dict__', None)
- if isinstance(__dict__, property):
- type_kwargs['__dict__'] = __dict__
-
- save = self.save
- write = self.write
-
- # We write pickle instructions explicitly here to handle the
- # possibility that the type object participates in a cycle with its own
- # __dict__. We first write an empty "skeleton" version of the class and
- # memoize it before writing the class' __dict__ itself. We then write
- # instructions to "rehydrate" the skeleton class by restoring the
- # attributes from the __dict__.
- #
- # A type can appear in a cycle with its __dict__ if an instance of the
- # type appears in the type's __dict__ (which happens for the stdlib
- # Enum class), or if the type defines methods that close over the name
- # of the type, (which is common for Python 2-style super() calls).
-
- # Push the rehydration function.
- save(_rehydrate_skeleton_class)
-
- # Mark the start of the args tuple for the rehydration function.
- write(pickle.MARK)
-
- # Create and memoize an skeleton class with obj's name and bases.
- if Enum is not None and issubclass(obj, Enum):
- # Special handling of Enum subclasses
- self._save_dynamic_enum(obj, clsdict)
+ initargs = (Tuple, obj.__args__)
+ elif type(obj) is type(Callable):
+ if sys.version_info < (3, 5, 3): # pragma: no cover
+ args = obj.__args__
+ result = obj.__result__
+ if args != Ellipsis:
+ if isinstance(args, tuple):
+ args = list(args)
+ else:
+ args = [args]
else:
- # "Regular" class definition:
- tp = type(obj)
- self.save_reduce(_make_skeleton_class,
- (tp, obj.__name__, _get_bases(obj), type_kwargs,
- _get_or_create_tracker_id(obj), None),
- obj=obj)
-
- # Now save the rest of obj's __dict__. Any references to obj
- # encountered while saving will point to the skeleton class.
- save(clsdict)
-
- # Write a tuple of (skeleton_class, clsdict).
- write(pickle.TUPLE)
-
- # Call _rehydrate_skeleton_class(skeleton_class, clsdict)
- write(pickle.REDUCE)
-
- def save_function_tuple(self, func):
- """ Pickles an actual func object.
-
- A func comprises: code, globals, defaults, closure, and dict. We
- extract and save these, injecting reducing functions at certain points
- to recreate the func object. Keep in mind that some of these pieces
- can contain a ref to the func itself. Thus, a naive save on these
- pieces could trigger an infinite loop of save's. To get around that,
- we first create a skeleton func object using just the code (this is
- safe, since this won't contain a ref to the func), and memoize it as
- soon as it's created. The other stuff can then be filled in later.
- """
- if is_tornado_coroutine(func):
- self.save_reduce(_rebuild_tornado_coroutine, (func.__wrapped__,),
- obj=func)
- return
-
- save = self.save
- write = self.write
-
- code, f_globals, defaults, closure_values, dct, base_globals = self.extract_func_data(func)
-
- save(_fill_function) # skeleton function updater
- write(pickle.MARK) # beginning of tuple that _fill_function expects
-
- # Extract currently-imported submodules used by func. Storing these
- # modules in a smoke _cloudpickle_subimports attribute of the object's
- # state will trigger the side effect of importing these modules at
- # unpickling time (which is necessary for func to work correctly once
- # depickled)
- submodules = _find_imported_submodules(
- code,
- itertools.chain(f_globals.values(), closure_values or ()),
- )
-
- # create a skeleton function object and memoize it
- save(_make_skel_func)
- save((
- code,
- len(closure_values) if closure_values is not None else -1,
- base_globals,
- ))
- write(pickle.REDUCE)
- self.memoize(func)
-
- # save the rest of the func data needed by _fill_function
- state = {
- 'globals': f_globals,
- 'defaults': defaults,
- 'dict': dct,
- 'closure_values': closure_values,
- 'module': func.__module__,
- 'name': func.__name__,
- 'doc': func.__doc__,
- '_cloudpickle_submodules': submodules
- }
- if hasattr(func, '__annotations__'):
- state['annotations'] = func.__annotations__
- if hasattr(func, '__qualname__'):
- state['qualname'] = func.__qualname__
- if hasattr(func, '__kwdefaults__'):
- state['kwdefaults'] = func.__kwdefaults__
- save(state)
- write(pickle.TUPLE)
- write(pickle.REDUCE) # applies _fill_function on the tuple
-
- def extract_func_data(self, func):
- """
- Turn the function into a tuple of data necessary to recreate it:
- code, globals, defaults, closure_values, dict
- """
- code = func.__code__
-
- # extract all global ref's
- func_global_refs = _extract_code_globals(code)
-
- # process all variables referenced by global environment
- f_globals = {}
- for var in func_global_refs:
- if var in func.__globals__:
- f_globals[var] = func.__globals__[var]
-
- # defaults requires no processing
- defaults = func.__defaults__
-
- # process closure
- closure = (
- list(map(_get_cell_contents, func.__closure__))
- if func.__closure__ is not None
- else None
+ (*args, result) = obj.__args__
+ if len(args) == 1 and args[0] is Ellipsis:
+ args = Ellipsis
+ else:
+ args = list(args)
+ initargs = (Callable, (args, result))
+ else: # pragma: no cover
+ raise pickle.PicklingError(
+ "Cloudpickle Error: Unknown type {}".format(type(obj))
)
-
- # save the dict
- dct = func.__dict__
-
- # base_globals represents the future global namespace of func at
- # unpickling time. Looking it up and storing it in globals_ref allow
- # functions sharing the same globals at pickling time to also
- # share them once unpickled, at one condition: since globals_ref is
- # an attribute of a Cloudpickler instance, and that a new CloudPickler is
- # created each time pickle.dump or pickle.dumps is called, functions
- # also need to be saved within the same invokation of
- # cloudpickle.dump/cloudpickle.dumps (for example: cloudpickle.dumps([f1, f2])). There
- # is no such limitation when using Cloudpickler.dump, as long as the
- # multiple invokations are bound to the same Cloudpickler.
- base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
-
- if base_globals == {}:
- # Add module attributes used to resolve relative imports
- # instructions inside func.
- for k in ["__package__", "__name__", "__path__", "__file__"]:
- # Some built-in functions/methods such as object.__new__ have
- # their __globals__ set to None in PyPy
- if func.__globals__ is not None and k in func.__globals__:
- base_globals[k] = func.__globals__[k]
-
- return (code, f_globals, defaults, closure, dct, base_globals)
-
- def save_getset_descriptor(self, obj):
- return self.save_reduce(getattr, (obj.__objclass__, obj.__name__))
-
- dispatch[types.GetSetDescriptorType] = save_getset_descriptor
-
- def save_global(self, obj, name=None, pack=struct.pack):
- """
- Save a "global".
-
- The name of this method is somewhat misleading: all types get
- dispatched here.
- """
- if obj is type(None):
- return self.save_reduce(type, (None,), obj=obj)
- elif obj is type(Ellipsis):
- return self.save_reduce(type, (Ellipsis,), obj=obj)
- elif obj is type(NotImplemented):
- return self.save_reduce(type, (NotImplemented,), obj=obj)
- elif obj in _BUILTIN_TYPE_NAMES:
- return self.save_reduce(
- _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
-
- if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
- # Parametrized typing constructs in Python < 3.7 are not compatible
- # with type checks and ``isinstance`` semantics. For this reason,
- # it is easier to detect them using a duck-typing-based check
- # (``_is_parametrized_type_hint``) than to populate the Pickler's
- # dispatch with type-specific savers.
- self._save_parametrized_type_hint(obj)
- elif name is not None:
- Pickler.save_global(self, obj, name=name)
- elif not _is_importable_by_name(obj, name=name):
- self.save_dynamic_class(obj)
- else:
- Pickler.save_global(self, obj, name=name)
-
- dispatch[type] = save_global
-
- def save_instancemethod(self, obj):
- # Memoization rarely is ever useful due to python bounding
- if obj.__self__ is None:
- self.save_reduce(getattr, (obj.im_class, obj.__name__))
- else:
- self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
-
- dispatch[types.MethodType] = save_instancemethod
-
- def save_property(self, obj):
- # properties not correctly saved in python
- self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__),
- obj=obj)
-
- dispatch[property] = save_property
-
- def save_classmethod(self, obj):
- orig_func = obj.__func__
- self.save_reduce(type(obj), (orig_func,), obj=obj)
-
- dispatch[classmethod] = save_classmethod
- dispatch[staticmethod] = save_classmethod
-
- def save_itemgetter(self, obj):
- """itemgetter serializer (needed for namedtuple support)"""
- class Dummy:
- def __getitem__(self, item):
- return item
- items = obj(Dummy())
- if not isinstance(items, tuple):
- items = (items,)
- return self.save_reduce(operator.itemgetter, items)
-
- if type(operator.itemgetter) is type:
- dispatch[operator.itemgetter] = save_itemgetter
-
- def save_attrgetter(self, obj):
- """attrgetter serializer"""
- class Dummy(object):
- def __init__(self, attrs, index=None):
- self.attrs = attrs
- self.index = index
- def __getattribute__(self, item):
- attrs = object.__getattribute__(self, "attrs")
- index = object.__getattribute__(self, "index")
- if index is None:
- index = len(attrs)
- attrs.append(item)
- else:
- attrs[index] = ".".join([attrs[index], item])
- return type(self)(attrs, index)
- attrs = []
- obj(Dummy(attrs))
- return self.save_reduce(operator.attrgetter, tuple(attrs))
-
- if type(operator.attrgetter) is type:
- dispatch[operator.attrgetter] = save_attrgetter
-
- def save_file(self, obj):
- """Save a file"""
-
- if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
- raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
- if obj is sys.stdout:
- return self.save_reduce(getattr, (sys, 'stdout'), obj=obj)
- if obj is sys.stderr:
- return self.save_reduce(getattr, (sys, 'stderr'), obj=obj)
- if obj is sys.stdin:
- raise pickle.PicklingError("Cannot pickle standard input")
- if obj.closed:
- raise pickle.PicklingError("Cannot pickle closed files")
- if hasattr(obj, 'isatty') and obj.isatty():
- raise pickle.PicklingError("Cannot pickle files that map to tty objects")
- if 'r' not in obj.mode and '+' not in obj.mode:
- raise pickle.PicklingError("Cannot pickle files that are not opened for reading: %s" % obj.mode)
-
- name = obj.name
-
- # TODO: also support binary mode files with io.BytesIO
- retval = io.StringIO()
-
- try:
- # Read the whole file
- curloc = obj.tell()
- obj.seek(0)
- contents = obj.read()
- obj.seek(curloc)
- except IOError:
- raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
- retval.write(contents)
- retval.seek(curloc)
-
- retval.name = name
- self.save(retval)
- self.memoize(obj)
-
- def save_ellipsis(self, obj):
- self.save_reduce(_gen_ellipsis, ())
-
- def save_not_implemented(self, obj):
- self.save_reduce(_gen_not_implemented, ())
-
- dispatch[io.TextIOWrapper] = save_file
- dispatch[type(Ellipsis)] = save_ellipsis
- dispatch[type(NotImplemented)] = save_not_implemented
-
- def save_weakset(self, obj):
- self.save_reduce(weakref.WeakSet, (list(obj),))
-
- dispatch[weakref.WeakSet] = save_weakset
-
- def save_logger(self, obj):
- self.save_reduce(logging.getLogger, (obj.name,), obj=obj)
-
- dispatch[logging.Logger] = save_logger
-
- def save_root_logger(self, obj):
- self.save_reduce(logging.getLogger, (), obj=obj)
-
- dispatch[logging.RootLogger] = save_root_logger
-
- if hasattr(types, "MappingProxyType"): # pragma: no branch
- def save_mappingproxy(self, obj):
- self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)
-
- dispatch[types.MappingProxyType] = save_mappingproxy
-
- """Special functions for Add-on libraries"""
- def inject_addons(self):
- """Plug in system. Register additional pickling functions if modules already loaded"""
- pass
-
- if sys.version_info < (3, 7): # pragma: no branch
- def _save_parametrized_type_hint(self, obj):
- # The distorted type check sematic for typing construct becomes:
- # ``type(obj) is type(TypeHint)``, which means "obj is a
- # parametrized TypeHint"
- if type(obj) is type(Literal): # pragma: no branch
- initargs = (Literal, obj.__values__)
- elif type(obj) is type(Final): # pragma: no branch
- initargs = (Final, obj.__type__)
- elif type(obj) is type(ClassVar):
- initargs = (ClassVar, obj.__type__)
- elif type(obj) is type(Generic):
- parameters = obj.__parameters__
- if len(obj.__parameters__) > 0:
- # in early Python 3.5, __parameters__ was sometimes
- # preferred to __args__
- initargs = (obj.__origin__, parameters)
- else:
- initargs = (obj.__origin__, obj.__args__)
- elif type(obj) is type(Union):
- if sys.version_info < (3, 5, 3): # pragma: no cover
- initargs = (Union, obj.__union_params__)
- else:
- initargs = (Union, obj.__args__)
- elif type(obj) is type(Tuple):
- if sys.version_info < (3, 5, 3): # pragma: no cover
- initargs = (Tuple, obj.__tuple_params__)
- else:
- initargs = (Tuple, obj.__args__)
- elif type(obj) is type(Callable):
- if sys.version_info < (3, 5, 3): # pragma: no cover
- args = obj.__args__
- result = obj.__result__
- if args != Ellipsis:
- if isinstance(args, tuple):
- args = list(args)
- else:
- args = [args]
- else:
- (*args, result) = obj.__args__
- if len(args) == 1 and args[0] is Ellipsis:
- args = Ellipsis
- else:
- args = list(args)
- initargs = (Callable, (args, result))
- else: # pragma: no cover
- raise pickle.PicklingError(
- "Cloudpickle Error: Unknown type {}".format(type(obj))
- )
- self.save_reduce(_create_parametrized_type_hint, initargs, obj=obj)
+ return initargs
# Tornado support
@@ -1052,40 +552,6 @@ def _rebuild_tornado_coroutine(func):
return gen.coroutine(func)
-# Shorthands for legacy support
-
-def dump(obj, file, protocol=None):
- """Serialize obj as bytes streamed into file
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
- between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- CloudPickler(file, protocol=protocol).dump(obj)
-
-
-def dumps(obj, protocol=None):
- """Serialize obj as a string of bytes allocated in memory
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
- between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- file = BytesIO()
- try:
- cp = CloudPickler(file, protocol=protocol)
- cp.dump(obj)
- return file.getvalue()
- finally:
- file.close()
-
-
# including pickles unloading functions in this namespace
load = pickle.load
loads = pickle.loads
@@ -1184,7 +650,7 @@ def _fill_function(*args):
if 'annotations' in state:
func.__annotations__ = state['annotations']
if 'doc' in state:
- func.__doc__ = state['doc']
+ func.__doc__ = state['doc']
if 'name' in state:
func.__name__ = state['name']
if 'module' in state:
@@ -1219,11 +685,24 @@ def _make_empty_cell():
return (lambda: cell).__closure__[0]
+def _make_cell(value=_empty_cell_value):
+ cell = _make_empty_cell()
+ if value is not _empty_cell_value:
+ cell_set(cell, value)
+ return cell
+
+
def _make_skel_func(code, cell_count, base_globals=None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
+ # This function is deprecated and should be removed in cloudpickle 1.7
+ warnings.warn(
+ "A pickle file created using an old (<=1.4.1) version of cloudpicke "
+ "is currently being loaded. This is not supported by cloudpickle and "
+ "will break in cloudpickle 1.7", category=UserWarning
+ )
# This is backward-compatibility code: for cloudpickle versions between
# 0.5.4 and 0.7, base_globals could be a string or None. base_globals
# should now always be a dictionary.
@@ -1307,39 +786,6 @@ class id will also reuse this enum definition.
return _lookup_class_or_track(class_tracker_id, enum_class)
-def _is_dynamic(module):
- """
- Return True if the module is special module that cannot be imported by its
- name.
- """
- # Quick check: module that have __file__ attribute are not dynamic modules.
- if hasattr(module, '__file__'):
- return False
-
- if module.__spec__ is not None:
- return False
-
- # In PyPy, Some built-in modules such as _codecs can have their
- # __spec__ attribute set to None despite being imported. For such
- # modules, the ``_find_spec`` utility of the standard library is used.
- parent_name = module.__name__.rpartition('.')[0]
- if parent_name: # pragma: no cover
- # This code handles the case where an imported package (and not
- # module) remains with __spec__ set to None. It is however untested
- # as no package in the PyPy stdlib has __spec__ set to None after
- # it is imported.
- try:
- parent = sys.modules[parent_name]
- except KeyError:
- msg = "parent {!r} not in sys.modules"
- raise ImportError(msg.format(parent_name))
- else:
- pkgpath = parent.__path__
- else:
- pkgpath = None
- return _find_spec(module.__name__, pkgpath, module) is None
-
-
def _make_typevar(name, bound, constraints, covariant, contravariant,
class_tracker_id):
tv = typing.TypeVar(
@@ -1382,3 +828,15 @@ def _get_bases(typ):
# For regular class objects
bases_attr = '__bases__'
return getattr(typ, bases_attr)
+
+
+def _make_dict_keys(obj):
+ return dict.fromkeys(obj).keys()
+
+
+def _make_dict_values(obj):
+ return {i: _ for i, _ in enumerate(obj)}.values()
+
+
+def _make_dict_items(obj):
+ return obj.items()
diff --git a/python/ray/cloudpickle/cloudpickle_fast.py b/python/ray/cloudpickle/cloudpickle_fast.py
index fd16493b8475..fa8da0f635c4 100644
--- a/python/ray/cloudpickle/cloudpickle_fast.py
+++ b/python/ray/cloudpickle/cloudpickle_fast.py
@@ -10,65 +10,100 @@
guards present in cloudpickle.py that were written to handle PyPy specificities
are not present in cloudpickle_fast.py
"""
+import _collections_abc
import abc
import copyreg
import io
import itertools
import logging
import sys
+import struct
import types
import weakref
import typing
+from enum import Enum
+from collections import ChainMap
+
+from .compat import pickle, Pickler
from .cloudpickle import (
- _is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
- _find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type,
- Enum, _get_or_create_tracker_id, _make_skeleton_class, _make_skeleton_enum,
- _extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases,
- cell_set, _make_empty_cell,
+ _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
+ _find_imported_submodules, _get_cell_contents, _is_importable,
+ _builtin_type, _get_or_create_tracker_id, _make_skeleton_class,
+ _make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport,
+ _typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType,
+ _is_parametrized_type_hint, PYPY, cell_set,
+ parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
+ builtin_code_type,
+ _make_dict_keys, _make_dict_values, _make_dict_items,
)
-if sys.version_info[:2] < (3, 8):
- import pickle5 as pickle
- from pickle5 import Pickler
- load, loads = pickle.load, pickle.loads
-else:
- import _pickle
- import pickle
- from _pickle import Pickler
- load, loads = _pickle.load, _pickle.loads
-import numpy
+if pickle.HIGHEST_PROTOCOL >= 5 and not PYPY:
+ # Shorthands similar to pickle.dump/pickle.dumps
+ def dump(obj, file, protocol=None, buffer_callback=None):
+ """Serialize obj as bytes streamed into file
-# Shorthands similar to pickle.dump/pickle.dumps
-def dump(obj, file, protocol=None, buffer_callback=None):
- """Serialize obj as bytes streamed into file
+ protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
+ pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
+ speed between processes running the same Python version.
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
- between processes running the same Python version.
+ Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
+ compatibility with older versions of Python.
+ """
+ CloudPickler(
+ file, protocol=protocol, buffer_callback=buffer_callback
+ ).dump(obj)
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
+ def dumps(obj, protocol=None, buffer_callback=None):
+ """Serialize obj as a string of bytes allocated in memory
+ protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
+ pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
+ speed between processes running the same Python version.
-def dumps(obj, protocol=None, buffer_callback=None):
- """Serialize obj as a string of bytes allocated in memory
+ Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
+ compatibility with older versions of Python.
+ """
+ with io.BytesIO() as file:
+ cp = CloudPickler(
+ file, protocol=protocol, buffer_callback=buffer_callback
+ )
+ cp.dump(obj)
+ return file.getvalue()
+
+else:
+ # Shorthands similar to pickle.dump/pickle.dumps
+ def dump(obj, file, protocol=None):
+ """Serialize obj as bytes streamed into file
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
- between processes running the same Python version.
+ protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
+ pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
+ speed between processes running the same Python version.
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- with io.BytesIO() as file:
- cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
- cp.dump(obj)
- return file.getvalue()
+ Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
+ compatibility with older versions of Python.
+ """
+ CloudPickler(file, protocol=protocol).dump(obj)
+
+ def dumps(obj, protocol=None):
+ """Serialize obj as a string of bytes allocated in memory
+
+ protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
+ pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
+ speed between processes running the same Python version.
+
+ Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
+ compatibility with older versions of Python.
+ """
+ with io.BytesIO() as file:
+ cp = CloudPickler(file, protocol=protocol)
+ cp.dump(obj)
+ return file.getvalue()
+
+
+load, loads = pickle.load, pickle.loads
# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
@@ -151,21 +186,15 @@ def _class_getstate(obj):
clsdict.pop('_abc_cache', None)
clsdict.pop('_abc_negative_cache', None)
clsdict.pop('_abc_negative_cache_version', None)
- clsdict.pop('_abc_impl', None)
registry = clsdict.pop('_abc_registry', None)
if registry is None:
# in Python3.7+, the abc caches and registered subclasses of a
# class are bundled into the single _abc_impl attribute
- if hasattr(abc, '_get_dump'):
- (registry, _, _, _) = abc._get_dump(obj)
- clsdict["_abc_impl"] = [subclass_weakref()
- for subclass_weakref in registry]
- else:
- # FIXME(suquark): The upstream cloudpickle cannot work in Ray
- # because sometimes both '_abc_registry' and '_get_dump' does
- # not exist. Some strange typing objects may cause this issue.
- # Here the workaround just set "_abc_impl" to None.
- clsdict["_abc_impl"] = None
+ clsdict.pop('_abc_impl', None)
+ (registry, _, _, _) = abc._get_dump(obj)
+
+ clsdict["_abc_impl"] = [subclass_weakref()
+ for subclass_weakref in registry]
else:
# In the above if clause, registry is a set of weakrefs -- in
# this case, registry is a WeakSet
@@ -217,13 +246,13 @@ def _code_reduce(obj):
"""codeobject reducer"""
if hasattr(obj, "co_posonlyargcount"): # pragma: no branch
args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
- obj.co_cellvars
- )
+ obj.co_argcount, obj.co_posonlyargcount,
+ obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
+ obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
+ obj.co_varnames, obj.co_filename, obj.co_name,
+ obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
+ obj.co_cellvars
+ )
else:
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
@@ -235,26 +264,14 @@ def _code_reduce(obj):
return types.CodeType, args
-def _make_cell(contents):
- cell = _make_empty_cell()
- cell_set(cell, contents)
- return cell
-
-
def _cell_reduce(obj):
"""Cell (containing values of a function's free variables) reducer"""
try:
- contents = (obj.cell_contents,)
+ obj.cell_contents
except ValueError: # cell is empty
- contents = ()
-
- if sys.version_info[:2] < (3, 8):
- if contents:
- return _make_cell, contents
- else:
- return _make_empty_cell, ()
+ return _make_empty_cell, ()
else:
- return types.CellType, contents
+ return _make_cell, (obj.cell_contents, )
def _classmethod_reduce(obj):
@@ -298,10 +315,10 @@ def _file_reduce(obj):
obj.seek(0)
contents = obj.read()
obj.seek(curloc)
- except IOError:
+ except IOError as e:
raise pickle.PicklingError(
"Cannot pickle file %s as it cannot be read" % name
- )
+ ) from e
retval.write(contents)
retval.seek(curloc)
@@ -322,11 +339,11 @@ def _memoryview_reduce(obj):
def _module_reduce(obj):
- if _is_dynamic(obj):
+ if _is_importable(obj):
+ return subimport, (obj.__name__,)
+ else:
obj.__dict__.pop('__builtins__', None)
return dynamic_subimport, (obj.__name__, vars(obj))
- else:
- return subimport, (obj.__name__,)
def _method_reduce(obj):
@@ -379,11 +396,29 @@ def _class_reduce(obj):
return type, (NotImplemented,)
elif obj in _BUILTIN_TYPE_NAMES:
return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],)
- elif not _is_importable_by_name(obj):
+ elif not _is_importable(obj):
return _dynamic_class_reduce(obj)
return NotImplemented
+def _dict_keys_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_keys, (list(obj), )
+
+
+def _dict_values_reduce(obj):
+ # Safer not to ship the full dict as sending the rest might
+ # be unintended and could potentially cause leaking of
+ # sensitive information
+ return _make_dict_values, (list(obj), )
+
+
+def _dict_items_reduce(obj):
+ return _make_dict_items, (dict(obj), )
+
+
# COLLECTIONS OF OBJECTS STATE SETTERS
# ------------------------------------
# state setters are called at unpickling time, once the object is created and
@@ -439,154 +474,30 @@ def _class_setstate(obj, state):
return obj
-def _numpy_frombuffer(buffer, dtype, shape, order):
- # Get the _frombuffer() function for reconstruction
- from numpy.core.numeric import _frombuffer
- array = _frombuffer(buffer, dtype, shape, order)
- # Unfortunately, numpy does not follow the standard, so we still
- # have to set the readonly flag for it here.
- array.setflags(write=isinstance(buffer, bytearray) or not buffer.readonly)
- return array
-
-
-def _numpy_ndarray_reduce(array):
- # This function is implemented according to 'array_reduce_ex_picklebuffer'
- # in numpy C backend. This is a workaround for python3.5 pickling support.
- if sys.version_info >= (3, 8):
- import pickle
- picklebuf_class = pickle.PickleBuffer
- elif sys.version_info >= (3, 5):
- try:
- import pickle5
- picklebuf_class = pickle5.PickleBuffer
- except Exception:
- raise ImportError("Using pickle protocol 5 requires the pickle5 "
- "module for Python >=3.5 and <3.8")
- else:
- raise ValueError("pickle protocol 5 is not available for Python < 3.5")
- # if the array if Fortran-contiguous and not C-contiguous,
- # the PickleBuffer instance will hold a view on the transpose
- # of the initial array, that is C-contiguous.
- if not array.flags.c_contiguous and array.flags.f_contiguous:
- order = "F"
- picklebuf_args = array.transpose()
- else:
- order = "C"
- picklebuf_args = array
- try:
- buffer = picklebuf_class(picklebuf_args)
- except Exception:
- # Some arrays may refuse to export a buffer, in which case
- # just fall back on regular __reduce_ex__ implementation
- # (gh-12745).
- return array.__reduce__()
-
- return _numpy_frombuffer, (buffer, array.dtype, array.shape, order)
-
-
class CloudPickler(Pickler):
- """Fast C Pickler extension with additional reducing routines.
-
- CloudPickler's extensions exist into into:
-
- * its dispatch_table containing reducers that are called only if ALL
- built-in saving functions were previously discarded.
- * a special callback named "reducer_override", invoked before standard
- function/class builtin-saving method (save_global), to serialize dynamic
- functions
- """
-
- # cloudpickle's own dispatch_table, containing the additional set of
- # objects (compared to the standard library pickle) that cloupickle can
- # serialize.
- dispatch = {}
- dispatch[classmethod] = _classmethod_reduce
- dispatch[io.TextIOWrapper] = _file_reduce
- dispatch[logging.Logger] = _logger_reduce
- dispatch[logging.RootLogger] = _root_logger_reduce
- dispatch[memoryview] = _memoryview_reduce
- dispatch[property] = _property_reduce
- dispatch[staticmethod] = _classmethod_reduce
- if sys.version_info[:2] >= (3, 8):
- dispatch[types.CellType] = _cell_reduce
- else:
- dispatch[type(_make_empty_cell())] = _cell_reduce
- dispatch[types.CodeType] = _code_reduce
- dispatch[types.GetSetDescriptorType] = _getset_descriptor_reduce
- dispatch[types.ModuleType] = _module_reduce
- dispatch[types.MethodType] = _method_reduce
- dispatch[types.MappingProxyType] = _mappingproxy_reduce
- dispatch[weakref.WeakSet] = _weakset_reduce
- dispatch[typing.TypeVar] = _typevar_reduce
-
- def __init__(self, file, protocol=None, buffer_callback=None):
- if protocol is None:
- protocol = DEFAULT_PROTOCOL
- Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback)
- # map functions __globals__ attribute ids, to ensure that functions
- # sharing the same global namespace at pickling time also share their
- # global namespace at unpickling time.
- self.globals_ref = {}
-
- # Take into account potential custom reducers registered by external
- # modules
- self.dispatch_table = copyreg.dispatch_table.copy()
- self.dispatch_table.update(self.dispatch)
- self.proto = int(protocol)
-
- def reducer_override(self, obj):
- """Type-agnostic reducing callback for function and classes.
-
- For performance reasons, subclasses of the C _pickle.Pickler class
- cannot register custom reducers for functions and classes in the
- dispatch_table. Reducer for such types must instead implemented in the
- special reducer_override method.
-
- Note that method will be called for any object except a few
- builtin-types (int, lists, dicts etc.), which differs from reducers in
- the Pickler's dispatch_table, each of them being invoked for objects of
- a specific type only.
-
- This property comes in handy for classes: although most classes are
- instances of the ``type`` metaclass, some of them can be instances of
- other custom metaclasses (such as enum.EnumMeta for example). In
- particular, the metaclass will likely not be known in advance, and thus
- cannot be special-cased using an entry in the dispatch_table.
- reducer_override, among other things, allows us to register a reducer
- that will be called for any class, independently of its type.
-
-
- Notes:
-
- * reducer_override has the priority over dispatch_table-registered
- reducers.
- * reducer_override can be used to fix other limitations of cloudpickle
- for other types that suffered from type-specific reducers, such as
- Exceptions. See https://github.com/cloudpipe/cloudpickle/issues/248
- """
-
- # This is a patch for python3.5
- if isinstance(obj, numpy.ndarray):
- if (self.proto < 5 or
- (not obj.flags.c_contiguous and not obj.flags.f_contiguous) or
- (issubclass(type(obj), numpy.ndarray) and type(obj) is not numpy.ndarray) or
- obj.dtype == "O" or obj.itemsize == 0):
- return NotImplemented
- return _numpy_ndarray_reduce(obj)
-
- t = type(obj)
- try:
- is_anyclass = issubclass(t, type)
- except TypeError: # t is not a class (old Boost; see SF #502085)
- is_anyclass = False
-
- if is_anyclass:
- return _class_reduce(obj)
- elif isinstance(obj, types.FunctionType):
- return self._function_reduce(obj)
- else:
- # fallback to save_global, including the Pickler's distpatch_table
- return NotImplemented
+ # set of reducers defined and used by cloudpickle (private)
+ _dispatch_table = {}
+ _dispatch_table[classmethod] = _classmethod_reduce
+ _dispatch_table[io.TextIOWrapper] = _file_reduce
+ _dispatch_table[logging.Logger] = _logger_reduce
+ _dispatch_table[logging.RootLogger] = _root_logger_reduce
+ _dispatch_table[memoryview] = _memoryview_reduce
+ _dispatch_table[property] = _property_reduce
+ _dispatch_table[staticmethod] = _classmethod_reduce
+ _dispatch_table[CellType] = _cell_reduce
+ _dispatch_table[types.CodeType] = _code_reduce
+ _dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce
+ _dispatch_table[types.ModuleType] = _module_reduce
+ _dispatch_table[types.MethodType] = _method_reduce
+ _dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
+ _dispatch_table[weakref.WeakSet] = _weakset_reduce
+ _dispatch_table[typing.TypeVar] = _typevar_reduce
+ _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
+ _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
+ _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
+
+
+ dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
# function reducers are defined as instance methods of CloudPickler
# objects, as they rely on a CloudPickler attribute (globals_ref)
@@ -609,7 +520,7 @@ def _function_reduce(self, obj):
As opposed to cloudpickle.py, There no special handling for builtin
pypy functions because cloudpickle_fast is CPython-specific.
"""
- if _is_importable_by_name(obj):
+ if _is_importable(obj):
return NotImplemented
else:
return self._dynamic_function_reduce(obj)
@@ -642,12 +553,8 @@ def _function_getnewargs(self, func):
if func.__closure__ is None:
closure = None
else:
- if sys.version_info[:2] >= (3, 8):
- closure = tuple(
- types.CellType() for _ in range(len(code.co_freevars)))
- else:
- closure = tuple(
- _make_empty_cell() for _ in range(len(code.co_freevars)))
+ closure = tuple(
+ _make_empty_cell() for _ in range(len(code.co_freevars)))
return code, base_globals, None, None, closure
@@ -660,6 +567,204 @@ def dump(self, obj):
"Could not pickle object as excessively deep recursion "
"required."
)
- raise pickle.PicklingError(msg)
+ raise pickle.PicklingError(msg) from e
else:
raise
+
+ if pickle.HIGHEST_PROTOCOL >= 5:
+ # `CloudPickler.dispatch` is only left for backward compatibility - note
+ # that when using protocol 5, `CloudPickler.dispatch` is not an
+ # extension of `Pickler.dispatch` dictionary, because CloudPickler
+ # subclasses the C-implemented Pickler, which does not expose a
+ # `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler
+ # used `CloudPickler.dispatch` as a class-level attribute storing all
+ # reducers implemented by cloudpickle, but the attribute name was not a
+ # great choice given the meaning of `Cloudpickler.dispatch` when
+ # `CloudPickler` extends the pure-python pickler.
+ dispatch = dispatch_table
+
+ # Implementation of the reducer_override callback, in order to
+ # efficiently serialize dynamic functions and classes by subclassing
+ # the C-implemented Pickler.
+ # TODO: decorrelate reducer_override (which is tied to CPython's
+ # implementation - would it make sense to backport it to pypy? - and
+ # pickle's protocol 5 which is implementation agnostic. Currently, the
+ # availability of both notions coincide on CPython's pickle and the
+ # pickle5 backport, but it may not be the case anymore when pypy
+ # implements protocol 5
+ def __init__(self, file, protocol=None, buffer_callback=None):
+ if protocol is None:
+ protocol = DEFAULT_PROTOCOL
+ Pickler.__init__(
+ self, file, protocol=protocol, buffer_callback=buffer_callback
+ )
+ # map functions __globals__ attribute ids, to ensure that functions
+ # sharing the same global namespace at pickling time also share
+ # their global namespace at unpickling time.
+ self.globals_ref = {}
+ self.proto = int(protocol)
+
+ def reducer_override(self, obj):
+ """Type-agnostic reducing callback for function and classes.
+
+ For performance reasons, subclasses of the C _pickle.Pickler class
+ cannot register custom reducers for functions and classes in the
+ dispatch_table. Reducer for such types must instead implemented in
+ the special reducer_override method.
+
+ Note that method will be called for any object except a few
+ builtin-types (int, lists, dicts etc.), which differs from reducers
+ in the Pickler's dispatch_table, each of them being invoked for
+ objects of a specific type only.
+
+ This property comes in handy for classes: although most classes are
+ instances of the ``type`` metaclass, some of them can be instances
+ of other custom metaclasses (such as enum.EnumMeta for example). In
+ particular, the metaclass will likely not be known in advance, and
+ thus cannot be special-cased using an entry in the dispatch_table.
+ reducer_override, among other things, allows us to register a
+ reducer that will be called for any class, independently of its
+ type.
+
+
+ Notes:
+
+ * reducer_override has the priority over dispatch_table-registered
+ reducers.
+ * reducer_override can be used to fix other limitations of
+ cloudpickle for other types that suffered from type-specific
+ reducers, such as Exceptions. See
+ https://github.com/cloudpipe/cloudpickle/issues/248
+ """
+ if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
+ return (
+ _create_parametrized_type_hint,
+ parametrized_type_hint_getinitargs(obj)
+ )
+ t = type(obj)
+ try:
+ is_anyclass = issubclass(t, type)
+ except TypeError: # t is not a class (old Boost; see SF #502085)
+ is_anyclass = False
+
+ if is_anyclass:
+ return _class_reduce(obj)
+ elif isinstance(obj, types.FunctionType):
+ return self._function_reduce(obj)
+ else:
+ # fallback to save_global, including the Pickler's
+ # distpatch_table
+ return NotImplemented
+
+ else:
+ # When reducer_override is not available, hack the pure-Python
+ # Pickler's types.FunctionType and type savers. Note: the type saver
+ # must override Pickler.save_global, because pickle.py contains a
+ # hard-coded call to save_global when pickling meta-classes.
+ dispatch = Pickler.dispatch.copy()
+
+ def __init__(self, file, protocol=None):
+ if protocol is None:
+ protocol = DEFAULT_PROTOCOL
+ Pickler.__init__(self, file, protocol=protocol)
+ # map functions __globals__ attribute ids, to ensure that functions
+ # sharing the same global namespace at pickling time also share
+ # their global namespace at unpickling time.
+ self.globals_ref = {}
+ assert hasattr(self, 'proto')
+
+ def _save_reduce_pickle5(self, func, args, state=None, listitems=None,
+ dictitems=None, state_setter=None, obj=None):
+ save = self.save
+ write = self.write
+ self.save_reduce(
+ func, args, state=None, listitems=listitems,
+ dictitems=dictitems, obj=obj
+ )
+ # backport of the Python 3.8 state_setter pickle operations
+ save(state_setter)
+ save(obj) # simple BINGET opcode as obj is already memoized.
+ save(state)
+ write(pickle.TUPLE2)
+ # Trigger a state_setter(obj, state) function call.
+ write(pickle.REDUCE)
+ # The purpose of state_setter is to carry-out an
+ # inplace modification of obj. We do not care about what the
+ # method might return, so its output is eventually removed from
+ # the stack.
+ write(pickle.POP)
+
+ def save_global(self, obj, name=None, pack=struct.pack):
+ """
+ Save a "global".
+
+ The name of this method is somewhat misleading: all types get
+ dispatched here.
+ """
+ if obj is type(None): # noqa
+ return self.save_reduce(type, (None,), obj=obj)
+ elif obj is type(Ellipsis):
+ return self.save_reduce(type, (Ellipsis,), obj=obj)
+ elif obj is type(NotImplemented):
+ return self.save_reduce(type, (NotImplemented,), obj=obj)
+ elif obj in _BUILTIN_TYPE_NAMES:
+ return self.save_reduce(
+ _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
+
+ if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
+ # Parametrized typing constructs in Python < 3.7 are not
+ # compatible with type checks and ``isinstance`` semantics. For
+ # this reason, it is easier to detect them using a
+ # duck-typing-based check (``_is_parametrized_type_hint``) than
+ # to populate the Pickler's dispatch with type-specific savers.
+ self.save_reduce(
+ _create_parametrized_type_hint,
+ parametrized_type_hint_getinitargs(obj),
+ obj=obj
+ )
+ elif name is not None:
+ Pickler.save_global(self, obj, name=name)
+ elif not _is_importable(obj, name=name):
+ self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
+ else:
+ Pickler.save_global(self, obj, name=name)
+ dispatch[type] = save_global
+
+ def save_function(self, obj, name=None):
+ """ Registered with the dispatch to handle all function types.
+
+ Determines what kind of function obj is (e.g. lambda, defined at
+ interactive prompt, etc) and handles the pickling appropriately.
+ """
+ if _is_importable(obj, name=name):
+ return Pickler.save_global(self, obj, name=name)
+ elif PYPY and isinstance(obj.__code__, builtin_code_type):
+ return self.save_pypy_builtin_func(obj)
+ else:
+ return self._save_reduce_pickle5(
+ *self._dynamic_function_reduce(obj), obj=obj
+ )
+
+ def save_pypy_builtin_func(self, obj):
+ """Save pypy equivalent of builtin functions.
+ PyPy does not have the concept of builtin-functions. Instead,
+ builtin-functions are simple function instances, but with a
+ builtin-code attribute.
+ Most of the time, builtin functions should be pickled by attribute.
+ But PyPy has flaky support for __qualname__, so some builtin
+ functions such as float.__new__ will be classified as dynamic. For
+ this reason only, we created this special routine. Because
+ builtin-functions are not expected to have closure or globals,
+ there is no additional hack (compared the one already implemented
+ in pickle) to protect ourselves from reference cycles. A simple
+ (reconstructor, newargs, obj.__dict__) tuple is save_reduced. Note
+ also that PyPy improved their support for __qualname__ in v3.6, so
+ this routing should be removed when cloudpickle supports only PyPy
+ 3.6 and later.
+ """
+ rv = (types.FunctionType, (obj.__code__, {}, obj.__name__,
+ obj.__defaults__, obj.__closure__),
+ obj.__dict__)
+ self.save_reduce(*rv, obj=obj)
+
+ dispatch[types.FunctionType] = save_function
diff --git a/python/ray/cloudpickle/compat.py b/python/ray/cloudpickle/compat.py
new file mode 100644
index 000000000000..afa285f62903
--- /dev/null
+++ b/python/ray/cloudpickle/compat.py
@@ -0,0 +1,13 @@
+import sys
+
+
+if sys.version_info < (3, 8):
+ try:
+ import pickle5 as pickle # noqa: F401
+ from pickle5 import Pickler # noqa: F401
+ except ImportError:
+ import pickle # noqa: F401
+ from pickle import _Pickler as Pickler # noqa: F401
+else:
+ import pickle # noqa: F401
+ from _pickle import Pickler # noqa: F401
diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py
index 21a2f40fde92..a8cdf949732e 100644
--- a/python/ray/cluster_utils.py
+++ b/python/ray/cluster_utils.py
@@ -53,7 +53,7 @@ def connect(self):
output_info = ray.init(
ignore_reinit_error=True,
address=self.redis_address,
- redis_password=self.redis_password)
+ _redis_password=self.redis_password)
logger.info(output_info)
self.connected = True
diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py
index 3f0c047144d6..8b9dc77daf39 100644
--- a/python/ray/cross_language.py
+++ b/python/ray/cross_language.py
@@ -66,7 +66,7 @@ def java_function(class_name, function_name):
None, # memory,
None, # object_store_memory,
None, # resources,
- None, # num_return_vals,
+ None, # num_returns,
None, # max_calls,
None, # max_retries
placement_group=None,
diff --git a/python/ray/dashboard/client/src/pages/dashboard/logical-view/Actor.tsx b/python/ray/dashboard/client/src/pages/dashboard/logical-view/Actor.tsx
index 3fe6bdcff81b..ebfc50d9ede9 100644
--- a/python/ray/dashboard/client/src/pages/dashboard/logical-view/Actor.tsx
+++ b/python/ray/dashboard/client/src/pages/dashboard/logical-view/Actor.tsx
@@ -202,7 +202,7 @@ class Actor extends React.Component, State> {
.sort()
.map((key, _, __) => {
// Construct the value from actor.
- // Please refer to worker.py::show_in_webui for schema.
+ // Please refer to worker.py::show_in_dashboard for schema.
const valueEncoded = actor.webuiDisplay![key];
const valueParsed = JSON.parse(valueEncoded);
let valueRendered = valueParsed["message"];
diff --git a/python/ray/dashboard/node_stats.py b/python/ray/dashboard/node_stats.py
index 63eb420e8c8e..2ad1b3eca76e 100644
--- a/python/ray/dashboard/node_stats.py
+++ b/python/ray/dashboard/node_stats.py
@@ -60,14 +60,22 @@ def __init__(self, redis_address, redis_password=None):
def _insert_log_counts(self):
for ip, logs_by_pid in self._logs.items():
hostname = self._ip_to_hostname[ip]
- logs_by_pid = {pid: len(logs) for pid, logs in logs_by_pid.items()}
- self._node_stats[hostname]["log_count"] = logs_by_pid
+ if hostname in self._node_stats:
+ logs_by_pid = {
+ pid: len(logs)
+ for pid, logs in logs_by_pid.items()
+ }
+ self._node_stats[hostname]["log_count"] = logs_by_pid
def _insert_error_counts(self):
for ip, errs_by_pid in self._errors.items():
hostname = self._ip_to_hostname[ip]
- errs_by_pid = {pid: len(errs) for pid, errs in errs_by_pid.items()}
- self._node_stats[hostname]["error_count"] = errs_by_pid
+ if hostname in self._node_stats:
+ errs_by_pid = {
+ pid: len(errs)
+ for pid, errs in errs_by_pid.items()
+ }
+ self._node_stats[hostname]["error_count"] = errs_by_pid
def _purge_outdated_stats(self):
def current(then, now):
diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py
index 5b246afdae99..ec8cb6a88a12 100644
--- a/python/ray/exceptions.py
+++ b/python/ray/exceptions.py
@@ -41,12 +41,7 @@ def __init__(self, ray_exception):
ray_exception.formatted_exception_string))
-class RayConnectionError(RayError):
- """Raised when ray is not yet connected but needs to be."""
- pass
-
-
-class RayCancellationError(RayError):
+class TaskCancelledError(RayError):
"""Raised when this task is cancelled.
Attributes:
@@ -143,7 +138,7 @@ def __str__(self):
return "\n".join(out)
-class RayWorkerError(RayError):
+class WorkerCrashedError(RayError):
"""Indicates that the worker died unexpectedly while executing a task."""
def __str__(self):
@@ -161,8 +156,8 @@ def __str__(self):
return "The actor died unexpectedly before finishing this task."
-class RayletError(RayError):
- """Indicates that the Raylet client has errored.
+class RaySystemError(RayError):
+ """Indicates that Ray encountered a system error.
This exception can be thrown when the raylet is killed.
"""
@@ -171,7 +166,7 @@ def __init__(self, client_exc):
self.client_exc = client_exc
def __str__(self):
- return f"The Raylet died with this message: {self.client_exc}"
+ return f"System error: {self.client_exc}"
class ObjectStoreFullError(RayError):
@@ -184,21 +179,13 @@ class ObjectStoreFullError(RayError):
def __str__(self):
return super(ObjectStoreFullError, self).__str__() + (
"\n"
- "The local object store is full of objects that are still in scope"
- " and cannot be evicted. Try increasing the object store memory "
- "available with ray.init(object_store_memory=). "
- "You can also try setting an option to fallback to LRU eviction "
- "when the object store is full by calling "
- "ray.init(lru_evict=True). See also: "
- "https://docs.ray.io/en/latest/memory-management.html.")
-
+ "The local object store is full of objects that are still in "
+ "scope and cannot be evicted. Tip: Use the `ray memory` command "
+ "to list active objects in the cluster.")
-class UnreconstructableError(RayError):
- """Indicates that an object is lost and cannot be reconstructed.
- Note, this exception only happens for actor objects. If actor's current
- state is after object's creating task, the actor cannot re-run the task to
- reconstruct the object.
+class ObjectLostError(RayError):
+ """Indicates that an object has been lost due to node failure.
Attributes:
object_ref: ID of the object.
@@ -208,17 +195,10 @@ def __init__(self, object_ref):
self.object_ref = object_ref
def __str__(self):
- return (
- f"Object {self.object_ref.hex()} is lost "
- "(either LRU evicted or deleted by user) and "
- "cannot be reconstructed. Try increasing the object store "
- "memory available with ray.init(object_store_memory=) "
- "or setting object store limits with "
- "ray.remote(object_store_memory=). "
- "See also: https://docs.ray.io/en/latest/memory-management.html")
+ return (f"Object {self.object_ref.hex()} is lost due to node failure.")
-class RayTimeoutError(RayError):
+class GetTimeoutError(RayError):
"""Indicates that a call to the worker timed out."""
pass
@@ -232,9 +212,9 @@ class PlasmaObjectNotAvailable(RayError):
PlasmaObjectNotAvailable,
RayError,
RayTaskError,
- RayWorkerError,
+ WorkerCrashedError,
RayActorError,
ObjectStoreFullError,
- UnreconstructableError,
- RayTimeoutError,
+ ObjectLostError,
+ GetTimeoutError,
]
diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py
index 4ef40a716b41..9ffab24e8560 100644
--- a/python/ray/experimental/__init__.py
+++ b/python/ray/experimental/__init__.py
@@ -1,10 +1,8 @@
-from .api import get, wait
from .dynamic_resources import set_resource
from .object_spilling import force_spill_objects, force_restore_spilled_objects
from .placement_group import (placement_group, placement_group_table,
remove_placement_group)
__all__ = [
- "get", "wait", "set_resource", "force_spill_objects",
- "force_restore_spilled_objects", "placement_group",
- "placement_group_table", "remove_placement_group"
+ "set_resource", "force_spill_objects", "force_restore_spilled_objects",
+ "placement_group", "placement_group_table", "remove_placement_group"
]
diff --git a/python/ray/experimental/api.py b/python/ray/experimental/api.py
deleted file mode 100644
index 4fab64d03f3d..000000000000
--- a/python/ray/experimental/api.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import ray
-import numpy as np
-
-
-def get(object_refs):
- """Get a single or a collection of remote objects from the object store.
-
- This method is identical to `ray.get` except it adds support for tuples,
- ndarrays and dictionaries.
-
- Args:
- object_refs: Object ref of the object to get, a list, tuple, ndarray of
- object refs to get or a dict of {key: object ref}.
-
- Returns:
- A Python object, a list of Python objects or a dict of {key: object}.
- """
- if isinstance(object_refs, (tuple, np.ndarray)):
- return ray.get(list(object_refs))
- elif isinstance(object_refs, dict):
- keys_to_get = [
- k for k, v in object_refs.items() if isinstance(v, ray.ObjectRef)
- ]
- ids_to_get = [
- v for k, v in object_refs.items() if isinstance(v, ray.ObjectRef)
- ]
- values = ray.get(ids_to_get)
-
- result = object_refs.copy()
- for key, value in zip(keys_to_get, values):
- result[key] = value
- return result
- else:
- return ray.get(object_refs)
-
-
-def wait(object_refs, num_returns=1, timeout=None):
- """Return a list of IDs that are ready and a list of IDs that are not.
-
- This method is identical to `ray.wait` except it adds support for tuples
- and ndarrays.
-
- Args:
- object_refs (List[ObjectRef], Tuple(ObjectRef), np.array(ObjectRef)):
- List like of object refs for objects that may or may not be ready.
- Note that these IDs must be unique.
- num_returns (int): The number of object refs that should be returned.
- timeout (float): The maximum amount of time in seconds to wait before
- returning.
-
- Returns:
- A list of object refs that are ready and a list of the remaining object
- IDs.
- """
- if isinstance(object_refs, (tuple, np.ndarray)):
- return ray.wait(
- list(object_refs), num_returns=num_returns, timeout=timeout)
-
- return ray.wait(object_refs, num_returns=num_returns, timeout=timeout)
diff --git a/python/ray/experimental/array/distributed/linalg.py b/python/ray/experimental/array/distributed/linalg.py
index 6317ff676a49..bde572c85817 100644
--- a/python/ray/experimental/array/distributed/linalg.py
+++ b/python/ray/experimental/array/distributed/linalg.py
@@ -7,7 +7,7 @@
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def tsqr(a):
"""Perform a QR decomposition of a tall-skinny matrix.
@@ -83,7 +83,7 @@ def tsqr(a):
# TODO(rkn): This is unoptimized, we really want a block version of this.
# This is Algorithm 5 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
-@ray.remote(num_return_vals=3)
+@ray.remote(num_returns=3)
def modified_lu(q):
"""Perform a modified LU decomposition of a matrix.
@@ -121,7 +121,7 @@ def modified_lu(q):
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
@@ -137,7 +137,7 @@ def tsqr_hr_helper2(s, r_temp):
# This is Algorithm 6 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
-@ray.remote(num_return_vals=4)
+@ray.remote(num_returns=4)
def tsqr_hr(a):
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
@@ -160,7 +160,7 @@ def qr_helper2(y_ri, a_rc):
# This is Algorithm 7 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def qr(a):
m, n = a.shape[0], a.shape[1]
diff --git a/python/ray/experimental/array/remote/linalg.py b/python/ray/experimental/array/remote/linalg.py
index bdde3a17b21b..47c3b98bf57c 100644
--- a/python/ray/experimental/array/remote/linalg.py
+++ b/python/ray/experimental/array/remote/linalg.py
@@ -18,12 +18,12 @@ def solve(a, b):
return np.linalg.solve(a, b)
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def tensorsolve(a):
raise NotImplementedError
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def tensorinv(a):
raise NotImplementedError
@@ -63,22 +63,22 @@ def det(a):
return np.linalg.det(a)
-@ray.remote(num_return_vals=3)
+@ray.remote(num_returns=3)
def svd(a):
return np.linalg.svd(a)
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def eig(a):
return np.linalg.eig(a)
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def eigh(a):
return np.linalg.eigh(a)
-@ray.remote(num_return_vals=4)
+@ray.remote(num_returns=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@@ -88,7 +88,7 @@ def norm(x):
return np.linalg.norm(x)
-@ray.remote(num_return_vals=2)
+@ray.remote(num_returns=2)
def qr(a):
return np.linalg.qr(a)
diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py
index 67284a6a2c1d..29737fa2e539 100644
--- a/python/ray/import_thread.py
+++ b/python/ray/import_thread.py
@@ -18,10 +18,6 @@
class ImportThread:
"""A thread used to import exports from the driver or other workers.
- Note: The driver also has an import thread, which is used only to import
- custom class definitions from calls to _register_custom_serializer that
- happen under the hood on workers.
-
Attributes:
worker: the worker object in this process.
mode: worker mode
diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd
index 13eb9c6f7901..8748449b9116 100644
--- a/python/ray/includes/ray_config.pxd
+++ b/python/ray/includes/ray_config.pxd
@@ -64,3 +64,5 @@ cdef extern from "ray/common/ray_config.h" nogil:
uint32_t max_tasks_in_flight_per_worker() const
uint64_t metrics_report_interval_ms() const
+
+ c_bool enable_timeline() const
diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi
index 25bd1a2400ab..5aac8994aff5 100644
--- a/python/ray/includes/ray_config.pxi
+++ b/python/ray/includes/ray_config.pxi
@@ -111,3 +111,7 @@ cdef class Config:
@staticmethod
def metrics_report_interval_ms():
return RayConfig.instance().metrics_report_interval_ms()
+
+ @staticmethod
+ def enable_timeline():
+ return RayConfig.instance().enable_timeline()
diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi
index 31a5f1e04581..f20165fc1594 100644
--- a/python/ray/includes/serialization.pxi
+++ b/python/ray/includes/serialization.pxi
@@ -1,6 +1,5 @@
from libc.string cimport memcpy
from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX
-from libcpp cimport nullptr
import cython
DEF MEMCOPY_THREADS = 6
@@ -116,6 +115,9 @@ cdef class SubBuffer:
self.buf, self.len)
def __getbuffer__(self, Py_buffer* buffer, int flags):
+ if flags & cpython.PyBUF_WRITABLE:
+ # Ray ensures all buffers are immutable.
+ raise BufferError
buffer.readonly = self.readonly
buffer.buf = self.buf
buffer.format = self._format.c_str()
diff --git a/python/ray/node.py b/python/ray/node.py
index 45bc3c8b10d2..fdc59330de9a 100644
--- a/python/ray/node.py
+++ b/python/ray/node.py
@@ -257,9 +257,15 @@ def get_resource_spec(self):
"""Resolve and return the current resource spec for the node."""
def merge_resources(env_dict, params_dict):
- """Merge two dictionaries, picking from the second in the event of a conflict.
- Also emit a warning on every conflict.
+ """Separates special case params and merges two dictionaries, picking from the
+ first in the event of a conflict. Also emit a warning on every
+ conflict.
"""
+ num_cpus = env_dict.pop("CPU", None)
+ num_gpus = env_dict.pop("GPU", None)
+ memory = env_dict.pop("memory", None)
+ object_store_memory = env_dict.pop("object_store_memory", None)
+
result = params_dict.copy()
result.update(env_dict)
@@ -268,19 +274,24 @@ def merge_resources(env_dict, params_dict):
logger.warning("Autoscaler is overriding your resource:"
"{}: {} with {}.".format(
key, params_dict[key], env_dict[key]))
- return result
+ return num_cpus, num_gpus, memory, object_store_memory, result
env_resources = {}
env_string = os.getenv(ray_constants.RESOURCES_ENVIRONMENT_VARIABLE)
if env_string:
env_resources = json.loads(env_string)
+ logger.info(f"Autosaler overriding resources: {env_resources}.")
if not self._resource_spec:
- resources = merge_resources(env_resources,
- self._ray_params.resources)
+ num_cpus, num_gpus, memory, object_store_memory, resources = \
+ merge_resources(env_resources, self._ray_params.resources)
self._resource_spec = ResourceSpec(
- self._ray_params.num_cpus, self._ray_params.num_gpus,
- self._ray_params.memory, self._ray_params.object_store_memory,
+ self._ray_params.num_cpus
+ if num_cpus is None else num_cpus, self._ray_params.num_gpus
+ if num_gpus is None else num_gpus, self._ray_params.memory
+ if memory is None else memory,
+ self._ray_params.object_store_memory
+ if object_store_memory is None else object_store_memory,
resources, self._ray_params.redis_max_memory).resolve(
is_head=self.head, node_ip_address=self.node_ip_address)
return self._resource_spec
diff --git a/python/ray/projects/__init__.py b/python/ray/projects/__init__.py
deleted file mode 100644
index 82d7e0766d58..000000000000
--- a/python/ray/projects/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from ray.projects.projects import ProjectDefinition
-
-__all__ = [
- "ProjectDefinition",
-]
diff --git a/python/ray/projects/examples/README.md b/python/ray/projects/examples/README.md
deleted file mode 100644
index cee7ece16b9b..000000000000
--- a/python/ray/projects/examples/README.md
+++ /dev/null
@@ -1,41 +0,0 @@
-Ray Projects
-============
-
-To run these example projects, we first have to make sure the full
-repository is checked out into the project directory.
-
-Open Tacotron
--------------
-
-```shell
-cd open-tacotron
-# Check out the original repository
-git init
-git remote add origin https://github.com/keithito/tacotron.git
-git fetch
-git checkout -t origin/master
-
-# Serve the model
-ray session start serve
-
-# Terminate the session
-ray session stop
-```
-
-PyTorch Transformers
---------------------
-
-```shell
-cd python-transformers
-# Check out the original repository
-git init
-git remote add origin https://github.com/huggingface/pytorch-transformers.git
-git fetch
-git checkout -t origin/master
-
-# Now we can start the training
-ray session start train --dataset SST-2
-
-# Terminate the session
-ray session stop
-```
diff --git a/python/ray/projects/examples/open-tacotron/ray-project/cluster.yaml b/python/ray/projects/examples/open-tacotron/ray-project/cluster.yaml
deleted file mode 100644
index 2dbfb452a6bd..000000000000
--- a/python/ray/projects/examples/open-tacotron/ray-project/cluster.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-# This file is generated by `ray project create`
-
-# A unique identifier for the head node and workers of this cluster.
-cluster_name: open-tacotron
-
-# The maximum number of workers nodes to launch in addition to the head
-# node. This takes precedence over min_workers. min_workers defaults to 0.
-max_workers: 1
-
-# Cloud-provider specific configuration.
-provider:
- type: aws
- region: us-west-2
- availability_zone: us-west-2a
-
-# How Ray will authenticate with newly launched nodes.
-auth:
- ssh_user: ubuntu
diff --git a/python/ray/projects/examples/open-tacotron/ray-project/project.yaml b/python/ray/projects/examples/open-tacotron/ray-project/project.yaml
deleted file mode 100644
index 4a2c1dfb5ce8..000000000000
--- a/python/ray/projects/examples/open-tacotron/ray-project/project.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-# This file is generated by `ray project create`
-
-name: open-tacotron
-description: "A TensorFlow implementation of Google's Tacotron speech synthesis with pre-trained model (unofficial)"
-repo: https://github.com/keithito/tacotron
-
-cluster:
- config: ray-project/cluster.yaml
-
-environment:
- requirements: ray-project/requirements.txt
-
- shell:
- - curl http://data.keithito.com/data/speech/tacotron-20180906.tar.gz | tar xzC /tmp
-
-commands:
- - name: serve
- command: python demo_server.py --checkpoint /tmp/tacotron-20180906/model.ckpt
diff --git a/python/ray/projects/examples/open-tacotron/ray-project/requirements.txt b/python/ray/projects/examples/open-tacotron/ray-project/requirements.txt
deleted file mode 100644
index 100979f0823f..000000000000
--- a/python/ray/projects/examples/open-tacotron/ray-project/requirements.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-# Adapted from https://github.com/keithito/tacotron/blob/master/requirements.txt
-# Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install
-# depends on your platform. It is assumed you have already installed tensorflow.
-falcon==1.2.0
-inflect==0.2.5
-librosa==0.5.1
-matplotlib==2.0.2
-numpy==1.14.3
-scipy==0.19.0
-tqdm==4.11.2
-Unidecode==0.4.20
diff --git a/python/ray/projects/examples/pytorch-transformers/ray-project/cluster.yaml b/python/ray/projects/examples/pytorch-transformers/ray-project/cluster.yaml
deleted file mode 100644
index b2143d8c458e..000000000000
--- a/python/ray/projects/examples/pytorch-transformers/ray-project/cluster.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-# This file is generated by `ray project create`
-
-# An unique identifier for the head node and workers of this cluster.
-cluster_name: pytorch-transformers
-
-# The maximum number of workers nodes to launch in addition to the head
-# node. This takes precedence over min_workers. min_workers default to 0.
-max_workers: 1
-
-# Cloud-provider specific configuration.
-provider:
- type: aws
- region: us-west-2
- availability_zone: us-west-2a
-
-# How Ray will authenticate with newly launched nodes.
-auth:
- ssh_user: ubuntu
diff --git a/python/ray/projects/examples/pytorch-transformers/ray-project/project.yaml b/python/ray/projects/examples/pytorch-transformers/ray-project/project.yaml
deleted file mode 100644
index 7bc27ff95ac1..000000000000
--- a/python/ray/projects/examples/pytorch-transformers/ray-project/project.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-# This file is generated by `ray project create`
-
-name: pytorch-transformers
-description: "A library of state-of-the-art pretrained models for Natural Language Processing (NLP)"
-repo: https://github.com/huggingface/pytorch-transformers
-
-cluster:
- config: ray-project/cluster.yaml
-
-environment:
- requirements: ray-project/requirements.txt
-
-commands:
- - name: train
- command: |
- wget https://raw.githubusercontent.com/ray-project/project-data/master/download_glue_data.py && \
- python download_glue_data.py -d /tmp -t {{dataset}} && \
- python ./examples/run_glue.py \
- --model_type bert \
- --model_name_or_path bert-base-uncased \
- --task_name {{dataset}} \
- --do_train \
- --do_eval \
- --do_lower_case \
- --data_dir /tmp/{{dataset}} \
- --max_seq_length 128 \
- --per_gpu_eval_batch_size=8 \
- --per_gpu_train_batch_size=8 \
- --learning_rate 2e-5 \
- --num_train_epochs 3.0 \
- --output_dir /tmp/output/
- params:
- - name: "dataset"
- help: "The GLUE dataset to fine-tune on"
- choices: ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE", "WNLI"]
diff --git a/python/ray/projects/examples/pytorch-transformers/ray-project/requirements.txt b/python/ray/projects/examples/pytorch-transformers/ray-project/requirements.txt
deleted file mode 100644
index 0ebd68f98eb7..000000000000
--- a/python/ray/projects/examples/pytorch-transformers/ray-project/requirements.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-# Adapted from https://github.com/huggingface/pytorch-transformers/blob/master/requirements.txt
-# PyTorch
-torch>=1.0.0
-# progress bars in model download and training scripts
-tqdm
-# Accessing files from S3 directly.
-boto3
-# Used for downloading models over HTTP
-requests
-# For OpenAI GPT
-regex
-# For XLNet
-sentencepiece
-# TensorBoard visualization
-tensorboardX
-# Pytorch transformers
-pytorch_transformers
diff --git a/python/ray/projects/projects.py b/python/ray/projects/projects.py
deleted file mode 100644
index add817c8754e..000000000000
--- a/python/ray/projects/projects.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import argparse
-import copy
-import json
-import jsonschema
-import os
-import yaml
-
-
-def make_argument_parser(name, params, wildcards):
- """Build argument parser dynamically to parse parameter arguments.
-
- Args:
- name (str): Name of the command to parse.
- params (dict): Parameter specification used to construct
- the argparse parser.
- wildcards (bool): Whether wildcards are allowed as arguments.
-
- Returns:
- The argparse parser.
- A dictionary from argument name to list of valid choices.
- """
-
- parser = argparse.ArgumentParser(prog=name)
- # For argparse arguments that have a 'choices' list associated
- # with them, save it in the following dictionary.
- choices = {}
- for param in params:
- # Construct arguments to pass into argparse's parser.add_argument.
- argparse_kwargs = copy.deepcopy(param)
- name = argparse_kwargs.pop("name")
- if wildcards and "choices" in param:
- choices[name] = param["choices"]
- argparse_kwargs["choices"] = param["choices"] + ["*"]
- if "type" in param:
- types = {"int": int, "str": str, "float": float}
- if param["type"] in types:
- argparse_kwargs["type"] = types[param["type"]]
- else:
- raise ValueError(
- "Parameter {} has type {} which is not supported. "
- "Type must be one of {}".format(name, param["type"],
- list(types.keys())))
- parser.add_argument("--" + name, dest=name, **argparse_kwargs)
-
- return parser, choices
-
-
-class ProjectDefinition:
- def __init__(self, current_dir):
- """Finds ray-project folder for current project, parse and validates it.
-
- Args:
- current_dir (str): Path from which to search for ray-project.
-
- Raises:
- jsonschema.exceptions.ValidationError: This exception is raised
- if the project file is not valid.
- ValueError: This exception is raised if there are other errors in
- the project definition (e.g. files not existing).
- """
- root = find_root(current_dir)
- if root is None:
- raise ValueError("No project root found")
- # Add an empty pathname to the end so that rsync will copy the project
- # directory to the correct target.
- self.root = os.path.join(root, "")
-
- # Parse the project YAML.
- project_file = os.path.join(self.root, "ray-project", "project.yaml")
- if not os.path.exists(project_file):
- raise ValueError("Project file {} not found".format(project_file))
- with open(project_file) as f:
- self.config = yaml.safe_load(f)
-
- check_project_config(self.root, self.config)
-
- def cluster_yaml(self):
- """Return the project's cluster configuration filename."""
- return self.config["cluster"]["config"]
-
- def working_directory(self):
- """Return the project's working directory on a cluster session."""
- # Add an empty pathname to the end so that rsync will copy the project
- # directory to the correct target.
- directory = os.path.join("~", self.config["name"], "")
- return directory
-
- def get_command_info(self, command_name, args, shell, wildcards=False):
- """Get the shell command, parsed arguments and config for a command.
-
- Args:
- command_name (str): Name of the command to run. The command
- definition should be available in project.yaml.
- args (tuple): Tuple containing arguments to format the command
- with.
- wildcards (bool): If True, enable wildcards as arguments.
-
- Returns:
- The raw shell command to run with placeholders for the arguments.
- The parsed argument dictonary, parsed with argparse.
- The config dictionary of the command.
-
- Raises:
- ValueError: This exception is raised if the given command is not
- found in project.yaml.
- """
- if shell or not command_name:
- return command_name, {}, {}
-
- command_to_run = None
- params = None
- config = None
-
- for command_definition in self.config["commands"]:
- if command_definition["name"] == command_name:
- command_to_run = command_definition["command"]
- params = command_definition.get("params", [])
- config = command_definition.get("config", {})
- if not command_to_run:
- raise ValueError(
- "Cannot find the command named '{}' in commmands section "
- "of the project file.".format(command_name))
-
- parser, choices = make_argument_parser(command_name, params, wildcards)
- parsed_args = vars(parser.parse_args(list(args)))
-
- if wildcards:
- for key, val in parsed_args.items():
- if val == "*":
- parsed_args[key] = choices[key]
-
- return command_to_run, parsed_args, config
-
- def git_repo(self):
- return self.config.get("repo", None)
-
-
-def find_root(directory):
- """Find root directory of the ray project.
-
- Args:
- directory (str): Directory to start the search in.
-
- Returns:
- Path of the parent directory containing the ray-project or
- None if no such project is found.
- """
- prev, directory = None, os.path.abspath(directory)
- while prev != directory:
- if os.path.isdir(os.path.join(directory, "ray-project")):
- return directory
- prev, directory = directory, os.path.abspath(
- os.path.join(directory, os.pardir))
- return None
-
-
-def validate_project_schema(project_config):
- """Validate a project config against the official ray project schema.
-
- Args:
- project_config (dict): Parsed project yaml.
-
- Raises:
- jsonschema.exceptions.ValidationError: This exception is raised
- if the project file is not valid.
- """
- dir = os.path.dirname(os.path.abspath(__file__))
- with open(os.path.join(dir, "schema.json")) as f:
- schema = json.load(f)
-
- jsonschema.validate(instance=project_config, schema=schema)
-
-
-def check_project_config(project_root, project_config):
- """Checks if the project definition is valid.
-
- Args:
- project_root (str): Path containing the ray-project
- project_config (dict): Project config definition
-
- Raises:
- jsonschema.exceptions.ValidationError: This exception is raised
- if the project file is not valid.
- ValueError: This exception is raised if there are other errors in
- the project definition (e.g. files not existing).
- """
- validate_project_schema(project_config)
-
- # Make sure the cluster yaml file exists
- cluster_file = os.path.join(project_root,
- project_config["cluster"]["config"])
- if not os.path.exists(cluster_file):
- raise ValueError("'cluster' file does not exist "
- "in {}".format(project_root))
-
- if "environment" in project_config:
- env = project_config["environment"]
-
- if sum(["dockerfile" in env, "dockerimage" in env]) > 1:
- raise ValueError("Cannot specify both 'dockerfile' and "
- "'dockerimage' in environment.")
-
- if "requirements" in env:
- requirements_file = os.path.join(project_root, env["requirements"])
- if not os.path.exists(requirements_file):
- raise ValueError("'requirements' file in 'environment' does "
- "not exist in {}".format(project_root))
-
- if "dockerfile" in env:
- docker_file = os.path.join(project_root, env["dockerfile"])
- if not os.path.exists(docker_file):
- raise ValueError("'dockerfile' file in 'environment' does "
- "not exist in {}".format(project_root))
diff --git a/python/ray/projects/schema.json b/python/ray/projects/schema.json
deleted file mode 100644
index 2040835bf05f..000000000000
--- a/python/ray/projects/schema.json
+++ /dev/null
@@ -1,183 +0,0 @@
-{
- "$schema": "http://json-schema.org/draft-07/schema#",
- "type": "object",
- "properties": {
- "name": {
- "description": "The name of the project",
- "type": "string"
- },
- "description": {
- "description": "A short description of the project",
- "type": "string"
- },
- "repo": {
- "description": "The URL of the repo this project is part of",
- "type": "string"
- },
- "documentation": {
- "description": "Link to the documentation of this project",
- "type": "string"
- },
- "tags": {
- "description": "Relevant tags for this project",
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "cluster": {
- "type": "object",
- "properties": {
- "config": {
- "type": "string",
- "description": "Path to a .yaml cluster configuration file (relative to the project root)"
- },
- "params": {
- "type": "array",
- "items": {
- "type": "object",
- "properties": {
- "name": {
- "type": "string"
- },
- "help": {
- "type": "string"
- },
- "choices": {
- "type": "array"
- },
- "default": {
- },
- "type": {
- "type": "string",
- "enum": [
- "int",
- "float",
- "str"
- ]
- }
- },
- "required": [
- "name"
- ],
- "additionalProperties": false
- }
- }
- },
- "required": [
- "config"
- ],
- "additionalProperties": false
- },
- "environment": {
- "description": "The environment that needs to be set up to run the project",
- "type": "object",
- "properties": {
- "dockerimage": {
- "description": "URL to a docker image that can be pulled to run the project in",
- "type": "string"
- },
- "dockerfile": {
- "description": "Path to a Dockerfile to set up an image the project can run in (relative to the project root)",
- "type": "string"
- },
- "requirements": {
- "description": "Path to a Python requirements.txt file to set up project dependencies (relative to the project root)",
- "type": "string"
- },
- "shell": {
- "description": "A sequence of shell commands to run to set up the project environment",
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- },
- "additionalProperties": false
- },
- "commands": {
- "type": "array",
- "items": {
- "description": "Possible commands to run to start a session",
- "type": "object",
- "properties": {
- "name": {
- "description": "Name of the command",
- "type": "string"
- },
- "help": {
- "description": "Help string for the command",
- "type": "string"
- },
- "command": {
- "description": "Shell command to run on the cluster",
- "type": "string"
- },
- "params": {
- "type": "array",
- "items": {
- "description": "Possible parameters in the command",
- "type": "object",
- "properties": {
- "name": {
- "description": "Name of the parameter",
- "type": "string"
- },
- "help": {
- "description": "Help string for the parameter",
- "type": "string"
- },
- "choices": {
- "description": "Possible values the parameter can take",
- "type": "array"
- },
- "default": {
- },
- "type": {
- "description": "Required type for the parameter",
- "type": "string",
- "enum": [
- "int",
- "float",
- "str"
- ]
- }
- },
- "required": [
- "name"
- ],
- "additionalProperties": false
- }
- },
- "config": {
- "description": "Configuration options for the command",
- "type": "object",
- "properties": {
- "tmux": {
- "description": "If true, the command will be run inside of tmux",
- "type": "boolean"
- }
- },
- "additionalProperties": false
- }
- },
- "required": [
- "name",
- "command"
- ],
- "additionalProperties": false
- }
- },
- "output_files": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- },
- "required": [
- "name",
- "cluster"
- ],
- "additionalProperties": false
-}
diff --git a/python/ray/projects/scripts.py b/python/ray/projects/scripts.py
deleted file mode 100644
index 31f566a9e8ac..000000000000
--- a/python/ray/projects/scripts.py
+++ /dev/null
@@ -1,445 +0,0 @@
-import argparse
-import click
-import copy
-import jsonschema
-import logging
-import os
-from shutil import copyfile
-import subprocess
-import sys
-import time
-
-import ray
-from ray.autoscaler.commands import (
- attach_cluster,
- exec_cluster,
- create_or_update_cluster,
- rsync,
- teardown_cluster,
-)
-
-logging.basicConfig(format=ray.ray_constants.LOGGER_FORMAT, level=logging.INFO)
-logger = logging.getLogger(__file__)
-
-# File layout for generated project files
-# user-dir/
-# ray-project/
-# project.yaml
-# cluster.yaml
-# requirements.txt
-PROJECT_DIR = "ray-project"
-PROJECT_YAML = os.path.join(PROJECT_DIR, "project.yaml")
-CLUSTER_YAML = os.path.join(PROJECT_DIR, "cluster.yaml")
-REQUIREMENTS_TXT = os.path.join(PROJECT_DIR, "requirements.txt")
-
-# File layout for templates file
-# RAY/.../projects/
-# templates/
-# cluster_template.yaml
-# project_template.yaml
-# requirements.txt
-_THIS_FILE_DIR = os.path.split(os.path.abspath(__file__))[0]
-_TEMPLATE_DIR = os.path.join(_THIS_FILE_DIR, "templates")
-PROJECT_TEMPLATE = os.path.join(_TEMPLATE_DIR, "project_template.yaml")
-CLUSTER_TEMPLATE = os.path.join(_TEMPLATE_DIR, "cluster_template.yaml")
-REQUIREMENTS_TXT_TEMPLATE = os.path.join(_TEMPLATE_DIR, "requirements.txt")
-
-
-@click.group(
- "project", help="[Experimental] Commands working with ray project")
-def project_cli():
- pass
-
-
-@project_cli.command(help="Validate current project spec")
-@click.option(
- "--verbose", help="If set, print the validated file", is_flag=True)
-def validate(verbose):
- try:
- project = ray.projects.ProjectDefinition(os.getcwd())
- print("Project files validated!", file=sys.stderr)
- if verbose:
- print(project.config)
- except (jsonschema.exceptions.ValidationError, ValueError) as e:
- print("Validation failed for the following reason", file=sys.stderr)
- raise click.ClickException(e)
-
-
-@project_cli.command(help="Create a new project within current directory")
-@click.argument("project_name")
-@click.option(
- "--cluster-yaml",
- help="Path to autoscaler yaml. Created by default",
- default=None)
-@click.option(
- "--requirements",
- help="Path to requirements.txt. Created by default",
- default=None)
-def create(project_name, cluster_yaml, requirements):
- if os.path.exists(PROJECT_DIR):
- raise click.ClickException(
- "Project directory {} already exists.".format(PROJECT_DIR))
- os.makedirs(PROJECT_DIR)
-
- if cluster_yaml is None:
- logger.warning("Using default autoscaler yaml")
-
- with open(CLUSTER_TEMPLATE) as f:
- template = f.read().replace(r"{{name}}", project_name)
- with open(CLUSTER_YAML, "w") as f:
- f.write(template)
-
- cluster_yaml = CLUSTER_YAML
-
- if requirements is None:
- logger.warning("Using default requirements.txt")
- # no templating required, just copy the file
- copyfile(REQUIREMENTS_TXT_TEMPLATE, REQUIREMENTS_TXT)
-
- requirements = REQUIREMENTS_TXT
-
- repo = None
- if os.path.exists(".git"):
- try:
- repo = subprocess.check_output(
- "git remote get-url origin".split(" ")).strip()
- logger.info("Setting repo URL to %s", repo)
- except subprocess.CalledProcessError:
- pass
-
- with open(PROJECT_TEMPLATE) as f:
- project_template = f.read()
- # NOTE(simon):
- # We could use jinja2, which will make the templating part easier.
- project_template = project_template.replace(r"{{name}}", project_name)
- project_template = project_template.replace(r"{{cluster}}",
- cluster_yaml)
- project_template = project_template.replace(r"{{requirements}}",
- requirements)
- if repo is None:
- project_template = project_template.replace(
- r"{{repo_string}}", "# repo: {}".format("..."))
- else:
- project_template = project_template.replace(
- r"{{repo_string}}", "repo: {}".format(repo))
- with open(PROJECT_YAML, "w") as f:
- f.write(project_template)
-
-
-@click.group(
- "session",
- help="[Experimental] Commands working with sessions, which are "
- "running instances of a project.")
-def session_cli():
- pass
-
-
-def load_project_or_throw():
- # Validate the project file
- try:
- return ray.projects.ProjectDefinition(os.getcwd())
- except (jsonschema.exceptions.ValidationError, ValueError):
- raise click.ClickException(
- "Project file validation failed. Please run "
- "`ray project validate` to inspect the error.")
-
-
-class SessionRunner:
- """Class for setting up a session and executing commands in it."""
-
- def __init__(self, session_name=None):
- """Initialize session runner and try to parse the command arguments.
-
- Args:
- session_name (str): Name of the session.
-
- Raises:
- click.ClickException: This exception is raised if any error occurs.
- """
- self.project_definition = load_project_or_throw()
- self.session_name = session_name
-
- # Check for features we don't support right now
- project_environment = self.project_definition.config.get(
- "environment", {})
- need_docker = ("dockerfile" in project_environment
- or "dockerimage" in project_environment)
- if need_docker:
- raise click.ClickException(
- "Docker support in session is currently not implemented.")
-
- def create_cluster(self, no_config_cache):
- """Create a cluster that will run the session."""
- create_or_update_cluster(
- config_file=self.project_definition.cluster_yaml(),
- override_min_workers=None,
- override_max_workers=None,
- no_restart=False,
- restart_only=False,
- yes=True,
- override_cluster_name=self.session_name,
- no_config_cache=no_config_cache,
- )
-
- def sync_files(self):
- """Synchronize files with the session."""
- rsync(
- self.project_definition.cluster_yaml(),
- source=self.project_definition.root,
- target=self.project_definition.working_directory(),
- override_cluster_name=self.session_name,
- down=False,
- )
-
- def setup_environment(self):
- """Set up the environment of the session."""
- project_environment = self.project_definition.config.get(
- "environment", {})
-
- if "requirements" in project_environment:
- requirements_txt = project_environment["requirements"]
-
- # Create a temporary requirements_txt in the head node.
- remote_requirements_txt = os.path.join(
- ray.utils.get_user_temp_dir(),
- "ray_project_requirements_txt_{}".format(time.time()))
-
- rsync(
- self.project_definition.cluster_yaml(),
- source=requirements_txt,
- target=remote_requirements_txt,
- override_cluster_name=self.session_name,
- down=False,
- )
- self.execute_command(
- "pip install -r {}".format(remote_requirements_txt))
-
- if "shell" in project_environment:
- for cmd in project_environment["shell"]:
- self.execute_command(cmd)
-
- def execute_command(self, cmd, config={}):
- """Execute a shell command in the session.
-
- Args:
- cmd (str): Shell command to run in the session. It will be
- run in the working directory of the project.
- """
- cwd = self.project_definition.working_directory()
- cmd = "cd {cwd}; {cmd}".format(cwd=cwd, cmd=cmd)
- exec_cluster(
- config_file=self.project_definition.cluster_yaml(),
- cmd=cmd,
- run_env=config.get("run_env", "auto"),
- screen=False,
- tmux=config.get("tmux", False),
- stop=False,
- start=False,
- override_cluster_name=self.session_name,
- port_forward=config.get("port_forward", None),
- )
-
-
-def format_command(command, parsed_args):
- """Substitute arguments into command.
-
- Args:
- command (str): Shell comand with argument placeholders.
- parsed_args (dict): Dictionary that maps from argument names
- to their value.
-
- Returns:
- Shell command with parameters from parsed_args substituted.
- """
- for key, val in parsed_args.items():
- command = command.replace("{{" + key + "}}", str(val))
- return command
-
-
-def get_session_runs(name, command, parsed_args):
- """Get a list of sessions to start.
-
- Args:
- command (str): Shell command with argument placeholders.
- parsed_args (dict): Dictionary that maps from argument names
- to their values.
-
- Returns:
- List of sessions to start, which are dictionaries with keys:
- "name": Name of the session to start,
- "command": Command to run after starting the session,
- "params": Parameters for this run,
- "num_steps": 4 if a command should be run, 3 if not.
- """
- if not command:
- return [{"name": name, "command": None, "params": {}, "num_steps": 3}]
-
- # Try to find a wildcard argument (i.e. one that has a list of values)
- # and give an error if there is more than one (currently unsupported).
- wildcard_arg = None
- for key, val in parsed_args.items():
- if isinstance(val, list):
- if not wildcard_arg:
- wildcard_arg = key
- else:
- raise click.ClickException(
- "More than one wildcard is not supported at the moment")
-
- if not wildcard_arg:
- session_run = {
- "name": name,
- "command": format_command(command, parsed_args),
- "params": parsed_args,
- "num_steps": 4
- }
- return [session_run]
- else:
- session_runs = []
- for val in parsed_args[wildcard_arg]:
- parsed_args = copy.deepcopy(parsed_args)
- parsed_args[wildcard_arg] = val
- session_run = {
- "name": "{}-{}-{}".format(name, wildcard_arg, val),
- "command": format_command(command, parsed_args),
- "params": parsed_args,
- "num_steps": 4
- }
- session_runs.append(session_run)
- return session_runs
-
-
-@session_cli.command(help="Attach to an existing cluster")
-@click.option(
- "--screen", is_flag=True, default=False, help="Run the command in screen.")
-@click.option("--tmux", help="Attach to tmux session", is_flag=True)
-def attach(screen, tmux):
- project_definition = load_project_or_throw()
- attach_cluster(
- project_definition.cluster_yaml(),
- start=False,
- use_screen=screen,
- use_tmux=tmux,
- override_cluster_name=None,
- new=False,
- )
-
-
-@session_cli.command(help="Stop a session based on current project config")
-@click.option("--name", help="Name of the session to stop", default=None)
-def stop(name):
- project_definition = load_project_or_throw()
-
- if not name:
- name = project_definition.config["name"]
-
- teardown_cluster(
- project_definition.cluster_yaml(),
- yes=True,
- workers_only=False,
- override_cluster_name=name)
-
-
-@session_cli.command(
- name="start",
- context_settings=dict(ignore_unknown_options=True, ),
- help="Start a session based on current project config")
-@click.argument("command", required=False)
-@click.argument("args", nargs=-1, type=click.UNPROCESSED)
-@click.option(
- "--shell",
- help=(
- "If set, run the command as a raw shell command instead of looking up "
- "the command in the project config"),
- is_flag=True)
-@click.option("--name", help="A name to tag the session with.", default=None)
-@click.option(
- "--no-config-cache",
- is_flag=True,
- default=False,
- help="Disable the local cluster config cache.")
-def session_start(command, args, shell, name, no_config_cache):
- project_definition = load_project_or_throw()
-
- if not name:
- name = project_definition.config["name"]
-
- # Get the actual command to run. This also validates the command,
- # which should be done before the cluster is started.
- try:
- command, parsed_args, config = project_definition.get_command_info(
- command, args, shell, wildcards=True)
- except ValueError as e:
- raise click.ClickException(e)
- session_runs = get_session_runs(name, command, parsed_args)
-
- if len(session_runs) > 1 and not config.get("tmux", False):
- logging.info("Using wildcards with tmux = False would not create "
- "sessions in parallel, so we are overriding it with "
- "tmux = True.")
- config["tmux"] = True
-
- for run in session_runs:
- runner = SessionRunner(session_name=run["name"])
- logger.info("[1/{}] Creating cluster".format(run["num_steps"]))
- runner.create_cluster(no_config_cache)
- logger.info("[2/{}] Syncing the project".format(run["num_steps"]))
- runner.sync_files()
- logger.info("[3/{}] Setting up environment".format(run["num_steps"]))
- runner.setup_environment()
-
- if run["command"]:
- # Run the actual command.
- logger.info("[4/4] Running command")
- runner.execute_command(run["command"], config)
-
-
-@session_cli.command(
- name="commands",
- help="Print available commands for sessions of this project.")
-def session_commands():
- project_definition = load_project_or_throw()
- print("Active project: " + project_definition.config["name"])
- print()
-
- commands = project_definition.config["commands"]
-
- for command in commands:
- print("Command \"{}\":".format(command["name"]))
- parser = argparse.ArgumentParser(
- command["name"], description=command.get("help"), add_help=False)
- params = command.get("params", [])
- for param in params:
- name = param.pop("name")
- if "type" in param:
- param.pop("type")
- parser.add_argument("--" + name, **param)
- help_string = parser.format_help()
- # Indent the help message by two spaces and print it.
- print("\n".join([" " + line for line in help_string.split("\n")]))
-
-
-@session_cli.command(
- name="execute",
- context_settings=dict(ignore_unknown_options=True, ),
- help="Execute a command in a session")
-@click.argument("command", required=False)
-@click.argument("args", nargs=-1, type=click.UNPROCESSED)
-@click.option(
- "--shell",
- help=(
- "If set, run the command as a raw shell command instead of looking up "
- "the command in the project config"),
- is_flag=True)
-@click.option(
- "--name", help="Name of the session to run this command on", default=None)
-def session_execute(command, args, shell, name):
- project_definition = load_project_or_throw()
- try:
- command, parsed_args, config = project_definition.get_command_info(
- command, args, shell, wildcards=False)
- except ValueError as e:
- raise click.ClickException(e)
-
- runner = SessionRunner(session_name=name)
- command = format_command(command, parsed_args)
- runner.execute_command(command)
diff --git a/python/ray/projects/templates/cluster_template.yaml b/python/ray/projects/templates/cluster_template.yaml
deleted file mode 100644
index 962dd1768313..000000000000
--- a/python/ray/projects/templates/cluster_template.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-# This file is generated by `ray project create`.
-
-# A unique identifier for the head node and workers of this cluster.
-cluster_name: {{name}}
-
-# The maximum number of workers nodes to launch in addition to the head
-# node. This takes precedence over min_workers. min_workers defaults to 0.
-max_workers: 1
-
-# Cloud-provider specific configuration.
-provider:
- type: aws
- region: us-west-2
- availability_zone: us-west-2a
-
-# How Ray will authenticate with newly launched nodes.
-auth:
- ssh_user: ubuntu
diff --git a/python/ray/projects/templates/project_template.yaml b/python/ray/projects/templates/project_template.yaml
deleted file mode 100644
index f0c4d1b8fa27..000000000000
--- a/python/ray/projects/templates/project_template.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-# This file is generated by `ray project create`.
-
-name: {{name}}
-
-# description: A short description of the project.
-# The URL of the repo this project is part of.
-{{repo_string}}
-
-cluster:
- config: {{cluster}}
-
-environment:
- # dockerfile: The dockerfile to be built and ran the commands with.
- # dockerimage: The docker image to be used to run the project in, e.g. ubuntu:18.04.
- requirements: {{requirements}}
-
- shell: # Shell commands to be ran for environment setup.
- - echo "Setting up the environment"
-
-commands:
- - name: default
- command: echo "Starting ray job"
diff --git a/python/ray/projects/templates/requirements.txt b/python/ray/projects/templates/requirements.txt
deleted file mode 100644
index 0f026d879f17..000000000000
--- a/python/ray/projects/templates/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-ray[debug]
\ No newline at end of file
diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py
index 47aff0ebe517..92aeac05f608 100644
--- a/python/ray/ray_constants.py
+++ b/python/ray/ray_constants.py
@@ -130,7 +130,7 @@ def to_memory_units(memory_bytes, round_up):
RAYLET_CONNECTION_ERROR = "raylet_connection_error"
# Used in gpu detection
-RESOURCE_CONSTRAINT_PREFIX = "GPUType:"
+RESOURCE_CONSTRAINT_PREFIX = "gpu_type:"
RESOURCES_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_RESOURCES"
diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py
index 1696212aa5c0..02a4735a2619 100644
--- a/python/ray/remote_function.py
+++ b/python/ray/remote_function.py
@@ -40,7 +40,7 @@ class RemoteFunction:
_object_store_memory: The object store memory request for this task.
_resources: The default custom resource requirements for invocations of
this remote function.
- _num_return_vals: The default number of return values for invocations
+ _num_returns: The default number of return values for invocations
of this remote function.
_max_calls: The number of times a worker can execute this function
before exiting.
@@ -61,8 +61,8 @@ class RemoteFunction:
"""
def __init__(self, language, function, function_descriptor, num_cpus,
- num_gpus, memory, object_store_memory, resources,
- num_return_vals, max_calls, max_retries, placement_group,
+ num_gpus, memory, object_store_memory, resources, num_returns,
+ max_calls, max_retries, placement_group,
placement_group_bundle_index):
self._language = language
self._function = function
@@ -79,8 +79,8 @@ def __init__(self, language, function, function_descriptor, num_cpus,
"setting object_store_memory is not implemented for tasks")
self._object_store_memory = None
self._resources = resources
- self._num_return_vals = (DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS if
- num_return_vals is None else num_return_vals)
+ self._num_returns = (DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS
+ if num_returns is None else num_returns)
self._max_calls = (DEFAULT_REMOTE_FUNCTION_MAX_CALLS
if max_calls is None else max_calls)
self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
@@ -107,7 +107,7 @@ def __call__(self, *args, **kwargs):
def _submit(self,
args=None,
kwargs=None,
- num_return_vals=None,
+ num_returns=None,
num_cpus=None,
num_gpus=None,
resources=None):
@@ -116,7 +116,7 @@ def _submit(self,
return self._remote(
args=args,
kwargs=kwargs,
- num_return_vals=num_return_vals,
+ num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
resources=resources)
@@ -144,8 +144,7 @@ def remote(self, *args, **kwargs):
def _remote(self,
args=None,
kwargs=None,
- num_return_vals=None,
- is_direct_call=None,
+ num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
@@ -183,10 +182,8 @@ def _remote(self,
kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args
- if num_return_vals is None:
- num_return_vals = self._num_return_vals
- if is_direct_call is not None and not is_direct_call:
- raise ValueError("Non-direct call tasks are no longer supported.")
+ if num_returns is None:
+ num_returns = self._num_returns
if max_retries is None:
max_retries = self._max_retries
@@ -216,7 +213,7 @@ def invocation(args, kwargs):
"cannot be executed locally."
object_refs = worker.core_worker.submit_task(
self._language, self._function_descriptor, list_args,
- num_return_vals, resources, max_retries, placement_group.id,
+ num_returns, resources, max_retries, placement_group.id,
placement_group_bundle_index)
if len(object_refs) == 1:
diff --git a/python/ray/reporter.py b/python/ray/reporter.py
index 3f7295df6536..517f4d8ba77e 100644
--- a/python/ray/reporter.py
+++ b/python/ray/reporter.py
@@ -43,9 +43,10 @@ def GetProfilingStats(self, request, context):
duration = request.duration
profiling_file_path = os.path.join(ray.utils.get_ray_temp_dir(),
f"{pid}_profiling.txt")
+ sudo = "sudo" if ray.utils.get_user() != "root" else ""
process = subprocess.Popen(
- "sudo $(which py-spy) record -o {} -p {} -d {} -f speedscope"
- .format(profiling_file_path, pid, duration),
+ (f"{sudo} $(which py-spy) record -o {profiling_file_path} -p {pid}"
+ f" -d {duration} -f speedscope"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py
index b53f00b10bb3..be87e70b647f 100644
--- a/python/ray/scripts/scripts.py
+++ b/python/ray/scripts/scripts.py
@@ -20,7 +20,6 @@
debug_status, RUN_ENV_TYPES)
import ray.ray_constants as ray_constants
import ray.utils
-from ray.projects.scripts import project_cli, session_cli
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
@@ -460,10 +459,9 @@ def start(node_ip_address, redis_address, address, redis_port, port,
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
- if redis_address is not None or address is not None:
+ if address is not None:
(redis_address, redis_address_ip,
- redis_address_port) = services.validate_redis_address(
- address, redis_address)
+ redis_address_port) = services.validate_redis_address(address)
try:
resources = json.loads(resources)
@@ -1009,7 +1007,7 @@ def down(cluster_config_file, yes, workers_only, cluster_name,
keep_min_workers)
-@cli.command()
+@cli.command(hidden=True)
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--yes",
@@ -1454,36 +1452,6 @@ def timeline(address):
"You can open this with chrome://tracing in the Chrome browser.")
-@cli.command()
-@click.option(
- "--address",
- required=False,
- type=str,
- help="Override the address to connect to.")
-def statistics(address):
- """Get the current metrics protobuf from a Ray cluster (developer tool)."""
- if not address:
- address = services.find_redis_address_or_die()
- logger.info(f"Connecting to Ray instance at {address}.")
- ray.init(address=address)
-
- import grpc
- from ray.core.generated import node_manager_pb2
- from ray.core.generated import node_manager_pb2_grpc
-
- for raylet in ray.nodes():
- raylet_address = "{}:{}".format(raylet["NodeManagerAddress"],
- ray.nodes()[0]["NodeManagerPort"])
- logger.info(f"Querying raylet {raylet_address}")
-
- channel = grpc.insecure_channel(raylet_address)
- stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
- reply = stub.GetNodeStats(
- node_manager_pb2.GetNodeStatsRequest(include_memory_info=False),
- timeout=2.0)
- print(reply)
-
-
@cli.command()
@click.option(
"--address",
@@ -1520,13 +1488,13 @@ def status(address):
print(debug_status())
-@cli.command()
+@cli.command(hidden=True)
@click.option(
"--address",
required=False,
type=str,
help="Override the address to connect to.")
-def globalgc(address):
+def global_gc(address):
"""Trigger Python garbage collection on all cluster workers."""
if not address:
address = services.find_redis_address_or_die()
@@ -1614,13 +1582,10 @@ def add_command_alias(command, name, hidden):
cli.add_command(get_worker_ips)
cli.add_command(microbenchmark)
cli.add_command(stack)
-cli.add_command(statistics)
cli.add_command(status)
cli.add_command(memory)
-cli.add_command(globalgc)
+cli.add_command(global_gc)
cli.add_command(timeline)
-cli.add_command(project_cli)
-cli.add_command(session_cli)
cli.add_command(install_nightly)
try:
diff --git a/python/ray/serialization.py b/python/ray/serialization.py
index faee819713cc..ae51289c5077 100644
--- a/python/ray/serialization.py
+++ b/python/ray/serialization.py
@@ -13,9 +13,9 @@
PlasmaObjectNotAvailable,
RayTaskError,
RayActorError,
- RayCancellationError,
- RayWorkerError,
- UnreconstructableError,
+ TaskCancelledError,
+ WorkerCrashedError,
+ ObjectLostError,
)
from ray._raylet import (
split_buffer,
@@ -265,14 +265,13 @@ def _deserialize_object(self, data, metadata, object_ref):
obj = self._deserialize_msgpack_data(data, metadata)
return RayError.from_bytes(obj)
elif error_type == ErrorType.Value("WORKER_DIED"):
- return RayWorkerError()
+ return WorkerCrashedError()
elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
elif error_type == ErrorType.Value("TASK_CANCELLED"):
- return RayCancellationError()
+ return TaskCancelledError()
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
- return UnreconstructableError(
- ray.ObjectRef(object_ref.binary()))
+ return ObjectLostError(ray.ObjectRef(object_ref.binary()))
else:
assert error_type != ErrorType.Value("OBJECT_IN_PLASMA"), \
"Tried to get object that has been promoted to plasma."
diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py
index 3634b11e7562..a7e06a1fd2c3 100644
--- a/python/ray/serve/tests/test_router.py
+++ b/python/ray/serve/tests/test_router.py
@@ -191,7 +191,7 @@ def get_queues(self):
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
# Neither queries should be available
- with pytest.raises(ray.exceptions.RayTimeoutError):
+ with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get([first_query, second_query], timeout=0.2)
# Let's retrieve the router internal state
diff --git a/python/ray/services.py b/python/ray/services.py
index 4adbb1f85292..46f002840708 100644
--- a/python/ray/services.py
+++ b/python/ray/services.py
@@ -256,35 +256,18 @@ def remaining_processes_alive():
return ray.worker._global_node.remaining_processes_alive()
-def validate_redis_address(address, redis_address):
- """Validates redis address parameter and splits it into host/ip components.
-
- We temporarily support both 'address' and 'redis_address', so both are
- handled here.
+def validate_redis_address(address):
+ """Validates address parameter.
Returns:
redis_address: string containing the full address.
redis_ip: string representing the host portion of the address.
redis_port: integer representing the port portion of the address.
-
- Raises:
- ValueError: if both address and redis_address were specified or the
- address was malformed.
"""
- if redis_address == "auto":
- raise ValueError("auto address resolution not supported for "
- "redis_address parameter. Please use address.")
-
- if address:
- if redis_address:
- raise ValueError(
- "Both address and redis_address specified. Use only address.")
- if address == "auto":
- address = find_redis_address_or_die()
- redis_address = address
-
- redis_address = address_to_ip(redis_address)
+ if address == "auto":
+ address = find_redis_address_or_die()
+ redis_address = address_to_ip(address)
redis_address_parts = redis_address.split(":")
if len(redis_address_parts) != 2:
@@ -1430,6 +1413,7 @@ def start_raylet(redis_address,
os.path.dirname(os.path.abspath(__file__)),
"new_dashboard/agent.py"),
"--redis-address={}".format(redis_address),
+ "--metrics-export-port={}".format(metrics_export_port),
"--node-manager-port={}".format(node_manager_port),
"--object-store-name={}".format(plasma_store_name),
"--raylet-name={}".format(raylet_name),
diff --git a/python/ray/state.py b/python/ray/state.py
index 005ab1a92ce0..05044cce41c7 100644
--- a/python/ray/state.py
+++ b/python/ray/state.py
@@ -48,7 +48,7 @@ def _check_connected(self):
"""
if (self.redis_client is None or self.redis_clients is None
or self.global_state_accessor is None):
- raise ray.exceptions.RayConnectionError(
+ raise ray.exceptions.RaySystemError(
"Ray has not been started yet. You can start Ray with "
"'ray.init()'.")
@@ -844,7 +844,7 @@ def actor_checkpoint_info(self, actor_id):
def jobs():
- """Get a list of the jobs in the cluster.
+ """Get a list of the jobs in the cluster (for debugging only).
Returns:
Information from the job table, namely a list of dicts with keys:
@@ -858,7 +858,7 @@ def jobs():
def nodes():
- """Get a list of the nodes in the cluster.
+ """Get a list of the nodes in the cluster (for debugging only).
Returns:
Information about the Ray clients in the cluster.
@@ -899,7 +899,7 @@ def node_ids():
def actors(actor_id=None):
- """Fetch and parse the actor info for one or more actor IDs.
+ """Fetch actor info for one or more actor IDs (for debugging only).
Args:
actor_id: A hex string of the actor ID to fetch information about. If
diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py
index 99cbd31c49ec..8bb01bf226af 100644
--- a/python/ray/test_utils.py
+++ b/python/ray/test_utils.py
@@ -428,3 +428,11 @@ def get_error_message(pub_sub, num, error_type=None, timeout=5):
time.sleep(0.01)
return msgs
+
+
+def format_web_url(url):
+ """Format web url."""
+ url = url.replace("localhost", "http://127.0.0.1")
+ if not url.startswith("http://"):
+ return "http://" + url
+ return url
diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD
index a597e2e446a7..dee43ed66500 100644
--- a/python/ray/tests/BUILD
+++ b/python/ray/tests/BUILD
@@ -31,6 +31,7 @@ py_test_module_list(
"test_global_gc.py",
"test_iter.py",
"test_joblib.py",
+ "test_resource_demand_scheduler.py",
],
size = "medium",
extra_srcs = SRCS,
@@ -52,9 +53,9 @@ py_test_module_list(
"test_output.py",
"test_placement_group.py",
"test_reconstruction.py",
- "test_resource_demand_scheduler.py",
"test_reference_counting_2.py",
"test_reference_counting.py",
+ "test_serialization.py",
"test_stress.py",
"test_stress_sharded.py",
"test_unreconstructable_errors.py",
@@ -88,9 +89,7 @@ py_test_module_list(
"test_multi_tenancy.py",
"test_node_manager.py",
"test_numba.py",
- "test_projects.py",
"test_ray_init.py",
- "test_serialization.py",
"test_tempfile.py",
"test_webui.py",
],
@@ -102,13 +101,12 @@ py_test_module_list(
py_test_module_list(
files = [
- "test_stress_faiure.py",
+ "test_stress_failure.py",
"test_failure.py"
],
size = "large",
extra_srcs = SRCS,
- # TODO(ekl) enable again once we support direct call reconstruction
- tags = ["exclusive", "manual"],
+ tags = ["exclusive"],
deps = ["//:ray_lib"],
)
diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py
index d5ff824ba797..6b7cc662b9af 100644
--- a/python/ray/tests/test_actor.py
+++ b/python/ray/tests/test_actor.py
@@ -646,15 +646,15 @@ class Foo:
def method0(self):
return 1
- @ray.method(num_return_vals=1)
+ @ray.method(num_returns=1)
def method1(self):
return 1
- @ray.method(num_return_vals=2)
+ @ray.method(num_returns=2)
def method2(self):
return 1, 2
- @ray.method(num_return_vals=3)
+ @ray.method(num_returns=3)
def method3(self):
return 1, 2, 3
diff --git a/python/ray/tests/test_actor_failures.py b/python/ray/tests/test_actor_failures.py
index 7985050ad599..7ddd4b26b732 100644
--- a/python/ray/tests/test_actor_failures.py
+++ b/python/ray/tests/test_actor_failures.py
@@ -30,7 +30,7 @@ def ray_init_with_task_retry_delay():
@pytest.mark.parametrize(
"ray_start_regular", [{
"object_store_memory": 150 * 1024 * 1024,
- "lru_evict": True,
+ "_lru_evict": True,
}],
indirect=True)
def test_actor_eviction(ray_start_regular):
@@ -63,7 +63,7 @@ def create_object(self, size):
val = ray.get(obj)
assert isinstance(val, np.ndarray), val
num_success += 1
- except ray.exceptions.UnreconstructableError:
+ except ray.exceptions.ObjectLostError:
num_evicted += 1
# Some objects should have been evicted, and some should still be in the
# object store.
diff --git a/python/ray/tests/test_actor_resources.py b/python/ray/tests/test_actor_resources.py
index f7f75bcb45ca..66fb5fe09d1a 100644
--- a/python/ray/tests/test_actor_resources.py
+++ b/python/ray/tests/test_actor_resources.py
@@ -95,10 +95,10 @@ def test_actor_gpus(ray_start_cluster):
@ray.remote(num_gpus=1)
class Actor1:
def __init__(self):
- self.gpu_ids = ray.get_gpu_ids(as_str=True)
+ self.gpu_ids = ray.get_gpu_ids()
def get_location_and_ids(self):
- assert ray.get_gpu_ids(as_str=True) == self.gpu_ids
+ assert ray.get_gpu_ids() == self.gpu_ids
return (ray.worker.global_worker.node.unique_id,
tuple(self.gpu_ids))
diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py
index 9d4c328995c8..08dd168fa505 100644
--- a/python/ray/tests/test_advanced.py
+++ b/python/ray/tests/test_advanced.py
@@ -49,27 +49,6 @@ def sample_big(self):
ray.get(big_id)
-def test_wait_iterables(ray_start_regular):
- @ray.remote
- def f(delay):
- time.sleep(delay)
- return 1
-
- object_refs = (f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5))
- ready_ids, remaining_ids = ray.experimental.wait(object_refs)
- assert len(ready_ids) == 1
- assert len(remaining_ids) == 3
-
- object_refs = np.array(
- [f.remote(1.0),
- f.remote(0.5),
- f.remote(0.5),
- f.remote(0.5)])
- ready_ids, remaining_ids = ray.experimental.wait(object_refs)
- assert len(ready_ids) == 1
- assert len(remaining_ids) == 3
-
-
def test_multiple_waits_and_gets(shutdown_only):
# It is important to use three workers here, so that the three tasks
# launched in this experiment can run at the same time.
diff --git a/python/ray/tests/test_advanced_2.py b/python/ray/tests/test_advanced_2.py
index 19f400534202..da20a974d65c 100644
--- a/python/ray/tests/test_advanced_2.py
+++ b/python/ray/tests/test_advanced_2.py
@@ -633,25 +633,6 @@ def save_gpu_ids_shutdown_only():
del os.environ["CUDA_VISIBLE_DEVICES"]
-@pytest.mark.parametrize("as_str", [False, True])
-def test_gpu_ids_as_str(save_gpu_ids_shutdown_only, as_str):
- allowed_gpu_ids = [4, 5, 6]
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
- str(i) for i in allowed_gpu_ids)
- ray.init()
-
- @ray.remote
- def get_gpu_ids(as_str):
- gpu_ids = ray.get_gpu_ids(as_str)
- for gpu_id in gpu_ids:
- if as_str:
- assert isinstance(gpu_id, str)
- else:
- assert isinstance(gpu_id, int)
-
- ray.get([get_gpu_ids.remote(as_str) for _ in range(10)])
-
-
def test_specific_gpus(save_gpu_ids_shutdown_only):
allowed_gpu_ids = [4, 5, 6]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py
index 1ec1d78218ff..ffb6e0ceab5d 100644
--- a/python/ray/tests/test_advanced_3.py
+++ b/python/ray/tests/test_advanced_3.py
@@ -272,23 +272,6 @@ def f():
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
-def test_specific_job_id():
- dummy_driver_id = ray.JobID.from_int(1)
- ray.init(num_cpus=1, job_id=dummy_driver_id)
-
- # in driver
- assert dummy_driver_id == ray.worker.global_worker.current_job_id
-
- # in worker
- @ray.remote
- def f():
- return ray.worker.global_worker.current_job_id
-
- assert dummy_driver_id == ray.get(f.remote())
-
- ray.shutdown()
-
-
def test_object_ref_properties():
id_bytes = b"00112233445566778899"
object_ref = ray.ObjectRef(id_bytes)
@@ -397,23 +380,6 @@ def unique_name_3():
"'ray stack'")
-def test_socket_dir_not_existing(shutdown_only):
- if sys.platform != "win32":
- random_name = ray.ObjectRef.from_random().hex()
- temp_raylet_socket_dir = os.path.join(ray.utils.get_ray_temp_dir(),
- "tests", random_name)
- temp_raylet_socket_name = os.path.join(temp_raylet_socket_dir,
- "raylet_socket")
- ray.init(num_cpus=2, raylet_socket_name=temp_raylet_socket_name)
-
- @ray.remote
- def foo(x):
- time.sleep(1)
- return 2 * x
-
- ray.get([foo.remote(i) for i in range(2)])
-
-
def test_raylet_is_robust_to_random_messages(ray_start_regular):
node_manager_address = None
node_manager_port = None
@@ -465,13 +431,6 @@ def test_put_pins_object(ray_start_object_store_memory):
assert not ray.worker.global_worker.core_worker.object_exists(
ray.ObjectRef(x_binary))
- # weakref put
- y_id = ray.put(obj, weakref=True)
- for _ in range(10):
- ray.put(np.zeros(10 * 1024 * 1024))
- with pytest.raises(ray.exceptions.UnreconstructableError):
- ray.get(y_id)
-
def test_decorated_function(ray_start_regular):
def function_invocation_decorator(f):
@@ -656,7 +615,8 @@ def test_ray_address_environment_variable(ray_start_cluster):
def test_ray_resources_environment_variable(ray_start_cluster):
address = ray_start_cluster.address
- os.environ["RAY_OVERRIDE_RESOURCES"] = "{\"custom1\":1, \"custom2\":2}"
+ os.environ[
+ "RAY_OVERRIDE_RESOURCES"] = "{\"custom1\":1, \"custom2\":2, \"CPU\":3}"
ray.init(address=address, resources={"custom1": 3, "custom3": 3})
cluster_resources = ray.cluster_resources()
@@ -664,6 +624,7 @@ def test_ray_resources_environment_variable(ray_start_cluster):
assert cluster_resources["custom1"] == 1
assert cluster_resources["custom2"] == 2
assert cluster_resources["custom3"] == 3
+ assert cluster_resources["CPU"] == 3
def test_gpu_info_parsing():
diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py
index c4b564f72622..203e484fc467 100644
--- a/python/ray/tests/test_basic.py
+++ b/python/ray/tests/test_basic.py
@@ -1,11 +1,9 @@
# coding: utf-8
-import io
import logging
import os
import pickle
import sys
import time
-import weakref
import numpy as np
import pytest
@@ -64,12 +62,12 @@ def f(n):
def g():
return ray.get_gpu_ids()
- assert f._remote([0], num_return_vals=0) is None
- id1 = f._remote(args=[1], num_return_vals=1)
+ assert f._remote([0], num_returns=0) is None
+ id1 = f._remote(args=[1], num_returns=1)
assert ray.get(id1) == [0]
- id1, id2 = f._remote(args=[2], num_return_vals=2)
+ id1, id2 = f._remote(args=[2], num_returns=2)
assert ray.get([id1, id2]) == [0, 1]
- id1, id2, id3 = f._remote(args=[3], num_return_vals=3)
+ id1, id2, id3 = f._remote(args=[3], num_returns=3)
assert ray.get([id1, id2, id3]) == [0, 1, 2]
assert ray.get(
g._remote(args=[], num_cpus=1, num_gpus=1,
@@ -107,10 +105,64 @@ def method(self):
ray.get(a2.method._remote())
id1, id2, id3, id4 = a.method._remote(
- args=["test"], kwargs={"b": 2}, num_return_vals=4)
+ args=["test"], kwargs={"b": 2}, num_returns=4)
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
+def test_invalid_arguments(shutdown_only):
+ ray.init(num_cpus=2)
+
+ for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]:
+ with pytest.raises(
+ ValueError,
+ match="The keyword 'num_returns' only accepts 0 or a"
+ " positive integer"):
+
+ @ray.remote(num_returns=opt)
+ def g1():
+ return 1
+
+ for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
+ with pytest.raises(
+ ValueError,
+ match="The keyword 'max_retries' only accepts 0, -1 or a"
+ " positive integer"):
+
+ @ray.remote(max_retries=opt)
+ def g2():
+ return 1
+
+ for opt in [np.random.randint(-100, -1), np.random.uniform(0, 1)]:
+ with pytest.raises(
+ ValueError,
+ match="The keyword 'max_calls' only accepts 0 or a positive"
+ " integer"):
+
+ @ray.remote(max_calls=opt)
+ def g3():
+ return 1
+
+ for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
+ with pytest.raises(
+ ValueError,
+ match="The keyword 'max_restarts' only accepts -1, 0 or a"
+ " positive integer"):
+
+ @ray.remote(max_restarts=opt)
+ class A1:
+ x = 1
+
+ for opt in [np.random.randint(-100, -2), np.random.uniform(0, 1)]:
+ with pytest.raises(
+ ValueError,
+ match="The keyword 'max_task_retries' only accepts -1, 0 or a"
+ " positive integer"):
+
+ @ray.remote(max_task_retries=opt)
+ class A2:
+ x = 1
+
+
def test_many_fractional_resources(shutdown_only):
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
@@ -362,51 +414,6 @@ class ClassA:
ray.put(obj)
-def test_reducer_override_no_reference_cycle(ray_start_shared_local_modes):
- # bpo-39492: reducer_override used to induce a spurious reference cycle
- # inside the Pickler object, that could prevent all serialized objects
- # from being garbage-collected without explicity invoking gc.collect.
-
- # test a dynamic function
- def f():
- return 4669201609102990671853203821578
-
- wr = weakref.ref(f)
-
- bio = io.BytesIO()
- from ray.cloudpickle import CloudPickler, loads, dumps
- p = CloudPickler(bio, protocol=5)
- p.dump(f)
- new_f = loads(bio.getvalue())
- assert new_f() == 4669201609102990671853203821578
-
- del p
- del f
-
- assert wr() is None
-
- # test a dynamic class
- class ShortlivedObject:
- def __del__(self):
- print("Went out of scope!")
-
- obj = ShortlivedObject()
- new_obj = weakref.ref(obj)
-
- dumps(obj)
- del obj
- assert new_obj() is None
-
-
-def test_deserialized_from_buffer_immutable(ray_start_shared_local_modes):
- x = np.full((2, 2), 1.)
- o = ray.put(x)
- y = ray.get(o)
- with pytest.raises(
- ValueError, match="assignment destination is read-only"):
- y[0, 0] = 9.
-
-
def test_passing_arguments_by_value_out_of_the_box(
ray_start_shared_local_modes):
@ray.remote
@@ -456,36 +463,6 @@ def method(self):
ray.put(f)
-def test_custom_serializers(ray_start_shared_local_modes):
- class Foo:
- def __init__(self):
- self.x = 3
-
- def custom_serializer(obj):
- return 3, "string1", type(obj).__name__
-
- def custom_deserializer(serialized_obj):
- return serialized_obj, "string2"
-
- ray.register_custom_serializer(
- Foo, serializer=custom_serializer, deserializer=custom_deserializer)
-
- assert ray.get(ray.put(Foo())) == ((3, "string1", Foo.__name__), "string2")
-
- class Bar:
- def __init__(self):
- self.x = 3
-
- ray.register_custom_serializer(
- Bar, serializer=custom_serializer, deserializer=custom_deserializer)
-
- @ray.remote
- def f():
- return Bar()
-
- assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2")
-
-
def test_keyword_args(ray_start_shared_local_modes):
@ray.remote
def keyword_fct1(a, b="hello"):
diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py
index e3804b54182d..216481b5655c 100644
--- a/python/ray/tests/test_basic_2.py
+++ b/python/ray/tests/test_basic_2.py
@@ -12,7 +12,7 @@
import ray
import ray.cluster_utils
import ray.test_utils
-from ray.exceptions import RayTimeoutError
+from ray.exceptions import GetTimeoutError
logger = logging.getLogger(__name__)
@@ -351,7 +351,7 @@ def test_system_config_when_connecting(ray_start_cluster):
ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8))
# This would not raise an exception if object pinning was enabled.
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
ray.get(obj_ref)
@@ -366,25 +366,6 @@ def test_get_multiple(ray_start_regular_shared):
assert results == indices
-def test_get_multiple_experimental(ray_start_regular_shared):
- object_refs = [ray.put(i) for i in range(10)]
-
- object_refs_tuple = tuple(object_refs)
- assert ray.experimental.get(object_refs_tuple) == list(range(10))
-
- object_refs_nparray = np.array(object_refs)
- assert ray.experimental.get(object_refs_nparray) == list(range(10))
-
-
-def test_get_dict(ray_start_regular_shared):
- d = {str(i): ray.put(i) for i in range(5)}
- for i in range(5, 10):
- d[str(i)] = i
- result = ray.experimental.get(d)
- expected = {str(i): i for i in range(10)}
- assert result == expected
-
-
def test_get_with_timeout(ray_start_regular_shared):
signal = ray.test_utils.SignalActor.remote()
@@ -396,7 +377,7 @@ def test_get_with_timeout(ray_start_regular_shared):
# Check that get() raises a TimeoutError after the timeout if the object
# is not ready yet.
result_id = signal.wait.remote()
- with pytest.raises(RayTimeoutError):
+ with pytest.raises(GetTimeoutError):
ray.get(result_id, timeout=0.1)
# Check that a subsequent get() returns early.
diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py
index 17fc14967929..b4e4aa439bf6 100644
--- a/python/ray/tests/test_cancel.py
+++ b/python/ray/tests/test_cancel.py
@@ -5,18 +5,18 @@
import pytest
import ray
-from ray.exceptions import RayCancellationError, RayTaskError, \
- RayTimeoutError, RayWorkerError, \
- UnreconstructableError
+from ray.exceptions import TaskCancelledError, RayTaskError, \
+ GetTimeoutError, WorkerCrashedError, \
+ ObjectLostError
from ray.test_utils import SignalActor
def valid_exceptions(use_force):
if use_force:
- return (RayTaskError, RayCancellationError, RayWorkerError,
- UnreconstructableError)
+ return (RayTaskError, TaskCancelledError, WorkerCrashedError,
+ ObjectLostError)
else:
- return (RayTaskError, RayCancellationError)
+ return (RayTaskError, TaskCancelledError)
@pytest.mark.parametrize("use_force", [True, False])
@@ -33,7 +33,7 @@ def wait_for(t):
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj1], timeout=.1)[0]) == 0
- ray.cancel(obj1, use_force)
+ ray.cancel(obj1, force=use_force)
for ob in [obj1, obj2, obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
@@ -45,15 +45,15 @@ def wait_for(t):
obj4 = wait_for.remote([obj3])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
- ray.cancel(obj3, use_force)
+ ray.cancel(obj3, force=use_force)
for ob in [obj3, obj4]:
with pytest.raises(valid_exceptions(use_force)):
ray.get(ob)
- with pytest.raises(RayTimeoutError):
+ with pytest.raises(GetTimeoutError):
ray.get(obj1, timeout=.1)
- with pytest.raises(RayTimeoutError):
+ with pytest.raises(GetTimeoutError):
ray.get(obj2, timeout=.1)
signaler2.send.remote()
@@ -74,7 +74,7 @@ def wait_for(t):
deps.append(wait_for.remote([head]))
assert len(ray.wait([head], timeout=.1)[0]) == 0
- ray.cancel(head, use_force)
+ ray.cancel(head, force=use_force)
for d in deps:
with pytest.raises(valid_exceptions(use_force)):
ray.get(d)
@@ -86,7 +86,7 @@ def wait_for(t):
deps2.append(wait_for.remote([head]))
for d in deps2:
- ray.cancel(d, use_force)
+ ray.cancel(d, force=use_force)
for d in deps2:
with pytest.raises(valid_exceptions(use_force)):
@@ -111,11 +111,11 @@ def wait_for(t):
indep = wait_for.remote([signaler.wait.remote()])
assert len(ray.wait([obj3], timeout=.1)[0]) == 0
- ray.cancel(obj3, use_force)
+ ray.cancel(obj3, force=use_force)
with pytest.raises(valid_exceptions(use_force)):
ray.get(obj3)
- ray.cancel(obj1, use_force)
+ ray.cancel(obj1, force=use_force)
for d in [obj1, obj2]:
with pytest.raises(valid_exceptions(use_force)):
@@ -145,12 +145,12 @@ def combine(a, b):
assert len(ray.wait([a, b, a2, combo], timeout=1)[0]) == 0
- ray.cancel(a, use_force)
+ ray.cancel(a, force=use_force)
with pytest.raises(valid_exceptions(use_force)):
- ray.get(a, 10)
+ ray.get(a, timeout=10)
with pytest.raises(valid_exceptions(use_force)):
- ray.get(a2, 10)
+ ray.get(a2, timeout=10)
signaler.send.remote()
@@ -177,10 +177,10 @@ def infinite_sleep(y):
cancelled = set()
for t in tasks:
if random.random() > 0.5:
- ray.cancel(t, use_force)
+ ray.cancel(t, force=use_force)
cancelled.add(t)
- ray.cancel(first, use_force)
+ ray.cancel(first, force=use_force)
cancelled.add(first)
for done in cancelled:
@@ -188,7 +188,7 @@ def infinite_sleep(y):
ray.get(done)
for indx, t in enumerate(tasks):
if sleep_or_no[indx]:
- ray.cancel(t, use_force)
+ ray.cancel(t, force=use_force)
cancelled.add(t)
if t in cancelled:
with pytest.raises(valid_exceptions(use_force)):
@@ -209,7 +209,7 @@ def fast(y):
ids = list()
for _ in range(100):
x = fast.remote("a")
- ray.cancel(x, use_force)
+ ray.cancel(x, force=use_force)
ids.append(x)
@ray.remote
@@ -223,7 +223,7 @@ def wait_for(y):
for idx in range(100, 5100):
if random.random() > 0.95:
- ray.cancel(ids[idx], use_force)
+ ray.cancel(ids[idx], force=use_force)
signaler.send.remote()
for obj_ref in ids:
try:
@@ -249,13 +249,13 @@ def remote_wait(sg):
outer = remote_wait.remote([sig])
inner = ray.get(outer)[0]
- with pytest.raises(RayTimeoutError):
- ray.get(inner, 1)
+ with pytest.raises(GetTimeoutError):
+ ray.get(inner, timeout=1)
- ray.cancel(inner, use_force)
+ ray.cancel(inner, force=use_force)
with pytest.raises(valid_exceptions(use_force)):
- ray.get(inner, 10)
+ ray.get(inner, timeout=10)
if __name__ == "__main__":
diff --git a/python/ray/tests/test_cli_patterns/test_ray_attach.txt b/python/ray/tests/test_cli_patterns/test_ray_attach.txt
index 791a2766161a..cc648a61f64d 100644
--- a/python/ray/tests/test_cli_patterns/test_ray_attach.txt
+++ b/python/ray/tests/test_cli_patterns/test_ray_attach.txt
@@ -1,3 +1,3 @@
-Bootstraping AWS config
+Bootstrapping AWS config
Fetched IP: .+
ubuntu@ip-.+:~\$ exit
diff --git a/python/ray/tests/test_cli_patterns/test_ray_exec.txt b/python/ray/tests/test_cli_patterns/test_ray_exec.txt
index 62179d79c3dd..5496ac18714d 100644
--- a/python/ray/tests/test_cli_patterns/test_ray_exec.txt
+++ b/python/ray/tests/test_cli_patterns/test_ray_exec.txt
@@ -1,3 +1,3 @@
-Bootstraping AWS config
+Bootstrapping AWS config
Fetched IP: .+
This is a test!
diff --git a/python/ray/tests/test_cli_patterns/test_ray_submit.txt b/python/ray/tests/test_cli_patterns/test_ray_submit.txt
index a7e4be5f6d6c..ef94f13c076b 100644
--- a/python/ray/tests/test_cli_patterns/test_ray_submit.txt
+++ b/python/ray/tests/test_cli_patterns/test_ray_submit.txt
@@ -1,5 +1,5 @@
-Bootstraping AWS config
+Bootstrapping AWS config
Fetched IP: .+
-Bootstraping AWS config
+Bootstrapping AWS config
Fetched IP: .+
This is a test!
diff --git a/python/ray/tests/test_cli_patterns/test_ray_up.txt b/python/ray/tests/test_cli_patterns/test_ray_up.txt
index 31d62671fd33..5f41dfd048af 100644
--- a/python/ray/tests/test_cli_patterns/test_ray_up.txt
+++ b/python/ray/tests/test_cli_patterns/test_ray_up.txt
@@ -6,7 +6,7 @@ Cluster configuration valid
Cluster: test-cli
-Bootstraping AWS config
+Bootstrapping AWS config
AWS config
IAM Profile: .+ \[default\]
EC2 Key pair \(head & workers\): .+ \[default\]
diff --git a/python/ray/tests/test_command_runner.py b/python/ray/tests/test_command_runner.py
index c540a6989412..164ee32aa96a 100644
--- a/python/ray/tests/test_command_runner.py
+++ b/python/ray/tests/test_command_runner.py
@@ -14,7 +14,7 @@
def test_environment_variable_encoder_strings():
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
res = _with_environment_variables("echo hello", env_vars)
- expected = """export var1='quote between this " and this';export var2=123;echo hello""" # noqa: E501
+ expected = """export var1='"quote between this \\" and this"';export var2='"123"';echo hello""" # noqa: E501
assert res == expected
@@ -22,7 +22,7 @@ def test_environment_variable_encoder_dict():
env_vars = {"value1": "string1", "value2": {"a": "b", "c": 2}}
res = _with_environment_variables("echo hello", env_vars)
- expected = """export value1=string1;export value2='{a: b,c: 2}';echo hello""" # noqa: E501
+ expected = """export value1='"string1"';export value2='{"a":"b","c":2}';echo hello""" # noqa: E501
assert res == expected
@@ -84,7 +84,7 @@ def test_ssh_command_runner():
"--login",
"-c",
"-i",
- """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'quote between this " and this'"'"';export var2=123;echo helloo'""" # noqa: E501
+ """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'"quote between this \\" and this"'"'"';export var2='"'"'"123"'"'"';echo helloo'""" # noqa: E501
]
# Much easier to debug this loop than the function call.
@@ -122,7 +122,7 @@ def test_docker_command_runner():
# This string is insane because there are an absurd number of embedded
# quotes. While this is a ridiculous string, the escape behavior is
# important and somewhat difficult to get right for environment variables.
- cmd = """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && docker exec -it container /bin/bash -c '"'"'bash --login -c -i '"'"'"'"'"'"'"'"'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'quote between this " and this'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';export var2=123;echo hello'"'"'"'"'"'"'"'"''"'"' '""" # noqa: E501
+ cmd = """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && docker exec -it container /bin/bash -c '"'"'bash --login -c -i '"'"'"'"'"'"'"'"'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"quote between this \\" and this"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';export var2='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"123"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';echo hello'"'"'"'"'"'"'"'"''"'"' '""" # noqa: E501
expected = [
"ssh", "-tt", "-i", "8265.pem", "-o", "StrictHostKeyChecking=no", "-o",
@@ -135,6 +135,8 @@ def test_docker_command_runner():
]
# Much easier to debug this loop than the function call.
for x, y in zip(process_runner.calls[0], expected):
+ print(f"expeted:\t{y}")
+ print(f"actual: \t{x}")
assert x == y
process_runner.assert_has_call("1.2.3.4", exact=expected)
diff --git a/python/ray/tests/test_component_failures_2.py b/python/ray/tests/test_component_failures_2.py
index 4a056ba798de..5bfd159cc3c6 100644
--- a/python/ray/tests/test_component_failures_2.py
+++ b/python/ray/tests/test_component_failures_2.py
@@ -70,7 +70,8 @@ def f(x):
for object_ref in object_refs:
try:
ray.get(object_ref)
- except (ray.exceptions.RayTaskError, ray.exceptions.RayWorkerError):
+ except (ray.exceptions.RayTaskError,
+ ray.exceptions.WorkerCrashedError):
pass
diff --git a/python/ray/tests/test_cross_language.py b/python/ray/tests/test_cross_language.py
index cf655eae7eb8..9ba24a980628 100644
--- a/python/ray/tests/test_cross_language.py
+++ b/python/ray/tests/test_cross_language.py
@@ -6,7 +6,7 @@
def test_cross_language_raise_kwargs(shutdown_only):
- ray.init(load_code_from_local=True, include_java=True)
+ ray.init(_load_code_from_local=True, _include_java=True)
with pytest.raises(Exception, match="kwargs"):
ray.java_function("a", "b").remote(x="arg1")
@@ -16,7 +16,7 @@ def test_cross_language_raise_kwargs(shutdown_only):
def test_cross_language_raise_exception(shutdown_only):
- ray.init(load_code_from_local=True, include_java=True)
+ ray.init(_load_code_from_local=True, _include_java=True)
class PythonObject(object):
pass
diff --git a/python/ray/tests/test_error_ray_not_initialized.py b/python/ray/tests/test_error_ray_not_initialized.py
index fae3eb25c261..dd006b23919f 100644
--- a/python/ray/tests/test_error_ray_not_initialized.py
+++ b/python/ray/tests/test_error_ray_not_initialized.py
@@ -23,7 +23,7 @@ class Foo:
lambda: ray.get_actor("name"),
ray.get_gpu_ids,
ray.get_resource_ids,
- ray.get_webui_url,
+ ray.get_dashboard_url,
ray.jobs,
lambda: ray.kill(None), # Not valid API usage.
ray.nodes,
@@ -36,7 +36,7 @@ def test_exceptions_raised():
for api_method in api_methods:
print(api_method)
with pytest.raises(
- ray.exceptions.RayConnectionError,
+ ray.exceptions.RaySystemError,
match="Ray has not been started yet."):
api_method()
diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py
index a5256175c3db..b209b7fc6cab 100644
--- a/python/ray/tests/test_failure.py
+++ b/python/ray/tests/test_failure.py
@@ -30,7 +30,7 @@ def throw_exception_fct1():
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
- @ray.remote(num_return_vals=3)
+ @ray.remote(num_returns=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
@@ -362,7 +362,7 @@ def test_worker_dying(ray_start_regular, error_pubsub):
def f():
eval("exit()")
- with pytest.raises(ray.exceptions.RayWorkerError):
+ with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(f.remote())
errors = get_error_message(p, 1, ray_constants.WORKER_DIED_PUSH_ERROR)
@@ -901,7 +901,7 @@ def sleep_to_kill_raylet():
thread = threading.Thread(target=sleep_to_kill_raylet)
thread.start()
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
ray.get(object_ref)
thread.join()
@@ -1000,7 +1000,7 @@ def test_fill_object_store_lru_fallback(shutdown_only):
ray.init(
num_cpus=2,
object_store_memory=10**8,
- lru_evict=True,
+ _lru_evict=True,
_system_config=config)
@ray.remote
@@ -1062,7 +1062,7 @@ def large_object():
# Evict the object.
ray.internal.free([obj])
# ray.get throws an exception.
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
ray.get(obj)
@ray.remote
diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py
index 4fbd978cb068..836749049a55 100644
--- a/python/ray/tests/test_gcs_fault_tolerance.py
+++ b/python/ray/tests/test_gcs_fault_tolerance.py
@@ -56,7 +56,7 @@ def test_gcs_server_restart_during_actor_creation(ray_start_regular):
ray.worker._global_node.kill_gcs_server()
ray.worker._global_node.start_gcs_server()
- ready, unready = ray.wait(ids, 100, 240)
+ ready, unready = ray.wait(ids, num_returns=100, timeout=240)
print("Ready objects is {}.".format(ready))
print("Unready objects is {}.".format(unready))
assert len(unready) == 0
diff --git a/python/ray/tests/test_memory_limits.py b/python/ray/tests/test_memory_limits.py
index 3167f1354d03..fd7f487fb41c 100644
--- a/python/ray/tests/test_memory_limits.py
+++ b/python/ray/tests/test_memory_limits.py
@@ -5,7 +5,7 @@
MB = 1024 * 1024
-OBJECT_EVICTED = ray.exceptions.UnreconstructableError
+OBJECT_EVICTED = ray.exceptions.ObjectLostError
OBJECT_TOO_LARGE = ray.exceptions.ObjectStoreFullError
@@ -47,8 +47,9 @@ def testQuotaTooLarge(self):
def testTooLargeAllocation(self):
try:
- ray.init(num_cpus=1, driver_object_store_memory=100 * MB)
- ray.put(np.zeros(50 * MB, dtype=np.uint8), weakref=True)
+ ray.init(num_cpus=1, _driver_object_store_memory=100 * MB)
+ ray.worker.global_worker.put_object(
+ np.zeros(50 * MB, dtype=np.uint8), pin_object=False)
self.assertRaises(
OBJECT_TOO_LARGE,
lambda: ray.put(np.zeros(200 * MB, dtype=np.uint8)))
@@ -61,9 +62,9 @@ def _run(self, driver_quota, a_quota, b_quota):
ray.init(
num_cpus=1,
object_store_memory=300 * MB,
- driver_object_store_memory=driver_quota)
+ _driver_object_store_memory=driver_quota)
obj = np.ones(200 * 1024, dtype=np.uint8)
- z = ray.put(obj, weakref=True)
+ z = ray.worker.global_worker.put_object(obj, pin_object=False)
a = LightActor._remote(object_store_memory=a_quota)
b = GreedyActor._remote(object_store_memory=b_quota)
for _ in range(5):
diff --git a/python/ray/tests/test_memory_scheduling.py b/python/ray/tests/test_memory_scheduling.py
index 166313c34f9e..fa642a4b754a 100644
--- a/python/ray/tests/test_memory_scheduling.py
+++ b/python/ray/tests/test_memory_scheduling.py
@@ -34,7 +34,7 @@ def train_oom(config, reporter):
class TestMemoryScheduling(unittest.TestCase):
def testMemoryRequest(self):
try:
- ray.init(num_cpus=1, memory=200 * MB)
+ ray.init(num_cpus=1, _memory=200 * MB)
# fits first 2
a = Actor.remote()
b = Actor.remote()
diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py
index 6bdaa334fcc7..d23348ce0f95 100644
--- a/python/ray/tests/test_metrics.py
+++ b/python/ray/tests/test_metrics.py
@@ -50,7 +50,7 @@ def try_get_node_stats(num_retry=5, timeout=2):
@ray.remote
def f():
- ray.show_in_webui("test")
+ ray.show_in_dashboard("test")
return os.getpid()
@ray.remote
@@ -59,10 +59,10 @@ def __init__(self):
pass
def f(self):
- ray.show_in_webui("test")
+ ray.show_in_dashboard("test")
return os.getpid()
- # Test show_in_webui for remote functions.
+ # Test show_in_dashboard for remote functions.
worker_pid = ray.get(f.remote())
reply = try_get_node_stats()
target_worker_present = False
@@ -75,7 +75,7 @@ def f(self):
assert stats.webui_display[""] == "" # Empty proto
assert target_worker_present
- # Test show_in_webui for remote actors.
+ # Test show_in_dashboard for remote actors.
a = Actor.remote()
worker_pid = ray.get(a.f.remote())
reply = try_get_node_stats()
diff --git a/python/ray/tests/test_mini.py b/python/ray/tests/test_mini.py
index 55b6237338bc..dae1e11bd38f 100644
--- a/python/ray/tests/test_mini.py
+++ b/python/ray/tests/test_mini.py
@@ -15,7 +15,7 @@ def f_simple():
# Test multiple return values.
- @ray.remote(num_return_vals=3)
+ @ray.remote(num_returns=3)
def f_multiple_returns():
return 1, 2, 3
diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py
index 99f89486502d..1b376243254b 100644
--- a/python/ray/tests/test_multi_node.py
+++ b/python/ray/tests/test_multi_node.py
@@ -448,7 +448,7 @@ def test_calling_start_ray_head(call_ray_stop_only):
["ray start --head --num-cpus=1 " + "--node-ip-address=localhost"],
indirect=True)
def test_using_hostnames(call_ray_start):
- ray.init(node_ip_address="localhost", address="localhost:6379")
+ ray.init(_node_ip_address="localhost", address="localhost:6379")
@ray.remote
def f():
diff --git a/python/ray/tests/test_multinode_failures.py b/python/ray/tests/test_multinode_failures.py
index 025ae80739b4..58f17cc1ab13 100644
--- a/python/ray/tests/test_multinode_failures.py
+++ b/python/ray/tests/test_multinode_failures.py
@@ -70,7 +70,8 @@ def f(x):
for object_ref in object_refs:
try:
ray.get(object_ref)
- except (ray.exceptions.RayTaskError, ray.exceptions.RayWorkerError):
+ except (ray.exceptions.RayTaskError,
+ ray.exceptions.WorkerCrashedError):
pass
diff --git a/python/ray/tests/test_object_spilling.py b/python/ray/tests/test_object_spilling.py
index 405a667263f1..05bd70a93508 100644
--- a/python/ray/tests/test_object_spilling.py
+++ b/python/ray/tests/test_object_spilling.py
@@ -11,7 +11,7 @@ def test_spill_objects_manually(shutdown_only):
# Limit our object store to 75 MiB of memory.
ray.init(
object_store_memory=75 * 1024 * 1024,
- object_spilling_config={
+ _object_spilling_config={
"type": "filesystem",
"params": {
"directory_path": "/tmp"
@@ -58,7 +58,7 @@ def test_spill_objects_manually_from_workers(shutdown_only):
# Limit our object store to 100 MiB of memory.
ray.init(
object_store_memory=100 * 1024 * 1024,
- object_spilling_config={
+ _object_spilling_config={
"type": "filesystem",
"params": {
"directory_path": "/tmp"
@@ -84,7 +84,7 @@ def test_spill_objects_manually_with_workers(shutdown_only):
# Limit our object store to 75 MiB of memory.
ray.init(
object_store_memory=100 * 1024 * 1024,
- object_spilling_config={
+ _object_spilling_config={
"type": "filesystem",
"params": {
"directory_path": "/tmp"
@@ -111,7 +111,7 @@ def _worker(object_refs):
"ray_start_cluster_head", [{
"num_cpus": 0,
"object_store_memory": 75 * 1024 * 1024,
- "object_spilling_config": {
+ "_object_spilling_config": {
"type": "filesystem",
"params": {
"directory_path": "/tmp"
@@ -127,7 +127,7 @@ def test_spill_remote_object(ray_start_cluster_head):
cluster = ray_start_cluster_head
cluster.add_node(
object_store_memory=75 * 1024 * 1024,
- object_spilling_config={
+ _object_spilling_config={
"type": "filesystem",
"params": {
"directory_path": "/tmp"
diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py
index c23420196af2..0b5fdf67f550 100644
--- a/python/ray/tests/test_placement_group.py
+++ b/python/ray/tests/test_placement_group.py
@@ -355,7 +355,7 @@ def f():
# That means this request should fail.
with pytest.raises(ray.exceptions.RayActorError, match="actor died"):
ray.get(a.f.remote(), timeout=3.0)
- with pytest.raises(ray.exceptions.RayWorkerError):
+ with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(task_ref)
@@ -576,7 +576,7 @@ def test_pending_placement_group_wait(ray_start_cluster):
assert len(ready) == 0
table = ray.experimental.placement_group_table(placement_group)
assert table["state"] == "PENDING"
- with pytest.raises(ray.exceptions.RayTimeoutError):
+ with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get(placement_group.ready(), timeout=0.1)
diff --git a/python/ray/tests/test_projects.py b/python/ray/tests/test_projects.py
deleted file mode 100644
index 4e306546c514..000000000000
--- a/python/ray/tests/test_projects.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import jsonschema
-import os
-import pytest
-import subprocess
-import yaml
-from click.testing import CliRunner
-import sys
-from unittest.mock import patch, DEFAULT
-
-from contextlib import contextmanager
-
-from ray.projects.scripts import (session_start, session_commands,
- session_execute)
-from ray.test_utils import check_call_ray
-import ray
-
-TEST_DIR = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "project_files")
-
-
-def load_project_description(project_file):
- path = os.path.join(TEST_DIR, project_file)
- with open(path) as f:
- return yaml.safe_load(f)
-
-
-def test_validation():
- project_dirs = ["docker_project", "requirements_project", "shell_project"]
- for project_dir in project_dirs:
- project_dir = os.path.join(TEST_DIR, project_dir)
- ray.projects.ProjectDefinition(project_dir)
-
- bad_schema_dirs = ["no_project1"]
- for project_dir in bad_schema_dirs:
- project_dir = os.path.join(TEST_DIR, project_dir)
- with pytest.raises(jsonschema.exceptions.ValidationError):
- ray.projects.ProjectDefinition(project_dir)
-
- bad_project_dirs = ["no_project2", "noproject3"]
- for project_dir in bad_project_dirs:
- project_dir = os.path.join(TEST_DIR, project_dir)
- with pytest.raises(ValueError):
- ray.projects.ProjectDefinition(project_dir)
-
-
-def test_project_root():
- path = os.path.join(TEST_DIR, "project1")
- project_definition = ray.projects.ProjectDefinition(path)
- assert os.path.normpath(project_definition.root) == os.path.normpath(path)
-
- path2 = os.path.join(TEST_DIR, "project1", "subdir")
- project_definition = ray.projects.ProjectDefinition(path2)
- assert os.path.normpath(project_definition.root) == os.path.normpath(path)
-
- path3 = ray.utils.get_user_temp_dir() + os.sep
- with pytest.raises(ValueError):
- project_definition = ray.projects.ProjectDefinition(path3)
-
-
-def test_project_validation():
- with _chdir_and_back(os.path.join(TEST_DIR, "project1")):
- check_call_ray(["project", "validate"])
-
-
-def test_project_no_validation():
- with _chdir_and_back(TEST_DIR):
- with pytest.raises(subprocess.CalledProcessError):
- check_call_ray(["project", "validate"])
-
-
-@contextmanager
-def _chdir_and_back(d):
- old_dir = os.getcwd()
- try:
- os.chdir(d)
- yield
- finally:
- os.chdir(old_dir)
-
-
-def run_test_project(project_dir, command, args):
- # Run the CLI commands with patching
- test_dir = os.path.join(TEST_DIR, project_dir)
- with _chdir_and_back(test_dir):
- runner = CliRunner()
- with patch.multiple(
- "ray.projects.scripts",
- create_or_update_cluster=DEFAULT,
- rsync=DEFAULT,
- exec_cluster=DEFAULT,
- ) as mock_calls:
- result = runner.invoke(command, args)
-
- return result, mock_calls, test_dir
-
-
-def test_session_start_default_project():
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "project-pass"), session_start,
- ["default"])
-
- loaded_project = ray.projects.ProjectDefinition(test_dir)
- assert result.exit_code == 0
-
- # Part 1/3: Cluster Launching Call
- create_or_update_cluster_call = mock_calls["create_or_update_cluster"]
- assert create_or_update_cluster_call.call_count == 1
- _, kwargs = create_or_update_cluster_call.call_args
- assert kwargs["config_file"] == loaded_project.cluster_yaml()
-
- # Part 2/3: Rsync Calls
- rsync_call = mock_calls["rsync"]
- # 1 for rsyncing the project directory, 1 for rsyncing the
- # requirements.txt.
- assert rsync_call.call_count == 2
- _, kwargs = rsync_call.call_args
- assert kwargs["source"] == loaded_project.config["environment"][
- "requirements"]
-
- # Part 3/3: Exec Calls
- exec_cluster_call = mock_calls["exec_cluster"]
- commands_executed = []
- for _, kwargs in exec_cluster_call.call_args_list:
- commands_executed.append(kwargs["cmd"].replace(
- "cd {}; ".format(loaded_project.working_directory()), ""))
-
- expected_commands = loaded_project.config["environment"]["shell"]
- expected_commands += [
- command["command"] for command in loaded_project.config["commands"]
- ]
-
- if "requirements" in loaded_project.config["environment"]:
- assert any("pip install -r" for cmd in commands_executed)
- # pop the `pip install` off commands executed
- commands_executed = [
- cmd for cmd in commands_executed if "pip install -r" not in cmd
- ]
-
- assert expected_commands == commands_executed
-
-
-def test_session_execute_default_project():
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "project-pass"), session_execute,
- ["default"])
-
- loaded_project = ray.projects.ProjectDefinition(test_dir)
- assert result.exit_code == 0
-
- assert mock_calls["rsync"].call_count == 0
- assert mock_calls["create_or_update_cluster"].call_count == 0
-
- exec_cluster_call = mock_calls["exec_cluster"]
- commands_executed = []
- for _, kwargs in exec_cluster_call.call_args_list:
- commands_executed.append(kwargs["cmd"].replace(
- "cd {}; ".format(loaded_project.working_directory()), ""))
-
- expected_commands = [
- command["command"] for command in loaded_project.config["commands"]
- ]
-
- assert expected_commands == commands_executed
-
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "project-pass"), session_execute,
- ["--shell", "uptime"])
- assert result.exit_code == 0
-
-
-def test_session_start_docker_fail():
- result, _, _ = run_test_project(
- os.path.join("session-tests", "with-docker-fail"), session_start, [])
-
- assert result.exit_code == 1
- assert ("Docker support in session is currently "
- "not implemented") in result.output
-
-
-def test_session_invalid_config_errored():
- result, _, _ = run_test_project(
- os.path.join("session-tests", "invalid-config-fail"), session_start,
- [])
-
- assert result.exit_code == 1
- assert "validation failed" in result.output
- # check that we are displaying actional error message
- assert "ray project validate" in result.output
-
-
-def test_session_create_command():
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "commands-test"), session_start,
- ["first", "--a", "1", "--b", "2"])
-
- # Verify the project can be loaded.
- ray.projects.ProjectDefinition(test_dir)
- assert result.exit_code == 0
-
- exec_cluster_call = mock_calls["exec_cluster"]
- found_command = False
- for _, kwargs in exec_cluster_call.call_args_list:
- if "Starting ray job with 1 and 2" in kwargs["cmd"]:
- found_command = True
- assert found_command
-
-
-def test_session_create_multiple():
- for args in [{"a": "*", "b": "2"}, {"a": "1", "b": "*"}]:
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "commands-test"), session_start,
- ["first", "--a", args["a"], "--b", args["b"]])
-
- loaded_project = ray.projects.ProjectDefinition(test_dir)
- assert result.exit_code == 0
-
- exec_cluster_call = mock_calls["exec_cluster"]
- commands_executed = []
- for _, kwargs in exec_cluster_call.call_args_list:
- commands_executed.append(kwargs["cmd"].replace(
- "cd {}; ".format(loaded_project.working_directory()), ""))
- assert commands_executed.count("echo \"Setting up\"") == 2
- if args["a"] == "*":
- assert commands_executed.count(
- "echo \"Starting ray job with 1 and 2\"") == 1
- assert commands_executed.count(
- "echo \"Starting ray job with 2 and 2\"") == 1
- if args["b"] == "*":
- assert commands_executed.count(
- "echo \"Starting ray job with 1 and 1\"") == 1
- assert commands_executed.count(
- "echo \"Starting ray job with 1 and 2\"") == 1
-
- # Using multiple wildcards shouldn't work
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "commands-test"), session_start,
- ["first", "--a", "*", "--b", "*"])
- assert result.exit_code == 1
-
-
-def test_session_commands():
- result, mock_calls, test_dir = run_test_project(
- os.path.join("session-tests", "commands-test"), session_commands, [])
-
- assert "This is the first parameter" in result.output
- assert "This is the second parameter" in result.output
-
- assert 'Command "first"' in result.output
- assert 'Command "second"' in result.output
-
-
-if __name__ == "__main__":
- # Make subprocess happy in bazel.
- os.environ["LC_ALL"] = "en_US.UTF-8"
- os.environ["LANG"] = "en_US.UTF-8"
- sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/tests/test_queue.py b/python/ray/tests/test_queue.py
index 72bbb821c1b0..cfb254884dbc 100644
--- a/python/ray/tests/test_queue.py
+++ b/python/ray/tests/test_queue.py
@@ -1,7 +1,7 @@
import pytest
import ray
-from ray.exceptions import RayTimeoutError
+from ray.exceptions import GetTimeoutError
from ray.experimental.queue import Queue, Empty, Full
@@ -80,7 +80,7 @@ def test_async_get(ray_start_regular):
with pytest.raises(Empty):
q.get_nowait()
- with pytest.raises(RayTimeoutError):
+ with pytest.raises(GetTimeoutError):
ray.get(future, timeout=0.1) # task not canceled on timeout.
q.put(1)
@@ -95,7 +95,7 @@ def test_async_put(ray_start_regular):
with pytest.raises(Full):
q.put_nowait(3)
- with pytest.raises(RayTimeoutError):
+ with pytest.raises(GetTimeoutError):
ray.get(future, timeout=0.1) # task not canceled on timeout.
assert q.get() == 1
diff --git a/python/ray/tests/test_ray_init.py b/python/ray/tests/test_ray_init.py
index 13d16f303ad9..2b9b66f061da 100644
--- a/python/ray/tests/test_ray_init.py
+++ b/python/ray/tests/test_ray_init.py
@@ -23,7 +23,7 @@ def test_redis_password(self, password, shutdown_only):
def f():
return 1
- info = ray.init(redis_password=password)
+ info = ray.init(_redis_password=password)
address = info["redis_address"]
redis_ip, redis_port = address.split(":")
@@ -58,20 +58,6 @@ def f():
object_ref = f.remote()
ray.get(object_ref)
- def test_redis_port(self, shutdown_only):
- @ray.remote
- def f():
- return 1
-
- info = ray.init(redis_port=1234, redis_password="testpassword")
- address = info["redis_address"]
- redis_ip, redis_port = address.split(":")
- assert redis_port == "1234"
-
- redis_client = redis.StrictRedis(
- host=redis_ip, port=redis_port, password="testpassword")
- assert redis_client.ping()
-
if __name__ == "__main__":
import pytest
diff --git a/python/ray/tests/test_reconstruction.py b/python/ray/tests/test_reconstruction.py
index ab4a9b4ef015..f6599483dc94 100644
--- a/python/ray/tests/test_reconstruction.py
+++ b/python/ray/tests/test_reconstruction.py
@@ -110,7 +110,7 @@ def dependent_task(x):
else:
with pytest.raises(ray.exceptions.RayTaskError) as e:
ray.get(dependent_task.remote(obj))
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
@@ -159,7 +159,7 @@ def dependent_task(x):
else:
with pytest.raises(ray.exceptions.RayTaskError) as e:
ray.get(dependent_task.remote(obj))
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
@@ -215,7 +215,7 @@ def dependent_task(x):
# been evicted.
try:
ray.get(result)
- except ray.exceptions.UnreconstructableError:
+ except ray.exceptions.ObjectLostError:
pass
@@ -284,7 +284,7 @@ def dependent_task(x):
else:
with pytest.raises(ray.exceptions.RayTaskError) as e:
ray.get(dependent_task.remote(obj))
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
# Make sure the actor handle is still usable.
@@ -356,8 +356,7 @@ def probe():
return True
except ray.exceptions.RayActorError:
return False
- except (ray.exceptions.RayTaskError,
- ray.exceptions.UnreconstructableError):
+ except (ray.exceptions.RayTaskError, ray.exceptions.ObjectLostError):
return True
wait_for_condition(probe)
@@ -369,7 +368,7 @@ def probe():
x = a.dependent_task.remote(obj)
print(x)
ray.get(x)
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
@@ -429,7 +428,7 @@ def dependent_task(x):
dependent_task.options(resources={
"node1": 1
}).remote(obj))
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
@@ -480,7 +479,7 @@ def dependent_task(x):
else:
with pytest.raises(ray.exceptions.RayTaskError) as e:
ray.get(dependent_task.remote(obj))
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
raise e.as_instanceof_cause()
diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py
index 787d93144643..ba6a4b067934 100644
--- a/python/ray/tests/test_reference_counting.py
+++ b/python/ray/tests/test_reference_counting.py
@@ -354,7 +354,7 @@ def pending(ref, dep):
try:
ray.get(obj_ref)
assert not failure
- except ray.exceptions.RayWorkerError:
+ except ray.exceptions.WorkerCrashedError:
assert failure
# Reference should be gone, check that array gets evicted.
@@ -403,7 +403,7 @@ def recursive(ref, signal, max_depth, depth=0):
assert ray.get(tail_oid) is None
assert not failure
# TODO(edoakes): this should raise WorkerError.
- except ray.exceptions.UnreconstructableError:
+ except ray.exceptions.ObjectLostError:
assert failure
# Reference should be gone, check that array gets evicted.
@@ -501,8 +501,7 @@ def launch_pending_task(ref, signal):
try:
ray.get(child_return_id)
assert not failure
- except (ray.exceptions.RayWorkerError,
- ray.exceptions.UnreconstructableError):
+ except (ray.exceptions.WorkerCrashedError, ray.exceptions.ObjectLostError):
assert failure
del child_return_id
diff --git a/python/ray/tests/test_reference_counting_2.py b/python/ray/tests/test_reference_counting_2.py
index 092314357cc5..17c578408246 100644
--- a/python/ray/tests/test_reference_counting_2.py
+++ b/python/ray/tests/test_reference_counting_2.py
@@ -91,7 +91,7 @@ def recursive(ref, signal, max_depth, depth=0):
ray.get(tail_oid)
assert not failure
# TODO(edoakes): this should raise WorkerError.
- except ray.exceptions.UnreconstructableError:
+ except ray.exceptions.ObjectLostError:
assert failure
# Reference should be gone, check that array gets evicted.
@@ -130,7 +130,7 @@ def exit():
# Check that the owner dying unpins the object. This should execute on
# the same worker because there is only one started and the other tasks
# have finished.
- with pytest.raises(ray.exceptions.RayWorkerError):
+ with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(exit.remote())
else:
# Check that removing the inner ID unpins the object.
@@ -173,7 +173,7 @@ def pending(ref, signal):
# Should succeed because inner_oid is pinned if no failure.
ray.get(pending_oid)
assert not failure
- except ray.exceptions.RayWorkerError:
+ except ray.exceptions.WorkerCrashedError:
assert failure
def ref_not_exists():
@@ -232,7 +232,7 @@ def recursive(ref, signal, max_depth, depth=0):
_fill_object_store_and_get(inner_oid)
assert not failure
# TODO(edoakes): this should raise WorkerError.
- except ray.exceptions.UnreconstructableError:
+ except ray.exceptions.ObjectLostError:
assert failure
inner_oid_bytes = inner_oid.binary()
@@ -311,7 +311,7 @@ def receive_ref(self, ref):
def resolve_ref(self):
assert self.ref is not None
if failure:
- with pytest.raises(ray.exceptions.UnreconstructableError):
+ with pytest.raises(ray.exceptions.ObjectLostError):
ray.get(self.ref)
else:
ray.get(self.ref)
diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py
index fbc731654b77..5d913c9a5e09 100644
--- a/python/ray/tests/test_resource_demand_scheduler.py
+++ b/python/ray/tests/test_resource_demand_scheduler.py
@@ -366,10 +366,10 @@ def testResourcePassing(self):
# These checks are done separately because we have no guarantees on the
# order the dict is serialized in.
runner.assert_has_call("172.0.0.0", "RAY_OVERRIDE_RESOURCES=")
- runner.assert_has_call("172.0.0.0", "CPU: 2")
+ runner.assert_has_call("172.0.0.0", "\"CPU\":2")
runner.assert_has_call("172.0.0.1", "RAY_OVERRIDE_RESOURCES=")
- runner.assert_has_call("172.0.0.1", "CPU: 32")
- runner.assert_has_call("172.0.0.1", "GPU: 8")
+ runner.assert_has_call("172.0.0.1", "\"CPU\":32")
+ runner.assert_has_call("172.0.0.1", "\"GPU\":8")
def testScaleUpLoadMetrics(self):
config = MULTI_WORKER_CLUSTER.copy()
diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py
index 10a9ccdaa74e..8004e694d49a 100644
--- a/python/ray/tests/test_serialization.py
+++ b/python/ray/tests/test_serialization.py
@@ -5,6 +5,7 @@
import re
import string
import sys
+import weakref
import numpy as np
import pytest
@@ -315,30 +316,6 @@ def test_numpy_serialization(ray_start_regular):
assert len(buffers) == 1
-def test_numpy_subclass_serialization(ray_start_regular):
- class MyNumpyConstant(np.ndarray):
- def __init__(self, value):
- super().__init__()
- self.constant = value
-
- def __str__(self):
- print(self.constant)
-
- constant = MyNumpyConstant(123)
-
- def explode(x):
- raise RuntimeError("Expected error.")
-
- ray.register_custom_serializer(
- type(constant), serializer=explode, deserializer=explode)
-
- try:
- ray.put(constant)
- assert False, "Should never get here!"
- except (RuntimeError, IndexError):
- print("Correct behavior, proof that customer serializer was used.")
-
-
def test_numpy_subclass_serialization_pickle(ray_start_regular):
class MyNumpyConstant(np.ndarray):
def __init__(self, value):
@@ -446,7 +423,7 @@ def h2(x):
assert ray.get(h2.remote(10)).value == 10
# Test registering multiple classes with the same name.
- @ray.remote(num_return_vals=3)
+ @ray.remote(num_returns=3)
def j():
class Class0:
def method0(self):
@@ -521,6 +498,51 @@ def method2(self):
assert not hasattr(c2, "method1")
+def test_deserialized_from_buffer_immutable(ray_start_shared_local_modes):
+ x = np.full((2, 2), 1.)
+ o = ray.put(x)
+ y = ray.get(o)
+ with pytest.raises(
+ ValueError, match="assignment destination is read-only"):
+ y[0, 0] = 9.
+
+
+def test_reducer_override_no_reference_cycle(ray_start_shared_local_modes):
+ # bpo-39492: reducer_override used to induce a spurious reference cycle
+ # inside the Pickler object, that could prevent all serialized objects
+ # from being garbage-collected without explicity invoking gc.collect.
+
+ # test a dynamic function
+ def f():
+ return 4669201609102990671853203821578
+
+ wr = weakref.ref(f)
+
+ bio = io.BytesIO()
+ from ray.cloudpickle import CloudPickler, loads, dumps
+ p = CloudPickler(bio, protocol=5)
+ p.dump(f)
+ new_f = loads(bio.getvalue())
+ assert new_f() == 4669201609102990671853203821578
+
+ del p
+ del f
+
+ assert wr() is None
+
+ # test a dynamic class
+ class ShortlivedObject:
+ def __del__(self):
+ print("Went out of scope!")
+
+ obj = ShortlivedObject()
+ new_obj = weakref.ref(obj)
+
+ dumps(obj)
+ del obj
+ assert new_obj() is None
+
+
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/tests/test_stress_failure.py b/python/ray/tests/test_stress_failure.py
index 88b7fa04bead..01d39afa8065 100644
--- a/python/ray/tests/test_stress_failure.py
+++ b/python/ray/tests/test_stress_failure.py
@@ -1,5 +1,4 @@
import numpy as np
-import os
import pytest
import sys
import time
@@ -38,9 +37,7 @@ def ray_start_reconstruction(request):
cluster.shutdown()
-@pytest.mark.skipif(
- os.environ.get("RAY_USE_NEW_GCS") == "on",
- reason="Failing with new GCS API on Linux.")
+@pytest.mark.skip(reason="Failing with new GCS API on Linux.")
def test_simple(ray_start_reconstruction):
plasma_store_memory, num_nodes, cluster = ray_start_reconstruction
# Define the size of one task's return argument so that the combined
@@ -87,9 +84,7 @@ def sorted_random_indexes(total, output_num):
return random_indexes
-@pytest.mark.skipif(
- os.environ.get("RAY_USE_NEW_GCS") == "on",
- reason="Failing with new GCS API on Linux.")
+@pytest.mark.skip(reason="Failing with new GCS API on Linux.")
def test_recursive(ray_start_reconstruction):
plasma_store_memory, num_nodes, cluster = ray_start_reconstruction
# Define the size of one task's return argument so that the combined
@@ -146,9 +141,6 @@ def single_dependency(i, arg):
@pytest.mark.skip(reason="This test often hangs or fails in CI.")
-@pytest.mark.skipif(
- os.environ.get("RAY_USE_NEW_GCS") == "on",
- reason="Failing with new GCS API on Linux.")
def test_multiple_recursive(ray_start_reconstruction):
plasma_store_memory, _, cluster = ray_start_reconstruction
# Define the size of one task's return argument so that the combined
@@ -219,9 +211,6 @@ def wait_for_errors(p, error_check):
@pytest.mark.skip("This test does not work yet.")
-@pytest.mark.skipif(
- os.environ.get("RAY_USE_NEW_GCS") == "on",
- reason="Failing with new GCS API on Linux.")
def test_nondeterministic_task(ray_start_reconstruction, error_pubsub):
p = error_pubsub
plasma_store_memory, num_nodes, cluster = ray_start_reconstruction
@@ -288,9 +277,7 @@ def error_check(errors):
assert cluster.remaining_processes_alive()
-@pytest.mark.skipif(
- os.environ.get("RAY_USE_NEW_GCS") == "on",
- reason="Failing with new GCS API on Linux.")
+@pytest.mark.skip(reason="Failing with new GCS API on Linux.")
@pytest.mark.parametrize(
"ray_start_object_store_memory", [10**9], indirect=True)
def test_driver_put_errors(ray_start_object_store_memory, error_pubsub):
@@ -335,10 +322,9 @@ def error_check(errors):
return len(errors) > 1
errors = wait_for_errors(p, error_check)
- assert all(
- error.type == ray_constants.PUT_RECONSTRUCTION_PUSH_ERROR
- or "ray.exceptions.UnreconstructableError" in error.error_messages
- for error in errors)
+ assert all(error.type == ray_constants.PUT_RECONSTRUCTION_PUSH_ERROR
+ or "ray.exceptions.ObjectLostError" in error.error_messages
+ for error in errors)
# NOTE(swang): This test tries to launch 1000 workers and breaks.
diff --git a/python/ray/tests/test_stress_sharded.py b/python/ray/tests/test_stress_sharded.py
index 814348e9b142..8e49c9719d49 100644
--- a/python/ray/tests/test_stress_sharded.py
+++ b/python/ray/tests/test_stress_sharded.py
@@ -1,25 +1,20 @@
import numpy as np
-import os
import pytest
import ray
-@pytest.fixture(params=[1, 4])
+@pytest.fixture(params=[1])
def ray_start_sharded(request):
- num_redis_shards = request.param
-
- if os.environ.get("RAY_USE_NEW_GCS") == "on":
- num_redis_shards = 1
- # For now, RAY_USE_NEW_GCS supports 1 shard, and credis supports
- # 1-node chain for that shard only.
+ # TODO(ekl) enable this again once GCS supports sharding.
+ # num_redis_shards = request.param
# Start the Ray processes.
ray.init(
object_store_memory=int(0.5 * 10**9),
num_cpus=10,
- num_redis_shards=num_redis_shards,
- redis_max_memory=10**7)
+ # _num_redis_shards=num_redis_shards,
+ _redis_max_memory=10**7)
yield None
diff --git a/python/ray/tests/test_tempfile.py b/python/ray/tests/test_tempfile.py
index 0a800049034b..8875499d91e5 100644
--- a/python/ray/tests/test_tempfile.py
+++ b/python/ray/tests/test_tempfile.py
@@ -5,7 +5,6 @@
import pytest
import ray
-from ray.cluster_utils import Cluster
from ray.test_utils import check_call_ray
@@ -24,43 +23,11 @@ def unix_socket_delete(unix_socket):
return os.remove(unix_socket) if unix else None
-def test_conn_cluster():
- # plasma_store_socket_name
- with pytest.raises(Exception) as exc_info:
- ray.init(
- address="127.0.0.1:6379",
- plasma_store_socket_name=os.path.join(
- ray.utils.get_user_temp_dir(), "this_should_fail"))
- assert exc_info.value.args[0] == (
- "When connecting to an existing cluster, "
- "plasma_store_socket_name must not be provided.")
-
- # raylet_socket_name
- with pytest.raises(Exception) as exc_info:
- ray.init(
- address="127.0.0.1:6379",
- raylet_socket_name=os.path.join(ray.utils.get_user_temp_dir(),
- "this_should_fail"))
- assert exc_info.value.args[0] == (
- "When connecting to an existing cluster, "
- "raylet_socket_name must not be provided.")
-
- # temp_dir
- with pytest.raises(Exception) as exc_info:
- ray.init(
- address="127.0.0.1:6379",
- temp_dir=os.path.join(ray.utils.get_user_temp_dir(),
- "this_should_fail"))
- assert exc_info.value.args[0] == (
- "When connecting to an existing cluster, "
- "temp_dir must not be provided.")
-
-
def test_tempdir(shutdown_only):
shutil.rmtree(ray.utils.get_ray_temp_dir(), ignore_errors=True)
ray.init(
- temp_dir=os.path.join(ray.utils.get_user_temp_dir(),
- "i_am_a_temp_dir"))
+ _temp_dir=os.path.join(ray.utils.get_user_temp_dir(),
+ "i_am_a_temp_dir"))
assert os.path.exists(
os.path.join(ray.utils.get_user_temp_dir(),
"i_am_a_temp_dir")), "Specified temp dir not found."
@@ -94,47 +61,7 @@ def test_tempdir_long_path():
maxlen = 104 if sys.platform.startswith("darwin") else 108
temp_dir = os.path.join(ray.utils.get_user_temp_dir(), "z" * maxlen)
with pytest.raises(OSError):
- ray.init(temp_dir=temp_dir) # path should be too long
-
-
-def test_raylet_socket_name(shutdown_only):
- sock1 = unix_socket_create_path("i_am_a_temp_socket_1")
- ray.init(raylet_socket_name=sock1)
- unix_socket_verify(sock1)
- ray.shutdown()
- try:
- unix_socket_delete(sock1)
- except OSError:
- pass # It could have been removed by Ray.
- cluster = Cluster(True)
- sock2 = unix_socket_create_path("i_am_a_temp_socket_2")
- cluster.add_node(raylet_socket_name=sock2)
- unix_socket_verify(sock2)
- cluster.shutdown()
- try:
- unix_socket_delete(sock2)
- except OSError:
- pass # It could have been removed by Ray.
-
-
-def test_temp_plasma_store_socket(shutdown_only):
- sock1 = unix_socket_create_path("i_am_a_temp_socket_1")
- ray.init(plasma_store_socket_name=sock1)
- unix_socket_verify(sock1)
- ray.shutdown()
- try:
- unix_socket_delete(sock1)
- except OSError:
- pass # It could have been removed by Ray.
- cluster = Cluster(True)
- sock2 = unix_socket_create_path("i_am_a_temp_socket_2")
- cluster.add_node(plasma_store_socket_name=sock2)
- unix_socket_verify(sock2)
- cluster.shutdown()
- try:
- unix_socket_delete(sock2)
- except OSError:
- pass # It could have been removed by Ray.
+ ray.init(_temp_dir=temp_dir) # path should be too long
def test_raylet_tempfiles(shutdown_only):
diff --git a/python/ray/tests/test_unreconstructable_errors.py b/python/ray/tests/test_unreconstructable_errors.py
index c61fbdd586a9..a86d50da41c6 100644
--- a/python/ray/tests/test_unreconstructable_errors.py
+++ b/python/ray/tests/test_unreconstructable_errors.py
@@ -4,22 +4,23 @@
import ray
-class TestUnreconstructableErrors(unittest.TestCase):
+class TestObjectLostErrors(unittest.TestCase):
def setUp(self):
ray.init(
num_cpus=1,
object_store_memory=150 * 1024 * 1024,
- redis_max_memory=10000000)
+ _redis_max_memory=10000000)
def tearDown(self):
ray.shutdown()
def testDriverPutEvictedCannotReconstruct(self):
- x_id = ray.put(np.zeros(1 * 1024 * 1024), weakref=True)
+ x_id = ray.worker.global_worker.put_object(
+ np.zeros(1 * 1024 * 1024), pin_object=False)
ray.get(x_id)
for _ in range(20):
ray.put(np.zeros(10 * 1024 * 1024))
- self.assertRaises(ray.exceptions.UnreconstructableError,
+ self.assertRaises(ray.exceptions.ObjectLostError,
lambda: ray.get(x_id))
diff --git a/python/ray/tests/test_webui.py b/python/ray/tests/test_webui.py
index b02cc9572626..011993af3e77 100644
--- a/python/ray/tests/test_webui.py
+++ b/python/ray/tests/test_webui.py
@@ -13,7 +13,7 @@
def test_get_webui(shutdown_only):
addresses = ray.init(include_dashboard=True, num_cpus=1)
webui_url = addresses["webui_url"]
- assert ray.get_webui_url() == webui_url
+ assert ray.get_dashboard_url() == webui_url
assert re.match(r"^(localhost|\d+\.\d+\.\d+\.\d+):\d+$", webui_url)
diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py
index 43a65c430a9d..6312fe8986dc 100644
--- a/python/ray/tune/__init__.py
+++ b/python/ray/tune/__init__.py
@@ -12,15 +12,15 @@
save_checkpoint, checkpoint_dir)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
-from ray.tune.sample import (function, sample_from, uniform, choice, randint,
- randn, loguniform)
+from ray.tune.sample import (sample_from, uniform, choice, randint, randn,
+ loguniform)
__all__ = [
"Trainable", "DurableTrainable", "TuneError", "grid_search",
"register_env", "register_trainable", "run", "run_experiments", "Stopper",
- "EarlyStopping", "Experiment", "function", "sample_from", "track",
- "uniform", "choice", "randint", "randn", "loguniform",
- "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
- "ProgressReporter", "report", "get_trial_dir", "get_trial_name",
- "get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir"
+ "EarlyStopping", "Experiment", "sample_from", "track", "uniform", "choice",
+ "randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis",
+ "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report",
+ "get_trial_dir", "get_trial_name", "get_trial_id", "make_checkpoint_dir",
+ "save_checkpoint", "checkpoint_dir"
]
diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py
index 66410d5428ce..a312c1afeb12 100644
--- a/python/ray/tune/experiment.py
+++ b/python/ray/tune/experiment.py
@@ -8,7 +8,7 @@
from ray.tune.function_runner import detect_checkpoint_function
from ray.tune.registry import register_trainable, get_trainable_cls
from ray.tune.result import DEFAULT_RESULTS_DIR
-from ray.tune.sample import sample_from
+from ray.tune.sample import Domain
from ray.tune.stopper import FunctionStopper, Stopper
logger = logging.getLogger(__name__)
@@ -235,7 +235,7 @@ def register_if_needed(cls, run_object):
if isinstance(run_object, str):
return run_object
- elif isinstance(run_object, sample_from):
+ elif isinstance(run_object, Domain):
logger.warning("Not registering trainable. Resolving as variant.")
return run_object
elif isinstance(run_object, type) or callable(run_object):
diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py
index 309c7a8e3b6a..3018ea261d86 100644
--- a/python/ray/tune/ray_trial_executor.py
+++ b/python/ray/tune/ray_trial_executor.py
@@ -8,7 +8,7 @@
from contextlib import contextmanager
import ray
-from ray.exceptions import RayTimeoutError
+from ray.exceptions import GetTimeoutError
from ray import ray_constants
from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
@@ -396,8 +396,8 @@ def reset_trial(self, trial, new_config, new_experiment_tag):
try:
reset_val = ray.get(
trainable.reset.remote(new_config, trial.logdir),
- DEFAULT_GET_TIMEOUT)
- except RayTimeoutError:
+ timeout=DEFAULT_GET_TIMEOUT)
+ except GetTimeoutError:
logger.exception("Trial %s: reset timed out.", trial)
return False
return reset_val
@@ -465,7 +465,7 @@ def fetch_result(self, trial):
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
with warn_if_slow("fetch_result"):
- result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
+ result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
# For local mode
if isinstance(result, _LocalWrapper):
@@ -734,7 +734,7 @@ def export_trial_if_needed(self, trial):
with self._change_working_directory(trial):
return ray.get(
trial.runner.export_model.remote(trial.export_formats),
- DEFAULT_GET_TIMEOUT)
+ timeout=DEFAULT_GET_TIMEOUT)
return {}
def has_gpus(self):
diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py
index 4cac8572b417..f0e96cab3b58 100644
--- a/python/ray/tune/sample.py
+++ b/python/ray/tune/sample.py
@@ -1,97 +1,441 @@
import logging
import random
+from copy import copy
+from numbers import Number
+from typing import Any, Callable, Dict, Iterator, List, Optional, \
+ Sequence, \
+ Union
import numpy as np
+from scipy import stats
logger = logging.getLogger(__name__)
-class sample_from:
+class Domain:
+ sampler = None
+ default_sampler_cls = None
+
+ def set_sampler(self, sampler, allow_override=False):
+ if self.sampler and not allow_override:
+ raise ValueError("You can only choose one sampler for parameter "
+ "domains. Existing sampler for parameter {}: "
+ "{}. Tried to add {}".format(
+ self.__class__.__name__, self.sampler,
+ sampler))
+ self.sampler = sampler
+
+ def get_sampler(self):
+ sampler = self.sampler
+ if not sampler:
+ sampler = self.default_sampler_cls()
+ return sampler
+
+ def sample(self, spec=None, size=1):
+ sampler = self.get_sampler()
+ return sampler.sample(self, spec=spec, size=size)
+
+ def is_grid(self):
+ return isinstance(self.sampler, Grid)
+
+ def is_function(self):
+ return False
+
+
+class Sampler:
+ def sample(self,
+ domain: Domain,
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ raise NotImplementedError
+
+
+class BaseSampler(Sampler):
+ pass
+
+
+class Uniform(Sampler):
+ pass
+
+
+class LogUniform(Sampler):
+ pass
+
+
+class Normal(Sampler):
+ pass
+
+
+class Grid(Sampler):
+ """Dummy sampler used for grid search"""
+
+ def sample(self,
+ domain: Domain,
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ return RuntimeError("Do not call `sample()` on grid.")
+
+
+class Float(Domain):
+ class _Uniform(Uniform):
+ def sample(self,
+ domain: "Float",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ assert domain.min > float("-inf"), \
+ "Uniform needs a minimum bound"
+ assert 0 < domain.max < float("inf"), \
+ "Uniform needs a maximum bound"
+ items = np.random.uniform(domain.min, domain.max, size=size)
+ if len(items) == 1:
+ return items[0]
+ return list(items)
+
+ class _LogUniform(LogUniform):
+ def __init__(self, base: int = 10):
+ self.base = base
+ assert self.base > 0, "Base has to be strictly greater than 0"
+
+ def sample(self,
+ domain: "Float",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ assert domain.min > 0, \
+ "LogUniform needs a minimum bound greater than 0"
+ assert 0 < domain.max < float("inf"), \
+ "LogUniform needs a maximum bound greater than 0"
+ logmin = np.log(domain.min) / np.log(self.base)
+ logmax = np.log(domain.max) / np.log(self.base)
+
+ items = self.base**(np.random.uniform(logmin, logmax, size=size))
+ if len(items) == 1:
+ return items[0]
+ return list(items)
+
+ class _Normal(Normal):
+ def __init__(self, mean: float = 0., sd: float = 0.):
+ self.mean = mean
+ self.sd = sd
+
+ assert self.sd > 0, "SD has to be strictly greater than 0"
+
+ def sample(self,
+ domain: "Float",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ # Use a truncated normal to avoid oversampling border values
+ dist = stats.truncnorm(
+ (domain.min - self.mean) / self.sd,
+ (domain.max - self.mean) / self.sd,
+ loc=self.mean,
+ scale=self.sd)
+ items = dist.rvs(size)
+ if len(items) == 1:
+ return items[0]
+ return list(items)
+
+ default_sampler_cls = _Uniform
+
+ def __init__(self, min=float("-inf"), max=float("inf")):
+ self.min = min
+ self.max = max
+
+ def uniform(self):
+ if not self.min > float("-inf"):
+ raise ValueError(
+ "Uniform requires a minimum bound. Make sure to set the "
+ "`min` parameter of `Float()`.")
+ if not self.max < float("inf"):
+ raise ValueError(
+ "Uniform requires a maximum bound. Make sure to set the "
+ "`max` parameter of `Float()`.")
+ new = copy(self)
+ new.set_sampler(self._Uniform())
+ return new
+
+ def loguniform(self, base: int = 10):
+ if not self.min > 0:
+ raise ValueError(
+ "LogUniform requires a minimum bound greater than 0. "
+ "Make sure to set the `min` parameter of `Float()` correctly.")
+ if not 0 < self.max < float("inf"):
+ raise ValueError(
+ "LogUniform requires a minimum bound greater than 0. "
+ "Make sure to set the `max` parameter of `Float()` correctly.")
+ new = copy(self)
+ new.set_sampler(self._LogUniform(base))
+ return new
+
+ def normal(self, mean=0., sd=1.):
+ new = copy(self)
+ new.set_sampler(self._Normal(mean, sd))
+ return new
+
+ def quantized(self, q: Number):
+ new = copy(self)
+ new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
+ return new
+
+
+class Integer(Domain):
+ class _Uniform(Uniform):
+ def sample(self,
+ domain: "Integer",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ items = np.random.randint(domain.min, domain.max, size=size)
+ if len(items) == 1:
+ return items[0]
+ return list(items)
+
+ default_sampler_cls = _Uniform
+
+ def __init__(self, min, max):
+ self.min = min
+ self.max = max
+
+ def quantized(self, q: Number):
+ new = copy(self)
+ new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
+ return new
+
+ def uniform(self):
+ new = copy(self)
+ new.set_sampler(self._Uniform())
+ return new
+
+
+class Categorical(Domain):
+ class _Uniform(Uniform):
+ def sample(self,
+ domain: "Categorical",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ choices = []
+ for i in range(size):
+ choices.append(random.choice(domain.categories))
+ if len(choices) == 1:
+ return choices[0]
+ return choices
+
+ default_sampler_cls = _Uniform
+
+ def __init__(self, categories: Sequence):
+ self.categories = list(categories)
+
+ def uniform(self):
+ new = copy(self)
+ new.set_sampler(self._Uniform())
+ return new
+
+ def grid(self):
+ new = copy(self)
+ new.set_sampler(Grid())
+ return new
+
+ def __len__(self):
+ return len(self.categories)
+
+ def __getitem__(self, item):
+ return self.categories[item]
+
+
+class Iterative(Domain):
+ class _NextSampler(BaseSampler):
+ def sample(self,
+ domain: "Iterative",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ items = []
+ for i in range(size):
+ items.append(next(domain.iterator))
+ if len(items) == 1:
+ return items[0]
+ return items
+
+ default_sampler_cls = _NextSampler
+
+ def __init__(self, iterator: Iterator):
+ self.iterator = iterator
+
+
+class Function(Domain):
+ class _CallSampler(BaseSampler):
+ def sample(self,
+ domain: "Function",
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ items = []
+ for i in range(size):
+ this_spec = spec[i] if isinstance(spec, list) else spec
+ items.append(domain.func(this_spec))
+ if len(items) == 1:
+ return items[0]
+ return items
+
+ default_sampler_cls = _CallSampler
+
+ def __init__(self, func: Callable):
+ self.func = func
+
+ def is_function(self):
+ return True
+
+
+class Quantized(Sampler):
+ def __init__(self, sampler: Sampler, q: Number):
+ self.sampler = sampler
+ self.q = q
+
+ assert self.sampler, "Quantized() expects a sampler instance"
+
+ def get_sampler(self):
+ return self.sampler
+
+ def sample(self,
+ domain: Domain,
+ spec: Optional[Union[List[Dict], Dict]] = None,
+ size: int = 1):
+ values = self.sampler.sample(domain, spec, size)
+ quantized = np.round(np.divide(values, self.q)) * self.q
+ if len(quantized) == 1:
+ return quantized[0]
+ return list(quantized)
+
+
+def sample_from(func: Callable[[Dict], Any]):
"""Specify that tune should sample configuration values from this function.
Arguments:
func: An callable function to draw a sample from.
"""
+ return Function(func)
- def __init__(self, func):
- self.func = func
- def __str__(self):
- return "tune.sample_from({})".format(str(self.func))
+def uniform(min: float, max: float):
+ """Sample a float value uniformly between ``min`` and ``max``.
- def __repr__(self):
- return "tune.sample_from({})".format(repr(self.func))
+ Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
+ ``np.random.uniform(1, 10))``
+ """
+ return Float(min, max).uniform()
-def function(func):
- logger.warning(
- "DeprecationWarning: wrapping {} with tune.function() is no "
- "longer needed".format(func))
- return func
+def quniform(min: float, max: float, q: float):
+ """Sample a quantized float value uniformly between ``min`` and ``max``.
-class uniform(sample_from):
- """Wraps tune.sample_from around ``np.random.uniform``.
+ Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
+ ``np.random.uniform(1, 10))``
- ``tune.uniform(1, 10)`` is equivalent to
- ``tune.sample_from(lambda _: np.random.uniform(1, 10))``
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
"""
+ return Float(min, max).uniform().quantized(q)
- def __init__(self, *args, **kwargs):
- super().__init__(lambda _: np.random.uniform(*args, **kwargs))
+def loguniform(min: float, max: float, base: float = 10):
+ """Sugar for sampling in different orders of magnitude.
+
+ Args:
+ min (float): Lower boundary of the output interval (e.g. 1e-4)
+ max (float): Upper boundary of the output interval (e.g. 1e-2)
+ base (int): Base of the log. Defaults to 10.
-class loguniform(sample_from):
+ """
+ return Float(min, max).loguniform(base)
+
+
+def qloguniform(min, max, q, base=10):
"""Sugar for sampling in different orders of magnitude.
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
+
Args:
- min_bound (float): Lower boundary of the output interval (1e-4)
- max_bound (float): Upper boundary of the output interval (1e-2)
- base (float): Base of the log. Defaults to 10.
+ min (float): Lower boundary of the output interval (e.g. 1e-4)
+ max (float): Upper boundary of the output interval (e.g. 1e-2)
+ q (float): Quantization number. The result will be rounded to an
+ integer increment of this value.
+ base (int): Base of the log. Defaults to 10.
+
"""
+ return Float(min, max).loguniform(base).quantized(q)
- def __init__(self, min_bound, max_bound, base=10):
- logmin = np.log(min_bound) / np.log(base)
- logmax = np.log(max_bound) / np.log(base)
- def apply_log(_):
- return base**(np.random.uniform(logmin, logmax))
+def choice(categories: List):
+ """Sample a categorical value.
- super().__init__(apply_log)
+ Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
+ ``random.choice([1, 2])``
+
+ """
+ return Categorical(categories).uniform()
-class choice(sample_from):
- """Wraps tune.sample_from around ``random.choice``.
+def randint(min: int, max: int):
+ """Sample an integer value uniformly between ``min`` and ``max``.
- ``tune.choice([1, 2])`` is equivalent to
- ``tune.sample_from(lambda _: random.choice([1, 2]))``
+ ``min`` is inclusive, ``max`` is exclusive.
+
+ Sampling from ``tune.randint(10)`` is equivalent to sampling from
+ ``np.random.randint(10)``
"""
+ return Integer(min, max).uniform()
+
- def __init__(self, *args, **kwargs):
- super().__init__(lambda _: random.choice(*args, **kwargs))
+def qrandint(min: int, max: int, q: int = 1):
+ """Sample an integer value uniformly between ``min`` and ``max``.
+ ``min`` is inclusive, ``max`` is exclusive.
-class randint(sample_from):
- """Wraps tune.sample_from around ``np.random.randint``.
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
- ``tune.randint(10)`` is equivalent to
- ``tune.sample_from(lambda _: np.random.randint(10))``
+ Sampling from ``tune.randint(10)`` is equivalent to sampling from
+ ``np.random.randint(10)``
"""
+ return Integer(min, max).uniform().quantized(q)
- def __init__(self, *args, **kwargs):
- super().__init__(lambda _: np.random.randint(*args, **kwargs))
+def randn(mean: float = 0.,
+ sd: float = 1.,
+ min: float = float("-inf"),
+ max: float = float("inf")):
+ """Sample a float value normally with ``mean`` and ``sd``.
-class randn(sample_from):
- """Wraps tune.sample_from around ``np.random.randn``.
+ Will truncate the normal distribution at ``min`` and ``max`` to avoid
+ oversampling the border regions.
- ``tune.randn(10)`` is equivalent to
- ``tune.sample_from(lambda _: np.random.randn(10))``
+ Args:
+ mean (float): Mean of the normal distribution. Defaults to 0.
+ sd (float): SD of the normal distribution. Defaults to 1.
+ min (float): Minimum bound. Defaults to -inf.
+ max (float): Maximum bound. Defaults to inf.
"""
+ return Float(min, max).normal(mean, sd)
+
+
+def qrandn(q: float,
+ mean: float = 0.,
+ sd: float = 1.,
+ min: float = float("-inf"),
+ max: float = float("inf")):
+ """Sample a float value normally with ``mean`` and ``sd``.
+
+ The value will be quantized, i.e. rounded to an integer increment of ``q``.
- def __init__(self, *args, **kwargs):
- super().__init__(lambda _: np.random.randn(*args, **kwargs))
+ Will truncate the normal distribution at ``min`` and ``max`` to avoid
+ oversampling the border regions.
+
+ Args:
+ q (float): Quantization number. The result will be rounded to an
+ integer increment of this value.
+ mean (float): Mean of the normal distribution. Defaults to 0.
+ sd (float): SD of the normal distribution. Defaults to 1.
+ min (float): Minimum bound. Defaults to -inf.
+ max (float): Maximum bound. Defaults to inf.
+
+ """
+ return Float(min, max).normal(mean, sd).quantized(q)
diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py
index 588c247bde46..4b84b5d63b4c 100644
--- a/python/ray/tune/schedulers/pbt.py
+++ b/python/ray/tune/schedulers/pbt.py
@@ -9,7 +9,7 @@
from ray.tune.error import TuneError
from ray.tune.result import TRAINING_ITERATION
from ray.tune.logger import _SafeFallbackEncoder
-from ray.tune.sample import sample_from
+from ray.tune.sample import Domain, sample_from
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.trial import Trial, Checkpoint
@@ -27,10 +27,12 @@ def __init__(self, trial):
self.last_score = None
self.last_checkpoint = None
self.last_perturbation_time = 0
+ self.last_train_time = 0 # Used for synchronous mode.
+ self.last_result = None # Used for synchronous mode.
def __repr__(self):
return str((self.last_score, self.last_checkpoint,
- self.last_perturbation_time))
+ self.last_train_time, self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
@@ -66,8 +68,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
distribution.index(config[key]) + 1)]
else:
if random.random() < resample_probability:
- new_config[key] = distribution.func(None) if isinstance(
- distribution, sample_from) else distribution()
+ new_config[key] = distribution.sample(None) if isinstance(
+ distribution, Domain) else distribution()
elif random.random() > 0.5:
new_config[key] = config[key] * 1.2
else:
@@ -94,8 +96,8 @@ def fill_config(config, attr, search_space):
"""Add attr to config by sampling from search_space."""
if callable(search_space):
config[attr] = search_space()
- elif isinstance(search_space, sample_from):
- config[attr] = search_space.func(None)
+ elif isinstance(search_space, Domain):
+ config[attr] = search_space.sample(None)
elif isinstance(search_space, list):
config[attr] = random.choice(search_space)
elif isinstance(search_space, dict):
@@ -174,6 +176,13 @@ class PopulationBasedTraining(FIFOScheduler):
require_attrs (bool): Whether to require time_attr and metric to appear
in result for every iteration. If True, error will be raised
if these values are not present in trial result.
+ synch (bool): If False, will use asynchronous implementation of
+ PBT. Trial perturbations occur every perturbation_interval for each
+ trial independently. If True, will use synchronous implementation
+ of PBT. Perturbations will occur only after all trials are
+ synced at the same time_attr every perturbation_interval.
+ Defaults to False. See Appendix A.1 here
+ https://arxiv.org/pdf/1711.09846.pdf.
.. code-block:: python
@@ -215,10 +224,11 @@ def __init__(self,
resample_probability=0.25,
custom_explore_fn=None,
log_config=True,
- require_attrs=True):
+ require_attrs=True,
+ synch=False):
for value in hyperparam_mutations.values():
if not (isinstance(value,
- (list, dict, sample_from)) or callable(value)):
+ (list, dict, Domain)) or callable(value)):
raise TypeError("`hyperparam_mutation` values must be either "
"a List, Dict, a tune search space object, or "
"callable.")
@@ -234,10 +244,15 @@ def __init__(self,
"`custom_explore_fn` to use PBT.")
if quantile_fraction > 0.5 or quantile_fraction < 0:
- raise TuneError(
+ raise ValueError(
"You must set `quantile_fraction` to a value between 0 and"
"0.5. Current value: '{}'".format(quantile_fraction))
+ if perturbation_interval <= 0:
+ raise ValueError(
+ "perturbation_interval must be a positive number greater "
+ "than 0. Current value: '{}'".format(perturbation_interval))
+
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
@@ -263,6 +278,8 @@ def __init__(self,
self._custom_explore_fn = custom_explore_fn
self._log_config = log_config
self._require_attrs = require_attrs
+ self._synch = synch
+ self._next_perturbation_sync = self._perturbation_interval
# Metrics
self._num_checkpoints = 0
@@ -320,34 +337,90 @@ def on_trial_result(self, trial_runner, trial, result):
time = result[self._time_attr]
state = self._trial_state[trial]
+ # Continue training if perturbation interval has not been reached yet.
if time - state.last_perturbation_time < self._perturbation_interval:
return TrialScheduler.CONTINUE # avoid checkpoint overhead
+ # This trial has reached its perturbation interval
score = self._metric_op * result[self._metric]
state.last_score = score
- state.last_perturbation_time = time
- lower_quantile, upper_quantile = self._quantiles()
+ state.last_train_time = time
+ state.last_result = result
+
+ if not self._synch:
+ state.last_perturbation_time = time
+ lower_quantile, upper_quantile = self._quantiles()
+ self._perturb_trial(trial, trial_runner, upper_quantile,
+ lower_quantile)
+ for trial in trial_runner.get_trials():
+ if trial.status in [Trial.PENDING, Trial.PAUSED]:
+ return TrialScheduler.PAUSE # yield time to other trials
+ return TrialScheduler.CONTINUE
+ else:
+ # Synchronous mode.
+ if any(self._trial_state[t].last_train_time <
+ self._next_perturbation_sync and t != trial
+ for t in trial_runner.get_trials()):
+ logger.debug("Pausing trial {}".format(trial))
+ else:
+ # All trials are synced at the same timestep.
+ lower_quantile, upper_quantile = self._quantiles()
+ all_trials = trial_runner.get_trials()
+ not_in_quantile = []
+ for t in all_trials:
+ if t not in lower_quantile and t not in upper_quantile:
+ not_in_quantile.append(t)
+ # Move upper quantile trials to beginning and lower quantile
+ # to end. This ensures that checkpointing of strong trials
+ # occurs before exploiting of weaker ones.
+ all_trials = upper_quantile + not_in_quantile + lower_quantile
+ for t in all_trials:
+ logger.debug("Perturbing Trial {}".format(t))
+ self._trial_state[t].last_perturbation_time = time
+ self._perturb_trial(t, trial_runner, upper_quantile,
+ lower_quantile)
+
+ all_train_times = [
+ self._trial_state[trial].last_train_time
+ for trial in trial_runner.get_trials()
+ ]
+ max_last_train_time = max(all_train_times)
+ self._next_perturbation_sync = max(
+ self._next_perturbation_sync + self._perturbation_interval,
+ max_last_train_time)
+ # In sync mode we should pause all trials once result comes in.
+ # Once a perturbation step happens for all trials, they should
+ # still all be paused.
+ # choose_trial_to_run will then pick the next trial to run out of
+ # the paused trials.
+ return TrialScheduler.PAUSE
+
+ def _perturb_trial(self, trial, trial_runner, upper_quantile,
+ lower_quantile):
+ """Checkpoint if in upper quantile, exploits if in lower."""
+ state = self._trial_state[trial]
if trial in upper_quantile:
# The trial last result is only updated after the scheduler
# callback. So, we override with the current result.
- state.last_checkpoint = trial_runner.trial_executor.save(
- trial, Checkpoint.MEMORY, result=result)
+ logger.debug("Trial {} is in upper quantile".format(trial))
+ logger.debug("Checkpointing {}".format(trial))
+ if trial.status == Trial.PAUSED:
+ # Paused trial will always have an in-memory checkpoint.
+ state.last_checkpoint = trial.checkpoint
+ else:
+ state.last_checkpoint = trial_runner.trial_executor.save(
+ trial, Checkpoint.MEMORY, result=state.last_result)
self._num_checkpoints += 1
else:
state.last_checkpoint = None # not a top trial
if trial in lower_quantile:
+ logger.debug("Trial {} is in lower quantile".format(trial))
trial_to_clone = random.choice(upper_quantile)
assert trial is not trial_to_clone
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
- for trial in trial_runner.get_trials():
- if trial.status in [Trial.PENDING, Trial.PAUSED]:
- return TrialScheduler.PAUSE # yield time to other trials
-
- return TrialScheduler.CONTINUE
-
def _log_config_on_step(self, trial_state, new_state, trial,
trial_to_clone, new_config):
"""Logs transition during exploit/exploit step.
@@ -417,26 +490,40 @@ def _exploit(self, trial_executor, trial, trial_to_clone):
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
self._hyperparam_mutations)
- reset_successful = trial_executor.reset_trial(trial, new_config,
- new_tag)
-
- # TODO(ujvl): Refactor Scheduler abstraction to abstract
- # mechanism for trial restart away. We block on restore
- # and suppress train on start as a stop-gap fix to
- # https://github.com/ray-project/ray/issues/7258.
- if reset_successful:
- trial_executor.restore(
- trial, new_state.last_checkpoint, block=True)
- else:
- trial_executor.stop_trial(trial, stop_logger=False)
+ if trial.status == Trial.PAUSED:
+ # If trial is paused we update it with a new checkpoint.
+ # When the trial is started again, the new checkpoint is used.
+ if not self._synch:
+ raise TuneError("Trials should be paused here only if in "
+ "synchronous mode. If you encounter this error"
+ " please raise an issue on Ray Github.")
trial.config = new_config
trial.experiment_tag = new_tag
- trial_executor.start_trial(
- trial, new_state.last_checkpoint, train=False)
+ trial.on_checkpoint(new_state.last_checkpoint)
+ else:
+ # If trial is running, we first try to reset it.
+ # If that is unsuccessful, then we have to stop it and start it
+ # again with a new checkpoint.
+ reset_successful = trial_executor.reset_trial(
+ trial, new_config, new_tag)
+ # TODO(ujvl): Refactor Scheduler abstraction to abstract
+ # mechanism for trial restart away. We block on restore
+ # and suppress train on start as a stop-gap fix to
+ # https://github.com/ray-project/ray/issues/7258.
+ if reset_successful:
+ trial_executor.restore(
+ trial, new_state.last_checkpoint, block=True)
+ else:
+ trial_executor.stop_trial(trial, stop_logger=False)
+ trial.config = new_config
+ trial.experiment_tag = new_tag
+ trial_executor.start_trial(
+ trial, new_state.last_checkpoint, train=False)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
trial_state.last_perturbation_time = new_state.last_perturbation_time
+ trial_state.last_train_time = new_state.last_train_time
def _quantiles(self):
"""Returns trials in the lower and upper `quantile` of the population.
@@ -445,6 +532,9 @@ def _quantiles(self):
"""
trials = []
for trial, state in self._trial_state.items():
+ logger.debug("Trial {}, state {}".format(trial, state))
+ if trial.is_finished():
+ logger.debug("Trial {} is finished".format(trial))
if state.last_score is not None and not trial.is_finished():
trials.append(trial)
trials.sort(key=lambda t: self._trial_state[t].last_score)
@@ -469,9 +559,13 @@ def choose_trial_to_run(self, trial_runner):
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
trial_runner.has_resources(trial.resources):
- candidates.append(trial)
+ if not self._synch:
+ candidates.append(trial)
+ elif self._trial_state[trial].last_train_time < \
+ self._next_perturbation_sync:
+ candidates.append(trial)
candidates.sort(
- key=lambda trial: self._trial_state[trial].last_perturbation_time)
+ key=lambda trial: self._trial_state[trial].last_train_time)
return candidates[0] if candidates else None
def reset_stats(self):
diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py
index 9e58c4cb1f11..cdc86790a730 100644
--- a/python/ray/tune/suggest/ax.py
+++ b/python/ray/tune/suggest/ax.py
@@ -1,3 +1,11 @@
+from typing import Dict
+
+from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
+ Quantized, Uniform
+from ray.tune.suggest.variant_generator import parse_spec_vars
+from ray.tune.utils import flatten_dict
+from ray.tune.utils.util import unflatten_dict
+
try:
import ax
except ImportError:
@@ -94,7 +102,7 @@ def suggest(self, trial_id):
return None
parameters, trial_index = self._ax.get_next_trial()
self._live_trial_mapping[trial_id] = trial_index
- return parameters
+ return unflatten_dict(parameters)
def on_trial_complete(self, trial_id, result=None, error=False):
"""Notification for the completion of trial.
@@ -117,3 +125,79 @@ def _process_result(self, trial_id, result):
metric_dict.update({on: (result[on], 0.0) for on in outcome_names})
self._ax.complete_trial(
trial_index=ax_trial_index, raw_data=metric_dict)
+
+ @staticmethod
+ def convert_search_space(spec: Dict):
+ spec = flatten_dict(spec)
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
+
+ if grid_vars:
+ raise ValueError(
+ "Grid search parameters cannot be automatically converted "
+ "to an Ax search space.")
+
+ values = [{
+ "name": "/".join(path),
+ "type": "fixed",
+ "value": val
+ } for path, val in resolved_vars]
+
+ def resolve_value(par, domain):
+ sampler = domain.get_sampler()
+ if isinstance(sampler, Quantized):
+ logger.warning("Ax search does not support quantization. "
+ "Dropped quantization.")
+ sampler = sampler.sampler
+
+ if isinstance(domain, Float):
+ if isinstance(sampler, LogUniform):
+ return {
+ "name": par,
+ "type": "range",
+ "bounds": [domain.min, domain.max],
+ "value_type": "float",
+ "log_scale": True
+ }
+ elif isinstance(sampler, Uniform):
+ return {
+ "name": par,
+ "type": "range",
+ "bounds": [domain.min, domain.max],
+ "value_type": "float",
+ "log_scale": False
+ }
+ elif isinstance(domain, Integer):
+ if isinstance(sampler, LogUniform):
+ return {
+ "name": par,
+ "type": "range",
+ "bounds": [domain.min, domain.max],
+ "value_type": "int",
+ "log_scale": True
+ }
+ elif isinstance(sampler, Uniform):
+ return {
+ "name": par,
+ "type": "range",
+ "bounds": [domain.min, domain.max],
+ "value_type": "int",
+ "log_scale": False
+ }
+ elif isinstance(domain, Categorical):
+ if isinstance(sampler, Uniform):
+ return {
+ "name": par,
+ "type": "choice",
+ "values": domain.categories
+ }
+
+ raise ValueError("Ax search does not support parameters of type "
+ "`{}` with samplers of type `{}`".format(
+ type(domain).__name__,
+ type(domain.sampler).__name__))
+
+ # Parameter name is e.g. "a/b/c" for nested dicts
+ for path, domain in domain_vars:
+ par = "/".join(path)
+ values.append(resolve_value(par, domain))
+ return values
diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py
index c647966fe1ae..99cd0e1ce293 100644
--- a/python/ray/tune/suggest/bayesopt.py
+++ b/python/ray/tune/suggest/bayesopt.py
@@ -1,8 +1,13 @@
-import copy
from collections import defaultdict
import logging
import pickle
import json
+from typing import Dict
+
+from ray.tune.sample import Float, Quantized
+from ray.tune.suggest.variant_generator import parse_spec_vars
+from ray.tune.utils.util import unflatten_dict
+
try: # Python 3 only -- needed for lint test.
import bayes_opt as byo
except ImportError:
@@ -214,7 +219,7 @@ def suggest(self, trial_id):
self._live_trial_mapping[trial_id] = config
# Return a deep copy of the mapping
- return copy.deepcopy(config)
+ return unflatten_dict(config)
def register_analysis(self, analysis):
"""Integrate the given analysis into the gaussian process.
@@ -283,3 +288,44 @@ def restore(self, checkpoint_path):
(self.optimizer, self._buffered_trial_results,
self._total_random_search_trials,
self._config_counter) = pickle.load(f)
+
+ @staticmethod
+ def convert_search_space(spec: Dict):
+ spec = flatten_dict(spec)
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
+
+ if grid_vars:
+ raise ValueError(
+ "Grid search parameters cannot be automatically converted "
+ "to a BayesOpt search space.")
+
+ if resolved_vars:
+ raise ValueError(
+ "BayesOpt does not support fixed parameters. Please find a "
+ "different way to pass constants to your training function.")
+
+ def resolve_value(domain):
+ sampler = domain.get_sampler()
+ if isinstance(sampler, Quantized):
+ logger.warning(
+ "BayesOpt search does not support quantization. "
+ "Dropped quantization.")
+ sampler = sampler.get_sampler()
+
+ if isinstance(domain, Float):
+ if domain.sampler is not None:
+ logger.warning(
+ "BayesOpt does not support specific sampling methods. "
+ "The {} sampler will be dropped.".format(sampler))
+ return (domain.min, domain.max)
+
+ raise ValueError("BayesOpt does not support parameters of type "
+ "`{}`".format(type(domain).__name__))
+
+ # Parameter name is e.g. "a/b/c" for nested dicts
+ bounds = {
+ "/".join(path): resolve_value(domain)
+ for path, domain in domain_vars
+ }
+
+ return bounds
diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py
index 5bdb1cbbdd08..83e76762f71e 100644
--- a/python/ray/tune/suggest/hyperopt.py
+++ b/python/ray/tune/suggest/hyperopt.py
@@ -1,8 +1,16 @@
+from typing import Dict
+
import numpy as np
import copy
import logging
from functools import partial
import pickle
+
+from ray.tune.sample import Categorical, Float, Integer, LogUniform, Normal, \
+ Quantized, \
+ Uniform
+from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
+
try:
hyperopt_logger = logging.getLogger("hyperopt")
hyperopt_logger.setLevel(logging.WARNING)
@@ -235,3 +243,76 @@ def restore(self, checkpoint_path):
self.rstate.set_state(trials_object[1])
else:
self.set_state(trials_object)
+
+ @staticmethod
+ def convert_search_space(spec: Dict):
+ spec = copy.deepcopy(spec)
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
+
+ if not domain_vars and not grid_vars:
+ return []
+
+ if grid_vars:
+ raise ValueError(
+ "Grid search parameters cannot be automatically converted "
+ "to a HyperOpt search space.")
+
+ def resolve_value(par, domain):
+ quantize = None
+
+ sampler = domain.get_sampler()
+ if isinstance(sampler, Quantized):
+ quantize = sampler.q
+ sampler = sampler.sampler
+
+ if isinstance(domain, Float):
+ if isinstance(sampler, LogUniform):
+ if quantize:
+ return hpo.hp.qloguniform(par, domain.min, domain.max,
+ quantize)
+ return hpo.hp.loguniform(par, np.log(domain.min),
+ np.log(domain.max))
+ elif isinstance(sampler, Uniform):
+ if quantize:
+ return hpo.hp.quniform(par, domain.min, domain.max,
+ quantize)
+ return hpo.hp.uniform(par, domain.min, domain.max)
+ elif isinstance(sampler, Normal):
+ if quantize:
+ return hpo.hp.qnormal(par, sampler.mean, sampler.sd,
+ quantize)
+ return hpo.hp.normal(par, sampler.mean, sampler.sd)
+
+ elif isinstance(domain, Integer):
+ if isinstance(sampler, Uniform):
+ if quantize:
+ logger.warning(
+ "HyperOpt does not support quantization for "
+ "integer values. Dropped quantization.")
+ if domain.min != 0:
+ logger.warning(
+ "HyperOpt only allows integer sampling with "
+ "lower bound 0. Dropped the lower bound {}".format(
+ domain.min))
+ if domain.max < 1:
+ raise ValueError(
+ "HyperOpt does not support integer sampling "
+ "of values lower than 0. Set your maximum range "
+ "to something above 0 (currently {})".format(
+ domain.max))
+ return hpo.hp.randint(par, domain.max)
+ elif isinstance(domain, Categorical):
+ if isinstance(sampler, Uniform):
+ return hpo.hp.choice(par, domain.categories)
+
+ raise ValueError("HyperOpt does not support parameters of type "
+ "`{}` with samplers of type `{}`".format(
+ type(domain).__name__,
+ type(domain.sampler).__name__))
+
+ for path, domain in domain_vars:
+ par = "/".join(path)
+ value = resolve_value(par, domain)
+ assign_value(spec, path, value)
+
+ return spec
diff --git a/python/ray/tune/suggest/optuna.py b/python/ray/tune/suggest/optuna.py
index 51d3efc14ee7..5c22c3cf21fe 100644
--- a/python/ray/tune/suggest/optuna.py
+++ b/python/ray/tune/suggest/optuna.py
@@ -1,7 +1,13 @@
+import copy
import logging
import pickle
+from typing import Dict
from ray.tune.result import TRAINING_ITERATION
+from ray.tune.sample import Categorical, Float, Integer, LogUniform, \
+ Quantized, Uniform
+from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
+from ray.tune.utils import flatten_dict
try:
import optuna as ot
@@ -53,6 +59,10 @@ class OptunaSearch(Searcher):
minimizing or maximizing the metric attribute.
sampler (optuna.samplers.BaseSampler): Optuna sampler used to
draw hyperparameter configurations. Defaults to ``TPESampler``.
+ config (dict): Base config dict that gets overwritten by the Optuna
+ sampling and is returned to each Tune trial. This could e.g.
+ contain static variables or configurations that should be passed
+ to each trial.
Example:
@@ -80,6 +90,7 @@ def __init__(
metric="episode_reward_mean",
mode="max",
sampler=None,
+ config=None,
):
assert ot is not None, (
"Optuna must be installed! Run `pip install optuna`.")
@@ -91,6 +102,8 @@ def __init__(
self._space = space
+ self._config = config or {}
+
self._study_name = "optuna" # Fixed study name for in-memory storage
self._sampler = sampler or ot.samplers.TPESampler()
assert isinstance(self._sampler, ot.samplers.BaseSampler), \
@@ -116,10 +129,11 @@ def suggest(self, trial_id):
self._ot_trials[trial_id] = ot.trial.Trial(self._ot_study,
ot_trial_id)
ot_trial = self._ot_trials[trial_id]
- params = {}
+ params = copy.copy(self._config)
for (fn, args, kwargs) in self._space:
param_name = args[0] if len(args) > 0 else kwargs["name"]
- params[param_name] = getattr(ot_trial, fn)(*args, **kwargs)
+ value = getattr(ot_trial, fn)(*args, **kwargs) # Call Optuna trial
+ assign_value(params, param_name.split("/"), value)
return params
def on_trial_result(self, trial_id, result):
@@ -147,3 +161,66 @@ def restore(self, checkpoint_path):
save_object = pickle.load(inputFile)
self._storage, self._pruner, self._sampler, \
self._ot_trials, self._ot_study = save_object
+
+ @staticmethod
+ def convert_search_space(spec: Dict):
+ spec = flatten_dict(spec)
+ resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
+
+ if not domain_vars and not grid_vars:
+ return []
+
+ if grid_vars:
+ raise ValueError(
+ "Grid search parameters cannot be automatically converted "
+ "to an Optuna search space.")
+
+ def resolve_value(par, domain):
+ quantize = None
+
+ sampler = domain.get_sampler()
+ if isinstance(sampler, Quantized):
+ quantize = sampler.q
+ sampler = sampler.sampler
+
+ if isinstance(domain, Float):
+ if isinstance(sampler, LogUniform):
+ if quantize:
+ logger.warning(
+ "Optuna does not support both quantization and "
+ "sampling from LogUniform. Dropped quantization.")
+ return param.suggest_loguniform(par, domain.min,
+ domain.max)
+ elif isinstance(sampler, Uniform):
+ if quantize:
+ return param.suggest_discrete_uniform(
+ par, domain.min, domain.max, quantize)
+ return param.suggest_uniform(par, domain.min, domain.max)
+ elif isinstance(domain, Integer):
+ if isinstance(sampler, LogUniform):
+ if quantize:
+ logger.warning(
+ "Optuna does not support both quantization and "
+ "sampling from LogUniform. Dropped quantization.")
+ return param.suggest_int(
+ par, domain.min, domain.max, log=True)
+ elif isinstance(sampler, Uniform):
+ return param.suggest_int(
+ par, domain.min, domain.max, step=quantize or 1)
+ elif isinstance(domain, Categorical):
+ if isinstance(sampler, Uniform):
+ return param.suggest_categorical(par, domain.categories)
+
+ raise ValueError(
+ "Optuna search does not support parameters of type "
+ "`{}` with samplers of type `{}`".format(
+ type(domain).__name__,
+ type(domain.sampler).__name__))
+
+ # Parameter name is e.g. "a/b/c" for nested dicts
+ values = [
+ resolve_value("/".join(path), domain)
+ for path, domain in domain_vars
+ ]
+
+ return values
diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py
index e772ffb4933c..181296faec03 100644
--- a/python/ray/tune/suggest/variant_generator.py
+++ b/python/ray/tune/suggest/variant_generator.py
@@ -4,7 +4,7 @@
import random
from ray.tune import TuneError
-from ray.tune.sample import sample_from
+from ray.tune.sample import Categorical, Domain, Function
logger = logging.getLogger(__name__)
@@ -115,25 +115,38 @@ def _clean_value(value):
return str(value).replace("/", "_")
-def _generate_variants(spec):
- spec = copy.deepcopy(spec)
- unresolved = _unresolved_values(spec)
+def parse_spec_vars(spec):
+ resolved, unresolved = _split_resolved_unresolved_values(spec)
+ resolved_vars = []
+ for path, value in resolved.items():
+ resolved_vars.append((path, value))
+
if not unresolved:
- yield {}, spec
- return
+ return resolved_vars, [], []
grid_vars = []
- lambda_vars = []
+ domain_vars = []
for path, value in unresolved.items():
- if callable(value):
- lambda_vars.append((path, value))
- else:
+ if value.is_grid():
grid_vars.append((path, value))
+ else:
+ domain_vars.append((path, value))
grid_vars.sort()
+ return resolved_vars, domain_vars, grid_vars
+
+
+def _generate_variants(spec):
+ spec = copy.deepcopy(spec)
+ _, domain_vars, grid_vars = parse_spec_vars(spec)
+
+ if not domain_vars and not grid_vars:
+ yield {}, spec
+ return
+
grid_search = _grid_search_generator(spec, grid_vars)
for resolved_spec in grid_search:
- resolved_vars = _resolve_lambda_vars(resolved_spec, lambda_vars)
+ resolved_vars = _resolve_domain_vars(resolved_spec, domain_vars)
for resolved, spec in _generate_variants(resolved_spec):
for path, value in grid_vars:
resolved_vars[path] = _get_value(spec, path)
@@ -148,7 +161,7 @@ def _generate_variants(spec):
yield resolved_vars, spec
-def _assign_value(spec, path, value):
+def assign_value(spec, path, value):
for k in path[:-1]:
spec = spec[k]
spec[path[-1]] = value
@@ -160,23 +173,24 @@ def _get_value(spec, path):
return spec
-def _resolve_lambda_vars(spec, lambda_vars):
+def _resolve_domain_vars(spec, domain_vars):
resolved = {}
error = True
num_passes = 0
while error and num_passes < _MAX_RESOLUTION_PASSES:
num_passes += 1
error = False
- for path, fn in lambda_vars:
+ for path, domain in domain_vars:
try:
- value = fn(_UnresolvedAccessGuard(spec))
+ value = domain.sample(_UnresolvedAccessGuard(spec))
except RecursiveDependencyError as e:
error = e
except Exception:
raise ValueError(
- "Failed to evaluate expression: {}: {}".format(path, fn))
+ "Failed to evaluate expression: {}: {}".format(
+ path, domain))
else:
- _assign_value(spec, path, value)
+ assign_value(spec, path, value)
resolved[path] = value
if error:
raise error
@@ -203,7 +217,7 @@ def increment(i):
while value_indices[-1] < len(grid_vars[-1][1]):
spec = copy.deepcopy(unresolved_spec)
for i, (path, values) in enumerate(grid_vars):
- _assign_value(spec, path, values[value_indices[i]])
+ assign_value(spec, path, values[value_indices[i]])
yield spec
if grid_vars:
done = increment(0)
@@ -217,13 +231,13 @@ def _is_resolved(v):
def _try_resolve(v):
- if isinstance(v, sample_from):
- # Function to sample from
- return False, v.func
+ if isinstance(v, Domain):
+ # Domain to sample from
+ return False, v
elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
# Lambda function in eval syntax
- return False, lambda spec: eval(
- v["eval"], _STANDARD_IMPORTS, {"spec": spec})
+ return False, Function(
+ lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec}))
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
# Grid search values
grid_values = v["grid_search"]
@@ -231,26 +245,41 @@ def _try_resolve(v):
raise TuneError(
"Grid search expected list of values, got: {}".format(
grid_values))
- return False, grid_values
+ return False, Categorical(grid_values).grid()
return True, v
-def _unresolved_values(spec):
- found = {}
+def _split_resolved_unresolved_values(spec):
+ resolved_vars = {}
+ unresolved_vars = {}
for k, v in spec.items():
resolved, v = _try_resolve(v)
if not resolved:
- found[(k, )] = v
+ unresolved_vars[(k, )] = v
elif isinstance(v, dict):
# Recurse into a dict
- for (path, value) in _unresolved_values(v).items():
- found[(k, ) + path] = value
+ _resolved_children, _unresolved_children = \
+ _split_resolved_unresolved_values(v)
+ for (path, value) in _resolved_children.items():
+ resolved_vars[(k, ) + path] = value
+ for (path, value) in _unresolved_children.items():
+ unresolved_vars[(k, ) + path] = value
elif isinstance(v, list):
# Recurse into a list
for i, elem in enumerate(v):
- for (path, value) in _unresolved_values({i: elem}).items():
- found[(k, ) + path] = value
- return found
+ _resolved_children, _unresolved_children = \
+ _split_resolved_unresolved_values({i: elem})
+ for (path, value) in _resolved_children.items():
+ resolved_vars[(k, ) + path] = value
+ for (path, value) in _unresolved_children.items():
+ unresolved_vars[(k, ) + path] = value
+ else:
+ resolved_vars[(k, )] = v
+ return resolved_vars, unresolved_vars
+
+
+def _unresolved_values(spec):
+ return _split_resolved_unresolved_values(spec)[1]
class _UnresolvedAccessGuard(dict):
diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py
new file mode 100644
index 000000000000..4e786a266606
--- /dev/null
+++ b/python/ray/tune/tests/test_sample.py
@@ -0,0 +1,260 @@
+import numpy as np
+import unittest
+
+from ray import tune
+
+
+class SearchSpaceTest(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def testBoundedFloat(self):
+ bounded = tune.sample.Float(-4.2, 8.3)
+
+ # Don't allow to specify more than one sampler
+ with self.assertRaises(ValueError):
+ bounded.normal().uniform()
+
+ # Normal
+ samples = bounded.normal(-4, 2).sample(size=1000)
+ self.assertTrue(any(-4.2 < s < 8.3 for s in samples))
+ self.assertTrue(np.mean(samples) < -2)
+
+ # Uniform
+ samples = bounded.uniform().sample(size=1000)
+ self.assertTrue(any(-4.2 < s < 8.3 for s in samples))
+ self.assertFalse(np.mean(samples) < -2)
+
+ # Loguniform
+ with self.assertRaises(ValueError):
+ bounded.loguniform().sample(size=1000)
+
+ bounded_positive = tune.sample.Float(1e-4, 1e-1)
+ samples = bounded_positive.loguniform().sample(size=1000)
+ self.assertTrue(any(1e-4 < s < 1e-1 for s in samples))
+
+ def testUnboundedFloat(self):
+ unbounded = tune.sample.Float()
+
+ # Require min and max bounds for loguniform
+ with self.assertRaises(ValueError):
+ unbounded.loguniform()
+
+ def testBoundedInt(self):
+ bounded = tune.sample.Integer(-3, 12)
+
+ samples = bounded.uniform().sample(size=1000)
+ self.assertTrue(any(-3 <= s < 12 for s in samples))
+ self.assertFalse(np.mean(samples) < 2)
+
+ def testCategorical(self):
+ categories = [-2, -1, 0, 1, 2]
+ cat = tune.sample.Categorical(categories)
+
+ samples = cat.uniform().sample(size=1000)
+ self.assertTrue(any([-2 <= s <= 2 for s in samples]))
+ self.assertTrue(all([c in samples for c in categories]))
+
+ def testIterative(self):
+ categories = [-2, -1, 0, 1, 2]
+
+ def test_iter():
+ for i in categories:
+ yield i
+
+ itr = tune.sample.Iterative(test_iter())
+ samples = itr.sample(size=5)
+ self.assertTrue(any([-2 <= s <= 2 for s in samples]))
+ self.assertTrue(all([c in samples for c in categories]))
+
+ itr = tune.sample.Iterative(iter(categories))
+ samples = itr.sample(size=5)
+ self.assertTrue(any([-2 <= s <= 2 for s in samples]))
+ self.assertTrue(all([c in samples for c in categories]))
+
+ def testFunction(self):
+ def sample(spec):
+ return np.random.uniform(-4, 4)
+
+ fnc = tune.sample.Function(sample)
+
+ samples = fnc.sample(size=1000)
+ self.assertTrue(any([-4 < s < 4 for s in samples]))
+ self.assertTrue(-2 < np.mean(samples) < 2)
+
+ def testQuantized(self):
+ bounded_positive = tune.sample.Float(1e-4, 1e-1)
+ samples = bounded_positive.loguniform().quantized(5e-4).sample(size=10)
+
+ for sample in samples:
+ factor = sample / 5e-4
+ self.assertAlmostEqual(factor, round(factor), places=10)
+
+ def testConvertAx(self):
+ from ray.tune.suggest.ax import AxSearch
+ from ax.service.ax_client import AxClient
+
+ config = {
+ "a": tune.sample.Categorical([2, 3, 4]).uniform(),
+ "b": {
+ "x": tune.sample.Integer(0, 5).quantized(2),
+ "y": 4,
+ "z": tune.sample.Float(1e-4, 1e-2).loguniform()
+ }
+ }
+ converted_config = AxSearch.convert_search_space(config)
+ ax_config = [
+ {
+ "name": "a",
+ "type": "choice",
+ "values": [2, 3, 4]
+ },
+ {
+ "name": "b/x",
+ "type": "range",
+ "bounds": [0, 5],
+ "value_type": "int"
+ },
+ {
+ "name": "b/y",
+ "type": "fixed",
+ "value": 4
+ },
+ {
+ "name": "b/z",
+ "type": "range",
+ "bounds": [1e-4, 1e-2],
+ "value_type": "float",
+ "log_scale": True
+ },
+ ]
+
+ client1 = AxClient(random_seed=1234)
+ client1.create_experiment(parameters=converted_config)
+ searcher1 = AxSearch(client1)
+
+ client2 = AxClient(random_seed=1234)
+ client2.create_experiment(parameters=ax_config)
+ searcher2 = AxSearch(client2)
+
+ config1 = searcher1.suggest("0")
+ config2 = searcher2.suggest("0")
+
+ self.assertEqual(config1, config2)
+ self.assertIn(config1["a"], [2, 3, 4])
+ self.assertIn(config1["b"]["x"], list(range(5)))
+ self.assertEqual(config1["b"]["y"], 4)
+ self.assertLess(1e-4, config1["b"]["z"])
+ self.assertLess(config1["b"]["z"], 1e-2)
+
+ def testConvertBayesOpt(self):
+ from ray.tune.suggest.bayesopt import BayesOptSearch
+
+ config = {
+ "a": tune.sample.Categorical([2, 3, 4]).uniform(),
+ "b": {
+ "x": tune.sample.Integer(0, 5).quantized(2),
+ "y": 4,
+ "z": tune.sample.Float(1e-4, 1e-2).loguniform()
+ }
+ }
+ with self.assertRaises(ValueError):
+ converted_config = BayesOptSearch.convert_search_space(config)
+
+ config = {"b": {"z": tune.sample.Float(1e-4, 1e-2).loguniform()}}
+ bayesopt_config = {"b/z": (1e-4, 1e-2)}
+
+ converted_config = BayesOptSearch.convert_search_space(config)
+
+ searcher1 = BayesOptSearch(space=converted_config, metric="none")
+ searcher2 = BayesOptSearch(space=bayesopt_config, metric="none")
+
+ config1 = searcher1.suggest("0")
+ config2 = searcher2.suggest("0")
+
+ self.assertEqual(config1, config2)
+ self.assertLess(1e-4, config1["b"]["z"])
+ self.assertLess(config1["b"]["z"], 1e-2)
+
+ def testConvertHyperOpt(self):
+ from ray.tune.suggest.hyperopt import HyperOptSearch
+ from hyperopt import hp
+
+ config = {
+ "a": tune.sample.Categorical([2, 3, 4]).uniform(),
+ "b": {
+ "x": tune.sample.Integer(0, 5).quantized(2),
+ "y": 4,
+ "z": tune.sample.Float(1e-4, 1e-2).loguniform()
+ }
+ }
+ converted_config = HyperOptSearch.convert_search_space(config)
+ hyperopt_config = {
+ "a": hp.choice("a", [2, 3, 4]),
+ "b": {
+ "x": hp.randint("x", 5),
+ "y": 4,
+ "z": hp.loguniform("z", np.log(1e-4), np.log(1e-2))
+ }
+ }
+
+ searcher1 = HyperOptSearch(
+ space=converted_config, random_state_seed=1234)
+ searcher2 = HyperOptSearch(
+ space=hyperopt_config, random_state_seed=1234)
+
+ config1 = searcher1.suggest("0")
+ config2 = searcher2.suggest("0")
+
+ self.assertEqual(config1, config2)
+ self.assertIn(config1["a"], [2, 3, 4])
+ self.assertIn(config1["b"]["x"], list(range(5)))
+ self.assertEqual(config1["b"]["y"], 4)
+ self.assertLess(1e-4, config1["b"]["z"])
+ self.assertLess(config1["b"]["z"], 1e-2)
+
+ def testConvertOptuna(self):
+ from ray.tune.suggest.optuna import OptunaSearch, param
+ from optuna.samplers import RandomSampler
+
+ config = {
+ "a": tune.sample.Categorical([2, 3, 4]).uniform(),
+ "b": {
+ "x": tune.sample.Integer(0, 5).quantized(2),
+ "y": 4,
+ "z": tune.sample.Float(1e-4, 1e-2).loguniform()
+ }
+ }
+ converted_config = OptunaSearch.convert_search_space(config)
+ optuna_config = [
+ param.suggest_categorical("a", [2, 3, 4]),
+ param.suggest_int("b/x", 0, 5, 2),
+ param.suggest_loguniform("b/z", 1e-4, 1e-2)
+ ]
+
+ sampler1 = RandomSampler(seed=1234)
+ searcher1 = OptunaSearch(
+ space=converted_config, sampler=sampler1, config=config)
+
+ sampler2 = RandomSampler(seed=1234)
+ searcher2 = OptunaSearch(
+ space=optuna_config, sampler=sampler2, config=config)
+
+ config1 = searcher1.suggest("0")
+ config2 = searcher2.suggest("0")
+
+ self.assertEqual(config1, config2)
+ self.assertIn(config1["a"], [2, 3, 4])
+ self.assertIn(config1["b"]["x"], list(range(5)))
+ self.assertEqual(config1["b"]["y"], 4)
+ self.assertLess(1e-4, config1["b"]["z"])
+ self.assertLess(config1["b"]["z"], 1e-2)
+
+
+if __name__ == "__main__":
+ import pytest
+ import sys
+ sys.exit(pytest.main(["-v", __file__]))
diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py
index 30c3649bcc71..ec7e96c5bf14 100644
--- a/python/ray/tune/tests/test_trial_scheduler.py
+++ b/python/ray/tune/tests/test_trial_scheduler.py
@@ -241,6 +241,7 @@ def has_resources(self, resources):
return True
def _pause_trial(self, trial):
+ self.trial_executor.save(trial, Checkpoint.MEMORY, None)
trial.status = Trial.PAUSED
def _launch_trial(self, trial):
@@ -702,6 +703,13 @@ def __init__(self, i, config):
self.custom_trial_name = None
self.custom_dirname = None
+ def on_checkpoint(self, checkpoint):
+ self.restored_checkpoint = checkpoint.value
+
+ @property
+ def checkpoint(self):
+ return Checkpoint(Checkpoint.MEMORY, self.trainable_name, None)
+
class PopulationBasedTestingSuite(unittest.TestCase):
def setUp(self):
@@ -720,7 +728,8 @@ def basicSetup(self,
require_attrs=True,
hyperparams=None,
hyperparam_mutations=None,
- step_once=True):
+ step_once=True,
+ synch=False):
hyperparam_mutations = hyperparam_mutations or {
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
@@ -734,7 +743,9 @@ def basicSetup(self,
hyperparam_mutations=hyperparam_mutations,
custom_explore_fn=explore,
log_config=log_config,
- require_attrs=require_attrs)
+ synch=synch,
+ require_attrs=require_attrs,
+ )
runner = _MockTrialRunner(pbt)
for i in range(num_trials):
trial_hyperparams = hyperparams or {
@@ -746,10 +757,17 @@ def basicSetup(self,
trial = _MockTrial(i, trial_hyperparams)
runner.add_trial(trial)
trial.status = Trial.RUNNING
+ for i in range(num_trials):
+ trial = runner.trials[i]
if step_once:
- self.assertEqual(
- pbt.on_trial_result(runner, trial, result(10, 50 * i)),
- TrialScheduler.CONTINUE)
+ if synch:
+ self.assertEqual(
+ pbt.on_trial_result(runner, trial, result(10, 50 * i)),
+ TrialScheduler.PAUSE)
+ else:
+ self.assertEqual(
+ pbt.on_trial_result(runner, trial, result(10, 50 * i)),
+ TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
@@ -822,6 +840,32 @@ def testCheckpointsMostPromisingTrials(self):
self.assertEqual(pbt._num_checkpoints, 2)
self.assertEqual(pbt._num_perturbations, 0)
+ def testCheckpointMostPromisingTrialsSynch(self):
+ pbt, runner = self.basicSetup(synch=True)
+ trials = runner.get_trials()
+
+ # no checkpoint: haven't hit next perturbation interval yet
+ self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[0], result(15, 200)),
+ TrialScheduler.CONTINUE)
+ self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
+ self.assertEqual(pbt._num_checkpoints, 0)
+
+ # trials should be paused until all trials are synced.
+ for i in range(len(trials) - 1):
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[i], result(20, 200 + i)),
+ TrialScheduler.PAUSE)
+
+ self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200])
+ self.assertEqual(pbt._num_checkpoints, 0)
+
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[-1], result(20, 204)),
+ TrialScheduler.PAUSE)
+ self.assertEqual(pbt._num_checkpoints, 2)
+
def testPerturbsLowPerformingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
@@ -852,6 +896,35 @@ def testPerturbsLowPerformingTrials(self):
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
+ def testPerturbsLowPerformingTrialsSynch(self):
+ pbt, runner = self.basicSetup(synch=True)
+ trials = runner.get_trials()
+
+ # no perturbation: haven't hit next perturbation interval
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[-1], result(15, -100)),
+ TrialScheduler.CONTINUE)
+ self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
+ self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
+ self.assertEqual(pbt._num_perturbations, 0)
+
+ # Don't perturb until all trials are synched.
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[-1], result(20, -100)),
+ TrialScheduler.PAUSE)
+ self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100])
+ self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
+
+ # Synch all trials.
+ for i in range(len(trials) - 1):
+ self.assertEqual(
+ pbt.on_trial_result(runner, trials[i], result(20, -10 * i)),
+ TrialScheduler.PAUSE)
+ self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100])
+ self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"])
+ self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"])
+ self.assertEqual(pbt._num_perturbations, 2)
+
def testPerturbWithoutResample(self):
pbt, runner = self.basicSetup(resample_prob=0.0)
trials = runner.get_trials()
@@ -1088,6 +1161,27 @@ def testSchedulesMostBehindTrialToRun(self):
trials[i].status = Trial.PENDING
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
+ def testSchedulesMostBehindTrialToRunSynch(self):
+ pbt, runner = self.basicSetup(synch=True)
+ trials = runner.get_trials()
+ runner.process_action(
+ trials[0], pbt.on_trial_result(runner, trials[0], result(
+ 800, 1000)))
+ runner.process_action(
+ trials[1], pbt.on_trial_result(runner, trials[1], result(
+ 700, 1001)))
+ runner.process_action(
+ trials[2], pbt.on_trial_result(runner, trials[2], result(
+ 600, 1002)))
+ runner.process_action(
+ trials[3], pbt.on_trial_result(runner, trials[3], result(
+ 500, 1003)))
+ runner.process_action(
+ trials[4], pbt.on_trial_result(runner, trials[4], result(
+ 700, 1004)))
+ self.assertIn(
+ pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]])
+
def testPerturbationResetsLastPerturbTime(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
@@ -1141,6 +1235,44 @@ def check_policy(policy):
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
+ def testLogConfigSynch(self):
+ def check_policy(policy):
+ self.assertIsInstance(policy[2], int)
+ self.assertIsInstance(policy[3], int)
+ self.assertIn(policy[0], ["0tag", "1tag"])
+ self.assertIn(policy[1], ["3tag", "4tag"])
+ self.assertIn(policy[2], [0, 1])
+ self.assertIn(policy[3], [3, 4])
+ for i in [4, 5]:
+ self.assertIsInstance(policy[i], dict)
+ for key in [
+ "const_factor", "int_factor", "float_factor",
+ "id_factor"
+ ]:
+ self.assertIn(key, policy[i])
+ self.assertIsInstance(policy[i]["float_factor"], float)
+ self.assertIsInstance(policy[i]["int_factor"], int)
+ self.assertIn(policy[i]["const_factor"], [3])
+ self.assertIn(policy[i]["int_factor"], [8, 10, 12])
+ self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6])
+ self.assertIn(policy[i]["id_factor"], [3, 4, 100])
+
+ pbt, runner = self.basicSetup(
+ log_config=True, synch=True, step_once=False)
+ trials = runner.get_trials()
+ tmpdir = tempfile.mkdtemp()
+ for i, trial in enumerate(trials):
+ trial.local_dir = tmpdir
+ trial.last_result = {TRAINING_ITERATION: i}
+ pbt.on_trial_result(runner, trials[i], result(10, i))
+ log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"]
+ for log_file in log_files:
+ self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
+ raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines()
+ for line in raw_policy:
+ check_policy(json.loads(line))
+ shutil.rmtree(tmpdir)
+
def testReplay(self):
# Returns unique increasing parameter mutations
class _Counter:
@@ -1156,6 +1288,7 @@ def __call__(self, *args, **kwargs):
perturbation_interval=5,
log_config=True,
step_once=False,
+ synch=False,
hyperparam_mutations={
"float_factor": lambda: 100.0,
"int_factor": _Counter(1000)
@@ -1292,6 +1425,176 @@ def load_checkpoint(self, checkpoint):
shutil.rmtree(tmpdir)
+ def testReplaySynch(self):
+ # Returns unique increasing parameter mutations
+ class _Counter:
+ def __init__(self, start=0):
+ self.count = start - 1
+
+ def __call__(self, *args, **kwargs):
+ self.count += 1
+ return self.count
+
+ pbt, runner = self.basicSetup(
+ num_trials=4,
+ perturbation_interval=5,
+ log_config=True,
+ step_once=False,
+ synch=True,
+ hyperparam_mutations={
+ "float_factor": lambda: 100.0,
+ "int_factor": _Counter(1000)
+ })
+ trials = runner.get_trials()
+ tmpdir = tempfile.mkdtemp()
+
+ # Internal trial state to collect the real PBT history
+ class _TrialState:
+ def __init__(self, config):
+ self.step = 0
+ self.config = config
+ self.history = []
+
+ def forward(self, t):
+ while self.step < t:
+ self.history.append(self.config)
+ self.step += 1
+
+ trial_state = []
+ for i, trial in enumerate(trials):
+ trial.local_dir = tmpdir
+ trial.last_result = {TRAINING_ITERATION: 0}
+ trial_state.append(_TrialState(trial.config))
+
+ # Helper function to simulate stepping trial k a number of steps,
+ # and reporting a score at the end
+ def trial_step(k, steps, score, synced=False):
+ res = result(trial_state[k].step + steps, score)
+
+ trials[k].last_result = res
+ trial_state[k].forward(res[TRAINING_ITERATION])
+
+ trials[k].status = Trial.RUNNING
+ if not synced:
+ action = pbt.on_trial_result(runner, trials[k], res)
+ runner.process_action(trials[k], action)
+ return
+ else:
+ # Reached synchronization point
+ old_configs = [trial.config for trial in trials]
+ action = pbt.on_trial_result(runner, trials[k], res)
+ runner.process_action(trials[k], action)
+ new_configs = [trial.config for trial in trials]
+
+ for i in range(len(trials)):
+ old_config = old_configs[i]
+ new_config = new_configs[i]
+ if old_config != new_config:
+ # Copy history from source trial
+ source = -1
+ for m, cand in enumerate(trials):
+ if cand.trainable_name == trials[
+ i].restored_checkpoint:
+ source = m
+ break
+ assert source >= 0
+ trial_state[i].history = trial_state[
+ source].history.copy()
+ trial_state[i].step = trial_state[source].step
+ trial_state[i].config = new_config.copy()
+
+ # Initial steps
+ trial_step(0, 10, 0)
+ trial_step(1, 11, 10)
+ trial_step(2, 12, 0)
+ trial_step(3, 13, -1, synced=True)
+
+ # 3 <-- 1, new_t 11
+ # next_perturb_sync = 13
+
+ # Next block
+ trial_step(0, 17, -10) # 20
+ trial_step(2, 15, -20) # 20
+ trial_step(3, 16, 0) # 20
+ trial_step(1, 7, 1, synced=True) # 18
+
+ # 2 <-- 1, new_t=11+7=18
+ # next_perturb_sync = 20
+
+ # Next block
+ trial_step(2, 13, 0) # 31
+ trial_step(3, 14, 10) # 34
+ trial_step(0, 11, -1) # 31
+ trial_step(1, 12, 0, synced=True) # 30
+
+ # 0 <-- 3, new_t=11+9+14=34
+ # next_perturb_sync = 34
+
+ # Next block
+ trial_step(0, 6, 20) # 40
+ trial_step(3, 9, -40) # 43
+ trial_step(2, 8, -50) # 39
+ trial_step(1, 7, 30, synced=True) # 37
+
+ # 2 <-- 1, new_t=18+13+8=37
+ # next_perturb_sync = 43
+
+ # Playback trainable to collect configs at each step
+ class Playback(Trainable):
+ def setup(self, config):
+ self.config = config
+ self.replayed = []
+ self.iter = 0
+
+ def step(self):
+ self.iter += 1
+ self.replayed.append(self.config)
+ return {
+ "reward": 0,
+ "done": False,
+ "replayed": self.replayed,
+ TRAINING_ITERATION: self.iter
+ }
+
+ def reset_config(self, new_config):
+ self.config = new_config
+ return True
+
+ def save_checkpoint(self, tmp_checkpoint_dir):
+ return tmp_checkpoint_dir
+
+ def load_checkpoint(self, checkpoint):
+ pass
+
+ # Loop through all trials and check if PBT history is the
+ # same as the playback history
+ for i, trial in enumerate(trials):
+ if trial.trial_id in ["1"]: # Did not exploit anything
+ continue
+
+ replay = PopulationBasedTrainingReplay(
+ os.path.join(tmpdir,
+ "pbt_policy_{}.txt".format(trial.trial_id)))
+ analysis = tune.run(
+ Playback,
+ scheduler=replay,
+ stop={TRAINING_ITERATION: trial_state[i].step})
+
+ replayed = analysis.trials[0].last_result["replayed"]
+ self.assertSequenceEqual(trial_state[i].history, replayed)
+
+ # Trial 1 did not exploit anything and should raise an error
+ with self.assertRaises(ValueError):
+ replay = PopulationBasedTrainingReplay(
+ os.path.join(tmpdir,
+ "pbt_policy_{}.txt".format(trials[1].trial_id)))
+ tune.run(
+ Playback,
+ scheduler=replay,
+ stop={TRAINING_ITERATION: trial_state[1].step})
+
+ shutil.rmtree(tmpdir)
+
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42
diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py
index f4cfbdaadd4b..740616e8ce4d 100644
--- a/python/ray/tune/tests/test_trial_scheduler_pbt.py
+++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py
@@ -4,73 +4,93 @@
import random
import unittest
import sys
+import time
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
-class MockTrainable(tune.Trainable):
- def setup(self, config):
- self.iter = 0
- self.a = config["a"]
- self.b = config["b"]
- self.c = config["c"]
+class MockParam(object):
+ def __init__(self, params):
+ self._params = params
+ self._index = 0
- def step(self):
- self.iter += 1
- return {"mean_accuracy": (self.a - self.iter) * self.b}
+ def __call__(self, *args, **kwargs):
+ val = self._params[self._index % len(self._params)]
+ self._index += 1
+ return val
- def save_checkpoint(self, tmp_checkpoint_dir):
- checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
- with open(checkpoint_path, "wb") as fp:
- pickle.dump((self.a, self.b, self.iter), fp)
- return tmp_checkpoint_dir
- def load_checkpoint(self, tmp_checkpoint_dir):
- checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
- with open(checkpoint_path, "rb") as fp:
- self.a, self.b, self.iter = pickle.load(fp)
+class PopulationBasedTrainingSynchTest(unittest.TestCase):
+ def setUp(self):
+ ray.init(num_cpus=2)
+ def MockTrainingFuncSync(config, checkpoint_dir=None):
+ iter = 0
-def MockTrainingFunc(config, checkpoint_dir=None):
- iter = 0
- a = config["a"]
- b = config["b"]
+ if checkpoint_dir:
+ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
+ with open(checkpoint_path, "rb") as fp:
+ a, iter = pickle.load(fp)
- if checkpoint_dir:
- checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
- with open(checkpoint_path, "rb") as fp:
- a, b, iter = pickle.load(fp)
+ a = config["a"] # Use the new hyperparameter if perturbed.
- while True:
- iter += 1
- with tune.checkpoint_dir(step=iter) as checkpoint_dir:
- checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
- with open(checkpoint_path, "wb") as fp:
- pickle.dump((a, b, iter), fp)
- tune.report(mean_accuracy=(a - iter) * b)
+ while True:
+ iter += 1
+ with tune.checkpoint_dir(step=iter) as checkpoint_dir:
+ checkpoint_path = os.path.join(checkpoint_dir,
+ "checkpoint")
+ with open(checkpoint_path, "wb") as fp:
+ pickle.dump((a, iter), fp)
+ # Score gets better every iteration.
+ time.sleep(1)
+ tune.report(mean_accuracy=iter + a, a=a)
+ self.MockTrainingFuncSync = MockTrainingFuncSync
-def MockTrainingFunc2(config):
- a = config["a"]
- b = config["b"]
- c1 = config["c"]["c1"]
- c2 = config["c"]["c2"]
+ def tearDown(self):
+ ray.shutdown()
- while True:
- tune.report(mean_accuracy=a * b * (c1 + c2))
+ def synchSetup(self, synch, param=[10, 20, 30]):
+ scheduler = PopulationBasedTraining(
+ time_attr="training_iteration",
+ metric="mean_accuracy",
+ mode="max",
+ perturbation_interval=1,
+ log_config=True,
+ hyperparam_mutations={"c": lambda: 1},
+ synch=synch)
+ param_a = MockParam(param)
-class MockParam(object):
- def __init__(self, params):
- self._params = params
- self._index = 0
+ random.seed(100)
+ np.random.seed(100)
+ analysis = tune.run(
+ self.MockTrainingFuncSync,
+ config={
+ "a": tune.sample_from(lambda _: param_a()),
+ "c": 1
+ },
+ fail_fast=True,
+ num_samples=3,
+ scheduler=scheduler,
+ name="testPBTSync",
+ stop={"training_iteration": 3},
+ )
+ return analysis
- def __call__(self, *args, **kwargs):
- val = self._params[self._index % len(self._params)]
- self._index += 1
- return val
+ def testAsynchFail(self):
+ analysis = self.synchSetup(False)
+ self.assertTrue(any(analysis.dataframe()["mean_accuracy"] != 33))
+
+ def testSynchPass(self):
+ analysis = self.synchSetup(True)
+ self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
+
+ def testSynchPassLast(self):
+ analysis = self.synchSetup(True, param=[30, 20, 10])
+ self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
class PopulationBasedTrainingConfigTest(unittest.TestCase):
@@ -81,6 +101,15 @@ def tearDown(self):
ray.shutdown()
def testNoConfig(self):
+ def MockTrainingFunc(config):
+ a = config["a"]
+ b = config["b"]
+ c1 = config["c"]["c1"]
+ c2 = config["c"]["c2"]
+
+ while True:
+ tune.report(mean_accuracy=a * b * (c1 + c2))
+
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
@@ -97,7 +126,7 @@ def testNoConfig(self):
)
tune.run(
- MockTrainingFunc2,
+ MockTrainingFunc,
fail_fast=True,
num_samples=4,
scheduler=scheduler,
@@ -120,6 +149,31 @@ def testPermutationContinuation(self):
fix was not applied.
See issues #9036, #9036
"""
+
+ class MockTrainable(tune.Trainable):
+ def setup(self, config):
+ self.iter = 0
+ self.a = config["a"]
+ self.b = config["b"]
+ self.c = config["c"]
+
+ def step(self):
+ self.iter += 1
+ return {"mean_accuracy": (self.a - self.iter) * self.b}
+
+ def save_checkpoint(self, tmp_checkpoint_dir):
+ checkpoint_path = os.path.join(tmp_checkpoint_dir,
+ "model.mock")
+ with open(checkpoint_path, "wb") as fp:
+ pickle.dump((self.a, self.b, self.iter), fp)
+ return tmp_checkpoint_dir
+
+ def load_checkpoint(self, tmp_checkpoint_dir):
+ checkpoint_path = os.path.join(tmp_checkpoint_dir,
+ "model.mock")
+ with open(checkpoint_path, "rb") as fp:
+ self.a, self.b, self.iter = pickle.load(fp)
+
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
@@ -151,6 +205,25 @@ def testPermutationContinuation(self):
stop={"training_iteration": 3})
def testPermutationContinuationFunc(self):
+ def MockTrainingFunc(config, checkpoint_dir=None):
+ iter = 0
+ a = config["a"]
+ b = config["b"]
+
+ if checkpoint_dir:
+ checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
+ with open(checkpoint_path, "rb") as fp:
+ a, b, iter = pickle.load(fp)
+
+ while True:
+ iter += 1
+ with tune.checkpoint_dir(step=iter) as checkpoint_dir:
+ checkpoint_path = os.path.join(checkpoint_dir,
+ "model.mock")
+ with open(checkpoint_path, "wb") as fp:
+ pickle.dump((a, b, iter), fp)
+ tune.report(mean_accuracy=(a - iter) * b)
+
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
diff --git a/python/ray/tune/tests/test_var.py b/python/ray/tune/tests/test_var.py
index 3df8e7f16edb..082a9fe0345c 100644
--- a/python/ray/tune/tests/test_var.py
+++ b/python/ray/tune/tests/test_var.py
@@ -263,13 +263,13 @@ def testNestedValues(self):
})
def testLogUniform(self):
- sampler = tune.loguniform(1e-10, 1e-1).func
- results = [sampler(None) for i in range(1000)]
+ sampler = tune.loguniform(1e-10, 1e-1)
+ results = sampler.sample(None, 1000)
assert abs(np.log(min(results)) / np.log(10) - -10) < 0.1
assert abs(np.log(max(results)) / np.log(10) - -1) < 0.1
- sampler_e = tune.loguniform(np.e**-4, np.e, base=np.e).func
- results_e = [sampler_e(None) for i in range(1000)]
+ sampler_e = tune.loguniform(np.e**-4, np.e, base=np.e)
+ results_e = sampler_e.sample(None, 1000)
assert abs(np.log(min(results_e)) - -4) < 0.1
assert abs(np.log(max(results_e)) - 1) < 0.1
diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py
index 8e644c452402..57c1d8119649 100644
--- a/python/ray/tune/trial_runner.py
+++ b/python/ray/tune/trial_runner.py
@@ -460,6 +460,7 @@ def _get_next_trial(self):
self._update_trial_queue(blocking=wait_for_trial)
with warn_if_slow("choose_trial_to_run"):
trial = self._scheduler_alg.choose_trial_to_run(self)
+ logger.debug("Running trial {}".format(trial))
return trial
def _process_events(self):
diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py
index fefcf7665c81..8b8136657953 100644
--- a/python/ray/tune/utils/util.py
+++ b/python/ray/tune/utils/util.py
@@ -96,9 +96,9 @@ def stop(self):
def pin_in_object_store(obj):
- """Deprecated, use ray.put(value, weakref=False) instead."""
+ """Deprecated, use ray.put(value) instead."""
- obj_ref = ray.put(obj, weakref=False)
+ obj_ref = ray.put(obj)
_pinned_objects.append(obj_ref)
return obj_ref
@@ -231,6 +231,18 @@ def flatten_dict(dt, delimiter="/"):
return dt
+def unflatten_dict(dt, delimiter="/"):
+ """Unflatten dict. Does not support unflattening lists."""
+ out = defaultdict(dict)
+ for key, val in dt.items():
+ path = key.split(delimiter)
+ item = out
+ for k in path[:-1]:
+ item = item[k]
+ item[path[-1]] = val
+ return dict(out)
+
+
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
"""
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py
index 17b048c16c50..6bf57683f100 100644
--- a/python/ray/util/__init__.py
+++ b/python/ray/util/__init__.py
@@ -2,13 +2,11 @@
from ray.util.actor_pool import ActorPool
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
-from ray.util.named_actors import get_actor
__all__ = [
"ActorPool",
"disable_log_once_globally",
"enable_periodic_logging",
- "get_actor",
"iter",
"log_once",
]
diff --git a/python/ray/util/named_actors.py b/python/ray/util/named_actors.py
deleted file mode 100644
index b1d39fe2d100..000000000000
--- a/python/ray/util/named_actors.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import logging
-
-import ray
-
-logger = logging.getLogger(__name__)
-
-
-def _get_actor(name):
- worker = ray.worker.global_worker
- handle = worker.core_worker.get_named_actor_handle(name)
- return handle
-
-
-def get_actor(name: str) -> ray.actor.ActorHandle:
- """Get a named actor which was previously created.
-
- If the actor doesn't exist, an exception will be raised.
-
- Args:
- name: The name of the named actor.
-
- Returns:
- The ActorHandle object corresponding to the name.
- """
- logger.warning("ray.util.get_actor has been moved to ray.get_actor and "
- "will be removed in the future.")
- return _get_actor(name)
diff --git a/python/ray/utils.py b/python/ray/utils.py
index c41c76ba5601..f5071a7f9dbb 100644
--- a/python/ray/utils.py
+++ b/python/ray/utils.py
@@ -18,6 +18,10 @@
import ray.ray_constants as ray_constants
import psutil
+pwd = None
+if sys.platform != "win32":
+ import pwd
+
logger = logging.getLogger(__name__)
# Linux can bind child processes' lifetimes to that of their parents via prctl.
@@ -118,7 +122,7 @@ def push_error_to_driver_through_redis(redis_client,
error_data = ray.gcs_utils.construct_error_message(job_id, error_type,
message, time.time())
pubsub_msg = ray.gcs_utils.PubSubMessage()
- pubsub_msg.id = job_id.hex()
+ pubsub_msg.id = job_id.binary()
pubsub_msg.data = error_data
redis_client.publish("ERROR_INFO:" + job_id.hex(),
pubsub_msg.SerializeAsString())
@@ -780,3 +784,12 @@ def try_to_symlink(symlink_path, target_path):
os.symlink(target_path, symlink_path)
except OSError:
return
+
+
+def get_user():
+ if pwd is None:
+ return ""
+ try:
+ return pwd.getpwuid(os.getuid()).pw_name
+ except Exception:
+ return ""
diff --git a/python/ray/worker.py b/python/ray/worker.py
index 72f72e1d09a4..f7d67e008e98 100644
--- a/python/ray/worker.py
+++ b/python/ray/worker.py
@@ -41,7 +41,7 @@
from ray import profiling
from ray.exceptions import (
- RayConnectionError,
+ RaySystemError,
RayError,
RayTaskError,
ObjectStoreFullError,
@@ -202,8 +202,8 @@ def check_connected(self):
Exception: An exception is raised if the worker is not connected.
"""
if not self.connected:
- raise RayConnectionError("Ray has not been started yet. You can "
- "start Ray with 'ray.init()'.")
+ raise RaySystemError("Ray has not been started yet. You can "
+ "start Ray with 'ray.init()'.")
def set_mode(self, mode):
"""Set the mode of the worker.
@@ -373,7 +373,7 @@ def sigterm_handler(signum, frame):
sys.exit(0)
-def get_gpu_ids(as_str=False):
+def get_gpu_ids():
"""Get the IDs of the GPUs that are available to the worker.
If the CUDA_VISIBLE_DEVICES environment variable was set when the worker
@@ -407,16 +407,6 @@ def get_gpu_ids(as_str=False):
max_gpus = global_worker.node.get_resource_spec().num_gpus
assigned_ids = global_worker.original_gpu_ids[:max_gpus]
- if not as_str:
- from ray.util.debug import log_once
- if log_once("ray.get_gpu_ids.as_str"):
- logger.warning(
- "ray.get_gpu_ids() will return a list of strings by default"
- " in a future version of Ray for compatibility with CUDA. "
- "To enable the forward-compatible behavior, use "
- "`ray.get_gpu_ids(as_str=True)`.")
- assigned_ids = [int(assigned_id) for assigned_id in assigned_ids]
-
return assigned_ids
@@ -438,13 +428,13 @@ def get_resource_ids():
return global_worker.core_worker.resource_ids()
-def get_webui_url():
- """Get the URL to access the web UI.
+def get_dashboard_url():
+ """Get the URL to access the Ray dashboard.
- Note that the URL does not specify which node the web UI is on.
+ Note that the URL does not specify which node the dashboard is on.
Returns:
- The URL of the web UI as a string.
+ The URL of the dashboard as a string.
"""
worker = global_worker
worker.check_connected()
@@ -477,48 +467,38 @@ def print_failed_task(task_status):
""")
-def init(address=None,
- redis_address=None,
- redis_port=None,
- num_cpus=None,
- num_gpus=None,
- memory=None,
- object_store_memory=None,
- resources=None,
- driver_object_store_memory=None,
- redis_max_memory=None,
- log_to_driver=True,
- node_ip_address=ray_constants.NODE_DEFAULT_IP,
- object_ref_seed=None,
- local_mode=False,
- redirect_worker_output=None,
- redirect_output=None,
- ignore_reinit_error=False,
- num_redis_shards=None,
- redis_max_clients=None,
- redis_password=ray_constants.REDIS_DEFAULT_PASSWORD,
- plasma_directory=None,
- huge_pages=False,
- include_java=False,
- include_dashboard=None,
- dashboard_host=ray_constants.DEFAULT_DASHBOARD_IP,
- dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT,
- job_id=None,
- job_config=None,
- configure_logging=True,
- logging_level=logging.INFO,
- logging_format=ray_constants.LOGGER_FORMAT,
- plasma_store_socket_name=None,
- raylet_socket_name=None,
- temp_dir=None,
- load_code_from_local=False,
- java_worker_options=None,
- use_pickle=True,
- _system_config=None,
- lru_evict=False,
- enable_object_reconstruction=False,
- _metrics_export_port=None,
- object_spilling_config=None):
+def init(
+ address=None,
+ *,
+ num_cpus=None,
+ num_gpus=None,
+ resources=None,
+ object_store_memory=None,
+ local_mode=False,
+ ignore_reinit_error=False,
+ include_dashboard=None,
+ dashboard_host=ray_constants.DEFAULT_DASHBOARD_IP,
+ dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT,
+ job_config=None,
+ configure_logging=True,
+ logging_level=logging.INFO,
+ logging_format=ray_constants.LOGGER_FORMAT,
+ log_to_driver=True,
+ enable_object_reconstruction=False,
+ # The following are unstable parameters and their use is discouraged.
+ _redis_max_memory=None,
+ _node_ip_address=ray_constants.NODE_DEFAULT_IP,
+ _driver_object_store_memory=None,
+ _memory=None,
+ _redis_password=ray_constants.REDIS_DEFAULT_PASSWORD,
+ _include_java=False,
+ _java_worker_options=None,
+ _temp_dir=None,
+ _load_code_from_local=False,
+ _lru_evict=False,
+ _metrics_export_port=None,
+ _object_spilling_config=None,
+ _system_config=None):
"""
Connect to an existing Ray cluster or start one and connect to it.
@@ -551,53 +531,19 @@ def init(address=None,
is running on a node in a Ray cluster, using `auto` as the value
tells the driver to detect the the cluster, removing the need to
specify a specific node address.
- redis_address (str): Deprecated; same as address.
- redis_port (int): The port that the primary Redis shard should listen
- to. If None, then a random port will be chosen.
num_cpus (int): Number of CPUs the user wishes to assign to each
- raylet.
+ raylet. By default, this is set based on virtual cores.
num_gpus (int): Number of GPUs the user wishes to assign to each
- raylet.
+ raylet. By default, this is set based on detected GPUs.
resources: A dictionary mapping the names of custom resources to the
quantities for them available.
- memory: The amount of memory (in bytes) that is available for use by
- workers requesting memory resources. By default, this is
- automatically set based on available system memory.
object_store_memory: The amount of memory (in bytes) to start the
object store with. By default, this is automatically set based on
- available system memory, subject to a 20GB cap.
- redis_max_memory: The max amount of memory (in bytes) to allow each
- redis shard to use. Once the limit is exceeded, redis will start
- LRU eviction of entries. This only applies to the sharded redis
- tables (task, object, and profile tables). By default, this is
- autoset based on available system memory, subject to a 10GB cap.
- log_to_driver (bool): If true, the output from all of the worker
- processes on all nodes will be directed to the driver.
- node_ip_address (str): The IP address of the node that we are on.
- object_ref_seed (int): Used to seed the deterministic generation of
- object refs. The same value can be used across multiple runs of the
- same driver in order to generate the object refs in a consistent
- manner. However, the same ID should not be used for different
- drivers.
+ available system memory.
local_mode (bool): If true, the code will be executed serially. This
is useful for debugging.
- driver_object_store_memory (int): Limit the amount of memory the driver
- can use in the object store for creating objects. By default, this
- is autoset based on available system memory, subject to a 20GB cap.
ignore_reinit_error: If true, Ray suppresses errors from calling
ray.init() a second time. Ray won't be restarted.
- num_redis_shards: The number of Redis shards to start in addition to
- the primary Redis shard.
- redis_max_clients: If provided, attempt to configure Redis with this
- maxclients number.
- redis_password (str): Prevents external clients without the password
- from connecting to Redis if provided.
- plasma_directory: A directory where the Plasma memory mapped files
- will be created.
- huge_pages: Boolean flag indicating whether to start the Object
- Store with hugetlbfs support. Requires plasma_directory.
- include_java: Boolean flag indicating whether or not to enable java
- workers.
include_dashboard: Boolean flag indicating whether or not to start the
Ray dashboard, which displays the status of the Ray
cluster. If this argument is None, then the UI will be started if
@@ -608,7 +554,6 @@ def init(address=None,
external machines.
dashboard_port: The port to bind the dashboard server to. Defaults to
8265.
- job_id: The ID of this job.
job_config (ray.job_config.JobConfig): The job configuration.
configure_logging: True (default) if configuration of logging is
allowed here. Otherwise, the user may want to configure it
@@ -619,37 +564,42 @@ def init(address=None,
timestamp, filename, line number, and message. See the source file
ray_constants.py for details. Ignored unless "configure_logging"
is true.
- plasma_store_socket_name (str): If provided, specifies the socket
- name used by the plasma store.
- raylet_socket_name (str): If provided, specifies the socket path
- used by the raylet process.
- temp_dir (str): If provided, specifies the root temporary
+ log_to_driver (bool): If true, the output from all of the worker
+ processes on all nodes will be directed to the driver.
+ enable_object_reconstruction (bool): If True, when an object stored in
+ the distributed plasma store is lost due to node failure, Ray will
+ attempt to reconstruct the object by re-executing the task that
+ created the object. Arguments to the task will be recursively
+ reconstructed. If False, then ray.ObjectLostError will be
+ thrown.
+ _redis_max_memory: Redis max memory.
+ _node_ip_address (str): The IP address of the node that we are on.
+ _driver_object_store_memory (int): Limit the amount of memory the
+ driver can use in the object store for creating objects.
+ _memory: Amount of reservable memory resource to create.
+ _redis_password (str): Prevents external clients without the password
+ from connecting to Redis if provided.
+ _include_java: Boolean flag indicating whether or not to enable java
+ workers.
+ _temp_dir (str): If provided, specifies the root temporary
directory for the Ray process. Defaults to an OS-specific
conventional location, e.g., "/tmp/ray".
- load_code_from_local: Whether code should be loaded from a local
+ _load_code_from_local: Whether code should be loaded from a local
module or from the GCS.
- java_worker_options: Overwrite the options to start Java workers.
- use_pickle: Deprecated.
- _system_config (dict): Configuration for overriding RayConfig
- defaults. Used to set system configuration and for experimental Ray
- core feature flags.
- lru_evict (bool): If True, when an object store is full, it will evict
+ _java_worker_options: Overwrite the options to start Java workers.
+ _lru_evict (bool): If True, when an object store is full, it will evict
objects in LRU order to make more space and when under memory
- pressure, ray.UnreconstructableError may be thrown. If False, then
+ pressure, ray.ObjectLostError may be thrown. If False, then
reference counting will be used to decide which objects are safe
to evict and when under memory pressure, ray.ObjectStoreFullError
may be thrown.
- enable_object_reconstruction (bool): If True, when an object stored in
- the distributed plasma store is lost due to node failure, Ray will
- attempt to reconstruct the object by re-executing the task that
- created the object. Arguments to the task will be recursively
- reconstructed. If False, then ray.UnreconstructableError will be
- thrown.
_metrics_export_port(int): Port number Ray exposes system metrics
through a Prometheus endpoint. It is currently under active
development, and the API is subject to change.
- object_spilling_config (str): The configuration json string for object
+ _object_spilling_config (str): The configuration json string for object
spilling I/O worker.
+ _system_config (str): JSON configuration for overriding
+ RayConfig defaults. For testing purposes ONLY.
Returns:
Address information about the started processes.
@@ -659,15 +609,8 @@ def init(address=None,
arguments is passed in.
"""
- if not use_pickle:
- raise DeprecationWarning("The use_pickle argument is deprecated.")
-
- if redis_address is not None:
- raise DeprecationWarning("The redis_address argument is deprecated. "
- "Please use address instead.")
-
if "RAY_ADDRESS" in os.environ:
- if redis_address is None and (address is None or address == "auto"):
+ if address is None or address == "auto":
address = os.environ["RAY_ADDRESS"]
else:
raise RuntimeError(
@@ -677,9 +620,10 @@ def init(address=None,
"please call ray.init() or ray.init(address=\"auto\") on the "
"driver.")
- if redis_address is not None or address is not None:
- redis_address, _, _ = services.validate_redis_address(
- address, redis_address)
+ if address:
+ redis_address, _, _ = services.validate_redis_address(address)
+ else:
+ redis_address = None
if configure_logging:
setup_logger(logging_level, logging_format)
@@ -700,12 +644,6 @@ def init(address=None,
"'ignore_reinit_error=True' or by calling "
"'ray.shutdown()' prior to 'ray.init()'.")
- # Convert hostnames to numerical IP address.
- if node_ip_address is not None:
- node_ip_address = services.address_to_ip(node_ip_address)
-
- raylet_ip_address = node_ip_address
-
_system_config = _system_config or {}
if not isinstance(_system_config, dict):
raise TypeError("The _system_config must be a dict.")
@@ -715,39 +653,37 @@ def init(address=None,
# In this case, we need to start a new cluster.
ray_params = ray.parameter.RayParams(
redis_address=redis_address,
- redis_port=redis_port,
- node_ip_address=node_ip_address,
- raylet_ip_address=raylet_ip_address,
- object_ref_seed=object_ref_seed,
+ node_ip_address=None,
+ raylet_ip_address=None,
+ object_ref_seed=None,
driver_mode=driver_mode,
- redirect_worker_output=redirect_worker_output,
- redirect_output=redirect_output,
+ redirect_worker_output=None,
+ redirect_output=None,
num_cpus=num_cpus,
num_gpus=num_gpus,
resources=resources,
- num_redis_shards=num_redis_shards,
- redis_max_clients=redis_max_clients,
- redis_password=redis_password,
- plasma_directory=plasma_directory,
- huge_pages=huge_pages,
- include_java=include_java,
+ num_redis_shards=None,
+ redis_max_clients=None,
+ redis_password=_redis_password,
+ plasma_directory=None,
+ huge_pages=None,
+ include_java=_include_java,
include_dashboard=include_dashboard,
dashboard_host=dashboard_host,
dashboard_port=dashboard_port,
- memory=memory,
+ memory=_memory,
object_store_memory=object_store_memory,
- redis_max_memory=redis_max_memory,
- plasma_store_socket_name=plasma_store_socket_name,
- raylet_socket_name=raylet_socket_name,
- temp_dir=temp_dir,
- load_code_from_local=load_code_from_local,
- java_worker_options=java_worker_options,
+ redis_max_memory=_redis_max_memory,
+ plasma_store_socket_name=None,
+ temp_dir=_temp_dir,
+ load_code_from_local=_load_code_from_local,
+ java_worker_options=_java_worker_options,
start_initial_python_workers_for_first_job=True,
_system_config=_system_config,
- lru_evict=lru_evict,
+ lru_evict=_lru_evict,
enable_object_reconstruction=enable_object_reconstruction,
metrics_export_port=_metrics_export_port,
- object_spilling_config=object_spilling_config)
+ object_spilling_config=_object_spilling_config)
# Start the Ray processes. We set shutdown_at_exit=False because we
# shutdown the node in the ray.shutdown call that happens in the atexit
# handler. We still spawn a reaper process in case the atexit handler
@@ -766,45 +702,15 @@ def init(address=None,
if resources is not None:
raise ValueError("When connecting to an existing cluster, "
"resources must not be provided.")
- if num_redis_shards is not None:
- raise ValueError("When connecting to an existing cluster, "
- "num_redis_shards must not be provided.")
- if redis_max_clients is not None:
- raise ValueError("When connecting to an existing cluster, "
- "redis_max_clients must not be provided.")
- if memory is not None:
- raise ValueError("When connecting to an existing cluster, "
- "memory must not be provided.")
if object_store_memory is not None:
raise ValueError("When connecting to an existing cluster, "
"object_store_memory must not be provided.")
- if redis_max_memory is not None:
- raise ValueError("When connecting to an existing cluster, "
- "redis_max_memory must not be provided.")
- if plasma_directory is not None:
- raise ValueError("When connecting to an existing cluster, "
- "plasma_directory must not be provided.")
- if huge_pages:
- raise ValueError("When connecting to an existing cluster, "
- "huge_pages must not be provided.")
- if temp_dir is not None:
- raise ValueError("When connecting to an existing cluster, "
- "temp_dir must not be provided.")
- if plasma_store_socket_name is not None:
- raise ValueError("When connecting to an existing cluster, "
- "plasma_store_socket_name must not be provided.")
- if raylet_socket_name is not None:
- raise ValueError("When connecting to an existing cluster, "
- "raylet_socket_name must not be provided.")
- if java_worker_options is not None:
- raise ValueError("When connecting to an existing cluster, "
- "java_worker_options must not be provided.")
if _system_config is not None and len(_system_config) != 0:
raise ValueError("When connecting to an existing cluster, "
"_system_config must not be provided.")
- if lru_evict:
+ if _lru_evict:
raise ValueError("When connecting to an existing cluster, "
- "lru_evict must not be provided.")
+ "_lru_evict must not be provided.")
if enable_object_reconstruction:
raise ValueError(
"When connecting to an existing cluster, "
@@ -812,15 +718,15 @@ def init(address=None,
# In this case, we only need to connect the node.
ray_params = ray.parameter.RayParams(
- node_ip_address=node_ip_address,
- raylet_ip_address=raylet_ip_address,
+ node_ip_address=None,
+ raylet_ip_address=None,
redis_address=redis_address,
- redis_password=redis_password,
- object_ref_seed=object_ref_seed,
- temp_dir=temp_dir,
- load_code_from_local=load_code_from_local,
+ redis_password=_redis_password,
+ object_ref_seed=None,
+ temp_dir=_temp_dir,
+ load_code_from_local=_load_code_from_local,
_system_config=_system_config,
- lru_evict=lru_evict,
+ lru_evict=_lru_evict,
enable_object_reconstruction=enable_object_reconstruction,
metrics_export_port=_metrics_export_port)
_global_node = ray.node.Node(
@@ -835,8 +741,8 @@ def init(address=None,
mode=driver_mode,
log_to_driver=log_to_driver,
worker=global_worker,
- driver_object_store_memory=driver_object_store_memory,
- job_id=job_id,
+ driver_object_store_memory=_driver_object_store_memory,
+ job_id=None,
job_config=job_config)
for hook in _post_init_hooks:
@@ -849,7 +755,7 @@ def init(address=None,
_post_init_hooks = []
-def shutdown(exiting_interpreter=False):
+def shutdown(_exiting_interpreter=False):
"""Disconnect the worker, and terminate processes started by ray.init().
This will automatically run at the end when a Python process that uses Ray
@@ -863,16 +769,16 @@ def shutdown(exiting_interpreter=False):
will need to reload the module.
Args:
- exiting_interpreter (bool): True if this is called by the atexit hook
+ _exiting_interpreter (bool): True if this is called by the atexit hook
and false otherwise. If we are exiting the interpreter, we will
wait a little while to print any extra error messages.
"""
- if exiting_interpreter and global_worker.mode == SCRIPT_MODE:
+ if _exiting_interpreter and global_worker.mode == SCRIPT_MODE:
# This is a duration to sleep before shutting down everything in order
# to make sure that log messages finish printing.
time.sleep(0.5)
- disconnect(exiting_interpreter)
+ disconnect(_exiting_interpreter)
# We need to destruct the core worker here because after this function,
# we will tear down any processes spawned by ray.init() and the background
@@ -1422,50 +1328,7 @@ def _changeproctitle(title, next_title):
setproctitle.setproctitle(next_title)
-def register_custom_serializer(cls,
- serializer,
- deserializer,
- use_pickle=False,
- use_dict=False,
- class_id=None):
- """Registers custom functions for efficient object serialization.
-
- The serializer and deserializer are used when transferring objects of
- `cls` across processes and nodes. This can be significantly faster than
- the Ray default fallbacks. Wraps `register_custom_serializer` underneath.
-
- Args:
- cls (type): The class that ray should use this custom serializer for.
- serializer: The custom serializer that takes in a cls instance and
- outputs a serialized representation. use_pickle and use_dict
- must be False if provided.
- deserializer: The custom deserializer that takes in a serialized
- representation of the cls and outputs a cls instance. use_pickle
- and use_dict must be False if provided.
- use_pickle: Deprecated.
- use_dict: Deprecated.
- class_id (str): Unique ID of the class. Autogenerated if None.
- """
- worker = global_worker
- worker.check_connected()
-
- if use_pickle:
- raise DeprecationWarning(
- "`use_pickle` is no longer a valid parameter and will be removed "
- "in future versions of Ray. If this breaks your application, "
- "see `SerializationContext.register_custom_serializer`.")
- if use_dict:
- raise DeprecationWarning(
- "`use_pickle` is no longer a valid parameter and will be removed "
- "in future versions of Ray. If this breaks your application, "
- "see `SerializationContext.register_custom_serializer`.")
- assert serializer is not None and deserializer is not None
- context = global_worker.get_serialization_context()
- context.register_custom_serializer(
- cls, serializer, deserializer, class_id=class_id)
-
-
-def show_in_webui(message, key="", dtype="text"):
+def show_in_dashboard(message, key="", dtype="text"):
"""Display message in dashboard.
Display message for the current task or actor in the dashboard.
@@ -1497,7 +1360,7 @@ def show_in_webui(message, key="", dtype="text"):
blocking_get_inside_async_warned = False
-def get(object_refs, timeout=None):
+def get(object_refs, *, timeout=None):
"""Get a remote object or a list of remote objects from the object store.
This method blocks until the object corresponding to the object ref is
@@ -1520,7 +1383,7 @@ def get(object_refs, timeout=None):
A Python object or a list of Python objects.
Raises:
- RayTimeoutError: A RayTimeoutError is raised if a timeout is set and
+ GetTimeoutError: A GetTimeoutError is raised if a timeout is set and
the get takes longer than timeout to return.
Exception: An exception is raised if the task that created the object
or that created one of the objects raised an exception.
@@ -1554,7 +1417,7 @@ def get(object_refs, timeout=None):
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
- if isinstance(value, ray.exceptions.UnreconstructableError):
+ if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
@@ -1570,17 +1433,13 @@ def get(object_refs, timeout=None):
return values
-def put(value, weakref=False):
+def put(value):
"""Store an object in the object store.
The object may not be evicted while a reference to the returned ID exists.
Args:
value: The Python object to be stored.
- weakref: If set, allows the object to be evicted while a reference
- to the returned ID exists. You might want to set this if putting
- a lot of objects that you might not need in the future.
- It allows Ray to more aggressively reclaim memory.
Returns:
The object ref assigned to this value.
@@ -1589,7 +1448,7 @@ def put(value, weakref=False):
worker.check_connected()
with profiling.profile("ray.put"):
try:
- object_ref = worker.put_object(value, pin_object=not weakref)
+ object_ref = worker.put_object(value, pin_object=True)
except ObjectStoreFullError:
logger.info(
"Put failed since the value was either too large or the "
@@ -1602,7 +1461,7 @@ def put(value, weakref=False):
blocking_wait_inside_async_warned = False
-def wait(object_refs, num_returns=1, timeout=None):
+def wait(object_refs, *, num_returns=1, timeout=None):
"""Return a list of IDs that are ready and a list of IDs that are not.
If timeout is set, the function returns either when the requested number of
@@ -1710,11 +1569,11 @@ def get_actor(name):
"""
worker = global_worker
worker.check_connected()
-
- return ray.util.named_actors._get_actor(name)
+ handle = worker.core_worker.get_named_actor_handle(name)
+ return handle
-def kill(actor, no_restart=True):
+def kill(actor, *, no_restart=True):
"""Kill an actor forcefully.
This will interrupt any running tasks on the actor, causing them to fail
@@ -1740,7 +1599,7 @@ def kill(actor, no_restart=True):
worker.core_worker.kill_actor(actor._ray_actor_id, no_restart)
-def cancel(object_ref, force=False):
+def cancel(object_ref, *, force=False):
"""Cancels a task according to the following conditions.
If the specified task is pending execution, it will not be executed. If
@@ -1752,7 +1611,7 @@ def cancel(object_ref, force=False):
Only non-actor tasks can be canceled. Canceled tasks will not be
retried (max_retries will not be respected).
- Calling ray.get on a canceled task will raise a RayCancellationError.
+ Calling ray.get on a canceled task will raise a TaskCancelledError.
Args:
object_ref (ObjectRef): ObjectRef returned by the task
@@ -1783,7 +1642,7 @@ def _mode(worker=global_worker):
return worker.mode
-def make_decorator(num_return_vals=None,
+def make_decorator(num_returns=None,
num_cpus=None,
num_gpus=None,
memory=None,
@@ -1806,21 +1665,43 @@ def decorator(function_or_class):
if max_task_retries is not None:
raise ValueError("The keyword 'max_task_retries' is not "
"allowed for remote functions.")
-
+ if num_returns is not None and (not isinstance(num_returns, int)
+ or num_returns < 0):
+ raise ValueError(
+ "The keyword 'num_returns' only accepts 0 or a"
+ " positive integer")
+ if max_retries is not None and (not isinstance(max_retries, int)
+ or max_retries < -1):
+ raise ValueError(
+ "The keyword 'max_retries' only accepts 0, -1 or a"
+ " positive integer")
+ if max_calls is not None and (not isinstance(max_calls, int)
+ or max_calls < 0):
+ raise ValueError(
+ "The keyword 'max_calls' only accepts 0 or a positive"
+ " integer")
return ray.remote_function.RemoteFunction(
Language.PYTHON, function_or_class, None, num_cpus, num_gpus,
- memory, object_store_memory, resources, num_return_vals,
- max_calls, max_retries, placement_group,
- placement_group_bundle_index)
+ memory, object_store_memory, resources, num_returns, max_calls,
+ max_retries, placement_group, placement_group_bundle_index)
if inspect.isclass(function_or_class):
- if num_return_vals is not None:
- raise TypeError("The keyword 'num_return_vals' is not "
+ if num_returns is not None:
+ raise TypeError("The keyword 'num_returns' is not "
"allowed for actors.")
if max_calls is not None:
raise TypeError("The keyword 'max_calls' is not "
"allowed for actors.")
-
+ if max_restarts is not None and (not isinstance(max_restarts, int)
+ or max_restarts < -1):
+ raise ValueError(
+ "The keyword 'max_restarts' only accepts -1, 0 or a"
+ " positive integer")
+ if max_task_retries is not None and (not isinstance(
+ max_task_retries, int) or max_task_retries < -1):
+ raise ValueError(
+ "The keyword 'max_task_retries' only accepts -1, 0 or a"
+ " positive integer")
return ray.actor.make_actor(function_or_class, num_cpus, num_gpus,
memory, object_store_memory, resources,
max_restarts, max_task_retries)
@@ -1850,7 +1731,7 @@ def method(self):
It can also be used with specific keyword arguments:
- * **num_return_vals:** This is only for *remote functions*. It specifies
+ * **num_returns:** This is only for *remote functions*. It specifies
the number of object refs returned by the remote function invocation.
* **num_cpus:** The quantity of CPU cores to reserve for this task or for
the lifetime of the actor.
@@ -1892,7 +1773,7 @@ def method(self):
.. code-block:: python
- @ray.remote(num_gpus=1, max_calls=1, num_return_vals=2)
+ @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
def f():
return 1, 2
@@ -1907,7 +1788,7 @@ def method(self):
.. code-block:: python
- @ray.remote(num_gpus=1, max_calls=1, num_return_vals=2)
+ @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
def f():
return 1, 2
g = f.options(num_gpus=2, max_calls=None)
@@ -1933,15 +1814,15 @@ def method(self):
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
- "the arguments 'num_return_vals', 'num_cpus', 'num_gpus', "
+ "the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
- "'@ray.remote(num_return_vals=2, "
+ "'@ray.remote(num_returns=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in [
- "num_return_vals",
+ "num_returns",
"num_cpus",
"num_gpus",
"memory",
@@ -1966,7 +1847,7 @@ def method(self):
assert "GPU" not in resources, "Use the 'num_gpus' argument."
# Handle other arguments.
- num_return_vals = kwargs.get("num_return_vals")
+ num_returns = kwargs.get("num_returns")
max_calls = kwargs.get("max_calls")
max_restarts = kwargs.get("max_restarts")
max_task_retries = kwargs.get("max_task_retries")
@@ -1975,7 +1856,7 @@ def method(self):
max_retries = kwargs.get("max_retries")
return make_decorator(
- num_return_vals=num_return_vals,
+ num_returns=num_returns,
num_cpus=num_cpus,
num_gpus=num_gpus,
memory=memory,
diff --git a/python/requirements_tune.txt b/python/requirements_tune.txt
index 1123fc7ed879..b61dc62393a4 100644
--- a/python/requirements_tune.txt
+++ b/python/requirements_tune.txt
@@ -19,6 +19,7 @@ optuna
pytest-remotedata>=0.3.1
pytorch-lightning
scikit-optimize
+scipy
sigopt
smart_open
tensorflow_probability
diff --git a/python/setup.py b/python/setup.py
index bbe2a0be44b3..0d804209af06 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -111,7 +111,7 @@
extras = {
"debug": [],
"serve": ["uvicorn", "flask", "requests"],
- "tune": ["tabulate", "tensorboardX", "pandas"]
+ "tune": ["tabulate", "tensorboardX", "pandas", "scipy"]
}
extras["rllib"] = extras["tune"] + [
@@ -121,7 +121,6 @@
"lz4",
"opencv-python-headless<=4.3.0.36",
"pyyaml",
- "scipy",
]
extras["streaming"] = []
diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py
index 73e69bd55998..0cef1b1783e5 100644
--- a/rllib/evaluation/rollout_worker.py
+++ b/rllib/evaluation/rollout_worker.py
@@ -420,7 +420,7 @@ def make_env(vector_index):
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE):
# Check available number of GPUs
- if not ray.get_gpu_ids(as_str=True):
+ if not ray.get_gpu_ids():
logger.debug("Creating policy evaluation worker {}".format(
worker_index) +
" on CPU (please ignore any CUDA init errors)")
@@ -619,7 +619,7 @@ def sample(self) -> SampleBatchType:
return batch
@DeveloperAPI
- @ray.method(num_return_vals=2)
+ @ray.method(num_returns=2)
def sample_with_count(self) -> Tuple[SampleBatchType, int]:
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
diff --git a/rllib/train.py b/rllib/train.py
index b19268dfd4d4..08c9a4cefc08 100755
--- a/rllib/train.py
+++ b/rllib/train.py
@@ -69,16 +69,6 @@ def create_parser(parser_creator=None):
default=None,
type=int,
help="Emulate multiple cluster nodes for debugging.")
- parser.add_argument(
- "--ray-redis-max-memory",
- default=None,
- type=int,
- help="--redis-max-memory to use if starting a new cluster.")
- parser.add_argument(
- "--ray-memory",
- default=None,
- type=int,
- help="--memory to use if starting a new cluster.")
parser.add_argument(
"--ray-object-store-memory",
default=None,
@@ -204,17 +194,13 @@ def run(args, parser):
cluster.add_node(
num_cpus=args.ray_num_cpus or 1,
num_gpus=args.ray_num_gpus or 0,
- object_store_memory=args.ray_object_store_memory,
- memory=args.ray_memory,
- redis_max_memory=args.ray_redis_max_memory)
+ object_store_memory=args.ray_object_store_memory)
ray.init(address=cluster.address)
else:
ray.init(
include_dashboard=not args.no_ray_ui,
address=args.ray_address,
object_store_memory=args.ray_object_store_memory,
- memory=args.ray_memory,
- redis_max_memory=args.ray_redis_max_memory,
num_cpus=args.ray_num_cpus,
num_gpus=args.ray_num_gpus,
local_mode=args.local_mode)
diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h
index e643daef96e7..08280b9b611d 100644
--- a/src/ray/common/ray_config_def.h
+++ b/src/ray/common/ray_config_def.h
@@ -294,3 +294,7 @@ RAY_CONFIG(uint64_t, metrics_report_interval_ms, 10000)
/// The maximum number of I/O worker that raylet starts.
RAY_CONFIG(int, max_io_workers, 1)
+
+/// Enable the task timeline. If this is enabled, certain events such as task
+/// execution are profiled and sent to the GCS.
+RAY_CONFIG(bool, enable_timeline, true)
diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc
index 425f967ab1e0..abd3d7712b28 100644
--- a/src/ray/core_worker/core_worker.cc
+++ b/src/ray/core_worker/core_worker.cc
@@ -2146,6 +2146,7 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest &
stats->set_current_task_func_desc(current_task_.FunctionDescriptor()->ToString());
stats->set_ip_address(rpc_address_.ip_address());
stats->set_port(rpc_address_.port());
+ stats->set_job_id(worker_context_.GetCurrentJobID().Binary());
stats->set_actor_id(actor_id_.Binary());
auto used_resources_map = stats->mutable_used_resources();
for (auto const &it : *resource_ids_) {
diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc
index 388eeb36b077..ff21188adcc8 100644
--- a/src/ray/core_worker/test/direct_task_transport_test.cc
+++ b/src/ray/core_worker/test/direct_task_transport_test.cc
@@ -365,6 +365,10 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
@@ -395,6 +399,10 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
@@ -446,6 +454,10 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
@@ -490,7 +502,6 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
-
// Task 3 finishes, the worker is returned.
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 1);
@@ -504,6 +515,10 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) {
@@ -560,6 +575,10 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) {
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 3);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestConcurrentCancellationAndSubmission) {
@@ -613,6 +632,10 @@ TEST(DirectTaskTransportTest, TestConcurrentCancellationAndSubmission) {
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
@@ -657,6 +680,10 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) {
@@ -691,6 +718,10 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) {
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestSpillback) {
@@ -749,6 +780,10 @@ TEST(DirectTaskTransportTest, TestSpillback) {
ASSERT_EQ(remote_client.second->num_leases_canceled, 0);
ASSERT_FALSE(remote_client.second->ReplyCancelWorkerLease());
}
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) {
@@ -813,6 +848,10 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) {
ASSERT_EQ(remote_client.second->num_leases_canceled, 0);
ASSERT_FALSE(remote_client.second->ReplyCancelWorkerLease());
}
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
// Helper to run a test that checks that 'same1' and 'same2' are treated as the same
@@ -841,13 +880,16 @@ void TestSchedulingKey(const std::shared_ptr store,
ASSERT_EQ(worker_client->callbacks.size(), 1);
// Another worker is requested because same2 is pending.
ASSERT_EQ(raylet_client->num_workers_requested, 3);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
// same1 runs successfully. Worker isn't returned.
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 0);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
- // taske1_2 is pushed.
+ // same2 is pushed.
ASSERT_EQ(worker_client->callbacks.size(), 1);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 1);
+ ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
// different is pushed.
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
@@ -863,6 +905,16 @@ void TestSchedulingKey(const std::shared_ptr store,
ASSERT_TRUE(worker_client->ReplyPushTask());
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+
+ ASSERT_EQ(raylet_client->num_leases_canceled, 1);
+
+ // Trigger reply to RequestWorkerLease to remove the canceled pending lease request
+ ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil(), true));
+ ASSERT_EQ(raylet_client->num_workers_returned, 2);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestSchedulingKeys) {
@@ -983,6 +1035,10 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) {
ASSERT_EQ(raylet_client->num_workers_disconnected, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestKillExecutingTask) {
@@ -1032,6 +1088,10 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) {
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestKillPendingTask) {
@@ -1061,6 +1121,13 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) {
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_TRUE(raylet_client->ReplyCancelWorkerLease());
+
+ // Trigger reply to RequestWorkerLease to remove the canceled pending lease request
+ ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil(), true));
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestKillResolvingTask) {
@@ -1092,6 +1159,10 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) {
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 0);
ASSERT_EQ(task_finisher->num_tasks_failed, 1);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) {
@@ -1162,6 +1233,10 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) {
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) {
@@ -1237,6 +1312,188 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) {
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 1);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
+}
+
+TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) {
+ rpc::Address address;
+ auto raylet_client = std::make_shared();
+ auto worker_client = std::make_shared();
+ auto store = std::make_shared();
+ auto client_pool = std::make_shared(
+ [&](const rpc::Address &addr) { return worker_client; });
+ auto task_finisher = std::make_shared();
+ auto actor_creator = std::make_shared();
+
+ // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the pipelining
+ // of task submissions. This is done by passing a max_tasks_in_flight_per_worker
+ // parameter to the CoreWorkerDirectTaskSubmitter.
+ uint32_t max_tasks_in_flight_per_worker = 10;
+ CoreWorkerDirectTaskSubmitter submitter(
+ address, raylet_client, client_pool, nullptr, store, task_finisher, ClientID::Nil(),
+ kLongTimeout, actor_creator, max_tasks_in_flight_per_worker);
+
+ // prepare 30 tasks and save them in a vector
+ std::unordered_map empty_resources;
+ ray::FunctionDescriptor empty_descriptor =
+ ray::FunctionDescriptorBuilder::BuildPython("", "", "", "");
+ std::vector tasks;
+ for (int i = 0; i < 30; i++) {
+ tasks.push_back(BuildTaskSpec(empty_resources, empty_descriptor));
+ }
+ ASSERT_EQ(tasks.size(), 30);
+
+ // Submit 4 tasks, and check that 1 worker is requested.
+ for (int i = 1; i <= 4; i++) {
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ }
+ ASSERT_EQ(tasks.size(), 26);
+ ASSERT_EQ(raylet_client->num_workers_requested, 1);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 0);
+
+ // Grant a worker lease, and check that still only 1 worker was requested.
+ ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil()));
+ ASSERT_EQ(raylet_client->num_workers_requested, 1);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 4);
+
+ // Submit 6 more tasks, and check that still only 1 worker was requested.
+ for (int i = 1; i <= 6; i++) {
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ }
+ ASSERT_EQ(tasks.size(), 20);
+ ASSERT_EQ(raylet_client->num_workers_requested, 1);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 10);
+
+ // Submit 1 more task, and check that one more worker is requested, for a total of 2.
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ ASSERT_EQ(tasks.size(), 19);
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 10);
+
+ // Grant a worker lease, and check that still only 2 workers were requested.
+ ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil()));
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 11);
+
+ // Submit 9 more tasks, and check that the total number of workers requested is still 2.
+ for (int i = 1; i <= 9; i++) {
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ }
+ ASSERT_EQ(tasks.size(), 10);
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 0);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 20);
+
+ // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the total
+ // number of workers requested remains equal to 2.
+ for (int i = 1; i <= 5; i++) {
+ ASSERT_TRUE(worker_client->ReplyPushTask());
+ }
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 5);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 15);
+
+ // Submit 5 new tasks, and check that we still have requested only 2 workers.
+ for (int i = 1; i <= 5; i++) {
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ }
+ ASSERT_EQ(tasks.size(), 5);
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 5);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 20);
+
+ // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the total
+ // number of workers requested remains equal to 2.
+ for (int i = 1; i <= 5; i++) {
+ ASSERT_TRUE(worker_client->ReplyPushTask());
+ }
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 10);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 15);
+
+ // Submit last 5 tasks, and check that the total number of workers requested is still 2
+ for (int i = 1; i <= 5; i++) {
+ auto task = tasks.front();
+ ASSERT_TRUE(submitter.SubmitTask(task).ok());
+ tasks.erase(tasks.begin());
+ }
+ ASSERT_EQ(tasks.size(), 0);
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 0);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 10);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 20);
+
+ // Execute all the resulting 20 tasks, and check that the total number of workers
+ // requested is 2.
+ for (int i = 1; i <= 20; i++) {
+ ASSERT_TRUE(worker_client->ReplyPushTask());
+ }
+ ASSERT_EQ(raylet_client->num_workers_requested, 2);
+ ASSERT_EQ(raylet_client->num_workers_returned, 2);
+ ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
+ ASSERT_EQ(task_finisher->num_tasks_complete, 30);
+ ASSERT_EQ(task_finisher->num_tasks_failed, 0);
+ ASSERT_EQ(raylet_client->num_leases_canceled, 0);
+ ASSERT_EQ(worker_client->callbacks.size(), 0);
+
+ // Check that there are no entries left in the scheduling_key_entries_ hashmap. These
+ // would otherwise cause a memory leak.
+ ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
} // namespace ray
diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc
index 79f9cd733645..3895cfbbe3f2 100644
--- a/src/ray/core_worker/transport/direct_task_transport.cc
+++ b/src/ray/core_worker/transport/direct_task_transport.cc
@@ -75,12 +75,32 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(),
task_spec.IsActorCreationTask() ? task_spec.ActorCreationId()
: ActorID::Nil());
- auto it = task_queues_.find(scheduling_key);
- if (it == task_queues_.end()) {
- it =
- task_queues_.emplace(scheduling_key, std::deque()).first;
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ scheduling_key_entry.task_queue.push_back(task_spec);
+ if (!scheduling_key_entry.AllPipelinesToWorkersFull(
+ max_tasks_in_flight_per_worker_)) {
+ // The pipelines to the current workers are not full yet, so we don't need more
+ // workers.
+
+ // Find a worker with a number of tasks in flight that is less than the maximum
+ // value (max_tasks_in_flight_per_worker_) and call OnWorkerIdle to send tasks
+ // to that worker
+ for (auto active_worker_addr : scheduling_key_entry.active_workers) {
+ RAY_CHECK(worker_to_lease_entry_.find(active_worker_addr) !=
+ worker_to_lease_entry_.end());
+ auto &lease_entry = worker_to_lease_entry_[active_worker_addr];
+ if (!lease_entry.PipelineToWorkerFull(max_tasks_in_flight_per_worker_)) {
+ OnWorkerIdle(active_worker_addr, scheduling_key, false,
+ lease_entry.assigned_resources);
+ // If we find a worker with a non-full pipeline, all we need to do is to
+ // submit the new task to the worker in question by calling OnWorkerIdle
+ // once. We don't need to worry about other tasks in the queue because the
+ // queue cannot have other tasks in it if there are active workers with
+ // non-full pipelines.
+ break;
+ }
+ }
}
- it->second.push_back(task_spec);
RequestNewWorkerIfNeeded(scheduling_key);
}
}
@@ -93,32 +113,51 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
}
void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient(
- const rpc::WorkerAddress &addr, std::shared_ptr lease_client) {
+ const rpc::WorkerAddress &addr, std::shared_ptr lease_client,
+ const google::protobuf::RepeatedPtrField &assigned_resources,
+ const SchedulingKey &scheduling_key) {
client_cache_->GetOrConnect(addr.ToProto());
int64_t expiration = current_time_ms() + lease_timeout_ms_;
- LeaseEntry new_lease_entry = LeaseEntry(std::move(lease_client), expiration, 0);
+ LeaseEntry new_lease_entry = LeaseEntry(std::move(lease_client), expiration, 0,
+ assigned_resources, scheduling_key);
worker_to_lease_entry_.emplace(addr, new_lease_entry);
+
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ RAY_CHECK(scheduling_key_entry.active_workers.emplace(addr).second);
+ RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
}
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
const rpc::WorkerAddress &addr, const SchedulingKey &scheduling_key, bool was_error,
const google::protobuf::RepeatedPtrField &assigned_resources) {
auto &lease_entry = worker_to_lease_entry_[addr];
- if (!lease_entry.lease_client_) {
+ if (!lease_entry.lease_client) {
return;
}
- RAY_CHECK(lease_entry.lease_client_);
+ RAY_CHECK(lease_entry.lease_client);
- auto queue_entry = task_queues_.find(scheduling_key);
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ auto ¤t_queue = scheduling_key_entry.task_queue;
// Return the worker if there was an error executing the previous task,
// the previous task is an actor creation task,
// there are no more applicable queued tasks, or the lease is expired.
- if (was_error || queue_entry == task_queues_.end() ||
- current_time_ms() > lease_entry.lease_expiration_time_) {
+ if (was_error || current_queue.empty() ||
+ current_time_ms() > lease_entry.lease_expiration_time) {
+ RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
+
// Return the worker only if there are no tasks in flight
- if (lease_entry.tasks_in_flight_ == 0) {
+ if (lease_entry.tasks_in_flight == 0) {
+ // Decrement the number of active workers consuming tasks from the queue associated
+ // with the current scheduling_key
+ scheduling_key_entry.active_workers.erase(addr);
+ if (scheduling_key_entry.CanDelete()) {
+ // We can safely remove the entry keyed by scheduling_key from the
+ // scheduling_key_entries_ hashmap.
+ scheduling_key_entries_.erase(scheduling_key);
+ }
+
auto status =
- lease_entry.lease_client_->ReturnWorker(addr.port, addr.worker_id, was_error);
+ lease_entry.lease_client->ReturnWorker(addr.port, addr.worker_id, was_error);
if (!status.ok()) {
RAY_LOG(ERROR) << "Error returning worker to raylet: " << status.ToString();
}
@@ -128,21 +167,27 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
} else {
auto &client = *client_cache_->GetOrConnect(addr.ToProto());
- while (!queue_entry->second.empty() &&
- lease_entry.tasks_in_flight_ < max_tasks_in_flight_per_worker_) {
- auto task_spec = queue_entry->second.front();
+ while (!current_queue.empty() &&
+ !lease_entry.PipelineToWorkerFull(max_tasks_in_flight_per_worker_)) {
+ auto task_spec = current_queue.front();
lease_entry
- .tasks_in_flight_++; // Increment the number of tasks in flight to the worker
+ .tasks_in_flight++; // Increment the number of tasks in flight to the worker
+
+ // Increment the total number of tasks in flight to any worker associated with the
+ // current scheduling_key
+
+ RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
+ scheduling_key_entry.total_tasks_in_flight++;
+
executing_tasks_.emplace(task_spec.TaskId(), addr);
PushNormalTask(addr, client, scheduling_key, task_spec, assigned_resources);
- queue_entry->second.pop_front();
+ current_queue.pop_front();
}
// Delete the queue if it's now empty. Note that the queue cannot already be empty
// because this is the only place tasks are removed from it.
- if (queue_entry->second.empty()) {
- task_queues_.erase(queue_entry);
- RAY_LOG(DEBUG) << "Task queue empty, canceling lease request";
+ if (current_queue.empty()) {
+ RAY_LOG(INFO) << "Task queue empty, canceling lease request";
CancelWorkerLeaseIfNeeded(scheduling_key);
}
}
@@ -151,17 +196,18 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
void CoreWorkerDirectTaskSubmitter::CancelWorkerLeaseIfNeeded(
const SchedulingKey &scheduling_key) {
- auto queue_entry = task_queues_.find(scheduling_key);
- if (queue_entry != task_queues_.end()) {
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ auto &task_queue = scheduling_key_entry.task_queue;
+ if (!task_queue.empty()) {
// There are still pending tasks, so let the worker lease request succeed.
return;
}
- auto it = pending_lease_requests_.find(scheduling_key);
- if (it != pending_lease_requests_.end()) {
+ auto &pending_lease_request = scheduling_key_entry.pending_lease_request;
+ if (pending_lease_request.first) {
// There is an in-flight lease request. Cancel it.
- auto &lease_client = it->second.first;
- auto &lease_id = it->second.second;
+ auto &lease_client = pending_lease_request.first;
+ auto &lease_id = pending_lease_request.second;
RAY_LOG(DEBUG) << "Canceling lease request " << lease_id;
lease_client->CancelWorkerLease(
lease_id, [this, scheduling_key](const Status &status,
@@ -208,30 +254,48 @@ CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient(
void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
const SchedulingKey &scheduling_key, const rpc::Address *raylet_address) {
- if (pending_lease_requests_.find(scheduling_key) != pending_lease_requests_.end()) {
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ auto &pending_lease_request = scheduling_key_entry.pending_lease_request;
+
+ if (pending_lease_request.first) {
// There's already an outstanding lease request for this type of task.
return;
}
- auto it = task_queues_.find(scheduling_key);
- if (it == task_queues_.end()) {
+
+ auto &task_queue = scheduling_key_entry.task_queue;
+ if (task_queue.empty()) {
// We don't have any of this type of task to run.
+ if (scheduling_key_entry.CanDelete()) {
+ // We can safely remove the entry keyed by scheduling_key from the
+ // scheduling_key_entries_ hashmap.
+ scheduling_key_entries_.erase(scheduling_key);
+ }
+ return;
+ }
+
+ // Check whether we really need a new worker or whether we have
+ // enough room in an existing worker's pipeline to send the new tasks
+ if (!scheduling_key_entry.AllPipelinesToWorkersFull(max_tasks_in_flight_per_worker_)) {
+ // The pipelines to the current workers are not full yet, so we don't need more
+ // workers.
+
return;
}
auto lease_client = GetOrConnectLeaseClient(raylet_address);
- TaskSpecification &resource_spec = it->second.front();
+ TaskSpecification &resource_spec = task_queue.front();
TaskID task_id = resource_spec.TaskId();
- RAY_LOG(DEBUG) << "Lease requested " << task_id;
lease_client->RequestWorkerLease(
resource_spec, [this, scheduling_key](const Status &status,
const rpc::RequestWorkerLeaseReply &reply) {
absl::MutexLock lock(&mu_);
- auto it = pending_lease_requests_.find(scheduling_key);
- RAY_CHECK(it != pending_lease_requests_.end());
- auto lease_client = std::move(it->second.first);
- const auto task_id = it->second.second;
- pending_lease_requests_.erase(it);
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ auto &pending_lease_request = scheduling_key_entry.pending_lease_request;
+ RAY_CHECK(pending_lease_request.first);
+ auto lease_client = std::move(pending_lease_request.first);
+ const auto task_id = pending_lease_request.second;
+ pending_lease_request = std::make_pair(nullptr, TaskID::Nil());
if (status.ok()) {
if (reply.canceled()) {
@@ -242,8 +306,11 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
// assign work to the worker.
RAY_LOG(DEBUG) << "Lease granted " << task_id;
rpc::WorkerAddress addr(reply.worker_address());
- AddWorkerLeaseClient(addr, std::move(lease_client));
auto resources_copy = reply.resource_mapping();
+
+ AddWorkerLeaseClient(addr, std::move(lease_client), resources_copy,
+ scheduling_key);
+ RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
OnWorkerIdle(addr, scheduling_key,
/*error=*/false, resources_copy);
} else {
@@ -267,9 +334,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
RAY_LOG(FATAL) << status.ToString();
}
});
- RAY_CHECK(pending_lease_requests_
- .emplace(scheduling_key, std::make_pair(lease_client, task_id))
- .second);
+ pending_lease_request = std::make_pair(lease_client, task_id);
}
void CoreWorkerDirectTaskSubmitter::PushNormalTask(
@@ -281,7 +346,6 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
bool is_actor = task_spec.IsActorTask();
bool is_actor_creation = task_spec.IsActorCreationTask();
- RAY_LOG(DEBUG) << "Pushing normal task " << task_spec.TaskId();
// NOTE(swang): CopyFrom is needed because if we use Swap here and the task
// fails, then the task data will be gone when the TaskManager attempts to
// access the task.
@@ -298,14 +362,28 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
// Decrement the number of tasks in flight to the worker
auto &lease_entry = worker_to_lease_entry_[addr];
- RAY_CHECK(lease_entry.tasks_in_flight_ > 0);
- lease_entry.tasks_in_flight_--;
+ RAY_CHECK(lease_entry.tasks_in_flight > 0);
+ lease_entry.tasks_in_flight--;
+
+ // Decrement the total number of tasks in flight to any worker with the current
+ // scheduling_key.
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ RAY_CHECK(scheduling_key_entry.active_workers.size() >= 1);
+ RAY_CHECK(scheduling_key_entry.total_tasks_in_flight >= 1);
+ scheduling_key_entry.total_tasks_in_flight--;
}
if (reply.worker_exiting()) {
// The worker is draining and will shutdown after it is done. Don't return
// it to the Raylet since that will kill it early.
absl::MutexLock lock(&mu_);
worker_to_lease_entry_.erase(addr);
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ scheduling_key_entry.active_workers.erase(addr);
+ if (scheduling_key_entry.CanDelete()) {
+ // We can safely remove the entry keyed by scheduling_key from the
+ // scheduling_key_entries_ hashmap.
+ scheduling_key_entries_.erase(scheduling_key);
+ }
} else if (!status.ok() || !is_actor_creation) {
// Successful actor creation leases the worker indefinitely from the raylet.
absl::MutexLock lock(&mu_);
@@ -340,17 +418,16 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
return Status::OK();
}
- auto scheduled_tasks = task_queues_.find(scheduling_key);
+ auto &scheduling_key_entry = scheduling_key_entries_[scheduling_key];
+ auto &scheduled_tasks = scheduling_key_entry.task_queue;
// This cancels tasks that have completed dependencies and are awaiting
// a worker lease.
- if (scheduled_tasks != task_queues_.end()) {
- for (auto spec = scheduled_tasks->second.begin();
- spec != scheduled_tasks->second.end(); spec++) {
+ if (!scheduled_tasks.empty()) {
+ for (auto spec = scheduled_tasks.begin(); spec != scheduled_tasks.end(); spec++) {
if (spec->TaskId() == task_spec.TaskId()) {
- scheduled_tasks->second.erase(spec);
+ scheduled_tasks.erase(spec);
- if (scheduled_tasks->second.empty()) {
- task_queues_.erase(scheduling_key);
+ if (scheduled_tasks.empty()) {
CancelWorkerLeaseIfNeeded(scheduling_key);
}
RAY_UNUSED(task_finisher_->PendingTaskFailed(task_spec.TaskId(),
@@ -359,6 +436,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
}
}
}
+
// This will get removed either when the RPC call to cancel is returned
// or when all dependencies are resolved.
RAY_CHECK(cancelled_tasks_.emplace(task_spec.TaskId()).second);
@@ -367,6 +445,11 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
if (rpc_client == executing_tasks_.end()) {
// This case is reached for tasks that have unresolved dependencies.
// No executing tasks, so cancelling is a noop.
+ if (scheduling_key_entry.CanDelete()) {
+ // We can safely remove the entry keyed by scheduling_key from the
+ // scheduling_key_entries_ hashmap.
+ scheduling_key_entries_.erase(scheduling_key);
+ }
return Status::OK();
}
// Looks for an RPC handle for the worker executing the task.
@@ -385,10 +468,11 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
request.set_intended_task_id(task_spec.TaskId().Binary());
request.set_force_kill(force_kill);
client->CancelTask(
- request, [this, task_spec, force_kill](const Status &status,
- const rpc::CancelTaskReply &reply) {
+ request, [this, task_spec, scheduling_key, force_kill](
+ const Status &status, const rpc::CancelTaskReply &reply) {
absl::MutexLock lock(&mu_);
cancelled_tasks_.erase(task_spec.TaskId());
+
if (status.ok() && !reply.attempt_succeeded()) {
if (cancel_retry_timer_.has_value()) {
if (cancel_retry_timer_->expiry().time_since_epoch() <=
diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h
index c038bb75c0b2..25985e03f38b 100644
--- a/src/ray/core_worker/transport/direct_task_transport.h
+++ b/src/ray/core_worker/transport/direct_task_transport.h
@@ -86,6 +86,13 @@ class CoreWorkerDirectTaskSubmitter {
Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr,
bool force_kill);
+ /// Check that the scheduling_key_entries_ hashmap is empty by calling the private
+ /// CheckNoSchedulingKeyEntries function after acquiring the lock.
+ bool CheckNoSchedulingKeyEntriesPublic() {
+ absl::MutexLock lock(&mu_);
+ return scheduling_key_entries_.empty();
+ }
+
private:
/// Schedule more work onto an idle worker or return it back to the raylet if
/// no more tasks are queued for submission. If an error was encountered
@@ -122,9 +129,10 @@ class CoreWorkerDirectTaskSubmitter {
EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Set up client state for newly granted worker lease.
- void AddWorkerLeaseClient(const rpc::WorkerAddress &addr,
- std::shared_ptr lease_client)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void AddWorkerLeaseClient(
+ const rpc::WorkerAddress &addr, std::shared_ptr lease_client,
+ const google::protobuf::RepeatedPtrField &assigned_resources,
+ const SchedulingKey &scheduling_key) EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Push a task to a specific worker.
void PushNormalTask(const rpc::WorkerAddress &addr,
@@ -134,6 +142,11 @@ class CoreWorkerDirectTaskSubmitter {
const google::protobuf::RepeatedPtrField
&assigned_resources);
+ /// Check that the scheduling_key_entries_ hashmap is empty.
+ bool CheckNoSchedulingKeyEntries() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return scheduling_key_entries_.empty();
+ }
+
/// Address of our RPC server.
rpc::Address rpc_address_;
@@ -178,31 +191,75 @@ class CoreWorkerDirectTaskSubmitter {
/// (1) The lease client through which the worker should be returned
/// (2) The expiration time of a worker's lease.
/// (3) The number of tasks that are currently in flight to the worker
+ /// (4) The resources assigned to the worker
+ /// (5) The SchedulingKey assigned to tasks that will be sent to the worker
struct LeaseEntry {
- std::shared_ptr lease_client_;
- int64_t lease_expiration_time_;
- uint32_t tasks_in_flight_;
-
- LeaseEntry(std::shared_ptr lease_client = nullptr,
- int64_t lease_expiration_time = 0, uint32_t tasks_in_flight = 0)
- : lease_client_(lease_client),
- lease_expiration_time_(lease_expiration_time),
- tasks_in_flight_(tasks_in_flight) {}
+ std::shared_ptr lease_client;
+ int64_t lease_expiration_time;
+ uint32_t tasks_in_flight;
+ google::protobuf::RepeatedPtrField assigned_resources;
+ SchedulingKey scheduling_key;
+
+ LeaseEntry(
+ std::shared_ptr lease_client = nullptr,
+ int64_t lease_expiration_time = 0, uint32_t tasks_in_flight = 0,
+ google::protobuf::RepeatedPtrField assigned_resources =
+ google::protobuf::RepeatedPtrField(),
+ SchedulingKey scheduling_key = std::make_tuple(0, std::vector(),
+ ActorID::Nil()))
+ : lease_client(lease_client),
+ lease_expiration_time(lease_expiration_time),
+ tasks_in_flight(tasks_in_flight),
+ assigned_resources(assigned_resources),
+ scheduling_key(scheduling_key) {}
+
+ // Check whether the pipeline to the worker associated with a LeaseEntry is full.
+ bool PipelineToWorkerFull(uint32_t max_tasks_in_flight_per_worker) const {
+ return tasks_in_flight == max_tasks_in_flight_per_worker;
+ }
};
// Map from worker address to a LeaseEntry struct containing the lease's metadata.
absl::flat_hash_map worker_to_lease_entry_
GUARDED_BY(mu_);
- // Keeps track of pending worker lease requests to the raylet.
- absl::flat_hash_map, TaskID>>
- pending_lease_requests_ GUARDED_BY(mu_);
+ struct SchedulingKeyEntry {
+ // Keep track of pending worker lease requests to the raylet.
+ std::pair, TaskID> pending_lease_request =
+ std::make_pair(nullptr, TaskID::Nil());
+ // Tasks that are queued for execution. We keep an individual queue per
+ // scheduling class to ensure fairness.
+ std::deque task_queue = std::deque();
+ // Keep track of the active workers, so that we can quickly check if one of them has
+ // room for more tasks in flight
+ absl::flat_hash_set active_workers =
+ absl::flat_hash_set();
+ // Keep track of how many tasks with this SchedulingKey are in flight, in total
+ uint32_t total_tasks_in_flight = 0;
+
+ // Check whether it's safe to delete this SchedulingKeyEntry from the
+ // scheduling_key_entries_ hashmap.
+ bool CanDelete() const {
+ if (!pending_lease_request.first && task_queue.empty() &&
+ active_workers.size() == 0 && total_tasks_in_flight == 0) {
+ return true;
+ }
+
+ return false;
+ }
+
+ // Check whether the pipelines to the active workers associated with a
+ // SchedulingKeyEntry are all full.
+ bool AllPipelinesToWorkersFull(uint32_t max_tasks_in_flight_per_worker) const {
+ return total_tasks_in_flight ==
+ (active_workers.size() * max_tasks_in_flight_per_worker);
+ }
+ };
- // Tasks that are queued for execution. We keep individual queues per
- // scheduling class to ensure fairness.
- // Invariant: if a queue is in this map, it has at least one task.
- absl::flat_hash_map> task_queues_
+ // For each Scheduling Key, scheduling_key_entries_ contains a SchedulingKeyEntry struct
+ // with the queue of tasks belonging to that SchedulingKey, together with the other
+ // fields that are needed to orchestrate the execution of those tasks by the workers.
+ absl::flat_hash_map scheduling_key_entries_
GUARDED_BY(mu_);
// Tasks that were cancelled while being resolved.
diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto
index 1c622b602955..fa2cf15a6d70 100644
--- a/src/ray/protobuf/common.proto
+++ b/src/ray/protobuf/common.proto
@@ -384,6 +384,8 @@ message CoreWorkerStats {
string actor_title = 16;
// Local reference table.
repeated ObjectRefInfo object_refs = 17;
+ // Job ID.
+ bytes job_id = 18;
}
message MetricPoint {
diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto
index c33f0abcbbee..45c22be962b5 100644
--- a/src/ray/protobuf/node_manager.proto
+++ b/src/ray/protobuf/node_manager.proto
@@ -114,6 +114,10 @@ message WorkerStats {
CoreWorkerStats core_worker_stats = 3;
// Error string if fetching core worker stats failed.
string fetch_error = 4;
+ // Worker id of core worker.
+ bytes worker_id = 5;
+ // Worker language.
+ Language language = 6;
}
message GetNodeStatsReply {
@@ -122,6 +126,7 @@ message GetNodeStatsReply {
uint32 num_workers = 3;
repeated TaskSpec infeasible_tasks = 4;
repeated TaskSpec ready_tasks = 5;
+ int32 pid = 6;
}
message GlobalGCRequest {
diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc
index 3007da6db279..bce98cba5fd6 100644
--- a/src/ray/raylet/node_manager.cc
+++ b/src/ray/raylet/node_manager.cc
@@ -3147,6 +3147,7 @@ void NodeManager::FlushObjectsToFree() {
void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_request,
rpc::GetNodeStatsReply *reply,
rpc::SendReplyCallback send_reply_callback) {
+ reply->set_pid(getpid());
for (const auto &task : local_queues_.GetTasks(TaskState::INFEASIBLE)) {
if (task.GetTaskSpecification().IsActorCreationTask()) {
auto infeasible_task = reply->add_infeasible_tasks();
@@ -3211,6 +3212,10 @@ void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_
all_workers.push_back(driver);
driver_ids.insert(driver->WorkerId());
}
+ if (all_workers.empty()) {
+ send_reply_callback(Status::OK(), nullptr, nullptr);
+ return;
+ }
for (const auto &worker : all_workers) {
rpc::GetCoreWorkerStatsRequest request;
request.set_intended_worker_id(worker->WorkerId().Binary());
@@ -3220,7 +3225,9 @@ void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_
const ray::Status &status, const rpc::GetCoreWorkerStatsReply &r) {
auto worker_stats = reply->add_workers_stats();
worker_stats->set_pid(worker->GetProcess().GetId());
+ worker_stats->set_worker_id(worker->WorkerId().Binary());
worker_stats->set_is_driver(driver_ids.contains(worker->WorkerId()));
+ worker_stats->set_language(worker->GetLanguage());
reply->set_num_workers(reply->num_workers() + 1);
if (status.ok()) {
worker_stats->mutable_core_worker_stats()->MergeFrom(r.core_worker_stats());