diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index 7ee6ecfa2..e0f7c9550 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -11,6 +11,7 @@ rust_library( srcs = [ "src/action_scheduler.rs", "src/cache_lookup_scheduler.rs", + "src/default_action_listener.rs", "src/default_scheduler_factory.rs", "src/grpc_scheduler.rs", "src/lib.rs", diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index b04442294..5e07cb346 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -12,16 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; +use futures::Future; use nativelink_error::Error; use nativelink_util::action_messages::{ActionInfo, ActionState, ClientOperationId}; use nativelink_util::metrics_utils::Registry; -use tokio::sync::watch; use crate::platform_property_manager::PlatformPropertyManager; +/// ActionListener interface is responsible for interfacing with clients +/// that are interested in the state of an action. +pub trait ActionListener: Sync + Send + Unpin { + /// Returns the client operation id. + fn client_operation_id(&self) -> &ClientOperationId; + + /// Returns the current action state. + fn action_state(&self) -> Arc; + + /// Waits for the action state to change. + fn changed( + &mut self, + ) -> Pin, Error>> + Send + Sync + '_>>; +} + /// ActionScheduler interface is responsible for interactions between the scheduler /// and action related operations. #[async_trait] @@ -37,13 +53,13 @@ pub trait ActionScheduler: Sync + Send + Unpin { &self, client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error>; + ) -> Result>, Error>; /// Find an existing action by its name. async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error>; + ) -> Result>>, Error>; /// Cleans up the cache of recently completed actions. async fn clean_recently_completed_actions(&self); diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 679d47759..9b6ccbee1 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -13,21 +13,19 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; -use futures::future::Shared as SharedFuture; -use futures::stream::StreamExt; -use futures::FutureExt; -use nativelink_error::{make_err, Code, Error}; +use futures::Future; +use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ ActionResult as ProtoActionResult, GetActionResultRequest, }; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ClientOperationId, - OperationId, + ActionInfo, ActionInfoHashKey, ActionStage, ActionState, ClientOperationId, OperationId, }; use nativelink_util::background_spawn; use nativelink_util::common::DigestInfo; @@ -35,28 +33,22 @@ use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::store_trait::{Store, StoreLike}; use parking_lot::{Mutex, MutexGuard}; use scopeguard::guard; -use tokio::select; -use tokio::sync::{oneshot, watch}; -use tokio_stream::wrappers::WatchStream; +use tokio::sync::oneshot; use tonic::Request; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; -/// A future containing the resolved `ClientOperationId` once it is figured out. -/// This future may be cloned and will always yield the same value once resolved. -type ClientOperationIdFuture = SharedFuture>; - /// Actions that are having their cache checked or failed cache lookup and are /// being forwarded upstream. Missing the skip_cache_check actions which are /// forwarded directly. type CheckActions = HashMap< ActionInfoHashKey, - ( - SharedFuture>, - watch::Receiver>, - ), + Vec<( + ClientOperationId, + oneshot::Sender>, Error>>, + )>, >; pub struct CacheLookupScheduler { @@ -67,7 +59,7 @@ pub struct CacheLookupScheduler { /// in the action cache. action_scheduler: Arc, /// Actions that are currently performing a CacheCheck. - cache_check_actions: Arc>, + inflight_cache_checks: Arc>, } async fn get_action_from_store( @@ -98,25 +90,49 @@ async fn get_action_from_store( } } +/// Future for when ActionListeners are known. +type ActionListenerOneshot = oneshot::Receiver>, Error>>; + fn subscribe_to_existing_action( - cache_check_actions: &MutexGuard, + inflight_cache_checks: &mut MutexGuard, unique_qualifier: &ActionInfoHashKey, -) -> Option<(ClientOperationIdFuture, watch::Receiver>)> { - cache_check_actions - .get(unique_qualifier) - .map(|(client_operation_id_rx, rx)| { - let mut rx = rx.clone(); - rx.mark_changed(); - (client_operation_id_rx.clone(), rx) + client_operation_id: &ClientOperationId, +) -> Option { + inflight_cache_checks + .get_mut(unique_qualifier) + .map(|oneshots| { + let (tx, rx) = oneshot::channel(); + oneshots.push((client_operation_id.clone(), tx)); + rx }) } +struct CachedActionListener { + client_operation_id: ClientOperationId, + action_state: Arc, +} + +impl ActionListener for CachedActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn action_state(&self) -> Arc { + self.action_state.clone() + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + Sync + '_>> { + Box::pin(async { Ok(self.action_state.clone()) }) + } +} impl CacheLookupScheduler { pub fn new(ac_store: Store, action_scheduler: Arc) -> Result { Ok(Self { ac_store, action_scheduler, - cache_check_actions: Default::default(), + inflight_cache_checks: Default::default(), }) } } @@ -136,7 +152,7 @@ impl ActionScheduler for CacheLookupScheduler { &self, client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { if action_info.skip_cache_lookup { // Cache lookup skipped, forward to the upstream. return self @@ -144,77 +160,85 @@ impl ActionScheduler for CacheLookupScheduler { .add_action(client_operation_id, action_info) .await; } - let mut current_state = Arc::new(ActionState { - id: OperationId::new(action_info.unique_qualifier.clone()), - stage: ActionStage::CacheCheck, - }); let cache_check_result = { // Check this isn't a duplicate request first. - let mut cache_check_actions = self.cache_check_actions.lock(); - let current_state = current_state.clone(); + let mut inflight_cache_checks = self.inflight_cache_checks.lock(); let unique_qualifier = action_info.unique_qualifier.clone(); - subscribe_to_existing_action(&cache_check_actions, &unique_qualifier).ok_or_else( - move || { - let (client_operation_id_tx, client_operation_id_rx) = oneshot::channel(); - let client_operation_id_rx = client_operation_id_rx.shared(); - let (tx, rx) = watch::channel(current_state); - cache_check_actions.insert( - unique_qualifier.clone(), - (client_operation_id_rx.clone(), rx), - ); - // In the event we loose the reference to our `scope_guard`, it will remove - // the action from the cache_check_actions map. - let cache_check_actions = self.cache_check_actions.clone(); - ( - client_operation_id_tx, - client_operation_id_rx, - tx, - guard((), move |_| { - cache_check_actions.lock().remove(&unique_qualifier); - }), - ) - }, + subscribe_to_existing_action( + &mut inflight_cache_checks, + &unique_qualifier, + &client_operation_id, ) + .ok_or_else(move || { + let (action_listener_tx, action_listener_rx) = oneshot::channel(); + inflight_cache_checks.insert( + unique_qualifier.clone(), + vec![(client_operation_id, action_listener_tx)], + ); + // In the event we loose the reference to our `scope_guard`, it will remove + // the action from the inflight_cache_checks map. + let inflight_cache_checks = self.inflight_cache_checks.clone(); + ( + action_listener_rx, + guard((), move |_| { + inflight_cache_checks.lock().remove(&unique_qualifier); + }), + ) + }) + }; + let (action_listener_rx, scope_guard) = match cache_check_result { + Ok(action_listener_fut) => { + let action_listener = action_listener_fut.await.map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })?; + return action_listener; + } + Err(client_tx_and_scope_guard) => client_tx_and_scope_guard, }; - let (client_operation_id_tx, client_operation_id_rx, tx, scope_guard) = - match cache_check_result { - Ok((client_operation_id_tx, rx)) => { - let client_operation_id = client_operation_id_tx.await.map_err(|_| { - make_err!( - Code::Internal, - "Client operation id tx hung up in CacheLookupScheduler::add_action" - ) - })?; - return Ok((client_operation_id, rx)); - } - Err(client_tx_and_scope_guard) => client_tx_and_scope_guard, - }; - let rx = tx.subscribe(); let ac_store = self.ac_store.clone(); let action_scheduler = self.action_scheduler.clone(); - let client_operation_id_clone = client_operation_id.clone(); + let inflight_cache_checks = self.inflight_cache_checks.clone(); // We need this spawn because we are returning a stream and this spawn will populate the stream's data. background_spawn!("cache_lookup_scheduler_add_action", async move { - // If our spawn ever dies, we will remove the action from the cache_check_actions map. + // If our spawn ever dies, we will remove the action from the inflight_cache_checks map. let _scope_guard = scope_guard; // Perform cache check. - let action_digest = current_state.action_digest(); - let instance_name = action_info.instance_name().clone(); + let instance_name = action_info.unique_qualifier.instance_name.clone(); if let Some(action_result) = get_action_from_store( &ac_store, - *action_digest, + action_info.unique_qualifier.digest, instance_name, - current_state.id.unique_qualifier.digest_function, + action_info.unique_qualifier.digest_function, ) .await { - match ac_store.has(*action_digest).await { + match ac_store.has(action_info.unique_qualifier.digest).await { Ok(Some(_)) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::CompletedFromCache(action_result); - let _ = tx.send(current_state); + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + // We are ready to resolve the in-flight actions. We remove the + // in-flight actions from the map. + inflight_cache_checks.remove(&action_info.unique_qualifier) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Nobody is waiting for this action anymore. + }; + let action_state = Arc::new(ActionState { + id: OperationId::new(action_info.unique_qualifier.clone()), + stage: ActionStage::CompletedFromCache(action_result), + }); + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send(Ok(Box::pin(CachedActionListener { + client_operation_id, + action_state: action_state.clone(), + }))); + } return; } Err(err) => { @@ -227,52 +251,39 @@ impl ActionScheduler for CacheLookupScheduler { _ => {} } } - // Not in cache, forward to upstream and proxy state. - match action_scheduler - .add_action(client_operation_id_clone, action_info) - .await - { - Ok((new_client_operation_id, rx)) => { - // It's ok if the other end hung up, just keep going just - // in case they come back. - let _ = client_operation_id_tx.send(new_client_operation_id); - let mut watch_stream = WatchStream::new(rx); - loop { - select!( - Some(action_state) = watch_stream.next() => { - if tx.send(action_state).is_err() { - break; - } - } - _ = tx.closed() => { - break; - } - ) - } - } - Err(err) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::Completed(ActionResult { - error: Some(err), - ..Default::default() - }); - let _ = tx.send(current_state); - } + + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + inflight_cache_checks.remove(&action_info.unique_qualifier) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Noone is waiting for this action anymore. + }; + + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send( + action_scheduler + .add_action(client_operation_id, action_info.clone()) + .await, + ); } }); - let client_operation_id = client_operation_id_rx.await.map_err(|_| { - make_err!( - Code::Internal, - "Client operation id tx hung up in CacheLookupScheduler::add_action" - ) - })?; - Ok((client_operation_id, rx)) + action_listener_rx + .await + .map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })? + .err_tip(|| "In CacheLookupScheduler::add_action") } async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { self.action_scheduler .find_by_client_operation_id(client_operation_id) .await diff --git a/nativelink-scheduler/src/default_action_listener.rs b/nativelink-scheduler/src/default_action_listener.rs new file mode 100644 index 000000000..91ac36ce8 --- /dev/null +++ b/nativelink-scheduler/src/default_action_listener.rs @@ -0,0 +1,80 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use futures::Future; +use nativelink_error::{make_err, Code, Error}; +use nativelink_util::action_messages::{ActionState, ClientOperationId}; +use tokio::sync::watch; + +use crate::action_scheduler::ActionListener; + +/// Simple implementation of ActionListener using tokio's watch. +pub struct DefaultActionListener { + client_operation_id: ClientOperationId, + action_state: watch::Receiver>, +} + +impl DefaultActionListener { + pub fn new( + client_operation_id: ClientOperationId, + mut action_state: watch::Receiver>, + ) -> Self { + action_state.mark_changed(); + Self { + client_operation_id, + action_state, + } + } +} + +impl ActionListener for DefaultActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn action_state(&self) -> Arc { + self.action_state.borrow().clone() + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + Sync + '_>> { + Box::pin(async move { + self.action_state.changed().await.map_or_else( + |e| { + Err(make_err!( + Code::Internal, + "Sender of ActionState went away unexpectedly - {e:?}" + )) + }, + |()| Ok(self.action_state.borrow_and_update().clone()), + ) + }) + } +} + +impl Clone for DefaultActionListener { + /// Clones the current action state and marks the receiver as changed. + fn clone(&self) -> Self { + let mut action_state = self.action_state.clone(); + action_state.mark_changed(); + Self { + client_operation_id: self.client_operation_id.clone(), + action_state, + } + } +} diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 5f4c174a5..ab96b02a2 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -45,7 +46,8 @@ use tokio::time::sleep; use tonic::{Request, Streaming}; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::default_action_listener::DefaultActionListener; use crate::platform_property_manager::PlatformPropertyManager; pub struct GrpcScheduler { @@ -115,7 +117,7 @@ impl GrpcScheduler { async fn stream_state( mut result_stream: Streaming, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { if let Some(initial_response) = result_stream .message() .await @@ -175,7 +177,10 @@ impl GrpcScheduler { ) } }); - return Ok((client_operation_id, rx)); + return Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + rx, + ))); } Err(make_err!( Code::Internal, @@ -237,7 +242,7 @@ impl ActionScheduler for GrpcScheduler { &self, _client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY { None } else { @@ -278,7 +283,7 @@ impl ActionScheduler for GrpcScheduler { async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { let request = WaitExecutionRequest { name: client_operation_id.to_string(), }; @@ -297,7 +302,7 @@ impl ActionScheduler for GrpcScheduler { .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; match result_stream { - Ok(result_stream) => Ok(Some(result_stream.1)), + Ok(result_stream) => Ok(Some(result_stream)), Err(err) => { event!( Level::WARN, diff --git a/nativelink-scheduler/src/lib.rs b/nativelink-scheduler/src/lib.rs index 0818ce059..27bc8f8a9 100644 --- a/nativelink-scheduler/src/lib.rs +++ b/nativelink-scheduler/src/lib.rs @@ -14,6 +14,7 @@ pub mod action_scheduler; pub mod cache_lookup_scheduler; +pub mod default_action_listener; pub mod default_scheduler_factory; pub mod grpc_scheduler; pub mod operation_state_manager; diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index 47e52fb4e..e6fce736d 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -14,17 +14,17 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_config::schedulers::{PropertyModification, PropertyType}; use nativelink_error::{Error, ResultExt}; -use nativelink_util::action_messages::{ActionInfo, ActionState, ClientOperationId}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; use nativelink_util::metrics_utils::Registry; use parking_lot::Mutex; -use tokio::sync::watch; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; pub struct PropertyModifierScheduler { @@ -92,7 +92,7 @@ impl ActionScheduler for PropertyModifierScheduler { &self, client_operation_id: ClientOperationId, mut action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { let platform_property_manager = self .get_platform_property_manager(&action_info.unique_qualifier.instance_name) .await @@ -120,7 +120,7 @@ impl ActionScheduler for PropertyModifierScheduler { async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { self.scheduler .find_by_client_operation_id(client_operation_id) .await diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 7724aa0ad..3a23acff8 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -19,17 +19,18 @@ use async_trait::async_trait; use futures::{Future, Stream}; use nativelink_error::{Error, ResultExt}; use nativelink_util::action_messages::{ - ActionInfo, ActionStage, ActionState, ClientOperationId, OperationId, WorkerId, + ActionInfo, ActionStage, ClientOperationId, OperationId, WorkerId, }; use nativelink_util::metrics_utils::Registry; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{watch, Notify}; +use tokio::sync::Notify; use tokio::time::Duration; use tokio_stream::StreamExt; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::default_action_listener::DefaultActionListener; use crate::operation_state_manager::{ ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, OperationStageFlags, @@ -83,7 +84,7 @@ impl SimpleScheduler { &self, client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { let add_action_result = self .client_state_manager .add_action(client_operation_id.clone(), action_info) @@ -91,7 +92,12 @@ impl SimpleScheduler { add_action_result .as_receiver() .await - .map(move |receiver| (client_operation_id, receiver.into_owned())) + .map(move |receiver| -> Pin> { + Box::pin(DefaultActionListener::new( + client_operation_id, + receiver.into_owned(), + )) + }) } async fn clean_recently_completed_actions(&self) { @@ -110,7 +116,7 @@ impl SimpleScheduler { async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { let filter_result = self .client_state_manager .filter_operations(&OperationFilter { @@ -124,13 +130,14 @@ impl SimpleScheduler { let Some(result) = stream.next().await else { return Ok(None); }; - Ok(Some( + Ok(Some(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), result .as_receiver() .await .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting receiver")? .into_owned(), - )) + )))) } async fn get_queued_operations( @@ -326,14 +333,14 @@ impl ActionScheduler for SimpleScheduler { &self, client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { self.add_action(client_operation_id, action_info).await } async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { let maybe_receiver = self .find_by_client_operation_id(client_operation_id) .await diff --git a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs index b986c255b..9599c0d2f 100644 --- a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs +++ b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs @@ -27,6 +27,7 @@ use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::ActionResult as ProtoActionResult; use nativelink_scheduler::action_scheduler::ActionScheduler; use nativelink_scheduler::cache_lookup_scheduler::CacheLookupScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_store::memory_store::MemoryStore; use nativelink_util::action_messages::{ @@ -105,7 +106,10 @@ async fn add_action_handles_skip_cache() -> Result<(), Error> { .add_action(client_operation_id.clone(), skip_cache_action), context .mock_scheduler - .expect_add_action(Ok((client_operation_id, forward_watch_channel_rx))) + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + forward_watch_channel_rx + )))) ); Ok(()) } diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index c4b7551c5..c85b74347 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -26,6 +26,7 @@ use nativelink_config::schedulers::{PlatformPropertyAddition, PropertyModificati use nativelink_error::Error; use nativelink_macro::nativelink_test; use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::property_modifier_scheduler::PropertyModifierScheduler; use nativelink_util::action_messages::{ @@ -88,7 +89,10 @@ async fn add_action_adds_property() -> Result<(), Error> { .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok((client_operation_id.clone(), forward_watch_channel_rx))), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( @@ -132,7 +136,10 @@ async fn add_action_overwrites_property() -> Result<(), Error> { .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok((client_operation_id.clone(), forward_watch_channel_rx))), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( @@ -173,7 +180,10 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok((client_operation_id.clone(), forward_watch_channel_rx))), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( @@ -214,7 +224,10 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok((client_operation_id.clone(), forward_watch_channel_rx))), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( @@ -250,7 +263,10 @@ async fn add_action_property_remove() -> Result<(), Error> { .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok((client_operation_id.clone(), forward_watch_channel_rx))), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 93f79a4ef..70b23e42b 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -13,17 +13,20 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use futures::poll; +use futures::task::Poll; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::simple_scheduler::SimpleScheduler; use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; @@ -36,7 +39,7 @@ use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; use pretty_assertions::assert_eq; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; use utils::scheduler_utils::{make_base_action_info, INSTANCE_NAME}; use uuid::Uuid; @@ -128,7 +131,7 @@ async fn setup_action( action_digest: DigestInfo, platform_properties: PlatformProperties, insert_timestamp: SystemTime, -) -> Result<(ClientOperationId, watch::Receiver>), Error> { +) -> Result>, Error> { let mut action_info = make_base_action_info(insert_timestamp); action_info.platform_properties = platform_properties; action_info.unique_qualifier.digest = action_digest; @@ -153,13 +156,14 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); { // Worker should have been sent an execute command. @@ -182,7 +186,7 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -207,17 +211,19 @@ async fn find_executing_action() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (client_operation_id, client_rx) = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); + let client_operation_id = action_listener.client_operation_id().clone(); // Drop our receiver and look up a new one. - drop(client_rx); - let mut client_rx = scheduler + drop(action_listener); + let mut action_listener = scheduler .find_by_client_operation_id(&client_operation_id) .await .expect("Action not found") @@ -244,7 +250,7 @@ async fn find_executing_action() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -273,7 +279,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp1 = make_system_time(1); - let (_client_operation_id1, mut client_rx1) = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, PlatformProperties::default(), @@ -281,7 +287,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ) .await?; let insert_timestamp2 = make_system_time(2); - let (_client_operation_id2, mut client_rx2) = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, PlatformProperties::default(), @@ -358,14 +364,14 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } @@ -387,14 +393,14 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); + let action_state = client2_action_listener.action_state(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } { let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); + let action_state = client2_action_listener.action_state(); // We now know the name of the action so populate it. assert_eq!(&action_state.stage, &expected_action_stage); } @@ -441,7 +447,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -458,7 +464,10 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> v => panic!("Expected StartAction, got : {v:?}"), }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); operation_id }; @@ -468,7 +477,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> let action_digest = DigestInfo::new([88u8; 32], 512); let insert_timestamp = make_system_time(14); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -478,7 +487,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -493,7 +502,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -504,7 +513,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> Ok(()) } -// + #[nativelink_test] async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), Error> { let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); @@ -529,7 +538,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, platform_properties.clone()).await?; let insert_timestamp = make_system_time(1); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, worker_properties.clone(), @@ -539,7 +548,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -569,7 +578,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -611,14 +620,14 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let insert_timestamp1 = make_system_time(1); let insert_timestamp2 = make_system_time(2); - let (_client1_operation_id, mut client1_rx) = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp1, ) .await?; - let (_client2_operation_id, mut client2_rx) = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -628,8 +637,8 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Clients should get notification saying it's been queued. - let action_state1 = client1_rx.borrow_and_update(); - let action_state2 = client2_rx.borrow_and_update(); + let action_state1 = client1_action_listener.changed().await.unwrap(); + let action_state2 = client2_action_listener.changed().await.unwrap(); // Name is random so we set force it to be the same. expected_action_state.id = action_state1.id.clone(); assert_eq!(action_state1.as_ref(), &expected_action_state); @@ -665,11 +674,11 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { // Both client1 and client2 should be receiving the same updates. // Most importantly the `name` (which is random) will be the same. assert_eq!( - client1_rx.borrow_and_update().as_ref(), + client1_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); assert_eq!( - client2_rx.borrow_and_update().as_ref(), + client2_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -677,7 +686,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Now if another action is requested it should also join with executing action. let insert_timestamp3 = make_system_time(2); - let (_client3_operation_id, mut client3_rx) = setup_action( + let mut client3_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -685,7 +694,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { ) .await?; assert_eq!( - client3_rx.borrow_and_update().as_ref(), + client3_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -709,7 +718,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), drop(rx_from_worker); let insert_timestamp = make_system_time(1); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -718,7 +727,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), .await?; { // Client should get notification saying it's being queued not executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -747,7 +756,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -795,7 +804,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); assert_eq!( action_state.as_ref(), &ActionState { @@ -827,7 +836,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); assert_eq!( action_state.as_ref(), &ActionState { @@ -865,7 +874,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_client_operation_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -878,7 +887,10 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err match rx_from_worker.recv().await.unwrap().update { Some(update_for_worker::Update::StartAction(start_execute)) => { // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); start_execute.operation_id } v => panic!("Expected StartAction, got : {v:?}"), @@ -938,7 +950,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -963,7 +975,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (client_id, client_rx) = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -971,8 +983,10 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E ) .await?; + let client_id = action_listener.client_operation_id().clone(); + // Drop our receiver and don't reconnect until completed. - drop(client_rx); + drop(action_listener); let operation_id = { // Other tests check full data. We only care if we got StartAction. @@ -1030,14 +1044,14 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E .await?; // Now look up a channel after the action has completed. - let mut client_rx = scheduler + let mut action_listener = scheduler .find_by_client_operation_id(&client_id) .await .unwrap() .expect("Action not found"); { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1063,7 +1077,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, good_worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_client_id, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1078,7 +1092,10 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } let _ = setup_new_worker(&scheduler, rogue_worker_id, PlatformProperties::default()).await?; @@ -1137,8 +1154,8 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { { // Ensure client did not get notified. assert_eq!( - client_rx.has_changed().unwrap(), - false, + poll!(action_listener.changed()), + Poll::Pending, "Client should not have been notified of event" ); } @@ -1169,7 +1186,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro }; let insert_timestamp = make_system_time(1); - let (_, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1201,7 +1218,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro let operation_id = { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1245,7 +1262,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro // Action should now be executing. expected_action_state.stage = ActionStage::Completed(action_result.clone()); assert_eq!( - client_rx.borrow_and_update().as_ref(), + action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -1255,7 +1272,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro { let insert_timestamp = make_system_time(1); - let (_, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1264,7 +1281,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro .await?; // We didn't disconnect our worker, so it will have scheduled it to the worker. expected_action_state.stage = ActionStage::Executing; - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // The name of the action changed (since it's a new action), so update it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1292,7 +1309,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, platform_properties.clone()).await?; let insert_timestamp1 = make_system_time(1); - let (_client1_id, mut client1_rx) = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1300,7 +1317,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> ) .await?; let insert_timestamp2 = make_system_time(1); - let (_client2_id, mut client2_rx) = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties, @@ -1313,8 +1330,8 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> v => panic!("Expected StartAction, got : {v:?}"), } let (operation_id1, operation_id2) = { - let state_1 = client1_rx.borrow_and_update(); - let state_2 = client2_rx.borrow_and_update(); + let state_1 = client1_action_listener.changed().await.unwrap(); + let state_2 = client2_action_listener.changed().await.unwrap(); // First client should be in an Executing state. assert_eq!(state_1.stage, ActionStage::Executing); // Second client should be in a queued state. @@ -1356,15 +1373,9 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> ) .await?; - // Ensure client did not get notified. - assert!( - client1_rx.changed().await.is_ok(), - "Client should have been notified of event" - ); - { // First action should now be completed. - let action_state = client1_rx.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1385,7 +1396,10 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } // Tell scheduler our second task is completed. @@ -1399,7 +1413,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> { // Our second client should be notified it completed. - let action_state = client2_rx.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1432,7 +1446,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { // This is queued after the next one (even though it's placed in the map // first), so it should execute second. let insert_timestamp2 = make_system_time(2); - let (_, mut client2_rx) = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties.clone(), @@ -1440,7 +1454,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { ) .await?; let insert_timestamp1 = make_system_time(1); - let (_, mut client1_rx) = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1457,9 +1471,15 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { } { // First client should be in an Executing state. - assert_eq!(client1_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client1_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); // Second client should be in a queued state. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Queued); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Queued + ); } Ok(()) @@ -1481,7 +1501,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let (_, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1496,7 +1516,10 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> v => panic!("Expected StartAction, got : {v:?}"), }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); OperationId::try_from(operation_id.as_str())? }; @@ -1510,7 +1533,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1529,7 +1552,10 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } let err = make_err!(Code::Internal, "Some error"); @@ -1540,7 +1566,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1630,7 +1656,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; - let (_, mut client_rx) = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1648,7 +1674,10 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), v => panic!("Expected StartAction, got : {v:?}"), }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); OperationId::try_from(operation_id.as_str())? }; @@ -1670,7 +1699,10 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), .await .err_tip(|| "worker went away")?; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } Ok(()) diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 4e56ff038..e1b4d6d89 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_error::{make_input_err, Error}; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; -use nativelink_util::action_messages::{ActionInfo, ActionState, ClientOperationId}; -use tokio::sync::{mpsc, watch, Mutex}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; +use tokio::sync::{mpsc, Mutex}; #[allow(clippy::large_enum_variant)] enum ActionSchedulerCalls { @@ -30,8 +31,8 @@ enum ActionSchedulerCalls { enum ActionSchedulerReturns { GetPlatformPropertyManager(Result, Error>), - AddAction(Result<(ClientOperationId, watch::Receiver>), Error>), - FindExistingAction(Result>>, Error>), + AddAction(Result>, Error>), + FindExistingAction(Result>>, Error>), } pub struct MockActionScheduler { @@ -81,7 +82,7 @@ impl MockActionScheduler { pub async fn expect_add_action( &self, - result: Result<(ClientOperationId, watch::Receiver>), Error>, + result: Result>, Error>, ) -> (ClientOperationId, ActionInfo) { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::AddAction(req) = rx_call_lock @@ -100,7 +101,7 @@ impl MockActionScheduler { pub async fn expect_find_by_client_operation_id( &self, - result: Result>>, Error>, + result: Result>>, Error>, ) -> ClientOperationId { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::FindExistingAction(req) = rx_call_lock @@ -144,7 +145,7 @@ impl ActionScheduler for MockActionScheduler { &self, client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result<(ClientOperationId, watch::Receiver>), Error> { + ) -> Result>, Error> { self.tx_call .send(ActionSchedulerCalls::AddAction(( client_operation_id, @@ -165,7 +166,7 @@ impl ActionScheduler for MockActionScheduler { async fn find_by_client_operation_id( &self, client_operation_id: &ClientOperationId, - ) -> Result>>, Error> { + ) -> Result>>, Error> { self.tx_call .send(ActionSchedulerCalls::FindExistingAction( client_operation_id.clone(), diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index 0832852fb..8b7ddde43 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -17,7 +17,8 @@ use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use futures::{Stream, StreamExt}; +use futures::stream::unfold; +use futures::Stream; use nativelink_config::cas_server::{ExecutionConfig, InstanceName}; use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::execution_server::{ @@ -27,18 +28,16 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ Action, Command, ExecuteRequest, WaitExecutionRequest, }; use nativelink_proto::google::longrunning::Operation; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::store_manager::StoreManager; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionState, ClientOperationId, DEFAULT_EXECUTION_PRIORITY, + ActionInfo, ActionInfoHashKey, ClientOperationId, DEFAULT_EXECUTION_PRIORITY, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::{make_ctx_for_hash_func, DigestHasherFunc}; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; -use tokio::sync::watch; -use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; use tracing::{error_span, event, instrument, Level}; @@ -208,15 +207,34 @@ impl ExecutionServer { fn to_execute_stream( nl_client_operation_id: NativelinkClientOperationId, - receiver: watch::Receiver>, + action_listener: Pin>, ) -> Response { let client_operation_id_string = nl_client_operation_id.into_string(); - let receiver_stream = Box::pin(WatchStream::new(receiver).map(move |action_update| { - event!(Level::INFO, ?action_update, "Execute Resp Stream"); - let client_operation_id = - ClientOperationId::from_raw_string(client_operation_id_string.clone()); - Ok(action_update.as_operation(client_operation_id)) - })); + let receiver_stream = Box::pin(unfold( + Some(action_listener), + move |maybe_action_listener| { + let client_operation_id_string = client_operation_id_string.clone(); + async move { + let mut action_listener = maybe_action_listener?; + match action_listener.changed().await { + Ok(action_update) => { + event!(Level::INFO, ?action_update, "Execute Resp Stream"); + let client_operation_id = ClientOperationId::from_raw_string( + client_operation_id_string.clone(), + ); + Some(( + Ok(action_update.as_operation(client_operation_id)), + Some(action_listener), + )) + } + Err(err) => { + event!(Level::ERROR, ?err, "Error in action_listener stream"); + Some((Err(err.into()), None)) + } + } + } + }, + )); tonic::Response::new(receiver_stream) } @@ -258,7 +276,7 @@ impl ExecutionServer { ) .await?; - let (client_operation_id, rx) = instance_info + let action_listener = instance_info .scheduler .add_action( ClientOperationId::new(action_info.unique_qualifier.clone()), @@ -270,9 +288,9 @@ impl ExecutionServer { Ok(Self::to_execute_stream( NativelinkClientOperationId { instance_name, - client_operation_id, + client_operation_id: action_listener.client_operation_id().clone(), }, - rx, + action_listener, )) }