Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/api/auth/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use super::{ServerParameterProvider, StartupHandler};

pub mod scram;
pub mod oauth;

#[derive(Debug)]
pub enum SASLState {
Expand Down Expand Up @@ -42,6 +43,9 @@ pub struct SASLAuthStartupHandler<P> {
state: Mutex<SASLState>,
/// scram configuration
scram: Option<scram::ScramAuth>,
/// oauth configuration
// TODO
oauth: Option<String>,
}

#[async_trait]
Expand Down Expand Up @@ -74,6 +78,10 @@ impl<P: ServerParameterProvider> StartupHandler for SASLAuthStartupHandler<P> {
let sasl_initial_response = msg.into_sasl_initial_response()?;
let selected_mechanism = sasl_initial_response.auth_method.as_str();

// TODO: include the oauth mechanism, but I am not sure if the state should still be
// ScramClientFirstReceived. when I am respomding with the SASLInitialResponse, I have to include the `auth`
// field. that means I have to check if the selected mechanism is oauth, then construct the SASLInitialResponse.
// then, I'll handle the AuthenticationSASLContinue message type (add it to the PasswordMessageFamily enum)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add it to the PasswordMessageFamily enum

The reason we have PasswordMessageFamily enum is because all these types shares the same type code p. We already have AuthenticationSASLCcontinue in Authentication.

if [Self::SCRAM_SHA_256, Self::SCRAM_SHA_256_PLUS].contains(&selected_mechanism)
{
*state = SASLState::ScramClientFirstReceived;
Expand Down Expand Up @@ -121,6 +129,7 @@ impl<P> SASLAuthStartupHandler<P> {
parameter_provider,
state: Mutex::new(SASLState::Initial),
scram: None,
oauth: None,
}
}

Expand All @@ -131,6 +140,7 @@ impl<P> SASLAuthStartupHandler<P> {

const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
const OAUTHBEARER: &str = "OAUTHBEARER";

fn supported_mechanisms(&self) -> Vec<String> {
let mut mechanisms = vec![];
Expand All @@ -143,6 +153,10 @@ impl<P> SASLAuthStartupHandler<P> {
}
}

if let Some(oauth) = &self.oauth {
mechanisms.push(Self::OAUTHBEARER.to_owned());
}

mechanisms
}
}
Empty file added src/api/auth/sasl/oauth.rs
Empty file.
1 change: 1 addition & 0 deletions src/messages/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ impl Message for GssEncRequest {
pub struct SASLInitialResponse {
pub auth_method: String,
pub data: Option<Bytes>,
// TODO: add auth field here to be None
}

impl Message for SASLInitialResponse {
Expand Down
Loading