diff --git a/python/ray/dashboard/client/src/pages/dashboard/state.ts b/python/ray/dashboard/client/src/pages/dashboard/state.ts index 68cad9edd65c..732418378b2a 100644 --- a/python/ray/dashboard/client/src/pages/dashboard/state.ts +++ b/python/ray/dashboard/client/src/pages/dashboard/state.ts @@ -62,7 +62,10 @@ const slice = createSlice({ tuneAvailability: TuneAvailabilityResponse; }> ) => { - state.tuneAvailability = action.payload.tuneAvailability["available"]; + const tuneAvailability = action.payload.tuneAvailability === null ? + false : + action.payload.tuneAvailability["available"]; + state.tuneAvailability = tuneAvailability; state.lastUpdatedAt = Date.now(); }, setError: (state, action: PayloadAction) => { diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 2f37e48f09df..b2499defd222 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -5,7 +5,6 @@ import sys sys.exit(1) -import argparse import copy import datetime import errno @@ -27,13 +26,24 @@ import grpc from google.protobuf.json_format import MessageToDict import ray + from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc from ray.core.generated import reporter_pb2 from ray.core.generated import reporter_pb2_grpc from ray.core.generated import core_worker_pb2 from ray.core.generated import core_worker_pb2_grpc -import ray.ray_constants as ray_constants +from ray.dashboard.interface.dashboard_controller_interface import ( + BaseDashboardController) +from ray.dashboard.interface.dashboard_route_handler_interface import ( + BaseDashboardRouteHandler) +try: + from ray.dashboard.metrics_exporter.exporter import Exporter + from ray.dashboard.metrics_exporter.client import MetricsExportClient +except ImportError: + print("Pydantic is not donwloaded. pip install pydantic") + MetricsExportClient = None + Exporter = None try: from ray.tune.result import DEFAULT_RESULTS_DIR @@ -96,15 +106,381 @@ def b64_decode(reply): return b64decode(reply).decode("utf-8") -class Dashboard(object): +async def json_response(is_dev, result=None, error=None, + ts=None) -> aiohttp.web.Response: + if ts is None: + ts = datetime.datetime.utcnow() + + headers = None + if is_dev: + headers = {"Access-Control-Allow-Origin": "*"} + + return aiohttp.web.json_response( + { + "result": result, + "timestamp": to_unix_time(ts), + "error": error, + }, + headers=headers) + + +class DashboardController(BaseDashboardController): + def __init__(self, redis_address, redis_password): + self.node_stats = NodeStats(redis_address, redis_password) + self.raylet_stats = RayletStats( + redis_address, redis_password=redis_password) + if Analysis is not None: + self.tune_stats = TuneCollector(DEFAULT_RESULTS_DIR, 2.0) + + def _construct_raylet_info(self): + D = self.raylet_stats.get_raylet_stats() + workers_info_by_node = { + data["nodeId"]: data.get("workersStats") + for data in D.values() + } + infeasible_tasks = sum( + (data.get("infeasibleTasks", []) for data in D.values()), []) + # ready_tasks are used to render tasks that are not schedulable + # due to resource limitations. + # (e.g., Actor requires 2 GPUs but there is only 1 gpu available). + ready_tasks = sum((data.get("readyTasks", []) for data in D.values()), + []) + actor_tree = self.node_stats.get_actor_tree( + workers_info_by_node, infeasible_tasks, ready_tasks) + for address, data in D.items(): + # process view data + measures_dicts = {} + for view_data in data["viewData"]: + view_name = view_data["viewName"] + if view_name in ("local_available_resource", + "local_total_resource", + "object_manager_stats"): + measures_dicts[view_name] = measures_to_dict( + view_data["measures"]) + # process resources info + extra_info_strings = [] + prefix = "ResourceName:" + for resource_name, total_resource in measures_dicts[ + "local_total_resource"].items(): + available_resource = measures_dicts[ + "local_available_resource"].get(resource_name, .0) + resource_name = resource_name[len(prefix):] + extra_info_strings.append("{}: {} / {}".format( + resource_name, + format_resource(resource_name, + total_resource - available_resource), + format_resource(resource_name, total_resource))) + data["extraInfo"] = ", ".join(extra_info_strings) + "\n" + if os.environ.get("RAY_DASHBOARD_DEBUG"): + # process object store info + extra_info_strings = [] + prefix = "ValueType:" + for stats_name in [ + "used_object_store_memory", "num_local_objects" + ]: + stats_value = measures_dicts["object_manager_stats"].get( + prefix + stats_name, .0) + extra_info_strings.append("{}: {}".format( + stats_name, stats_value)) + data["extraInfo"] += ", ".join(extra_info_strings) + # process actor info + actor_tree_str = json.dumps( + actor_tree, indent=2, sort_keys=True) + lines = actor_tree_str.split("\n") + max_line_length = max(map(len, lines)) + to_print = [] + for line in lines: + to_print.append(line + (max_line_length - len(line)) * " ") + data["extraInfo"] += "\n" + "\n".join(to_print) + return {"nodes": D, "actors": actor_tree} + + def get_ray_config(self): + try: + config_path = os.path.expanduser("~/ray_bootstrap_config.yaml") + with open(config_path) as f: + cfg = yaml.safe_load(f) + except Exception: + error = "No config" + return error, None + + D = { + "min_workers": cfg["min_workers"], + "max_workers": cfg["max_workers"], + "initial_workers": cfg["initial_workers"], + "autoscaling_mode": cfg["autoscaling_mode"], + "idle_timeout_minutes": cfg["idle_timeout_minutes"], + } + + try: + D["head_type"] = cfg["head_node"]["InstanceType"] + except KeyError: + D["head_type"] = "unknown" + + try: + D["worker_type"] = cfg["worker_nodes"]["InstanceType"] + except KeyError: + D["worker_type"] = "unknown" + + return None, D + + def get_node_info(self): + return self.node_stats.get_node_stats() + + def get_raylet_info(self): + return self._construct_raylet_info() + + def tune_info(self): + if Analysis is not None: + D = self.tune_stats.get_stats() + else: + D = {} + return D + + def tune_availability(self): + if Analysis is not None: + D = self.tune_stats.get_availability() + else: + D = {"available": False} + return D + + def launch_profiling(self, node_id, pid, duration): + profiling_id = self.raylet_stats.launch_profiling( + node_id=node_id, pid=pid, duration=duration) + return profiling_id + + def check_profiling_status(self, profiling_id): + return self.raylet_stats.check_profiling_status(profiling_id) + + def get_profiling_info(self, profiling_id): + return self.raylet_stats.get_profiling_info(profiling_id) + + def kill_actor(self, actor_id, ip_address, port): + return self.raylet_stats.kill_actor(actor_id, ip_address, port) + + def get_logs(self, hostname, pid): + return self.node_stats.get_logs(hostname, pid) + + def get_errors(self, hostname, pid): + return self.node_stats.get_errors(hostname, pid) + + def start_collecting_metrics(self): + self.node_stats.start() + self.raylet_stats.start() + + +class DashboardRouteHandler(BaseDashboardRouteHandler): + def __init__(self, dashboard_controller: DashboardController, + is_dev=False): + self.dashboard_controller = dashboard_controller + self.is_dev = is_dev + + def forbidden(self) -> aiohttp.web.Response: + return aiohttp.web.Response(status=403, text="403 Forbidden") + + def get_forbidden(self, _) -> aiohttp.web.Response: + return self.forbidden() + + async def get_index(self, req) -> aiohttp.web.Response: + return aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "client/build/index.html")) + + async def get_favicon(self, req) -> aiohttp.web.Response: + return aiohttp.web.FileResponse( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "client/build/favicon.ico")) + + async def ray_config(self, req) -> aiohttp.web.Response: + error, result = self.dashboard_controller.get_ray_config() + if error: + return await json_response(self.is_dev, error=error) + return await json_response(self.is_dev, result=result) + + async def node_info(self, req) -> aiohttp.web.Response: + now = datetime.datetime.utcnow() + D = self.dashboard_controller.get_node_info() + return await json_response(self.is_dev, result=D, ts=now) + + async def raylet_info(self, req) -> aiohttp.web.Response: + result = self.dashboard_controller.get_raylet_info() + return await json_response(self.is_dev, result=result) + + async def tune_info(self, req) -> aiohttp.web.Response: + result = self.dashboard_controller.tune_info() + return await json_response(self.is_dev, result=result) + + async def tune_availability(self, req) -> aiohttp.web.Response: + result = self.dashboard_controller.tune_availability() + return await json_response(self.is_dev, result=result) + + async def launch_profiling(self, req) -> aiohttp.web.Response: + node_id = req.query.get("node_id") + pid = int(req.query.get("pid")) + duration = int(req.query.get("duration")) + profiling_id = self.dashboard_controller.launch_profiling( + node_id, pid, duration) + return await json_response(self.is_dev, result=str(profiling_id)) + + async def check_profiling_status(self, req) -> aiohttp.web.Response: + profiling_id = req.query.get("profiling_id") + status = self.dashboard_controller.check_profiling_status(profiling_id) + return await json_response(self.is_dev, result=status) + + async def get_profiling_info(self, req) -> aiohttp.web.Response: + profiling_id = req.query.get("profiling_id") + profiling_info = self.dashboard_controller.get_profiling_info( + profiling_id) + return aiohttp.web.json_response(self.is_dev, profiling_info) + + async def kill_actor(self, req) -> aiohttp.web.Response: + actor_id = req.query.get("actor_id") + ip_address = req.query.get("ip_address") + port = req.query.get("port") + return await json_response( + self.is_dev, + self.dashboard_controller.kill_actor(actor_id, ip_address, port)) + + async def logs(self, req) -> aiohttp.web.Response: + hostname = req.query.get("hostname") + pid = req.query.get("pid") + result = self.dashboard_controller.get_logs(hostname, pid) + return await json_response(self.is_dev, result=result) + + async def errors(self, req) -> aiohttp.web.Response: + hostname = req.query.get("hostname") + pid = req.query.get("pid") + result = self.dashboard_controller.get_errors(hostname, pid) + return await json_response(self.is_dev, result=result) + + +class MetricsExportHandler: + def __init__(self, + dashboard_controller: DashboardController, + metrics_export_client: MetricsExportClient, + dashboard_id, + is_dev=False): + assert metrics_export_client is not None + self.metrics_export_client = metrics_export_client + self.dashboard_controller = dashboard_controller + self.is_dev = is_dev + + async def enable_export_metrics(self, req) -> aiohttp.web.Response: + if self.metrics_export_client.enabled: + return await json_response( + self.is_dev, result={"url": None}, error="Already enabled") + + succeed = self.metrics_export_client.start_exporting_metrics() + if not succeed: + return await json_response( + self.is_dev, + result={"url": None}, + error="Failed to enable metrics exporting.") + + url = self.metrics_export_client.dashboard_url + return await json_response(self.is_dev, result={"url": url}) + + async def get_dashboard_address(self, req) -> aiohttp.web.Response: + if not self.metrics_export_client.enabled: + return await json_response( + self.is_dev, + result={"url": None}, + error="Metrics exporting is not enabled.") + + url = self.metrics_export_client.dashboard_url + return await json_response(self.is_dev, result={"url": url}) + + async def redirect_to_dashboard(self, req) -> aiohttp.web.Response: + if not self.metrics_export_client.enabled: + return await json_response( + self.is_dev, + result={"url": None}, + error="You should enable metrics export to use this endpoint.") + + raise aiohttp.web.HTTPFound(self.metrics_export_client.dashboard_url) + + +def setup_metrics_export_routes(app: aiohttp.web.Application, + handler: MetricsExportHandler): + """Routes that require dynamically changing class attributes.""" + app.router.add_get("/api/enable_metrics_export", + handler.enable_export_metrics) + app.router.add_get("/api/dashboard_url", handler.get_dashboard_address) + app.router.add_get("/dashboard", handler.redirect_to_dashboard) + + +def setup_static_dir(app): + build_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "client/build") + if not os.path.isdir(build_dir): + raise OSError( + errno.ENOENT, "Dashboard build directory not found. If installing " + "from source, please follow the additional steps " + "required to build the dashboard" + "(cd python/ray/dashboard/client " + "&& npm ci " + "&& npm run build)", build_dir) + + static_dir = os.path.join(build_dir, "static") + app.router.add_static("/static", static_dir) + return build_dir + + +def setup_speedscope_dir(app, build_dir): + speedscope_dir = os.path.join(build_dir, "speedscope-1.5.3") + app.router.add_static("/speedscope", speedscope_dir) + + +def setup_dashboard_route(app: aiohttp.web.Application, + handler: BaseDashboardRouteHandler, + index=None, + favicon=None, + ray_config=None, + node_info=None, + raylet_info=None, + tune_info=None, + tune_availability=None, + launch_profiling=None, + check_profiling_status=None, + get_profiling_info=None, + kill_actor=None, + logs=None, + errors=None): + def add_get_route(route, handler_func): + if route is not None: + app.router.add_get(route, handler_func) + + add_get_route(index, handler.get_index) + add_get_route(favicon, handler.get_favicon) + add_get_route(ray_config, handler.ray_config) + add_get_route(node_info, handler.node_info) + add_get_route(raylet_info, handler.raylet_info) + add_get_route(tune_info, handler.tune_info) + add_get_route(tune_availability, handler.tune_availability) + add_get_route(launch_profiling, handler.launch_profiling) + add_get_route(check_profiling_status, handler.check_profiling_status) + add_get_route(get_profiling_info, handler.get_profiling_info) + add_get_route(kill_actor, handler.kill_actor) + add_get_route(logs, handler.logs) + add_get_route(errors, handler.errors) + + +class Dashboard: """A dashboard process for monitoring Ray nodes. This dashboard is made up of a REST API which collates data published by Reporter processes on nodes into a json structure, and a webserver which polls said API for display purposes. - Attributes: - redis_client: A client used to communicate with the Redis server. + Args: + host(str): Host address of dashboard aiohttp server. + port(str): Port number of dashboard aiohttp server. + redis_address(str): GCS address of a Ray cluster + temp_dir (str): The temporary directory used for log files and + information for this Ray session. + redis_passord(str): Redis password to access GCS + metrics_export_address(str): The address users host their dashboard. """ def __init__(self, @@ -112,18 +488,16 @@ def __init__(self, port, redis_address, temp_dir, - redis_password=None): - """Initialize the dashboard object.""" + redis_password=None, + metrics_export_address=None): self.host = host self.port = port self.redis_client = ray.services.create_redis_client( redis_address, password=redis_password) self.temp_dir = temp_dir - - self.node_stats = NodeStats(redis_address, redis_password) - self.raylet_stats = RayletStats(redis_address, redis_password) - if Analysis is not None: - self.tune_stats = TuneCollector(DEFAULT_RESULTS_DIR, 2.0) + self.dashboard_id = str(uuid.uuid4()) + self.dashboard_controller = DashboardController( + redis_address, redis_password) # Setting the environment variable RAY_DASHBOARD_DEV=1 disables some # security checks in the dashboard server to ease development while @@ -132,240 +506,72 @@ def __init__(self, self.is_dev = os.environ.get("RAY_DASHBOARD_DEV") == "1" self.app = aiohttp.web.Application() - self.setup_routes() - - def setup_routes(self): - def forbidden() -> aiohttp.web.Response: - return aiohttp.web.Response(status=403, text="403 Forbidden") - - def get_forbidden(_) -> aiohttp.web.Response: - return forbidden() - - async def get_index(req) -> aiohttp.web.Response: - return aiohttp.web.FileResponse( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "client/build/index.html")) - - async def get_favicon(req) -> aiohttp.web.Response: - return aiohttp.web.FileResponse( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "client/build/favicon.ico")) - - async def json_response(result=None, error=None, - ts=None) -> aiohttp.web.Response: - if ts is None: - ts = datetime.datetime.utcnow() - - headers = None - if self.is_dev: - headers = {"Access-Control-Allow-Origin": "*"} - - return aiohttp.web.json_response( - { - "result": result, - "timestamp": to_unix_time(ts), - "error": error, - }, - headers=headers) - - async def ray_config(_) -> aiohttp.web.Response: - try: - config_path = os.path.expanduser("~/ray_bootstrap_config.yaml") - with open(config_path) as f: - cfg = yaml.safe_load(f) - except Exception: - return await json_response(error="No config") - - D = { - "min_workers": cfg["min_workers"], - "max_workers": cfg["max_workers"], - "initial_workers": cfg["initial_workers"], - "autoscaling_mode": cfg["autoscaling_mode"], - "idle_timeout_minutes": cfg["idle_timeout_minutes"], - } - - try: - D["head_type"] = cfg["head_node"]["InstanceType"] - except KeyError: - D["head_type"] = "unknown" - - try: - D["worker_type"] = cfg["worker_nodes"]["InstanceType"] - except KeyError: - D["worker_type"] = "unknown" - - return await json_response(result=D) - - async def node_info(req) -> aiohttp.web.Response: - now = datetime.datetime.utcnow() - D = self.node_stats.get_node_stats() - return await json_response(result=D, ts=now) - - async def raylet_info(req) -> aiohttp.web.Response: - D = self.raylet_stats.get_raylet_stats() - workers_info_by_node = { - data["nodeId"]: data.get("workersStats") - for data in D.values() - } - infeasible_tasks = sum( - (data.get("infeasibleTasks", []) for data in D.values()), []) - # ready_tasks are used to render tasks that are not schedulable - # due to resource limitations. - # (e.g., Actor requires 2 GPUs but there is only 1 gpu available). - ready_tasks = sum( - (data.get("readyTasks", []) for data in D.values()), []) - actor_tree = self.node_stats.get_actor_tree( - workers_info_by_node, infeasible_tasks, ready_tasks) - for address, data in D.items(): - # process view data - measures_dicts = {} - for view_data in data["viewData"]: - view_name = view_data["viewName"] - if view_name in ("local_available_resource", - "local_total_resource", - "object_manager_stats"): - measures_dicts[view_name] = measures_to_dict( - view_data["measures"]) - # process resources info - extra_info_strings = [] - prefix = "ResourceName:" - for resource_name, total_resource in measures_dicts[ - "local_total_resource"].items(): - available_resource = measures_dicts[ - "local_available_resource"].get(resource_name, .0) - resource_name = resource_name[len(prefix):] - extra_info_strings.append("{}: {} / {}".format( - resource_name, - format_resource(resource_name, - total_resource - available_resource), - format_resource(resource_name, total_resource))) - data["extraInfo"] = ", ".join(extra_info_strings) + "\n" - if os.environ.get("RAY_DASHBOARD_DEBUG"): - # process object store info - extra_info_strings = [] - prefix = "ValueType:" - for stats_name in [ - "used_object_store_memory", "num_local_objects" - ]: - stats_value = measures_dicts[ - "object_manager_stats"].get( - prefix + stats_name, .0) - extra_info_strings.append("{}: {}".format( - stats_name, stats_value)) - data["extraInfo"] += ", ".join(extra_info_strings) - # process actor info - actor_tree_str = json.dumps( - actor_tree, indent=2, sort_keys=True) - lines = actor_tree_str.split("\n") - max_line_length = max(map(len, lines)) - to_print = [] - for line in lines: - to_print.append(line + - (max_line_length - len(line)) * " ") - data["extraInfo"] += "\n" + "\n".join(to_print) - result = {"nodes": D, "actors": actor_tree} - return await json_response(result=result) - - async def tune_info(req) -> aiohttp.web.Response: - if Analysis is not None: - D = self.tune_stats.get_stats() - else: - D = {} - return await json_response(result=D) - - async def tune_availability(req) -> aiohttp.web.Response: - if Analysis is not None: - D = self.tune_stats.get_availability() - else: - D = {"available": False} - return await json_response(result=D) - - async def launch_profiling(req) -> aiohttp.web.Response: - node_id = req.query.get("node_id") - pid = int(req.query.get("pid")) - duration = int(req.query.get("duration")) - profiling_id = self.raylet_stats.launch_profiling( - node_id=node_id, pid=pid, duration=duration) - return await json_response(str(profiling_id)) - - async def check_profiling_status(req) -> aiohttp.web.Response: - profiling_id = req.query.get("profiling_id") - return await json_response( - self.raylet_stats.check_profiling_status(profiling_id)) - - async def get_profiling_info(req) -> aiohttp.web.Response: - profiling_id = req.query.get("profiling_id") - return aiohttp.web.json_response( - self.raylet_stats.get_profiling_info(profiling_id)) - - async def kill_actor(req) -> aiohttp.web.Response: - actor_id = req.query.get("actor_id") - ip_address = req.query.get("ip_address") - port = req.query.get("port") - return await json_response( - self.raylet_stats.kill_actor(actor_id, ip_address, port)) - - async def logs(req) -> aiohttp.web.Response: - hostname = req.query.get("hostname") - pid = req.query.get("pid") - result = self.node_stats.get_logs(hostname, pid) - return await json_response(result=result) - - async def errors(req) -> aiohttp.web.Response: - hostname = req.query.get("hostname") - pid = req.query.get("pid") - result = self.node_stats.get_errors(hostname, pid) - return await json_response(result=result) - - self.app.router.add_get("/", get_index) - self.app.router.add_get("/favicon.ico", get_favicon) - - build_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "client/build") - if not os.path.isdir(build_dir): - raise OSError( - errno.ENOENT, - "Dashboard build directory not found. If installing " - "from source, please follow the additional steps required to " - "build the dashboard " - "(cd python/ray/dashboard/client && npm ci && npm run build)", - build_dir) - - static_dir = os.path.join(build_dir, "static") - self.app.router.add_static("/static", static_dir) - - speedscope_dir = os.path.join(build_dir, "speedscope-1.5.3") - self.app.router.add_static("/speedscope", speedscope_dir) - - self.app.router.add_get("/api/ray_config", ray_config) - self.app.router.add_get("/api/node_info", node_info) - self.app.router.add_get("/api/raylet_info", raylet_info) - self.app.router.add_get("/api/tune_info", tune_info) - self.app.router.add_get("/api/tune_availability", tune_availability) - self.app.router.add_get("/api/launch_profiling", launch_profiling) - self.app.router.add_get("/api/check_profiling_status", - check_profiling_status) - self.app.router.add_get("/api/get_profiling_info", get_profiling_info) - self.app.router.add_get("/api/kill_actor", kill_actor) - self.app.router.add_get("/api/logs", logs) - self.app.router.add_get("/api/errors", errors) - - self.app.router.add_get("/{_}", get_forbidden) + route_handler = DashboardRouteHandler( + self.dashboard_controller, is_dev=self.is_dev) + + # Setup Metrics exporting service if necessary. + self.metrics_export_address = metrics_export_address + if self.metrics_export_address: + self._setup_metrics_export() + + # Setup Dashboard Routes + build_dir = setup_static_dir(self.app) + setup_speedscope_dir(self.app, build_dir) + setup_dashboard_route( + self.app, + route_handler, + index="/", + favicon="/favicon.ico", + ray_config="/api/ray_config", + node_info="/api/node_info", + raylet_info="/api/raylet_info", + tune_info="/api/tune_info", + tune_availability="/api/tune_availability", + launch_profiling="/api/launch_profiling", + check_profiling_status="/api/check_profiling_status", + get_profiling_info="/api/get_profiling_info", + kill_actor="/api/kill_actor", + logs="/api/logs", + errors="/api/errors") + self.app.router.add_get("/{_}", route_handler.get_forbidden) + + def _setup_metrics_export(self): + exporter = Exporter(self.dashboard_id, self.metrics_export_address, + self.dashboard_controller) + self.metrics_export_client = MetricsExportClient( + self.metrics_export_address, self.dashboard_controller, + self.dashboard_id, exporter) + + # Setup endpoints + metrics_export_handler = MetricsExportHandler( + self.dashboard_controller, + self.metrics_export_client, + self.dashboard_id, + is_dev=self.is_dev) + setup_metrics_export_routes(self.app, metrics_export_handler) + + def _start_exporting_metrics(self): + result, error = self.metrics_export_client.start_exporting_metrics() + if not result and error: + url = ray.services.get_webui_url_from_redis(self.redis_client) + error += (" Please reenable the metrics export by going to " + "the url: {}/api/enable_metrics_export".format(url)) + ray.utils.push_error_to_driver_through_redis( + self.redis_client, "metrics export failed", error) def log_dashboard_url(self): url = ray.services.get_webui_url_from_redis(self.redis_client) + if url is None: + raise ValueError("WebUI URL is not present in GCS.") with open(os.path.join(self.temp_dir, "dashboard_url"), "w") as f: f.write(url) logger.info("Dashboard running on {}".format(url)) def run(self): self.log_dashboard_url() - self.node_stats.start() - self.raylet_stats.start() - if Analysis is not None: - self.tune_stats.start() + self.dashboard_controller.start_collecting_metrics() + if self.metrics_export_address: + self._start_exporting_metrics() aiohttp.web.run_app(self.app, host=self.host, port=self.port) @@ -409,7 +615,7 @@ def __init__(self, redis_address, redis_password=None): super().__init__() - def calculate_log_counts(self): + def _calculate_log_counts(self): return { ip: { pid: len(logs_for_pid) @@ -418,7 +624,7 @@ def calculate_log_counts(self): for ip, logs_for_ip in self._logs.items() } - def calculate_error_counts(self): + def _calculate_error_counts(self): return { ip: { pid: len(errors_for_pid) @@ -427,7 +633,7 @@ def calculate_error_counts(self): for ip, errors_for_ip in self._errors.items() } - def purge_outdated_stats(self): + def _purge_outdated_stats(self): def current(then, now): if (now - then) > 5: return False @@ -442,14 +648,14 @@ def current(then, now): def get_node_stats(self) -> Dict: with self._node_stats_lock: - self.purge_outdated_stats() + self._purge_outdated_stats() node_stats = sorted( (v for v in self._node_stats.values()), key=itemgetter("boot_time")) return { "clients": node_stats, - "log_counts": self.calculate_log_counts(), - "error_counts": self.calculate_error_counts(), + "log_counts": self._calculate_log_counts(), + "error_counts": self._calculate_error_counts(), } def get_actor_tree(self, workers_info_by_node, infeasible_tasks, @@ -606,6 +812,7 @@ def run(self): else: data = json.loads(ray.utils.decode(data)) self._node_stats[data["hostname"]] = data + except Exception: logger.exception(traceback.format_exc()) continue @@ -624,11 +831,11 @@ def __init__(self, redis_address, redis_password=None): self._raylet_stats = {} self._profiling_stats = {} - self.update_nodes() + self._update_nodes() super().__init__() - def update_nodes(self): + def _update_nodes(self): with self.nodes_lock: self.nodes = ray.nodes() node_ids = [node["NodeID"] for node in self.nodes] @@ -688,15 +895,15 @@ def _callback(reply_future): def check_profiling_status(self, profiling_id): with self._raylet_stats_lock: is_present = profiling_id in self._profiling_stats - if is_present: - reply = self._profiling_stats[profiling_id] - if reply.stderr: - return {"status": "error", "error": reply.stderr} - else: - return {"status": "finished"} - else: + if not is_present: return {"status": "pending"} + reply = self._profiling_stats[profiling_id] + if reply.stderr: + return {"status": "error", "error": reply.stderr} + else: + return {"status": "finished"} + def get_profiling_info(self, profiling_id): with self._raylet_stats_lock: profiling_stats = self._profiling_stats.get(profiling_id) @@ -721,22 +928,27 @@ def run(self): while True: time.sleep(1.0) replies = {} - for node in self.nodes: - node_id = node["NodeID"] - stub = self.stubs[node_id] - reply = stub.GetNodeStats( - node_manager_pb2.GetNodeStatsRequest(), timeout=2) - reply_dict = MessageToDict(reply) - reply_dict["nodeId"] = node_id - replies[node["NodeManagerAddress"]] = reply_dict - with self._raylet_stats_lock: - for address, reply_dict in replies.items(): - self._raylet_stats[address] = reply_dict - counter += 1 - # From time to time, check if new nodes have joined the cluster - # and update self.nodes - if counter % 10: - self.update_nodes() + + try: + for node in self.nodes: + node_id = node["NodeID"] + stub = self.stubs[node_id] + reply = stub.GetNodeStats( + node_manager_pb2.GetNodeStatsRequest(), timeout=2) + reply_dict = MessageToDict(reply) + reply_dict["nodeId"] = node_id + replies[node["NodeManagerAddress"]] = reply_dict + with self._raylet_stats_lock: + for address, reply_dict in replies.items(): + self._raylet_stats[address] = reply_dict + except Exception: + logger.exception(traceback.format_exc()) + finally: + counter += 1 + # From time to time, check if new nodes have joined the cluster + # and update self.nodes + if counter % 10: + self._update_nodes() class TuneCollector(threading.Thread): @@ -749,7 +961,6 @@ class TuneCollector(threading.Thread): """ def __init__(self, logdir, reload_interval): - super().__init__() self._logdir = logdir self._trial_records = {} self._data_lock = threading.Lock() @@ -757,6 +968,9 @@ def __init__(self, logdir, reload_interval): self._available = False self._tensor_board_started = False + os.makedirs(self._logdir, exist_ok=True) + super().__init__() + def get_stats(self): with self._data_lock: return {"trial_records": copy.deepcopy(self._trial_records)} @@ -871,74 +1085,3 @@ def clean_trials(self, trial_details, job_name): details["job_id"] = job_name return trial_details - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=("Parse Redis server for the " - "dashboard to connect to.")) - parser.add_argument( - "--host", - required=True, - type=str, - help="The host to use for the HTTP server.") - parser.add_argument( - "--port", - required=True, - type=int, - help="The port to use for the HTTP server.") - parser.add_argument( - "--redis-address", - required=True, - type=str, - help="The address to use for Redis.") - parser.add_argument( - "--redis-password", - required=False, - type=str, - default=None, - help="the password to use for Redis") - parser.add_argument( - "--logging-level", - required=False, - type=str, - default=ray_constants.LOGGER_LEVEL, - choices=ray_constants.LOGGER_LEVEL_CHOICES, - help=ray_constants.LOGGER_LEVEL_HELP) - parser.add_argument( - "--logging-format", - required=False, - type=str, - default=ray_constants.LOGGER_FORMAT, - help=ray_constants.LOGGER_FORMAT_HELP) - parser.add_argument( - "--temp-dir", - required=False, - type=str, - default=None, - help="Specify the path of the temporary directory use by Ray process.") - args = parser.parse_args() - ray.utils.setup_logger(args.logging_level, args.logging_format) - - try: - dashboard = Dashboard( - args.host, - args.port, - args.redis_address, - args.temp_dir, - redis_password=args.redis_password, - ) - dashboard.run() - except Exception as e: - # Something went wrong, so push an error to all drivers. - redis_client = ray.services.create_redis_client( - args.redis_address, password=args.redis_password) - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - message = ("The dashboard on node {} failed with the following " - "error:\n{}".format(os.uname()[1], traceback_str)) - ray.utils.push_error_to_driver_through_redis( - redis_client, ray_constants.DASHBOARD_DIED_ERROR, message) - if isinstance(e, OSError) and e.errno == errno.ENOENT: - logger.warning(message) - else: - raise e diff --git a/python/ray/dashboard/dashboard_main.py b/python/ray/dashboard/dashboard_main.py new file mode 100644 index 000000000000..dc7e5cb08361 --- /dev/null +++ b/python/ray/dashboard/dashboard_main.py @@ -0,0 +1,80 @@ +import os +import traceback + +import click +import ray + +import ray.ray_constants as ray_constants +from ray.dashboard.dashboard import Dashboard + + +@click.command() +@click.option( + "--host", + required=True, + type=str, + help="The host to use for the HTTP server.") +@click.option( + "--port", + required=True, + type=int, + help="The port to use for the HTTP server.") +@click.option( + "--redis-address", + required=True, + type=str, + help="The address to use for Redis.") +@click.option( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") +@click.option( + "--logging-level", + required=False, + type=click.Choice(ray_constants.LOGGER_LEVEL_CHOICES), + default=ray_constants.LOGGER_LEVEL, + help=ray_constants.LOGGER_LEVEL_HELP) +@click.option( + "--logging-format", + required=False, + type=str, + default=ray_constants.LOGGER_FORMAT, + help=ray_constants.LOGGER_FORMAT_HELP) +@click.option( + "--temp-dir", + required=False, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.") +def main(host, port, redis_address, redis_password, logging_level, + logging_format, temp_dir): + ray.utils.setup_logger(logging_level, logging_format) + + metrics_export_address = os.environ.get("METRICS_EXPORT_ADDRESS") + + try: + dashboard = Dashboard( + host, + port, + redis_address, + temp_dir, + metrics_export_address=metrics_export_address, + redis_password=redis_password) + dashboard.run() + except Exception as e: + # Something went wrong, so push an error to all drivers. + redis_client = ray.services.create_redis_client( + redis_address, password=redis_password) + traceback_str = ray.utils.format_error_message(traceback.format_exc()) + message = ( + "The dashboard on node {} failed to start with the following " + "error:\n{}".format(os.uname()[1], traceback_str)) + ray.utils.push_error_to_driver_through_redis( + redis_client, ray_constants.DASHBOARD_DIED_ERROR, message) + raise e + + +if __name__ == "__main__": + main() diff --git a/python/ray/dashboard/interface/dashboard_controller_interface.py b/python/ray/dashboard/interface/dashboard_controller_interface.py new file mode 100644 index 000000000000..1cb38dddf6eb --- /dev/null +++ b/python/ray/dashboard/interface/dashboard_controller_interface.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod + + +class BaseDashboardController(ABC): + """Set of APIs to interact with a Dashboard class and routes. + + Make sure you run start_collecting_metrics function before using + get_[stats]_info methods. + """ + + @abstractmethod + def get_ray_config(self): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def get_node_info(self): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def get_raylet_info(self): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def tune_info(self): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def tune_availability(self): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def launch_profiling(self, node_id, pid, duration): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def check_profiling_status(self, profiling_id): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def get_profiling_info(self, profiling_id): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def kill_actor(self, actor_id, ip_address, port): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def get_logs(self, hostname, pid): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def get_errors(self, hostname, pid): + raise NotImplementedError("Please implement this method.") + + @abstractmethod + def start_collecting_metrics(self): + """Start threads/processes/actors to collect metrics + + NOTE: This interface should be called only once before using + other api calls. + """ + raise NotImplementedError("Please implement this method.") diff --git a/python/ray/dashboard/interface/dashboard_route_handler_interface.py b/python/ray/dashboard/interface/dashboard_route_handler_interface.py new file mode 100644 index 000000000000..37446770ef3f --- /dev/null +++ b/python/ray/dashboard/interface/dashboard_route_handler_interface.py @@ -0,0 +1,59 @@ +import aiohttp + +from abc import ABC, abstractmethod + + +class BaseDashboardRouteHandler(ABC): + """Collection of routes that should be implemented for dashboard.""" + + @abstractmethod + def get_forbidden(self, _) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def get_index(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def ray_config(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def node_info(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def raylet_info(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def tune_info(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def tune_availability(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def launch_profiling(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def check_profiling_status(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def get_profiling_info(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def kill_actor(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def logs(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") + + @abstractmethod + async def errors(self, req) -> aiohttp.web.Response: + raise NotImplementedError("Please implement this method.") diff --git a/python/ray/dashboard/metrics_exporter/__init__.py b/python/ray/dashboard/metrics_exporter/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/dashboard/metrics_exporter/api.py b/python/ray/dashboard/metrics_exporter/api.py new file mode 100644 index 000000000000..d6c24de89be6 --- /dev/null +++ b/python/ray/dashboard/metrics_exporter/api.py @@ -0,0 +1,31 @@ +try: + import requests # `requests` is not part of stdlib. +except ImportError: + requests = None + print("Couldn't import `requests` library. " + "Be sure to install it on the client side.") + +from ray.dashboard.metrics_exporter.schema import AuthRequest, AuthResponse +from ray.dashboard.metrics_exporter.schema import IngestRequest, IngestResponse + + +def authentication_request(url, cluster_id): + auth_requeset = AuthRequest(cluster_id=cluster_id) + response = requests.post(url, data=auth_requeset.json()) + response.raise_for_status() + return AuthResponse.parse_obj(response.json()) + + +def ingest_request(url, cluster_id, access_token, ray_config, node_info, + raylet_info, tune_info, tune_availability): + ingest_request = IngestRequest( + cluster_id=cluster_id, + access_token=access_token, + ray_config=ray_config, + node_info=node_info, + raylet_info=raylet_info, + tune_info=tune_info, + tune_availability=tune_availability) + response = requests.post(url, data=ingest_request.json()) + response.raise_for_status() + return IngestResponse.parse_obj(response.json()) diff --git a/python/ray/dashboard/metrics_exporter/client.py b/python/ray/dashboard/metrics_exporter/client.py new file mode 100644 index 000000000000..f7315c44d1ac --- /dev/null +++ b/python/ray/dashboard/metrics_exporter/client.py @@ -0,0 +1,83 @@ +import logging +import traceback + +from ray.dashboard.metrics_exporter import api +from ray.dashboard.metrics_exporter.exporter import Exporter + +logger = logging.getLogger(__name__) + + +class MetricsExportClient: + """Group of functionalities used by Dashboard to do external communication. + + start_export_metrics should not be called more than once as it can create + multiple threads that export the same metrics. + + Args: + address: Address to export metrics + dashboard_controller(BaseDashboardController): Dashboard controller to + run dashboard business logic. + dashboard_id(str): Unique dashboard ID. + """ + + def __init__(self, address, dashboard_controller, dashboard_id, + exporter: Exporter): + self.dashboard_id = dashboard_id + self.auth_url = "{}/auth".format(address) + self.dashboard_controller = dashboard_controller + self.exporter = exporter + + # Data obtained from requests. + self._dashboard_url = None + self.auth_info = None + + # Client states + self.is_authenticated = False + self.is_exporting_started = False + + def _authenticate(self): + """ + Return: + Whether or not the authentication succeed. + """ + self.auth_info = api.authentication_request(self.auth_url, + self.dashboard_id) + self._dashboard_url = self.auth_info.dashboard_url + self.is_authenticated = True + + @property + def enabled(self): + return self.is_authenticated + + @property + def dashboard_url(self): + # This function should be used only after authentication succeed. + assert self._dashboard_url is not None, ( + "dashboard url should be obtained by " + "`start_exporting_metrics` method first.") + return self._dashboard_url + + def start_exporting_metrics(self): + """Create a thread to export metrics. + + Once this function succeeds, it should not be called again. + + Return: + Whether or not it suceedes to run exporter. + """ + assert not self.is_exporting_started + if not self.is_authenticated: + try: + self._authenticate() + except Exception as e: + error = ("Authentication failed with an error: {}\n" + "Traceback: {}".format(e, traceback.format_exc())) + logger.error(error) + return False, error + + # Exporter is a Python thread that keeps exporting metrics with + # access token obtained by an authentication process. + self.exporter.access_token = self.auth_info.access_token + self.exporter.start() + self.is_exporting_started = True + return True, None diff --git a/python/ray/dashboard/metrics_exporter/exporter.py b/python/ray/dashboard/metrics_exporter/exporter.py new file mode 100644 index 000000000000..267fd88cdb42 --- /dev/null +++ b/python/ray/dashboard/metrics_exporter/exporter.py @@ -0,0 +1,66 @@ +import logging +import threading +import traceback +import time + +from ray.dashboard.metrics_exporter import api + +logger = logging.getLogger(__file__) + + +class Exporter(threading.Thread): + """Python thread that exports metrics periodically. + + Args: + dashboard_id(str): Unique Dashboard ID. + address(str): Address to export metrics. + dashboard_controller(BaseDashboardController): dashboard + controller for dashboard business logic. + update_frequency(float): Frequency to export metrics. + """ + + def __init__(self, + dashboard_id, + address, + dashboard_controller, + update_frequency=1.0): + assert update_frequency >= 1.0 + + self.dashboard_id = dashboard_id + self.dashboard_controller = dashboard_controller + self.export_address = "{}/ingest".format(address) + self.update_frequency = update_frequency + self._access_token = None + + super().__init__() + + @property + def access_token(self): + return self._access_token + + @access_token.setter + def access_token(self, access_token): + self._access_token = access_token + + def export(self, ray_config, node_info, raylet_info, tune_info, + tune_availability): + api.ingest_request(self.export_address, self.dashboard_id, + self.access_token, ray_config, node_info, + raylet_info, tune_info, tune_availability) + # TODO(sang): Add piggybacking response handler. + + def run(self): + assert self.access_token is not None, ( + "Set access token before running an exporter thread.") + while True: + try: + time.sleep(self.update_frequency) + self.export(self.dashboard_controller.get_ray_config(), + self.dashboard_controller.get_node_info(), + self.dashboard_controller.get_raylet_info(), + self.dashboard_controller.tune_info(), + self.dashboard_controller.tune_availability()) + except Exception as e: + logger.error("Exception occured while exporting metrics: {}.\n" + "Traceback: {}".format(e, traceback.format_exc())) + continue diff --git a/python/ray/dashboard/metrics_exporter/schema.py b/python/ray/dashboard/metrics_exporter/schema.py new file mode 100644 index 000000000000..9a31b7811043 --- /dev/null +++ b/python/ray/dashboard/metrics_exporter/schema.py @@ -0,0 +1,26 @@ +import pydantic +import typing + + +class IngestRequest(pydantic.BaseModel): + cluster_id: str + access_token: str + ray_config: typing.Any + node_info: dict + raylet_info: dict + tune_info: dict + tune_availability: dict + + +# TODO(sang): Add piggybacked response. +class IngestResponse(pydantic.BaseModel): + pass + + +class AuthRequest(pydantic.BaseModel): + cluster_id: str + + +class AuthResponse(pydantic.BaseModel): + dashboard_url: str + access_token: str diff --git a/python/ray/services.py b/python/ray/services.py index 6d3632542a3c..25d27a500fa6 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1081,15 +1081,12 @@ def start_dashboard(require_webui, port += 1 dashboard_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py") + os.path.dirname(os.path.abspath(__file__)), + "dashboard/dashboard_main.py") command = [ - sys.executable, - "-u", - dashboard_filepath, - "--host={}".format(host), - "--port={}".format(port), - "--redis-address={}".format(redis_address), - "--temp-dir={}".format(temp_dir), + sys.executable, "-u", dashboard_filepath, "--host={}".format(host), + "--port={}".format(port), "--redis-address={}".format(redis_address), + "--temp-dir={}".format(temp_dir) ] if redis_password: command += ["--redis-password", redis_password] @@ -1121,6 +1118,7 @@ def start_dashboard(require_webui, logger.info("View the Ray dashboard at {}{}{}{}{}".format( colorama.Style.BRIGHT, colorama.Fore.GREEN, dashboard_url, colorama.Fore.RESET, colorama.Style.NORMAL)) + return dashboard_url, process_info else: return None, None diff --git a/python/ray/tests/test_metrics_export.py b/python/ray/tests/test_metrics_export.py new file mode 100644 index 000000000000..efb23b6e47a8 --- /dev/null +++ b/python/ray/tests/test_metrics_export.py @@ -0,0 +1,127 @@ +import pytest +import requests + +from unittest.mock import patch + +from ray.dashboard.metrics_exporter.client import MetricsExportClient +from ray.dashboard.metrics_exporter.exporter import Exporter +from ray.dashboard.metrics_exporter.schema import AuthResponse + +MOCK_DASHBOARD_ID = "1234" +MOCK_DASHBOARD_ADDRESS = "127.0.0.1:9081" +MOCK_ACCESS_TOKEN = "1234" + + +def _setup_client_and_exporter(controller): + exporter = Exporter(MOCK_DASHBOARD_ID, MOCK_DASHBOARD_ADDRESS, controller) + client = MetricsExportClient(MOCK_DASHBOARD_ADDRESS, controller, + MOCK_DASHBOARD_ID, exporter) + return exporter, client + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_verify_exporter_cannot_run_without_access_token(mock_controller): + exporter, client = _setup_client_and_exporter(mock_controller) + # Should raise an assertion error because there's no access token set. + with pytest.raises(AssertionError): + exporter.run() + + +@patch("ray.dashboard.dashboard.DashboardController") +@patch( + "ray.dashboard.metrics_exporter.api.authentication_request", + side_effect=requests.exceptions.HTTPError) +def test_client_invalid_request_status_returned(auth_request, mock_controller): + """ + If authentication request fails with an invalid status code, + `start_exporting_metrics` should fail. + """ + exporter, client = _setup_client_and_exporter(mock_controller) + + # authenticate should throw an exception because API request fails. + with pytest.raises(requests.exceptions.HTTPError): + client._authenticate() + + # This should fail because authentication throws an exception. + result, error = client.start_exporting_metrics() + assert result is False + + +@patch("ray.dashboard.dashboard.DashboardController") +@patch("ray.dashboard.metrics_exporter.api.authentication_request") +def test_authentication(auth_request, mock_controller): + auth_request.return_value = AuthResponse( + dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN) + exporter, client = _setup_client_and_exporter(mock_controller) + + assert client.enabled is False + client._authenticate() + assert client.dashboard_url == MOCK_DASHBOARD_ADDRESS + assert client.enabled is True + + +@patch.object(Exporter, "start") +@patch("ray.dashboard.dashboard.DashboardController") +@patch("ray.dashboard.metrics_exporter.api.authentication_request") +def test_start_exporting_metrics_without_authentication( + auth_request, mock_controller, start): + """ + `start_exporting_metrics` should trigger authentication if users + are not authenticated. + """ + auth_request.return_value = AuthResponse( + dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN) + exporter, client = _setup_client_and_exporter(mock_controller) + + # start_exporting_metrics should succeed. + result, error = client.start_exporting_metrics() + assert result is True + assert error is None + assert client.enabled is True + + +@patch.object(Exporter, "start") +@patch("ray.dashboard.dashboard.DashboardController") +@patch("ray.dashboard.metrics_exporter.api.authentication_request") +def test_start_exporting_metrics_with_authentication(auth_request, + mock_controller, start): + """ + If users are already authenticated, `start_exporting_metrics` + should not authenticate users. + """ + auth_request.return_value = AuthResponse( + dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN) + exporter, client = _setup_client_and_exporter(mock_controller) + # Already authenticated. + client._authenticate() + assert client.enabled is True + + result, error = client.start_exporting_metrics() + # Auth request should be called only once because + # it was already authenticated. + auth_request.call_count == 1 + assert result is True + assert error is None + + +@patch.object(Exporter, "start") +@patch("ray.dashboard.dashboard.DashboardController") +@patch("ray.dashboard.metrics_exporter.api.authentication_request") +def test_start_exporting_metrics_succeed(auth_request, mock_controller, start): + auth_request.return_value = AuthResponse( + dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN) + exporter, client = _setup_client_and_exporter(mock_controller) + + result, error = client.start_exporting_metrics() + assert result is True + assert error is None + assert client.is_exporting_started is True + start.call_count == 1 + + with pytest.raises(AssertionError): + client.start_exporting_metrics() + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__]))