diff --git a/Cargo.lock b/Cargo.lock index 74f939759c..721fdd3425 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1838,9 +1838,11 @@ dependencies = [ "prost", "prost-types", "rand", + "serde_json5", "tokio", "tokio-stream", "tonic 0.11.0", + "tower", "tracing", "uuid", ] diff --git a/nativelink-config/examples/basic_cas.json b/nativelink-config/examples/basic_cas.json index a0403caf6c..173951deb9 100644 --- a/nativelink-config/examples/basic_cas.json +++ b/nativelink-config/examples/basic_cas.json @@ -154,7 +154,8 @@ "worker_api": { "scheduler": "MAIN_SCHEDULER", }, - "admin": {} + "admin": {}, + "health": {}, } }], "global": { diff --git a/nativelink-config/src/cas_server.rs b/nativelink-config/src/cas_server.rs index 42e18c6d9a..a8461d6da7 100644 --- a/nativelink-config/src/cas_server.rs +++ b/nativelink-config/src/cas_server.rs @@ -184,6 +184,18 @@ pub struct AdminConfig { pub path: String, } +#[derive(Deserialize, Debug, Default)] +#[serde(deny_unknown_fields)] +pub struct HealthConfig { + /// Path to register the health status check. If path is "/status", and your + /// domain is "example.com", you can reach the endpoint with: + /// . + /// + /// Default: "/status" + #[serde(default)] + pub path: String, +} + #[derive(Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct ServicesConfig { @@ -228,6 +240,9 @@ pub struct ServicesConfig { /// This is the service for any administrative tasks. /// It provides a REST API endpoint for administrative purposes. pub admin: Option, + + /// This is the service for health status check. + pub health: Option, } #[derive(Deserialize, Debug)] diff --git a/nativelink-service/BUILD.bazel b/nativelink-service/BUILD.bazel index 06e1d4ef2a..7dcdc3ab90 100644 --- a/nativelink-service/BUILD.bazel +++ b/nativelink-service/BUILD.bazel @@ -14,6 +14,7 @@ rust_library( "src/capabilities_server.rs", "src/cas_server.rs", "src/execution_server.rs", + "src/health_server.rs", "src/lib.rs", "src/worker_api_server.rs", ], @@ -27,13 +28,16 @@ rust_library( "//nativelink-util", "@crates//:bytes", "@crates//:futures", + "@crates//:hyper", "@crates//:log", "@crates//:parking_lot", "@crates//:prost", "@crates//:rand", + "@crates//:serde_json5", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", + "@crates//:tower", "@crates//:tracing", "@crates//:uuid", ], @@ -64,9 +68,11 @@ rust_test_suite( "@crates//:prometheus-client", "@crates//:prost", "@crates//:prost-types", + "@crates//:serde_json5", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", + "@crates//:tower", ], ) diff --git a/nativelink-service/Cargo.toml b/nativelink-service/Cargo.toml index a8f7adbe68..8a851d31ee 100644 --- a/nativelink-service/Cargo.toml +++ b/nativelink-service/Cargo.toml @@ -13,6 +13,8 @@ nativelink-scheduler = { path = "../nativelink-scheduler" } bytes = "1.6.0" futures = "0.3.30" +hyper = { version = "0.14.28" } +serde_json5 = "0.1.0" log = "0.4.21" parking_lot = "0.12.1" prost = "0.12.4" @@ -20,6 +22,7 @@ rand = "0.8.5" tokio = { version = "1.37.0", features = ["sync", "rt"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } +tower = "0.4.13" tracing = "0.1.40" uuid = { version = "1.8.0", features = ["v4"] } diff --git a/nativelink-service/src/health_server.rs b/nativelink-service/src/health_server.rs new file mode 100644 index 0000000000..1bd9af94a9 --- /dev/null +++ b/nativelink-service/src/health_server.rs @@ -0,0 +1,82 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::StreamExt; +use hyper::header::{HeaderValue, CONTENT_TYPE}; +use hyper::{Body, Request, Response, StatusCode}; +use nativelink_util::health_utils::{ + HealthRegistry, HealthStatus, HealthStatusDescription, HealthStatusReporter, +}; +use tower::Service; + +/// Content type header value for JSON. +const JSON_CONTENT_TYPE: &str = "application/json; charset=utf-8"; + +#[derive(Clone)] +pub struct HealthServer { + health_registry: HealthRegistry, +} + +impl HealthServer { + pub fn new(health_registry: HealthRegistry) -> Self { + Self { health_registry } + } +} + +impl Service> for HealthServer { + type Response = Response; + type Error = std::convert::Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let health_registry = self.health_registry.clone(); + Box::pin(async move { + let health_status_descriptions: Vec = + health_registry.health_status_report().collect().await; + match serde_json5::to_string(&health_status_descriptions) { + Ok(body) => { + let contains_failed_report = + health_status_descriptions.iter().any(|description| { + matches!(description.status, HealthStatus::Failed { .. }) + }); + let status_code = if contains_failed_report { + StatusCode::SERVICE_UNAVAILABLE + } else { + StatusCode::OK + }; + + Ok(Response::builder() + .status(status_code) + .header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE)) + .body(Body::from(body)) + .unwrap()) + } + + Err(e) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE)) + .body(Body::from(format!("Internal Failure: {e:?}"))) + .unwrap()), + } + }) + } +} diff --git a/nativelink-service/src/lib.rs b/nativelink-service/src/lib.rs index c1891af2b0..ba78350b74 100644 --- a/nativelink-service/src/lib.rs +++ b/nativelink-service/src/lib.rs @@ -17,4 +17,5 @@ pub mod bytestream_server; pub mod capabilities_server; pub mod cas_server; pub mod execution_server; +pub mod health_server; pub mod worker_api_server; diff --git a/nativelink-util/src/health_utils.rs b/nativelink-util/src/health_utils.rs index f0a4e43a8f..1a1cb31722 100644 --- a/nativelink-util/src/health_utils.rs +++ b/nativelink-util/src/health_utils.rs @@ -162,13 +162,17 @@ pub struct HealthRegistry { } pub trait HealthStatusReporter { - fn health_status_report(&self) -> Pin + '_>>; + fn health_status_report( + &self, + ) -> Pin + Send + '_>>; } /// Health status reporter implementation for the health registry that provides a stream /// of health status descriptions. impl HealthStatusReporter for HealthRegistry { - fn health_status_report(&self) -> Pin + '_>> { + fn health_status_report( + &self, + ) -> Pin + Send + '_>> { Box::pin(futures::stream::iter(self.indicators.iter()).then( |(namespace, indicator)| async move { HealthStatusDescription { diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 787cc5858f..a00fa6f80c 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -21,7 +21,7 @@ use async_lock::Mutex as AsyncMutex; use axum::Router; use clap::Parser; use futures::future::{select_all, BoxFuture, OptionFuture, TryFutureExt}; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt; use hyper::server::conn::Http; use hyper::{Response, StatusCode}; use mimalloc::MiMalloc; @@ -37,14 +37,13 @@ use nativelink_service::bytestream_server::ByteStreamServer; use nativelink_service::capabilities_server::CapabilitiesServer; use nativelink_service::cas_server::CasServer; use nativelink_service::execution_server::ExecutionServer; +use nativelink_service::health_server::HealthServer; use nativelink_service::worker_api_server::WorkerApiServer; use nativelink_store::default_store_factory::store_factory; use nativelink_store::store_manager::StoreManager; use nativelink_util::common::fs::{set_idle_file_descriptor_timeout, set_open_file_limit}; use nativelink_util::digest_hasher::{set_default_digest_hasher_func, DigestHasherFunc}; -use nativelink_util::health_utils::{ - HealthRegistryBuilder, HealthStatus, HealthStatusDescription, HealthStatusReporter, -}; +use nativelink_util::health_utils::HealthRegistryBuilder; use nativelink_util::metrics_utils::{ set_metrics_enabled_for_this_thread, Collector, CollectorState, Counter, MetricsComponent, Registry, @@ -79,12 +78,12 @@ const DEFAULT_PROMETHEUS_METRICS_PATH: &str = "/metrics"; /// Note: This must be kept in sync with the documentation in `AdminConfig::path`. const DEFAULT_ADMIN_API_PATH: &str = "/admin"; +// Note: This must be kept in sync with the documentation in `HealthConfig::path`. +const DEFAULT_HEALTH_STATUS_CHECK_PATH: &str = "/status"; + /// Name of environment variable to disable metrics. const METRICS_DISABLE_ENV: &str = "NATIVELINK_DISABLE_METRICS"; -/// Content type header value for JSON. -const JSON_CONTENT_TYPE: &str = "application/json; charset=utf-8"; - /// Backend for bazel remote execution / cache API. #[derive(Parser, Debug)] #[clap( @@ -397,68 +396,20 @@ async fn inner_main( ); let root_metrics_registry = root_metrics_registry.clone(); - let health_registry_status = health_registry_builder.lock().await.build(); + let health_registry = health_registry_builder.lock().await.build(); let mut svc = Router::new() // This is the default service that executes if no other endpoint matches. - .fallback_service(tonic_services.into_service().map_err(|e| panic!("{e}"))) - .route_service( - "/status", - axum::routing::get(move || async move { - fn error_to_response(e: E) -> Response { - let mut response = Response::new(format!("Error: {e:?}")); - *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - response - } + .fallback_service(tonic_services.into_service().map_err(|e| panic!("{e}"))); - spawn_blocking(move || { - futures::executor::block_on(async { - let health_status_descriptions: Vec = - health_registry_status - .health_status_report() - .collect() - .await; - - match serde_json5::to_string(&health_status_descriptions) { - Ok(body) => { - let contains_failed_report = - health_status_descriptions.iter().any(|description| { - matches!( - description.status, - HealthStatus::Failed { .. } - ) - }); - let status_code = if contains_failed_report { - StatusCode::SERVICE_UNAVAILABLE - } else { - StatusCode::OK - }; - Response::builder() - .status(status_code) - .header( - hyper::header::CONTENT_TYPE, - hyper::header::HeaderValue::from_static( - JSON_CONTENT_TYPE, - ), - ) - .body(body) - .unwrap() - } - Err(e) => Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header( - hyper::header::CONTENT_TYPE, - hyper::header::HeaderValue::from_static(JSON_CONTENT_TYPE), - ) - .body(format!("Internal Failure: {e:?}")) - .unwrap(), - } - }) - }) - .await - .unwrap_or_else(error_to_response) - }), - ); + if let Some(health_cfg) = services.health { + let path = if health_cfg.path.is_empty() { + DEFAULT_HEALTH_STATUS_CHECK_PATH + } else { + &health_cfg.path + }; + svc = svc.route_service(path, HealthServer::new(health_registry)); + } if let Some(prometheus_cfg) = services.experimental_prometheus { fn error_to_response(e: E) -> Response {