Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 8 additions & 34 deletions datafusion-postgres/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use log::warn;
use pgwire::api::auth::{AuthSource, LoginInfo, Password};
use pgwire::error::{PgWireError, PgWireResult};
use tokio::sync::RwLock;
Expand Down Expand Up @@ -122,11 +121,7 @@ impl Default for AuthManager {

impl AuthManager {
pub fn new() -> Self {
let auth_manager = AuthManager {
users: Arc::new(RwLock::new(HashMap::new())),
roles: Arc::new(RwLock::new(HashMap::new())),
};

let mut users = HashMap::new();
// Initialize with default postgres superuser
let postgres_user = User {
username: "postgres".to_string(),
Expand All @@ -136,7 +131,9 @@ impl AuthManager {
can_login: true,
connection_limit: None,
};
users.insert(postgres_user.username.clone(), postgres_user);

let mut roles = HashMap::new();
let postgres_role = Role {
name: "postgres".to_string(),
is_superuser: true,
Expand All @@ -153,35 +150,12 @@ impl AuthManager {
}],
inherited_roles: vec![],
};
roles.insert(postgres_role.name.clone(), postgres_role);

// Add default users and roles
let auth_manager_clone = AuthManager {
users: auth_manager.users.clone(),
roles: auth_manager.roles.clone(),
};

tokio::spawn({
let users = auth_manager.users.clone();
let roles = auth_manager.roles.clone();
let auth_manager_spawn = auth_manager_clone;
async move {
users
.write()
.await
.insert("postgres".to_string(), postgres_user);
roles
.write()
.await
.insert("postgres".to_string(), postgres_role);

// Create predefined roles
if let Err(e) = auth_manager_spawn.create_predefined_roles().await {
warn!("Failed to create predefined roles: {e:?}");
}
}
});

auth_manager
AuthManager {
users: Arc::new(RwLock::new(users)),
roles: Arc::new(RwLock::new(roles)),
}
}

/// Add a new user to the system
Expand Down
121 changes: 121 additions & 0 deletions datafusion-postgres/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::{collections::HashMap, sync::Arc};

use datafusion::prelude::SessionContext;
use datafusion_postgres::{auth::AuthManager, pg_catalog::setup_pg_catalog, DfSessionService};
use futures::Sink;
use pgwire::{
api::{ClientInfo, ClientPortalStore, PgWireConnectionState, METADATA_USER},
messages::{
response::TransactionStatus, startup::SecretKey, PgWireBackendMessage, ProtocolVersion,
},
};

pub fn setup_handlers() -> DfSessionService {
let session_context = SessionContext::new();
setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context");

DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
}

#[derive(Debug, Default)]
pub struct MockClient {
metadata: HashMap<String, String>,
portal_store: HashMap<String, String>,
}

impl MockClient {
pub fn new() -> MockClient {
let mut metadata = HashMap::new();
metadata.insert(METADATA_USER.to_string(), "postgres".to_string());

MockClient {
metadata,
portal_store: HashMap::default(),
}
}
}

impl ClientInfo for MockClient {
fn socket_addr(&self) -> std::net::SocketAddr {
"127.0.0.1".parse().unwrap()
}

fn is_secure(&self) -> bool {
false
}

fn protocol_version(&self) -> ProtocolVersion {
ProtocolVersion::PROTOCOL3_0
}

fn set_protocol_version(&mut self, _version: ProtocolVersion) {}

fn pid_and_secret_key(&self) -> (i32, SecretKey) {
(0, SecretKey::I32(0))
}

fn set_pid_and_secret_key(&mut self, _pid: i32, _secret_key: SecretKey) {}

fn state(&self) -> PgWireConnectionState {
PgWireConnectionState::ReadyForQuery
}

fn set_state(&mut self, _new_state: PgWireConnectionState) {}

fn transaction_status(&self) -> TransactionStatus {
TransactionStatus::Idle
}

fn set_transaction_status(&mut self, _new_status: TransactionStatus) {}

fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}

fn metadata_mut(&mut self) -> &mut HashMap<String, String> {
&mut self.metadata
}

fn client_certificates<'a>(&self) -> Option<&[rustls_pki_types::CertificateDer<'a>]> {
None
}
}

impl ClientPortalStore for MockClient {
type PortalStore = HashMap<String, String>;
fn portal_store(&self) -> &Self::PortalStore {
&self.portal_store
}
}

impl Sink<PgWireBackendMessage> for MockClient {
type Error = std::io::Error;

fn poll_ready(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}

fn start_send(
self: std::pin::Pin<&mut Self>,
_item: PgWireBackendMessage,
) -> Result<(), Self::Error> {
Ok(())
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
14 changes: 14 additions & 0 deletions datafusion-postgres/tests/dbeaver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
mod common;

use common::*;
use pgwire::api::query::SimpleQueryHandler;

#[tokio::test]
pub async fn test_dbeaver_startup_sql() {
let service = setup_handlers();
let mut client = MockClient::new();

SimpleQueryHandler::do_query(&service, &mut client, "SELECT 1")
.await
.expect("failed to run sql");
}
Loading