diff --git a/Cargo.lock b/Cargo.lock index db14c4599..45e8e21b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6236,12 +6236,14 @@ dependencies = [ "http-serde", "hyper", "once_cell", + "prost-types", "reqwest", "rmp-serde", "rustrict", "serde", "serde_json", "strum", + "thiserror", "tracing", "tracing-subscriber", "uuid 1.2.2", diff --git a/Cargo.toml b/Cargo.toml index b4e3005c9..fb3a9ccda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ async-trait = "0.1.58" axum = { version = "0.6.0", default-features = false } chrono = { version = "0.4.23", default-features = false, features = ["clock"] } once_cell = "1.16.0" +prost-types = "0.11.0" uuid = "1.2.2" thiserror = "1.0.37" serde = { version = "1.0.148", default-features = false } diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 3e9215dc2..e12b43d6c 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -7,6 +7,7 @@ mod provisioner_server; use shuttle_common::project::ProjectName; use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest}; use std::collections::HashMap; +use std::convert::TryInto; use std::ffi::OsString; use std::fs::{read_to_string, File}; use std::io::stdout; @@ -454,7 +455,7 @@ impl Shuttle { tokio::spawn(async move { while let Some(log) = stream.message().await.expect("to get log from stream") { - let log: shuttle_common::LogItem = log.into(); + let log: shuttle_common::LogItem = log.try_into().expect("to convert log"); println!("{log}"); } }); diff --git a/common/Cargo.toml b/common/Cargo.toml index 47db8aa75..776a21fbd 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -16,12 +16,14 @@ crossterm = { version = "0.25.0", optional = true } http = { version = "0.2.8", optional = true } http-serde = { version = "1.1.2", optional = true } once_cell = { workspace = true, optional = true } +prost-types = { workspace = true, optional = true } reqwest = { version = "0.11.13", optional = true } rmp-serde = { version = "1.1.1", optional = true } rustrict = { version = "0.5.5", optional = true } serde = { workspace = true } serde_json = { workspace = true, optional = true } strum = { version = "0.24.1", features = ["derive"], optional = true } +thiserror = { workspace = true, optional = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, optional = true } uuid = { workspace = true, features = ["v4", "serde"], optional = true } @@ -35,5 +37,5 @@ backend = ["async-trait", "axum"] display = ["comfy-table", "crossterm"] tracing = ["serde_json"] wasm = ["http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"] -models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json", "service"] +models = ["anyhow", "async-trait", "display", "http", "prost-types", "reqwest", "serde_json", "service", "thiserror"] service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "strum", "uuid"] diff --git a/common/src/models/mod.rs b/common/src/models/mod.rs index 0236c641d..17a8cb293 100644 --- a/common/src/models/mod.rs +++ b/common/src/models/mod.rs @@ -11,8 +11,10 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use http::StatusCode; use serde::de::DeserializeOwned; +use thiserror::Error; use tracing::trace; +/// A to_json wrapper for handling our error states #[async_trait] pub trait ToJson { async fn to_json(self) -> Result; @@ -48,3 +50,14 @@ impl ToJson for reqwest::Response { } } } + +/// Errors that can occur when changing types. Especially from prost +#[derive(Error, Debug)] +pub enum ParseError { + #[error("failed to parse UUID: {0}")] + Uuid(#[from] uuid::Error), + #[error("failed to parse timestamp: {0}")] + Timestamp(#[from] prost_types::TimestampError), + #[error("failed to parse serde: {0}")] + Serde(#[from] serde_json::Error), +} diff --git a/common/src/wasm.rs b/common/src/wasm.rs index 82da1d2ee..edaa8808b 100644 --- a/common/src/wasm.rs +++ b/common/src/wasm.rs @@ -15,7 +15,6 @@ use crate::tracing::JsonVisitor; extern crate rmp_serde as rmps; -// todo: add http extensions field #[derive(Serialize, Deserialize, Debug)] pub struct RequestWrapper { #[serde(with = "http_serde::method")] @@ -44,11 +43,11 @@ impl From for RequestWrapper { impl RequestWrapper { /// Serialize a RequestWrapper to the Rust MessagePack data format - pub fn into_rmp(self) -> Vec { + pub fn into_rmp(self) -> Result, rmps::encode::Error> { let mut buf = Vec::new(); - self.serialize(&mut Serializer::new(&mut buf)).unwrap(); + self.serialize(&mut Serializer::new(&mut buf))?; - buf + Ok(buf) } /// Consume the wrapper and return a request builder with `Parts` set @@ -60,7 +59,7 @@ impl RequestWrapper { request .headers_mut() - .unwrap() + .unwrap() // Safe to unwrap as we just made the builder .extend(self.headers.into_iter()); request @@ -92,11 +91,11 @@ impl From for ResponseWrapper { impl ResponseWrapper { /// Serialize a ResponseWrapper into the Rust MessagePack data format - pub fn into_rmp(self) -> Vec { + pub fn into_rmp(self) -> Result, rmps::encode::Error> { let mut buf = Vec::new(); - self.serialize(&mut Serializer::new(&mut buf)).unwrap(); + self.serialize(&mut Serializer::new(&mut buf))?; - buf + Ok(buf) } /// Consume the wrapper and return a response builder with `Parts` set @@ -107,7 +106,7 @@ impl ResponseWrapper { response .headers_mut() - .unwrap() + .unwrap() // Safe to unwrap since we just made the builder .extend(self.headers.into_iter()); response @@ -389,7 +388,7 @@ mod test { .unwrap(); let (parts, _) = request.into_parts(); - let rmp = RequestWrapper::from(parts).into_rmp(); + let rmp = RequestWrapper::from(parts).into_rmp().unwrap(); let back: RequestWrapper = rmps::from_slice(&rmp).unwrap(); @@ -415,7 +414,7 @@ mod test { .unwrap(); let (parts, _) = response.into_parts(); - let rmp = ResponseWrapper::from(parts).into_rmp(); + let rmp = ResponseWrapper::from(parts).into_rmp().unwrap(); let back: ResponseWrapper = rmps::from_slice(&rmp).unwrap(); diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index 5b56dcba0..f7d35873b 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -21,9 +21,9 @@ use chrono::{DateTime, Utc}; use serde_json::json; -use shuttle_common::{tracing::JsonVisitor, STATE_MESSAGE}; +use shuttle_common::{models::ParseError, tracing::JsonVisitor, STATE_MESSAGE}; use shuttle_proto::runtime; -use std::{str::FromStr, time::SystemTime}; +use std::{convert::TryFrom, str::FromStr, time::SystemTime}; use tracing::{field::Visit, span, warn, Metadata, Subscriber}; use tracing_subscriber::Layer; use uuid::Uuid; @@ -112,19 +112,25 @@ impl From for DeploymentState { } } -impl From for Log { - fn from(log: runtime::LogItem) -> Self { - Self { - id: Uuid::from_slice(&log.id).unwrap(), - state: runtime::LogState::from_i32(log.state).unwrap().into(), - level: runtime::LogLevel::from_i32(log.level).unwrap().into(), - timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()), +impl TryFrom for Log { + type Error = ParseError; + + fn try_from(log: runtime::LogItem) -> Result { + Ok(Self { + id: Uuid::from_slice(&log.id)?, + state: runtime::LogState::from_i32(log.state) + .unwrap_or_default() + .into(), + level: runtime::LogLevel::from_i32(log.level) + .unwrap_or_default() + .into(), + timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?), file: log.file, line: log.line, target: log.target, - fields: serde_json::from_slice(&log.fields).unwrap(), + fields: serde_json::from_slice(&log.fields)?, r#type: LogType::Event, - } + }) } } diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 095943cc0..b0ce1e3be 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -241,19 +241,23 @@ async fn load( ) -> Result<()> { info!( "loading project from: {}", - so_path.clone().into_os_string().into_string().unwrap() + so_path + .clone() + .into_os_string() + .into_string() + .unwrap_or_default() ); let secrets = secret_getter .get_secrets(&service_id) .await - .unwrap() + .map_err(|e| Error::SecretsGet(Box::new(e)))? .into_iter() .map(|secret| (secret.key, secret.value)); let secrets = HashMap::from_iter(secrets); let load_request = tonic::Request::new(LoadRequest { - path: so_path.into_os_string().into_string().unwrap(), + path: so_path.into_os_string().into_string().unwrap_or_default(), service_name: service_name.clone(), secrets, }); @@ -283,7 +287,10 @@ async fn run( mut kill_recv: KillReceiver, cleanup: impl FnOnce(std::result::Result, Status>) + Send + 'static, ) { - deployment_updater.set_address(&id, &address).await.unwrap(); + deployment_updater + .set_address(&id, &address) + .await + .expect("to set deployment address"); let start_request = tonic::Request::new(StartRequest { deployment_id: id.as_bytes().to_vec(), @@ -292,7 +299,10 @@ async fn run( }); info!("starting service"); - let response = runtime_client.start(start_request).await.unwrap(); + let response = runtime_client + .start(start_request) + .await + .expect("to start deployment"); info!(response = ?response.into_inner(), "start client response: "); diff --git a/deployer/src/error.rs b/deployer/src/error.rs index 0f9ad03cc..e81eae92b 100644 --- a/deployer/src/error.rs +++ b/deployer/src/error.rs @@ -24,6 +24,8 @@ pub enum Error { SecretsParse(#[from] toml::de::Error), #[error("Failed to set secrets: {0}")] SecretsSet(#[source] Box), + #[error("Failed to get secrets: {0}")] + SecretsGet(#[source] Box), #[error("Failed to cleanup old deployments: {0}")] OldCleanup(#[source] Box), #[error("Gateway client error: {0}")] diff --git a/deployer/src/runtime_manager.rs b/deployer/src/runtime_manager.rs index 089350c17..4952ba9dd 100644 --- a/deployer/src/runtime_manager.rs +++ b/deployer/src/runtime_manager.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, sync::Arc}; +use std::{convert::TryInto, path::PathBuf, sync::Arc}; use anyhow::Context; use shuttle_proto::runtime::{self, runtime_client::RuntimeClient, SubscribeLogsRequest}; @@ -99,7 +99,9 @@ impl RuntimeManager { tokio::spawn(async move { while let Ok(Some(log)) = stream.message().await { - sender.send(log.into()).expect("to send log to persistence"); + if let Ok(log) = log.try_into() { + sender.send(log).expect("to send log to persistence"); + } } }); diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 9e44fa415..aadf6f749 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -11,7 +11,7 @@ anyhow = { workspace = true } chrono = { workspace = true } home = "0.5.4" prost = "0.11.2" -prost-types = "0.11.0" +prost-types = { workspace = true } tokio = { version = "1.22.0", features = ["process"] } tonic = { workspace = true } tracing = { workspace = true } diff --git a/proto/src/lib.rs b/proto/src/lib.rs index 122a2cb0c..3a452cac9 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -96,6 +96,7 @@ pub mod provisioner { pub mod runtime { use std::{ + convert::TryFrom, path::PathBuf, process::Command, time::{Duration, SystemTime}, @@ -104,6 +105,7 @@ pub mod runtime { use anyhow::Context; use chrono::DateTime; use prost_types::Timestamp; + use shuttle_common::models::ParseError; use tokio::process; use tonic::transport::{Channel, Endpoint}; use tracing::info; @@ -159,18 +161,20 @@ pub mod runtime { } } - impl From for shuttle_common::LogItem { - fn from(log: LogItem) -> Self { - Self { - id: Uuid::from_slice(&log.id).unwrap(), - timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()), - state: LogState::from_i32(log.state).unwrap().into(), - level: LogLevel::from_i32(log.level).unwrap().into(), + impl TryFrom for shuttle_common::LogItem { + type Error = ParseError; + + fn try_from(log: LogItem) -> Result { + Ok(Self { + id: Uuid::from_slice(&log.id)?, + timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?), + state: LogState::from_i32(log.state).unwrap_or_default().into(), + level: LogLevel::from_i32(log.level).unwrap_or_default().into(), file: log.file, line: log.line, target: log.target, fields: log.fields, - } + }) } } diff --git a/runtime/src/axum/mod.rs b/runtime/src/axum/mod.rs index 74e3fd41f..15b3cf0c7 100644 --- a/runtime/src/axum/mod.rs +++ b/runtime/src/axum/mod.rs @@ -40,7 +40,7 @@ const BODY_FD: u32 = 4; pub struct AxumWasm { router: Mutex>, logs_rx: Mutex>>>, - logs_tx: Mutex>>, + logs_tx: Sender>, kill_tx: Mutex>>, } @@ -57,7 +57,7 @@ impl AxumWasm { Self { router: Mutex::new(None), logs_rx: Mutex::new(Some(rx)), - logs_tx: Mutex::new(tx), + logs_tx: tx, kill_tx: Mutex::new(None), } } @@ -103,7 +103,7 @@ impl Runtime for AxumWasm { .context("invalid socket address") .map_err(|err| Status::invalid_argument(err.to_string()))?; - let logs_tx = self.logs_tx.lock().unwrap().clone(); + let logs_tx = self.logs_tx.clone(); let (kill_tx, kill_rx) = tokio::sync::oneshot::channel(); @@ -200,7 +200,7 @@ impl RouterBuilder { } fn build(self) -> anyhow::Result { - let file = self.src.expect("module path should be set"); + let file = self.src.context("module path should be set")?; let module = Module::from_file(&self.engine, file)?; for export in module.exports() { @@ -268,14 +268,16 @@ impl Router { let mut log: runtime::LogItem = log.into(); log.id = deployment_id.clone(); - logs_tx.blocking_send(Ok(log)).unwrap(); + logs_tx.blocking_send(Ok(log)).expect("to send log"); } }); let (parts, body) = req.into_parts(); // Serialise request parts to rmp - let request_rmp = RequestWrapper::from(parts).into_rmp(); + let request_rmp = RequestWrapper::from(parts) + .into_rmp() + .context("failed to make request wrapper")?; // Write request parts to wasm module parts_stream @@ -315,9 +317,9 @@ impl Router { trace!("calling Router"); self.linker .get(&mut store, "axum", "__SHUTTLE_Axum_call") - .expect("wasm module should be loaded and the router function should be available") + .context("wasm module should be loaded and the router function should be available")? .into_func() - .expect("router function should be a function") + .context("router function should be a function")? .typed::<(RawFd, RawFd, RawFd), ()>(&store)? .call( &mut store, diff --git a/runtime/src/legacy/mod.rs b/runtime/src/legacy/mod.rs index c7b19b4fc..137e7a369 100644 --- a/runtime/src/legacy/mod.rs +++ b/runtime/src/legacy/mod.rs @@ -24,7 +24,7 @@ use tonic::{transport::Endpoint, Request, Response, Status}; use tracing::{error, instrument, trace}; use uuid::Uuid; -use crate::provisioner_factory::{AbstractFactory, AbstractProvisionerFactory}; +use crate::provisioner_factory::ProvisionerFactory; mod error; @@ -35,7 +35,7 @@ where // Mutexes are for interior mutability so_path: Mutex>, logs_rx: Mutex>>, - logs_tx: Mutex>, + logs_tx: UnboundedSender, provisioner_address: Endpoint, kill_tx: Mutex>>, secrets: Mutex>>, @@ -52,7 +52,7 @@ where Self { so_path: Mutex::new(None), logs_rx: Mutex::new(Some(rx)), - logs_tx: Mutex::new(tx), + logs_tx: tx, kill_tx: Mutex::new(None), provisioner_address, secrets: Mutex::new(None), @@ -92,8 +92,8 @@ where let provisioner_client = ProvisionerClient::connect(self.provisioner_address.clone()) .await - .expect("failed to connect to provisioner"); - let abstract_factory = AbstractProvisionerFactory::new(provisioner_client); + .context("failed to connect to provisioner") + .map_err(|err| Status::internal(err.to_string()))?; let so_path = self .so_path @@ -132,9 +132,11 @@ where let service_name = ServiceName::from_str(service_name.as_str()) .map_err(|err| Status::from_error(Box::new(err)))?; - let deployment_id = Uuid::from_slice(&deployment_id).unwrap(); + let deployment_id = Uuid::from_slice(&deployment_id) + .map_err(|error| Status::invalid_argument(error.to_string()))?; - let mut factory = abstract_factory.get_factory( + let mut factory = ProvisionerFactory::new( + provisioner_client, service_name, deployment_id, secrets, @@ -142,7 +144,7 @@ where ); trace!("got factory"); - let logs_tx = self.logs_tx.lock().unwrap().clone(); + let logs_tx = self.logs_tx.clone(); let logger = Logger::new(logs_tx, deployment_id); @@ -177,7 +179,7 @@ where // Move logger items into stream to be returned tokio::spawn(async move { while let Some(log) = logs_rx.recv().await { - tx.send(Ok(log.into())).await.unwrap(); + tx.send(Ok(log.into())).await.expect("to send log"); } }); diff --git a/runtime/src/provisioner_factory.rs b/runtime/src/provisioner_factory.rs index 58cd34aa2..9b017d01c 100644 --- a/runtime/src/provisioner_factory.rs +++ b/runtime/src/provisioner_factory.rs @@ -10,55 +10,6 @@ use tonic::{transport::Channel, Request}; use tracing::{debug, info, trace}; use uuid::Uuid; -/// Trait to make it easy to get a factory (service locator) for each service being started -pub trait AbstractFactory: Send + 'static { - type Output: Factory; - - /// Get a factory for a specific service - fn get_factory( - &self, - service_name: ServiceName, - deployment_id: Uuid, - secrets: BTreeMap, - storage_manager: S, - ) -> Self::Output; -} - -/// An abstract factory that makes factories which uses provisioner -#[derive(Clone)] -pub struct AbstractProvisionerFactory { - provisioner_client: ProvisionerClient, -} - -impl AbstractFactory for AbstractProvisionerFactory -where - S: StorageManager, -{ - type Output = ProvisionerFactory; - - fn get_factory( - &self, - service_name: ServiceName, - deployment_id: Uuid, - secrets: BTreeMap, - storage_manager: S, - ) -> Self::Output { - ProvisionerFactory::new( - self.provisioner_client.clone(), - service_name, - deployment_id, - secrets, - storage_manager, - ) - } -} - -impl AbstractProvisionerFactory { - pub fn new(provisioner_client: ProvisionerClient) -> Self { - Self { provisioner_client } - } -} - /// A factory (service locator) which goes through the provisioner crate pub struct ProvisionerFactory where