Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[next] bug: misc fixes #725

Merged
merged 3 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions codegen/src/shuttle_main/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ impl ToTokens for Loader {
.or_else(|_| shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

let _guard = shuttle_runtime::tracing_subscriber::registry()
shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.set_default(); // Scope our runtime logger to this thread scope only
.init();

#vars
#(let #fn_inputs = #fn_inputs_builder::new()#fn_inputs_builder_options.build(&mut #factory_ident).await.context(format!("failed to provision {}", stringify!(#fn_inputs_builder)))?;)*
Expand Down Expand Up @@ -317,10 +317,10 @@ mod tests {
.or_else(|_| shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

let _guard = shuttle_runtime::tracing_subscriber::registry()
shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.set_default();
.init();

simple().await
}
Expand Down Expand Up @@ -398,10 +398,10 @@ mod tests {
.or_else(|_| shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

let _guard = shuttle_runtime::tracing_subscriber::registry()
shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.set_default();
.init();

let pool = shuttle_shared_db::Postgres::new().build(&mut factory).await.context(format!("failed to provision {}", stringify!(shuttle_shared_db::Postgres)))?;
let redis = shuttle_shared_db::Redis::new().build(&mut factory).await.context(format!("failed to provision {}", stringify!(shuttle_shared_db::Redis)))?;
Expand Down Expand Up @@ -513,10 +513,10 @@ mod tests {
.or_else(|_| shuttle_runtime::tracing_subscriber::EnvFilter::try_new("INFO"))
.unwrap();

let _guard = shuttle_runtime::tracing_subscriber::registry()
shuttle_runtime::tracing_subscriber::registry()
.with(filter_layer)
.with(logger)
.set_default();
.init();

let vars = std::collections::HashMap::from_iter(factory.get_secrets().await?.into_iter().map(|(key, value)| (format!("secrets.{}", key), value)));
let pool = shuttle_shared_db::Postgres::new().size(&shuttle_runtime::strfmt("10Gb", &vars)?).public(false).build(&mut factory).await.context(format!("failed to provision {}", stringify!(shuttle_shared_db::Postgres)))?;
Expand Down
10 changes: 8 additions & 2 deletions common/src/claims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ use tower::{Layer, Service};
use tracing::{error, trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

pub const EXP_MINUTES: i64 = 5;
/// Minutes before a claim expires
///
/// We don't use the convention of 5 minutes because builds can take longer than 5 minutes. When this happens, requests
/// to provisioner will fail as the token expired.
pub const EXP_MINUTES: i64 = 15;
const ISS: &str = "shuttle";

/// The scope of operations that can be performed on shuttle
Expand Down Expand Up @@ -147,9 +151,11 @@ impl Claim {
"failed to convert token to claim"
);
match err.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
StatusCode::from_u16(499).unwrap() // Expired status code which is safe to unwrap
}
jsonwebtoken::errors::ErrorKind::InvalidSignature
| jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName
| jsonwebtoken::errors::ErrorKind::ExpiredSignature
| jsonwebtoken::errors::ErrorKind::InvalidIssuer
| jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
StatusCode::UNAUTHORIZED
Expand Down
16 changes: 15 additions & 1 deletion deployer/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@ mkdir -p $CARGO_HOME; \
echo '[patch.crates-io]
shuttle-service = { path = "/usr/src/shuttle/service" }
shuttle-runtime = { path = "/usr/src/shuttle/runtime" }

shuttle-aws-rds = { path = "/usr/src/shuttle/resources/aws-rds" }
shuttle-persist = { path = "/usr/src/shuttle/resources/persist" }
shuttle-shared-db = { path = "/usr/src/shuttle/resources/shared-db" }
shuttle-secrets = { path = "/usr/src/shuttle/resources/secrets" }
shuttle-static-folder = { path = "/usr/src/shuttle/resources/static-folder" }' > $CARGO_HOME/config.toml
shuttle-static-folder = { path = "/usr/src/shuttle/resources/static-folder" }

shuttle-axum = { path = "/usr/src/shuttle/services/shuttle-axum" }
shuttle-actix-web = { path = "/usr/src/shuttle/services/shuttle-actix-web" }
shuttle-next = { path = "/usr/src/shuttle/services/shuttle-next" }
shuttle-poem = { path = "/usr/src/shuttle/services/shuttle-poem" }
shuttle-poise = { path = "/usr/src/shuttle/services/shuttle-poise" }
shuttle-rocket = { path = "/usr/src/shuttle/services/shuttle-rocket" }
shuttle-salvo = { path = "/usr/src/shuttle/services/shuttle-salvo" }
shuttle-serenity = { path = "/usr/src/shuttle/services/shuttle-serenity" }
shuttle-thruster = { path = "/usr/src/shuttle/services/shuttle-thruster" }
shuttle-tide = { path = "/usr/src/shuttle/services/shuttle-tide" }
shuttle-tower = { path = "/usr/src/shuttle/services/shuttle-tower" }
shuttle-warp = { path = "/usr/src/shuttle/services/shuttle-warp" }' > $CARGO_HOME/config.toml

# Add the wasm32-wasi target
rustup target add wasm32-wasi
Expand Down
70 changes: 10 additions & 60 deletions gateway/src/api/auth_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use axum::{
headers::{authorization::Bearer, Authorization, Cookie, Header, HeaderMapExt},
response::Response,
};
use chrono::{TimeZone, Utc};
use futures::future::BoxFuture;
use http::{Request, StatusCode, Uri};
use hyper::{
Expand All @@ -24,6 +23,11 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
static PROXY_CLIENT: Lazy<ReverseProxy<HttpConnector<GaiResolver>>> =
Lazy::new(|| ReverseProxy::new(Client::new()));

/// Time to cache tokens for. Currently tokens take 15 minutes to expire (see [EXP_MINUTES]) which leaves a 10 minutes
/// buffer (EXP_MINUTES - CACHE_MINUTES). We want the buffer to be atleast as long as the longest builds which has
/// been observed to be around 5 minutes.
const CACHE_MINUTES: u64 = 5;

/// The idea of this layer is to do two things:
/// 1. Forward all user related routes (`/login`, `/logout`, `/users/*`, etc) to our auth service
/// 2. Upgrade all Authorization Bearer keys and session cookies to JWT tokens for internal
Expand Down Expand Up @@ -247,25 +251,11 @@ where
}
};

match extract_token_expiration(response.token.clone()) {
Ok(expiration) => {
// Cache the token.
this.cache_manager.insert(
key.as_str(),
response.token.clone(),
expiration,
);
}
Err(status) => {
error!(
"failed to extract token expiration before inserting into cache"
);
return Ok(Response::builder()
.status(status)
.body(boxed(Body::empty()))
.unwrap());
}
};
this.cache_manager.insert(
key.as_str(),
response.token.clone(),
Duration::from_secs(CACHE_MINUTES * 60),
);

trace!("token inserted in cache, request proceeding");
req.headers_mut()
Expand All @@ -290,46 +280,6 @@ where
}
}

fn extract_token_expiration(token: String) -> Result<Duration, StatusCode> {
oddgrd marked this conversation as resolved.
Show resolved Hide resolved
let (_header, rest) = token
.split_once('.')
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

let (claim, _sig) = rest
.split_once('.')
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

let claim = base64::decode_config(claim, base64::URL_SAFE_NO_PAD)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let claim: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&claim).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let exp = claim["exp"]
.as_i64()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

let expiration_timestamp = Utc
.timestamp_opt(exp, 0)
.single()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

let duration = expiration_timestamp - Utc::now();

// We will use this duration to set the TTL for the JWT in the cache. We subtract 180 seconds
// to make sure a token from the cache will still be valid in cases where it will be used to
// authorize some operation, the operation takes some time, and then the token needs to be
// used again.
//
// This number should never be negative since the JWT has just been created, and so should be
// safe to cast to u64. However, if the number *is* negative it would wrap and the TTL duration
// would be near u64::MAX, so we use try_from to ensure that can't happen.
let duration_minus_buffer = u64::try_from(duration.num_seconds() - 180)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(std::time::Duration::from_secs(duration_minus_buffer))
}

fn make_token_request(uri: &str, header: impl Header) -> Request<Body> {
let mut token_request = Request::builder().uri(uri);
token_request
Expand Down
4 changes: 1 addition & 3 deletions runtime/src/alpha/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use core::future::Future;
use shuttle_common::{
backends::{
auth::{AuthPublicKey, JwtAuthenticationLayer},
tracing::{setup_tracing, ExtractPropagationLayer},
tracing::ExtractPropagationLayer,
},
claims::{Claim, ClaimLayer, InjectPropagationLayer},
resource,
Expand Down Expand Up @@ -54,8 +54,6 @@ pub async fn start(loader: impl Loader<ProvisionerFactory> + Send + 'static) {
let args = Args::parse();
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port);

setup_tracing(tracing_subscriber::registry(), "shuttle-alpha");
oddgrd marked this conversation as resolved.
Show resolved Hide resolved

let provisioner_address = args.provisioner_address;
let mut server_builder = Server::builder()
.http2_keepalive_interval(Some(Duration::from_secs(60)))
Expand Down