Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/api/auth/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};

use super::{ServerParameterProvider, StartupHandler};

pub mod oauth;
pub mod scram;

#[derive(Debug)]
Expand All @@ -22,6 +23,10 @@ pub enum SASLState {
ScramClientFirstReceived,
// cached password, channel_binding and partial auth-message
ScramServerFirstSent(Password, String, String),
// oauth authentication method selected
OauthStateInit,
// failure during authentication
OauthStateError,
// finished
Finished,
}
Expand All @@ -33,6 +38,10 @@ impl SASLState {
SASLState::ScramClientFirstReceived | SASLState::ScramServerFirstSent(_, _, _)
)
}

fn is_oauth(&self) -> bool {
matches!(self, SASLState::OauthStateInit)
}
}

#[derive(Debug)]
Expand All @@ -42,6 +51,8 @@ pub struct SASLAuthStartupHandler<P> {
state: Mutex<SASLState>,
/// scram configuration
scram: Option<scram::ScramAuth>,
/// oauth configuration
oauth: Option<oauth::Oauth>,
}

#[async_trait]
Expand Down Expand Up @@ -77,12 +88,13 @@ impl<P: ServerParameterProvider> StartupHandler for SASLAuthStartupHandler<P> {
if [Self::SCRAM_SHA_256, Self::SCRAM_SHA_256_PLUS].contains(&selected_mechanism)
{
*state = SASLState::ScramClientFirstReceived;
} else if [Self::OAUTHBEARER].contains(&selected_mechanism) {
*state = SASLState::OauthStateInit;
} else {
return Err(PgWireError::UnsupportedSASLAuthMethod(
selected_mechanism.to_string(),
));
}

msg = PasswordMessageFamily::SASLInitialResponse(sasl_initial_response);
} else {
let sasl_response = msg.into_sasl_response()?;
Expand All @@ -104,6 +116,17 @@ impl<P: ServerParameterProvider> StartupHandler for SASLAuthStartupHandler<P> {
}
}

// oauth
if state.is_oauth() {
if let Some(_) = &self.oauth {
} else {
// oauth is not configured
return Err(PgWireError::UnsupportedSASLAuthMethod(
"OAUTHBEARER".to_string(),
));
}
}

if matches!(*state, SASLState::Finished) {
super::finish_authentication(client, self.parameter_provider.as_ref()).await?;
}
Expand All @@ -121,6 +144,7 @@ impl<P> SASLAuthStartupHandler<P> {
parameter_provider,
state: Mutex::new(SASLState::Initial),
scram: None,
oauth: None,
}
}

Expand All @@ -131,6 +155,7 @@ impl<P> SASLAuthStartupHandler<P> {

const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
const OAUTHBEARER: &str = "OAUTHBEARER";

fn supported_mechanisms(&self) -> Vec<String> {
let mut mechanisms = vec![];
Expand All @@ -143,6 +168,10 @@ impl<P> SASLAuthStartupHandler<P> {
}
}

if self.oauth.is_some() {
mechanisms.push(Self::OAUTHBEARER.to_owned());
}

mechanisms
}
}
41 changes: 41 additions & 0 deletions src/api/auth/sasl/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::{fmt::Debug, sync::Arc};

use crate::{
api::{auth::sasl::SASLState, ClientInfo},
error::PgWireResult,
messages::startup::{Authentication, PasswordMessageFamily},
};

#[derive(Debug)]
// TODO: find a way to get the issuer and scope, i think it'll be gotten from the client?
// as per: https://github.com/postgres/postgres/blob/e7ccb247b38fff342c13aa7bdf61ce5ab45b2a85/src/backend/libpq/auth-oauth.c#L102
pub struct Oauth {
pub issuer: String,
pub scope: String,
pub validator: Arc<dyn OauthValidator>,
}

// TODO: move this out
pub trait OauthValidator: Send + Sync + Debug {}

impl Oauth {
/// initialize the oauth context.
pub fn init() -> Self {
todo!()
}

pub async fn process_oauth_message<C>(
&self,
client: &C,
msg: PasswordMessageFamily,
state: &SASLState,
) -> PgWireResult<(Authentication, SASLState)>
where
C: ClientInfo + Unpin + Send,
{
// decode the message, if there's no auth in the SASL data field, return Authentication::SASLContinue
// to make the client respond with the Initial Client Response (still figuring out what this is exactly, but I think it is just the auth field as descibed in the docs)
// then handle authentication based on the states, validate token and bla bla bla
todo!()
}
}
1 change: 1 addition & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum PgWireConnectionState {
AwaitingSync,
}

// TODO: add oauth scope and issuer
/// Describe a client information holder
pub trait ClientInfo {
fn socket_addr(&self) -> SocketAddr;
Expand Down
Loading