Skip to content

Commit a8b6166

Browse files
authored
refactor: unwraps and mor (#587)
* refactor: better error handling * refactor: trim mutexes * refactor: remove abstract factory * refactor: remove extension todo
1 parent 35c0660 commit a8b6166

File tree

15 files changed

+103
-106
lines changed

15 files changed

+103
-106
lines changed

Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ async-trait = "0.1.58"
4040
axum = { version = "0.6.0", default-features = false }
4141
chrono = { version = "0.4.23", default-features = false, features = ["clock"] }
4242
once_cell = "1.16.0"
43+
prost-types = "0.11.0"
4344
uuid = "1.2.2"
4445
thiserror = "1.0.37"
4546
serde = { version = "1.0.148", default-features = false }

cargo-shuttle/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod provisioner_server;
77
use shuttle_common::project::ProjectName;
88
use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest};
99
use std::collections::HashMap;
10+
use std::convert::TryInto;
1011
use std::ffi::OsString;
1112
use std::fs::{read_to_string, File};
1213
use std::io::stdout;
@@ -454,7 +455,7 @@ impl Shuttle {
454455

455456
tokio::spawn(async move {
456457
while let Some(log) = stream.message().await.expect("to get log from stream") {
457-
let log: shuttle_common::LogItem = log.into();
458+
let log: shuttle_common::LogItem = log.try_into().expect("to convert log");
458459
println!("{log}");
459460
}
460461
});

common/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ crossterm = { version = "0.25.0", optional = true }
1616
http = { version = "0.2.8", optional = true }
1717
http-serde = { version = "1.1.2", optional = true }
1818
once_cell = { workspace = true, optional = true }
19+
prost-types = { workspace = true, optional = true }
1920
reqwest = { version = "0.11.13", optional = true }
2021
rmp-serde = { version = "1.1.1", optional = true }
2122
rustrict = { version = "0.5.5", optional = true }
2223
serde = { workspace = true }
2324
serde_json = { workspace = true, optional = true }
2425
strum = { version = "0.24.1", features = ["derive"], optional = true }
26+
thiserror = { workspace = true, optional = true }
2527
tracing = { workspace = true }
2628
tracing-subscriber = { workspace = true, optional = true }
2729
uuid = { workspace = true, features = ["v4", "serde"], optional = true }
@@ -35,5 +37,5 @@ backend = ["async-trait", "axum"]
3537
display = ["comfy-table", "crossterm"]
3638
tracing = ["serde_json"]
3739
wasm = ["http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"]
38-
models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json", "service"]
40+
models = ["anyhow", "async-trait", "display", "http", "prost-types", "reqwest", "serde_json", "service", "thiserror"]
3941
service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "strum", "uuid"]

common/src/models/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ use anyhow::{Context, Result};
1111
use async_trait::async_trait;
1212
use http::StatusCode;
1313
use serde::de::DeserializeOwned;
14+
use thiserror::Error;
1415
use tracing::trace;
1516

17+
/// A to_json wrapper for handling our error states
1618
#[async_trait]
1719
pub trait ToJson {
1820
async fn to_json<T: DeserializeOwned>(self) -> Result<T>;
@@ -48,3 +50,14 @@ impl ToJson for reqwest::Response {
4850
}
4951
}
5052
}
53+
54+
/// Errors that can occur when changing types. Especially from prost
55+
#[derive(Error, Debug)]
56+
pub enum ParseError {
57+
#[error("failed to parse UUID: {0}")]
58+
Uuid(#[from] uuid::Error),
59+
#[error("failed to parse timestamp: {0}")]
60+
Timestamp(#[from] prost_types::TimestampError),
61+
#[error("failed to parse serde: {0}")]
62+
Serde(#[from] serde_json::Error),
63+
}

common/src/wasm.rs

+10-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use crate::tracing::JsonVisitor;
1515

1616
extern crate rmp_serde as rmps;
1717

18-
// todo: add http extensions field
1918
#[derive(Serialize, Deserialize, Debug)]
2019
pub struct RequestWrapper {
2120
#[serde(with = "http_serde::method")]
@@ -44,11 +43,11 @@ impl From<http::request::Parts> for RequestWrapper {
4443

4544
impl RequestWrapper {
4645
/// Serialize a RequestWrapper to the Rust MessagePack data format
47-
pub fn into_rmp(self) -> Vec<u8> {
46+
pub fn into_rmp(self) -> Result<Vec<u8>, rmps::encode::Error> {
4847
let mut buf = Vec::new();
49-
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
48+
self.serialize(&mut Serializer::new(&mut buf))?;
5049

51-
buf
50+
Ok(buf)
5251
}
5352

5453
/// Consume the wrapper and return a request builder with `Parts` set
@@ -60,7 +59,7 @@ impl RequestWrapper {
6059

6160
request
6261
.headers_mut()
63-
.unwrap()
62+
.unwrap() // Safe to unwrap as we just made the builder
6463
.extend(self.headers.into_iter());
6564

6665
request
@@ -92,11 +91,11 @@ impl From<http::response::Parts> for ResponseWrapper {
9291

9392
impl ResponseWrapper {
9493
/// Serialize a ResponseWrapper into the Rust MessagePack data format
95-
pub fn into_rmp(self) -> Vec<u8> {
94+
pub fn into_rmp(self) -> Result<Vec<u8>, rmps::encode::Error> {
9695
let mut buf = Vec::new();
97-
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
96+
self.serialize(&mut Serializer::new(&mut buf))?;
9897

99-
buf
98+
Ok(buf)
10099
}
101100

102101
/// Consume the wrapper and return a response builder with `Parts` set
@@ -107,7 +106,7 @@ impl ResponseWrapper {
107106

108107
response
109108
.headers_mut()
110-
.unwrap()
109+
.unwrap() // Safe to unwrap since we just made the builder
111110
.extend(self.headers.into_iter());
112111

113112
response
@@ -389,7 +388,7 @@ mod test {
389388
.unwrap();
390389

391390
let (parts, _) = request.into_parts();
392-
let rmp = RequestWrapper::from(parts).into_rmp();
391+
let rmp = RequestWrapper::from(parts).into_rmp().unwrap();
393392

394393
let back: RequestWrapper = rmps::from_slice(&rmp).unwrap();
395394

@@ -415,7 +414,7 @@ mod test {
415414
.unwrap();
416415

417416
let (parts, _) = response.into_parts();
418-
let rmp = ResponseWrapper::from(parts).into_rmp();
417+
let rmp = ResponseWrapper::from(parts).into_rmp().unwrap();
419418

420419
let back: ResponseWrapper = rmps::from_slice(&rmp).unwrap();
421420

deployer/src/deployment/deploy_layer.rs

+17-11
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
2222
use chrono::{DateTime, Utc};
2323
use serde_json::json;
24-
use shuttle_common::{tracing::JsonVisitor, STATE_MESSAGE};
24+
use shuttle_common::{models::ParseError, tracing::JsonVisitor, STATE_MESSAGE};
2525
use shuttle_proto::runtime;
26-
use std::{str::FromStr, time::SystemTime};
26+
use std::{convert::TryFrom, str::FromStr, time::SystemTime};
2727
use tracing::{field::Visit, span, warn, Metadata, Subscriber};
2828
use tracing_subscriber::Layer;
2929
use uuid::Uuid;
@@ -112,19 +112,25 @@ impl From<Log> for DeploymentState {
112112
}
113113
}
114114

115-
impl From<runtime::LogItem> for Log {
116-
fn from(log: runtime::LogItem) -> Self {
117-
Self {
118-
id: Uuid::from_slice(&log.id).unwrap(),
119-
state: runtime::LogState::from_i32(log.state).unwrap().into(),
120-
level: runtime::LogLevel::from_i32(log.level).unwrap().into(),
121-
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()),
115+
impl TryFrom<runtime::LogItem> for Log {
116+
type Error = ParseError;
117+
118+
fn try_from(log: runtime::LogItem) -> Result<Self, Self::Error> {
119+
Ok(Self {
120+
id: Uuid::from_slice(&log.id)?,
121+
state: runtime::LogState::from_i32(log.state)
122+
.unwrap_or_default()
123+
.into(),
124+
level: runtime::LogLevel::from_i32(log.level)
125+
.unwrap_or_default()
126+
.into(),
127+
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?),
122128
file: log.file,
123129
line: log.line,
124130
target: log.target,
125-
fields: serde_json::from_slice(&log.fields).unwrap(),
131+
fields: serde_json::from_slice(&log.fields)?,
126132
r#type: LogType::Event,
127-
}
133+
})
128134
}
129135
}
130136

deployer/src/deployment/run.rs

+15-5
Original file line numberDiff line numberDiff line change
@@ -241,19 +241,23 @@ async fn load(
241241
) -> Result<()> {
242242
info!(
243243
"loading project from: {}",
244-
so_path.clone().into_os_string().into_string().unwrap()
244+
so_path
245+
.clone()
246+
.into_os_string()
247+
.into_string()
248+
.unwrap_or_default()
245249
);
246250

247251
let secrets = secret_getter
248252
.get_secrets(&service_id)
249253
.await
250-
.unwrap()
254+
.map_err(|e| Error::SecretsGet(Box::new(e)))?
251255
.into_iter()
252256
.map(|secret| (secret.key, secret.value));
253257
let secrets = HashMap::from_iter(secrets);
254258

255259
let load_request = tonic::Request::new(LoadRequest {
256-
path: so_path.into_os_string().into_string().unwrap(),
260+
path: so_path.into_os_string().into_string().unwrap_or_default(),
257261
service_name: service_name.clone(),
258262
secrets,
259263
});
@@ -283,7 +287,10 @@ async fn run(
283287
mut kill_recv: KillReceiver,
284288
cleanup: impl FnOnce(std::result::Result<Response<StopResponse>, Status>) + Send + 'static,
285289
) {
286-
deployment_updater.set_address(&id, &address).await.unwrap();
290+
deployment_updater
291+
.set_address(&id, &address)
292+
.await
293+
.expect("to set deployment address");
287294

288295
let start_request = tonic::Request::new(StartRequest {
289296
deployment_id: id.as_bytes().to_vec(),
@@ -292,7 +299,10 @@ async fn run(
292299
});
293300

294301
info!("starting service");
295-
let response = runtime_client.start(start_request).await.unwrap();
302+
let response = runtime_client
303+
.start(start_request)
304+
.await
305+
.expect("to start deployment");
296306

297307
info!(response = ?response.into_inner(), "start client response: ");
298308

deployer/src/error.rs

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub enum Error {
2424
SecretsParse(#[from] toml::de::Error),
2525
#[error("Failed to set secrets: {0}")]
2626
SecretsSet(#[source] Box<dyn StdError + Send>),
27+
#[error("Failed to get secrets: {0}")]
28+
SecretsGet(#[source] Box<dyn StdError + Send>),
2729
#[error("Failed to cleanup old deployments: {0}")]
2830
OldCleanup(#[source] Box<dyn StdError + Send>),
2931
#[error("Gateway client error: {0}")]

deployer/src/runtime_manager.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{path::PathBuf, sync::Arc};
1+
use std::{convert::TryInto, path::PathBuf, sync::Arc};
22

33
use anyhow::Context;
44
use shuttle_proto::runtime::{self, runtime_client::RuntimeClient, SubscribeLogsRequest};
@@ -99,7 +99,9 @@ impl RuntimeManager {
9999

100100
tokio::spawn(async move {
101101
while let Ok(Some(log)) = stream.message().await {
102-
sender.send(log.into()).expect("to send log to persistence");
102+
if let Ok(log) = log.try_into() {
103+
sender.send(log).expect("to send log to persistence");
104+
}
103105
}
104106
});
105107

proto/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ anyhow = { workspace = true }
1111
chrono = { workspace = true }
1212
home = "0.5.4"
1313
prost = "0.11.2"
14-
prost-types = "0.11.0"
14+
prost-types = { workspace = true }
1515
tokio = { version = "1.22.0", features = ["process"] }
1616
tonic = { workspace = true }
1717
tracing = { workspace = true }

proto/src/lib.rs

+12-8
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ pub mod provisioner {
9696

9797
pub mod runtime {
9898
use std::{
99+
convert::TryFrom,
99100
path::PathBuf,
100101
process::Command,
101102
time::{Duration, SystemTime},
@@ -104,6 +105,7 @@ pub mod runtime {
104105
use anyhow::Context;
105106
use chrono::DateTime;
106107
use prost_types::Timestamp;
108+
use shuttle_common::models::ParseError;
107109
use tokio::process;
108110
use tonic::transport::{Channel, Endpoint};
109111
use tracing::info;
@@ -159,18 +161,20 @@ pub mod runtime {
159161
}
160162
}
161163

162-
impl From<LogItem> for shuttle_common::LogItem {
163-
fn from(log: LogItem) -> Self {
164-
Self {
165-
id: Uuid::from_slice(&log.id).unwrap(),
166-
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()),
167-
state: LogState::from_i32(log.state).unwrap().into(),
168-
level: LogLevel::from_i32(log.level).unwrap().into(),
164+
impl TryFrom<LogItem> for shuttle_common::LogItem {
165+
type Error = ParseError;
166+
167+
fn try_from(log: LogItem) -> Result<Self, Self::Error> {
168+
Ok(Self {
169+
id: Uuid::from_slice(&log.id)?,
170+
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?),
171+
state: LogState::from_i32(log.state).unwrap_or_default().into(),
172+
level: LogLevel::from_i32(log.level).unwrap_or_default().into(),
169173
file: log.file,
170174
line: log.line,
171175
target: log.target,
172176
fields: log.fields,
173-
}
177+
})
174178
}
175179
}
176180

0 commit comments

Comments
 (0)