diff --git a/linera-execution/src/execution_state_actor.rs b/linera-execution/src/execution_state_actor.rs index 3d21ab7ca5d7..6599ec46b0d7 100644 --- a/linera-execution/src/execution_state_actor.rs +++ b/linera-execution/src/execution_state_actor.rs @@ -27,7 +27,7 @@ use crate::{ execution::UserAction, runtime::ContractSyncRuntime, system::{CreateApplicationResult, OpenChainConfig}, - util::RespondExt, + util::{OracleResponseExt as _, RespondExt as _}, ApplicationDescription, ApplicationId, ExecutionError, ExecutionRuntimeConfig, ExecutionRuntimeContext, ExecutionStateView, Message, MessageContext, MessageKind, ModuleId, Operation, OperationContext, OutgoingMessage, ProcessStreamsContext, QueryContext, @@ -387,66 +387,60 @@ where http_responses_are_oracle_responses, callback, } => { - let response = if let Some(response) = - self.txn_tracker.next_replayed_oracle_response()? - { - match response { - OracleResponse::Http(response) => response, - _ => return Err(ExecutionError::OracleResponseMismatch), - } - } else { - let headers = request - .headers - .into_iter() - .map(|http::Header { name, value }| Ok((name.parse()?, value.try_into()?))) - .collect::>()?; - - let url = Url::parse(&request.url)?; - let host = url - .host_str() - .ok_or_else(|| ExecutionError::UnauthorizedHttpRequest(url.clone()))?; - - let (_epoch, committee) = self - .state - .system - .current_committee() - .ok_or_else(|| ExecutionError::UnauthorizedHttpRequest(url.clone()))?; - let allowed_hosts = &committee.policy().http_request_allow_list; - - ensure!( - allowed_hosts.contains(host), - ExecutionError::UnauthorizedHttpRequest(url) - ); - - #[cfg_attr(web, allow(unused_mut))] - let mut request = Client::new() - .request(request.method.into(), url) - .body(request.body) - .headers(headers); - #[cfg(not(web))] - { - request = request.timeout(linera_base::time::Duration::from_millis( - committee.policy().http_request_timeout_ms, - )); - } - - let response = request.send().await?; + let system = &mut self.state.system; + let response = self + .txn_tracker + .oracle(|| async { + let headers = request + .headers + .into_iter() + .map(|http::Header { name, value }| { + Ok((name.parse()?, value.try_into()?)) + }) + .collect::>()?; + + let url = Url::parse(&request.url)?; + let host = url + .host_str() + .ok_or_else(|| ExecutionError::UnauthorizedHttpRequest(url.clone()))?; + + let (_epoch, committee) = system + .current_committee() + .ok_or_else(|| ExecutionError::UnauthorizedHttpRequest(url.clone()))?; + let allowed_hosts = &committee.policy().http_request_allow_list; - let mut response_size_limit = committee.policy().maximum_http_response_bytes; + ensure!( + allowed_hosts.contains(host), + ExecutionError::UnauthorizedHttpRequest(url) + ); - if http_responses_are_oracle_responses { - response_size_limit = response_size_limit - .min(committee.policy().maximum_oracle_response_bytes); - } + #[cfg_attr(web, allow(unused_mut))] + let mut request = Client::new() + .request(request.method.into(), url) + .body(request.body) + .headers(headers); + #[cfg(not(web))] + { + request = request.timeout(linera_base::time::Duration::from_millis( + committee.policy().http_request_timeout_ms, + )); + } - self.receive_http_response(response, response_size_limit) - .await? - }; + let response = request.send().await?; - // Record the oracle response - self.txn_tracker - .add_oracle_response(OracleResponse::Http(response.clone())); + let mut response_size_limit = + committee.policy().maximum_http_response_bytes; + if http_responses_are_oracle_responses { + response_size_limit = response_size_limit + .min(committee.policy().maximum_oracle_response_bytes); + } + Ok(OracleResponse::Http( + Self::receive_http_response(response, response_size_limit).await?, + )) + }) + .await? + .to_http_response()?; callback.respond(response); } @@ -513,25 +507,18 @@ where } ReadEvent { event_id, callback } => { - let event = match self.txn_tracker.next_replayed_oracle_response()? { - None => { - let event = self - .state - .context() - .extra() + let extra = self.state.context().extra(); + let event = self + .txn_tracker + .oracle(|| async { + let event = extra .get_event(event_id.clone()) - .await?; - event.ok_or(ExecutionError::EventsNotFound(vec![event_id.clone()]))? - } - Some(OracleResponse::Event(recorded_event_id, event)) - if recorded_event_id == event_id => - { - event - } - Some(_) => return Err(ExecutionError::OracleResponseMismatch), - }; - self.txn_tracker - .add_oracle_response(OracleResponse::Event(event_id, event.clone())); + .await? + .ok_or(ExecutionError::EventsNotFound(vec![event_id.clone()]))?; + Ok(OracleResponse::Event(event_id.clone(), event)) + }) + .await? + .to_event(&event_id)?; callback.respond(event); } @@ -598,36 +585,37 @@ where query, callback, } => { - let response = match self.txn_tracker.next_replayed_oracle_response()? { - Some(OracleResponse::Service(bytes)) => bytes, - Some(_) => return Err(ExecutionError::OracleResponseMismatch), - None => { + let state = &mut self.state; + let local_time = self.txn_tracker.local_time(); + let created_blobs = self.txn_tracker.created_blobs().clone(); + let bytes = self + .txn_tracker + .oracle(|| async { let context = QueryContext { - chain_id: self.state.context().extra().chain_id(), + chain_id: state.context().extra().chain_id(), next_block_height, - local_time: self.txn_tracker.local_time(), + local_time, }; let QueryOutcome { response, operations, - } = Box::pin(self.state.query_user_application_with_deadline( + } = Box::pin(state.query_user_application_with_deadline( application_id, context, query, deadline, - self.txn_tracker.created_blobs().clone(), + created_blobs, )) .await?; ensure!( operations.is_empty(), ExecutionError::ServiceOracleQueryOperations(operations) ); - response - } - }; - self.txn_tracker - .add_oracle_response(OracleResponse::Service(response.clone())); - callback.respond(response); + Ok(OracleResponse::Service(response)) + }) + .await? + .to_service_response()?; + callback.respond(bytes); } AddOutgoingMessage { message, callback } => { @@ -673,18 +661,12 @@ where } ValidationRound { round, callback } => { - let result_round = - if let Some(response) = self.txn_tracker.next_replayed_oracle_response()? { - match response { - OracleResponse::Round(round) => round, - _ => return Err(ExecutionError::OracleResponseMismatch), - } - } else { - round - }; - self.txn_tracker - .add_oracle_response(OracleResponse::Round(result_round)); - callback.respond(result_round); + let validation_round = self + .txn_tracker + .oracle(|| async { Ok(OracleResponse::Round(round)) }) + .await? + .to_round()?; + callback.respond(validation_round); } } @@ -929,7 +911,6 @@ where /// /// Ensures that the response does not exceed the provided `size_limit`. async fn receive_http_response( - &mut self, response: reqwest::Response, size_limit: u64, ) -> Result { diff --git a/linera-execution/src/system.rs b/linera-execution/src/system.rs index 9b1b82313d8a..e2496f9fc486 100644 --- a/linera-execution/src/system.rs +++ b/linera-execution/src/system.rs @@ -31,9 +31,9 @@ use serde::{Deserialize, Serialize}; #[cfg(test)] use crate::test_utils::SystemExecutionState; use crate::{ - committee::Committee, ApplicationDescription, ApplicationId, ExecutionError, - ExecutionRuntimeContext, MessageContext, MessageKind, OperationContext, OutgoingMessage, - QueryContext, QueryOutcome, ResourceController, TransactionTracker, + committee::Committee, util::OracleResponseExt as _, ApplicationDescription, ApplicationId, + ExecutionError, ExecutionRuntimeContext, MessageContext, MessageKind, OperationContext, + OutgoingMessage, QueryContext, QueryOutcome, ResourceController, TransactionTracker, }; /// The event stream name for new epochs and committees. @@ -492,17 +492,14 @@ where stream_id: StreamId::system(EPOCH_STREAM_NAME), index: epoch.0, }; - let bytes = match txn_tracker.next_replayed_oracle_response()? { - None => self.get_event(event_id.clone()).await?, - Some(OracleResponse::Event(recorded_event_id, bytes)) - if recorded_event_id == event_id => - { - bytes - } - Some(_) => return Err(ExecutionError::OracleResponseMismatch), - }; + let bytes = txn_tracker + .oracle(|| async { + let bytes = self.get_event(event_id.clone()).await?; + Ok(OracleResponse::Event(event_id.clone(), bytes)) + }) + .await? + .to_event(&event_id)?; let blob_id = BlobId::new(bcs::from_bytes(&bytes)?, BlobType::Committee); - txn_tracker.add_oracle_response(OracleResponse::Event(event_id, bytes)); let committee = bcs::from_bytes(self.read_blob_content(blob_id).await?.bytes())?; self.blob_used(txn_tracker, blob_id).await?; self.committees.get_mut().insert(epoch, committee); @@ -522,16 +519,12 @@ where stream_id: StreamId::system(REMOVED_EPOCH_STREAM_NAME), index: epoch.0, }; - let bytes = match txn_tracker.next_replayed_oracle_response()? { - None => self.get_event(event_id.clone()).await?, - Some(OracleResponse::Event(recorded_event_id, bytes)) - if recorded_event_id == event_id => - { - bytes - } - Some(_) => return Err(ExecutionError::OracleResponseMismatch), - }; - txn_tracker.add_oracle_response(OracleResponse::Event(event_id, bytes)); + txn_tracker + .oracle(|| async { + let bytes = self.get_event(event_id.clone()).await?; + Ok(OracleResponse::Event(event_id, bytes)) + }) + .await?; } UpdateStreams(streams) => { let mut missing_events = Vec::new(); @@ -562,23 +555,15 @@ where stream_id, index, }; - match txn_tracker.next_replayed_oracle_response()? { - None => { - if !self - .context() - .extra() - .contains_event(event_id.clone()) - .await? - { - missing_events.push(event_id); - continue; + let extra = self.context().extra(); + txn_tracker + .oracle(|| async { + if !extra.contains_event(event_id.clone()).await? { + missing_events.push(event_id.clone()); } - } - Some(OracleResponse::EventExists(recorded_event_id)) - if recorded_event_id == event_id => {} - Some(_) => return Err(ExecutionError::OracleResponseMismatch), - } - txn_tracker.add_oracle_response(OracleResponse::EventExists(event_id)); + Ok(OracleResponse::EventExists(event_id)) + }) + .await?; } ensure!( missing_events.is_empty(), diff --git a/linera-execution/src/transaction_tracker.rs b/linera-execution/src/transaction_tracker.rs index e3e67ddde704..92627da192dc 100644 --- a/linera-execution/src/transaction_tracker.rs +++ b/linera-execution/src/transaction_tracker.rs @@ -3,6 +3,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, + future::Future, mem, vec, }; @@ -166,14 +167,26 @@ impl TransactionTracker { &self.blobs } - pub fn add_oracle_response(&mut self, oracle_response: OracleResponse) { - self.oracle_responses.push(oracle_response); - } - pub fn add_operation_result(&mut self, result: Option>) { self.operation_result = result } + /// In replay mode, returns the next recorded oracle response. Otherwise executes `f` and + /// records and returns the result. `f` is the implementation of the actual oracle and is + /// only called in validation mode, so it does not have to be fully deterministic. + pub async fn oracle(&mut self, f: F) -> Result<&OracleResponse, ExecutionError> + where + F: FnOnce() -> G, + G: Future>, + { + let response = match self.next_replayed_oracle_response()? { + Some(response) => response, + None => f().await?, + }; + self.oracle_responses.push(response); + Ok(self.oracle_responses.last().unwrap()) + } + pub fn add_stream_to_process( &mut self, application_id: ApplicationId, @@ -245,7 +258,7 @@ impl TransactionTracker { } else { false }; - self.add_oracle_response(oracle_response); + self.oracle_responses.push(oracle_response); Ok(replaying) } @@ -256,9 +269,7 @@ impl TransactionTracker { /// /// In both cases, the value (returned or obtained from the oracle) must be recorded using /// `add_oracle_response`. - pub fn next_replayed_oracle_response( - &mut self, - ) -> Result, ExecutionError> { + fn next_replayed_oracle_response(&mut self) -> Result, ExecutionError> { let Some(responses) = &mut self.replaying_oracle_responses else { return Ok(None); // Not in replay mode. }; diff --git a/linera-execution/src/util/mod.rs b/linera-execution/src/util/mod.rs index 474ca267be30..e25549c28070 100644 --- a/linera-execution/src/util/mod.rs +++ b/linera-execution/src/util/mod.rs @@ -6,6 +6,7 @@ mod sync_response; use futures::channel::mpsc; +use linera_base::{data_types::OracleResponse, http::Response, identifiers::EventId}; pub use self::sync_response::SyncSender; use crate::ExecutionError; @@ -125,3 +126,45 @@ impl RespondExt for SyncSender { } } } + +pub(crate) trait OracleResponseExt { + fn to_round(&self) -> Result, ExecutionError>; + + fn to_service_response(&self) -> Result, ExecutionError>; + + fn to_http_response(&self) -> Result; + + fn to_event(&self, event_id: &EventId) -> Result, ExecutionError>; +} + +impl OracleResponseExt for OracleResponse { + fn to_round(&self) -> Result, ExecutionError> { + match self { + OracleResponse::Round(round) => Ok(*round), + _ => Err(ExecutionError::OracleResponseMismatch), + } + } + + fn to_service_response(&self) -> Result, ExecutionError> { + match self { + OracleResponse::Service(bytes) => Ok(bytes.clone()), + _ => Err(ExecutionError::OracleResponseMismatch), + } + } + + fn to_http_response(&self) -> Result { + match self { + OracleResponse::Http(response) => Ok(response.clone()), + _ => Err(ExecutionError::OracleResponseMismatch), + } + } + + fn to_event(&self, event_id: &EventId) -> Result, ExecutionError> { + match self { + OracleResponse::Event(recorded_event_id, event) if recorded_event_id == event_id => { + Ok(event.clone()) + } + _ => Err(ExecutionError::OracleResponseMismatch), + } + } +}