From 915aff54a4d403c50f9b98568bbcb4ed675dcce3 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 2 Oct 2024 15:36:42 +0200 Subject: [PATCH 01/15] [WIP]: Migrations over WebSocket --- Cargo.lock | 108 ++++++++++++++++-- Cargo.toml | 7 -- libs/user-facing-errors/src/quaint.rs | 29 +++++ .../postgres_datamodel_connector.rs | 5 +- quaint/Cargo.toml | 6 +- quaint/src/connector/connection_info.rs | 16 ++- quaint/src/connector/native.rs | 4 +- quaint/src/connector/postgres/native/mod.rs | 19 +++ .../connector/postgres/native/websocket.rs | 68 +++++++++++ quaint/src/connector/postgres/url.rs | 39 ++++++- .../src/flavour/postgres.rs | 101 ++++++++++++++-- .../src/flavour/postgres/connection.rs | 42 ++++--- schema-engine/core/src/lib.rs | 2 +- .../src/multi_engine_test_api.rs | 2 +- 14 files changed, 391 insertions(+), 57 deletions(-) create mode 100644 quaint/src/connector/postgres/native/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index 62a21959bf0f..40cc4d01ba29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,6 +164,32 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "async-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90e661b6cb0a6eb34d02c520b052daa3aa9ac0cc02495c9d066bbce13ead132b" +dependencies = [ + "futures-io", + "futures-util", + "log", + "pin-project-lite", + "tokio", + "tungstenite", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", +] + [[package]] name = "asynchronous-codec" version = "0.6.2" @@ -1736,7 +1762,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.2.2", "slab", "tokio", @@ -1908,6 +1934,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1915,7 +1952,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -1942,7 +1979,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -3150,7 +3187,7 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry", "prost", "thiserror", @@ -3379,6 +3416,16 @@ dependencies = [ "indexmap 1.9.3", ] +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "phf" version = "0.11.2" @@ -3770,6 +3817,7 @@ name = "quaint" version = "0.2.0-alpha.13" dependencies = [ "async-trait", + "async-tungstenite", "base64 0.12.3", "bigdecimal", "bit-vec", @@ -3810,11 +3858,12 @@ dependencies = [ "tiberius", "tokio", "tokio-postgres", - "tokio-util 0.6.10", + "tokio-util 0.7.8", "tracing", "tracing-core", "url", "uuid", + "ws_stream_tungstenite", ] [[package]] @@ -4484,7 +4533,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-tls", @@ -5867,7 +5916,6 @@ checksum = "36943ee01a6d67977dd3f84a5a1d2efeb4ada3a1ae771cadfaa535d9d9fc6507" dependencies = [ "bytes", "futures-core", - "futures-io", "futures-sink", "log", "pin-project-lite", @@ -5911,7 +5959,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-timeout", @@ -6117,6 +6165,24 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6258,6 +6324,12 @@ dependencies = [ "user-facing-error-macros", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8-width" version = "0.1.6" @@ -6778,6 +6850,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ws_stream_tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed39ff9f8b2eda91bf6390f9f49eee93d655489e15708e3bb638c1c4f07cecb4" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags 2.4.0", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "pharos", + "rustc_version", + "tokio", + "tracing", + "tungstenite", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index f2000ba619d3..750429edee96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,6 @@ tsify = { version = "0.4.5" } wasm-bindgen = { version = "0.2.92" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2", default-features = false, features = ["console-error"] } -wasm-bindgen-test = { version = "0.3.0" } url = { version = "2.5.0" } bson = { version = "2.11.0", features = ["chrono-0_4", "uuid-1"] } @@ -85,15 +84,9 @@ path = "quaint" [profile.dev.package.backtrace] opt-level = 3 -[profile.release.package.query-engine-node-api] -strip = "symbols" - [profile.release.package.query-engine] strip = "symbols" -[profile.release.package.query-engine-c-abi] -strip = "symbols" - [profile.release] lto = "fat" codegen-units = 1 diff --git a/libs/user-facing-errors/src/quaint.rs b/libs/user-facing-errors/src/quaint.rs index 9c5d72432edc..465f1960fef3 100644 --- a/libs/user-facing-errors/src/quaint.rs +++ b/libs/user-facing-errors/src/quaint.rs @@ -59,6 +59,14 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_port: url.port(), })) } + #[cfg(feature = "postgresql-native")] + ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { + Some(KnownError::new(common::DatabaseDoesNotExist::Postgres { + database_name: db_name.to_string(), + database_host: url.host().to_owned(), + database_port: url.port(), + })) + } #[cfg(feature = "mysql-native")] ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseDoesNotExist::Mysql { @@ -87,6 +95,12 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_name: format!("{}.{}", url.dbname(), url.schema()), })) } + ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { + Some(KnownError::new(common::DatabaseAccessDenied { + database_user: "".to_owned(), + database_name: url.dbname().to_owned(), + })) + } ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseAccessDenied { database_user: url.username().into_owned(), @@ -107,6 +121,14 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_port: url.port(), })) } + #[cfg(feature = "postgresql-native")] + ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { + Some(KnownError::new(common::DatabaseAlreadyExists { + database_name: format!("{db_name}"), + database_host: url.host().to_owned(), + database_port: url.port(), + })) + } #[cfg(feature = "mysql-native")] ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseAlreadyExists { @@ -257,6 +279,13 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_host: url.host().to_owned(), })) } + #[cfg(feature = "postgresql-native")] + (NativeErrorKind::ConnectionError(_), ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url))) => { + Some(KnownError::new(common::DatabaseNotReachable { + database_port: url.port(), + database_host: url.host().to_owned(), + })) + } #[cfg(feature = "mysql-native")] (NativeErrorKind::ConnectionError(_), ConnectionInfo::Native(NativeConnectionInfo::Mysql(url))) => { Some(KnownError::new(common::DatabaseNotReachable { diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index f14a6b9bf1be..65a0d929995f 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -463,7 +463,10 @@ impl Connector for PostgresDatamodelConnector { } fn validate_url(&self, url: &str) -> Result<(), String> { - if !url.starts_with("postgres://") && !url.starts_with("postgresql://") { + if !url.starts_with("postgres://") + && !url.starts_with("postgresql://") + && !url.starts_with("prisma+postgres://") + { return Err("must start with the protocol `postgresql://` or `postgres://`.".to_owned()); } diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b63fe18b4941..079d9035b6eb 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -95,6 +95,8 @@ sqlformat = { version = "0.2.3", optional = true } uuid.workspace = true crosstarget-utils = { path = "../libs/crosstarget-utils" } concat-idents = "1.1.5" +ws_stream_tungstenite = { version = "0.14.0", features = ["tokio_io"] } +async-tungstenite = { version = "0.28.0", features = ["tokio-runtime"]} [dev-dependencies] once_cell = "1.3" @@ -180,8 +182,8 @@ features = ["rt-multi-thread", "macros", "sync"] optional = true [dependencies.tokio-util] -version = "0.6" -features = ["compat"] +version = "0.7" +features = ["compat", "io"] optional = true [build-dependencies] diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 7dd8a5b58257..5c2818f4ae6e 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -98,6 +98,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.dbname()), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(url) => Some(url.dbname()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.dbname()), #[cfg(feature = "mssql")] @@ -120,6 +122,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => url.schema(), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(_) => "public", #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => url.dbname(), #[cfg(feature = "mssql")] @@ -140,6 +144,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => url.host(), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(url) => url.host(), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => url.host(), #[cfg(feature = "mssql")] @@ -159,6 +165,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.username()), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(_) => None, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.username()), #[cfg(feature = "mssql")] @@ -176,7 +184,7 @@ impl ConnectionInfo { #[cfg(not(target_arch = "wasm32"))] ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] - NativeConnectionInfo::Postgres(_) => None, + NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_) => None, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(_) => None, #[cfg(feature = "mssql")] @@ -209,7 +217,7 @@ impl ConnectionInfo { #[cfg(not(target_arch = "wasm32"))] ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] - NativeConnectionInfo::Postgres(_) => SqlFamily::Postgres, + NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_) => SqlFamily::Postgres, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(_) => SqlFamily::Mysql, #[cfg(feature = "mssql")] @@ -228,6 +236,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.port()), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(url) => Some(url.port()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.port()), #[cfg(feature = "mssql")] @@ -256,6 +266,8 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => format!("{}:{}", url.host(), url.port()), + #[cfg(feature = "postgresql")] + NativeConnectionInfo::PostgresWs(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "mssql")] diff --git a/quaint/src/connector/native.rs b/quaint/src/connector/native.rs index d70f710da8d8..f4eba5e1d26b 100644 --- a/quaint/src/connector/native.rs +++ b/quaint/src/connector/native.rs @@ -3,7 +3,7 @@ use crate::connector::MssqlUrl; #[cfg(feature = "mysql")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql")] -use crate::connector::PostgresUrl; +use crate::connector::{PostgresUrl, PostgresWebSocketUrl}; /// General information about a SQL connection, provided by native Rust drivers. #[cfg(not(target_arch = "wasm32"))] @@ -12,6 +12,8 @@ pub enum NativeConnectionInfo { /// A PostgreSQL connection URL. #[cfg(feature = "postgresql")] Postgres(PostgresUrl), + #[cfg(feature = "postgresql")] + PostgresWs(PostgresWebSocketUrl), /// A MySQL connection URL. #[cfg(feature = "mysql")] Mysql(MysqlUrl), diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 805ba13a6021..f6f821817a24 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod column_type; mod conversion; mod error; mod explain; +mod websocket; pub(crate) use crate::connector::postgres::url::PostgresUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; @@ -37,12 +38,15 @@ use std::{ time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; +use websocket::connect_via_websocket; /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] pub use tokio_postgres; +use super::PostgresWebSocketUrl; + struct PostgresClient(Client); impl Debug for PostgresClient { @@ -292,6 +296,21 @@ impl PostgreSql { }) } + /// Create a new websocket connection to managed database + pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let client = connect_via_websocket(url).await?; + + Ok(Self { + client: PostgresClient(client), + socket_timeout: None, + pg_bouncer: false, + statement_cache: Mutex::new(StatementCache::new(0)), + is_healthy: AtomicBool::new(true), + is_cockroachdb: false, + is_materialize: false, + }) + } + /// The underlying tokio_postgres::Client. Only available with the /// `expose-drivers` Cargo feature. This is a lower level API when you need /// to get into database specific features. diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs new file mode 100644 index 000000000000..fba6742deccf --- /dev/null +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -0,0 +1,68 @@ +use std::str::FromStr; + +use async_tungstenite::{ + tokio::connect_async, + tungstenite::{self, client::IntoClientRequest, http::HeaderValue, Error as TungsteniteError}, +}; +use futures::FutureExt; +use tokio_postgres::{Client, Config, NoTls}; +use ws_stream_tungstenite::WsStream; + +use crate::{ + connector::PostgresWebSocketUrl, + error::{self, Error, ErrorKind, NativeErrorKind}, +}; + +const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; + +pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let (ws_stream, response) = connect_async(url).await.inspect_err(|e| { + eprintln!("{}", e); + })?; + + let Some(header) = response.headers().get(CONNECTION_PARAMS_HEADER) else { + let message = format!("Missing response header {CONNECTION_PARAMS_HEADER}"); + let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); + return Err(error); + }; + + let connection_params = header.to_str().map_err(|inner| { + Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() + })?; + + let config = Config::from_str(connection_params)?; + let ws_byte_stream = WsStream::new(ws_stream); + + let (client, connection) = config.connect_raw(ws_byte_stream, NoTls).await?; + tokio::spawn(connection.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + Ok(client) +} + +impl IntoClientRequest for PostgresWebSocketUrl { + fn into_client_request(self) -> tungstenite::Result { + let mut request = self.url.to_string().into_client_request()?; + let bearer = format!("Bearer {}", self.api_key()); + let auth_header = HeaderValue::from_str(&bearer)?; + request.headers_mut().insert("Authorization", auth_header); + Ok(request) + } +} + +impl From for error::Error { + fn from(value: TungsteniteError) -> Self { + let builder = match value { + TungsteniteError::Tls(tls_error) => Error::builder(ErrorKind::Native(NativeErrorKind::TlsError { + message: tls_error.to_string(), + })), + + _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))), + }; + + builder.build() + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 844da48c8d66..8ee12320df1f 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -126,10 +126,7 @@ impl PostgresUrl { /// Name of the database connected. Defaults to `postgres`. pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } + dbname(&self.url) } /// The percent-decoded database password. @@ -431,6 +428,40 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) ssl_mode: SslMode, } +#[derive(Debug, Clone)] +pub struct PostgresWebSocketUrl { + pub(crate) url: Url, + pub(crate) api_key: String, +} + +impl PostgresWebSocketUrl { + pub fn new(url: Url, api_key: String) -> Self { + Self { url, api_key } + } + + pub fn api_key(&self) -> &str { + &self.api_key + } + + pub fn host(&self) -> &str { + self.url.host_str().unwrap_or("localhost") + } + + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(80) + } + + pub fn dbname(&self) -> &str { + dbname(&self.url) + } +} + +fn dbname(url: &Url) -> &str { + match url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } +} #[cfg(test)] mod tests { use super::*; diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index dca3b89f6f2c..4762427d8247 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -5,12 +5,17 @@ use self::connection::*; use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; -use quaint::{connector::PostgresUrl, Value}; +use once_cell::sync::Lazy; +use quaint::{ + connector::{PostgresUrl, PostgresWebSocketUrl}, + prelude::NativeConnectionInfo, + Value, +}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; use sql_schema_describer::SqlSchema; -use std::{borrow::Cow, collections::HashMap, future, time}; +use std::{borrow::Cow, collections::HashMap, future, str::FromStr, time}; use url::Url; use user_facing_errors::{ common::{DatabaseAccessDenied, DatabaseDoesNotExist}, @@ -28,9 +33,79 @@ SET enable_experimental_alter_column_type_general = true; type State = super::State, Connection)>; +#[derive(Clone)] +enum MigratePostgresUrl { + Native(PostgresUrl), + WebSocket(PostgresWebSocketUrl), +} + +static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { + std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") + .map(Cow::Owned) + .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma.io")) +}); + +impl MigratePostgresUrl { + const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; + const API_KEY_PARAM: &'static str = "apiKey"; + + fn new(url: Url) -> ConnectorResult { + if url.scheme() == Self::WEBSOCKET_SCHEME { + let mut ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; + ws_url.set_path(url.path()); + let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { + return Err(ConnectorError::url_parse_error( + "Required `apiKey` query string parameter was not provided in a connection URL", + )); + }; + Ok(Self::WebSocket(PostgresWebSocketUrl::new(ws_url, api_key.into_owned()))) + } else { + let postgres_url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; + Ok(Self::Native(postgres_url)) + } + } + + pub(super) fn host(&self) -> &str { + match self { + MigratePostgresUrl::Native(native_url) => native_url.host(), + MigratePostgresUrl::WebSocket(ws_url) => ws_url.host(), + } + } + + pub(super) fn port(&self) -> u16 { + match self { + MigratePostgresUrl::Native(native_url) => native_url.port(), + MigratePostgresUrl::WebSocket(ws_url) => ws_url.port(), + } + } + + pub(super) fn dbname(&self) -> &str { + match self { + MigratePostgresUrl::Native(native_url) => native_url.dbname(), + MigratePostgresUrl::WebSocket(ws_url) => ws_url.dbname(), + } + } + + pub(super) fn schema(&self) -> &str { + match self { + MigratePostgresUrl::Native(native_url) => native_url.schema(), + MigratePostgresUrl::WebSocket(_) => "public", + } + } +} + +impl From for NativeConnectionInfo { + fn from(value: MigratePostgresUrl) -> Self { + match value { + MigratePostgresUrl::Native(url) => NativeConnectionInfo::Postgres(url), + MigratePostgresUrl::WebSocket(url) => NativeConnectionInfo::PostgresWs(url), + } + } +} + struct Params { connector_params: ConnectorParams, - url: PostgresUrl, + url: MigratePostgresUrl, } /// The specific provider that was requested by the user. @@ -103,7 +178,13 @@ impl PostgresFlavour { } pub(crate) fn schema_name(&self) -> &str { - self.state.params().map(|p| p.url.schema()).unwrap_or("public") + self.state + .params() + .and_then(|p| match &p.url { + MigratePostgresUrl::Native(url) => Some(url.schema()), + MigratePostgresUrl::WebSocket(_) => None, + }) + .unwrap_or("public") } } @@ -378,7 +459,7 @@ impl SqlFlavour for PostgresFlavour { .map_err(ConnectorError::url_parse_error)?; disable_postgres_statement_cache(&mut url)?; let connection_string = url.to_string(); - let url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; + let url = MigratePostgresUrl::new(url)?; connector_params.connection_string = connection_string; let params = Params { connector_params, url }; self.state.set_params(params); @@ -510,7 +591,11 @@ impl SqlFlavour for PostgresFlavour { /// TL;DR, /// 1. pg >= 13 -> it works. /// 2. pg < 13 -> syntax error on WITH (FORCE), and then fail with db in use if pgbouncer is used. -async fn drop_db_try_force(conn: &mut Connection, url: &PostgresUrl, database_name: &str) -> ConnectorResult<()> { +async fn drop_db_try_force( + conn: &mut Connection, + url: &MigratePostgresUrl, + database_name: &str, +) -> ConnectorResult<()> { let drop_database = format!("DROP DATABASE IF EXISTS \"{database_name}\" WITH (FORCE)"); if let Err(err) = conn.raw_cmd(&drop_database, url).await { if let Some(msg) = err.message() { @@ -537,7 +622,7 @@ fn strip_schema_param_from_url(url: &mut Url) { /// Try to connect as an admin to a postgres database. We try to pick a default database from which /// we can create another database. -async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, PostgresUrl)> { +async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, MigratePostgresUrl)> { // "postgres" is the default database on most postgres installations, // "template1" is guaranteed to exist, and "defaultdb" is the only working // option on DigitalOcean managed postgres databases. @@ -547,7 +632,7 @@ async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection for database_name in CANDIDATE_DEFAULT_DATABASES { url.set_path(&format!("/{database_name}")); - let postgres_url = PostgresUrl::new(url.clone()).unwrap(); + let postgres_url = MigratePostgresUrl::Native(PostgresUrl::new(url.clone()).unwrap()); match Connection::new(url.clone()).await { // If the database does not exist, try the next one. Err(err) => match &err.error_code() { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 3ca9b673b0a0..c36161546cb4 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -4,8 +4,8 @@ use enumflags2::BitFlags; use indoc::indoc; use psl::PreviewFeature; use quaint::{ - connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, - prelude::{ConnectionInfo, NativeConnectionInfo, Queryable}, + connector::{self, tokio_postgres::error::ErrorPosition}, + prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; use sql_schema_describer::{postgres::PostgresSchemaExt, SqlSchema}; @@ -13,19 +13,22 @@ use user_facing_errors::{schema_engine::ApplyMigrationError, schema_engine::Data use crate::sql_renderer::IteratorJoin; +use super::MigratePostgresUrl; + pub(super) struct Connection(connector::PostgreSql); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { - let url = PostgresUrl::new(url).map_err(|err| { - ConnectorError::user_facing(user_facing_errors::common::InvalidConnectionString { - details: err.to_string(), - }) - })?; + let url = MigratePostgresUrl::new(url)?; - let quaint = connector::PostgreSql::new(url.clone()) - .await - .map_err(quaint_err(&url))?; + let quaint = match url { + MigratePostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.clone()) + .await + .map_err(quaint_err(&url))?, + MigratePostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()) + .await + .map_err(quaint_err(&url))?, + }; let version = quaint.version().await.map_err(quaint_err(&url))?; @@ -116,12 +119,12 @@ impl Connection { Ok(schema) } - pub(super) async fn raw_cmd(&mut self, sql: &str, url: &PostgresUrl) -> ConnectorResult<()> { + pub(super) async fn raw_cmd(&mut self, sql: &str, url: &MigratePostgresUrl) -> ConnectorResult<()> { tracing::debug!(query_type = "raw_cmd", sql); self.0.raw_cmd(sql).await.map_err(quaint_err(url)) } - pub(super) async fn version(&mut self, url: &PostgresUrl) -> ConnectorResult> { + pub(super) async fn version(&mut self, url: &MigratePostgresUrl) -> ConnectorResult> { tracing::debug!(query_type = "version"); self.0.version().await.map_err(quaint_err(url)) } @@ -129,7 +132,7 @@ impl Connection { pub(super) async fn query( &mut self, query: quaint::ast::Query<'_>, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { use quaint::visitor::Visitor; let (sql, params) = quaint::visitor::Postgres::build(query).unwrap(); @@ -140,7 +143,7 @@ impl Connection { &self, sql: &str, params: &[quaint::prelude::Value<'_>], - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "query_raw", sql, ?params); self.0.query_raw(sql, params).await.map_err(quaint_err(url)) @@ -149,7 +152,7 @@ impl Connection { pub(super) async fn describe_query( &self, sql: &str, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "describe_query", sql); self.0.describe_query(sql).await.map_err(quaint_err(url)) @@ -237,11 +240,6 @@ fn normalize_sql_schema(schema: &mut SqlSchema, preview_features: BitFlags impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { - |err| { - crate::flavour::quaint_error_to_connector_error( - err, - &ConnectionInfo::Native(NativeConnectionInfo::Postgres(url.clone())), - ) - } +fn quaint_err(url: &MigratePostgresUrl) -> impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { + |err| crate::flavour::quaint_error_to_connector_error(err, &ConnectionInfo::Native(url.clone().into())) } diff --git a/schema-engine/core/src/lib.rs b/schema-engine/core/src/lib.rs index b367ab0bfff9..3c0a2bf6d6a1 100644 --- a/schema-engine/core/src/lib.rs +++ b/schema-engine/core/src/lib.rs @@ -41,7 +41,7 @@ fn connector_for_connection_string( preview_features: BitFlags, ) -> CoreResult> { match connection_string.split(':').next() { - Some("postgres") | Some("postgresql") => { + Some("postgres") | Some("postgresql") | Some("prisma+postgres") => { let params = ConnectorParams { connection_string, preview_features, diff --git a/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs b/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs index ca1e807a46b9..93ac1d511890 100644 --- a/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs +++ b/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs @@ -196,7 +196,7 @@ impl TestApi { }; let mut connector = match &connection_info { - ConnectionInfo::Native(NativeConnectionInfo::Postgres(_)) => { + ConnectionInfo::Native(NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_)) => { if self.args.provider() == "cockroachdb" { SqlSchemaConnector::new_cockroach() } else { From 0d495ae9180553eaffd1caabaf54382b6188cdd6 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 2 Oct 2024 17:44:07 +0200 Subject: [PATCH 02/15] Avoid panics in error handling code --- libs/user-facing-errors/src/quaint.rs | 29 ----- quaint/src/connector/connection_info.rs | 23 ++-- quaint/src/connector/native.rs | 4 +- quaint/src/connector/postgres/native/mod.rs | 16 +-- quaint/src/connector/postgres/url.rs | 105 ++++++++++++++---- quaint/src/pooled.rs | 8 +- quaint/src/pooled/manager.rs | 4 +- quaint/src/single.rs | 2 +- .../qe-setup/src/cockroachdb.rs | 2 +- .../qe-setup/src/postgres.rs | 2 +- .../src/flavour/postgres.rs | 59 +++------- .../src/flavour/postgres/connection.rs | 8 +- .../src/multi_engine_test_api.rs | 2 +- 13 files changed, 132 insertions(+), 132 deletions(-) diff --git a/libs/user-facing-errors/src/quaint.rs b/libs/user-facing-errors/src/quaint.rs index 465f1960fef3..9c5d72432edc 100644 --- a/libs/user-facing-errors/src/quaint.rs +++ b/libs/user-facing-errors/src/quaint.rs @@ -59,14 +59,6 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_port: url.port(), })) } - #[cfg(feature = "postgresql-native")] - ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { - Some(KnownError::new(common::DatabaseDoesNotExist::Postgres { - database_name: db_name.to_string(), - database_host: url.host().to_owned(), - database_port: url.port(), - })) - } #[cfg(feature = "mysql-native")] ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseDoesNotExist::Mysql { @@ -95,12 +87,6 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_name: format!("{}.{}", url.dbname(), url.schema()), })) } - ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { - Some(KnownError::new(common::DatabaseAccessDenied { - database_user: "".to_owned(), - database_name: url.dbname().to_owned(), - })) - } ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseAccessDenied { database_user: url.username().into_owned(), @@ -121,14 +107,6 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_port: url.port(), })) } - #[cfg(feature = "postgresql-native")] - ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url)) => { - Some(KnownError::new(common::DatabaseAlreadyExists { - database_name: format!("{db_name}"), - database_host: url.host().to_owned(), - database_port: url.port(), - })) - } #[cfg(feature = "mysql-native")] ConnectionInfo::Native(NativeConnectionInfo::Mysql(url)) => { Some(KnownError::new(common::DatabaseAlreadyExists { @@ -279,13 +257,6 @@ pub fn render_quaint_error(kind: &ErrorKind, connection_info: &ConnectionInfo) - database_host: url.host().to_owned(), })) } - #[cfg(feature = "postgresql-native")] - (NativeErrorKind::ConnectionError(_), ConnectionInfo::Native(NativeConnectionInfo::PostgresWs(url))) => { - Some(KnownError::new(common::DatabaseNotReachable { - database_port: url.port(), - database_host: url.host().to_owned(), - })) - } #[cfg(feature = "mysql-native")] (NativeErrorKind::ConnectionError(_), ConnectionInfo::Native(NativeConnectionInfo::Mysql(url))) => { Some(KnownError::new(common::DatabaseNotReachable { diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 5c2818f4ae6e..2284aeb40902 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -37,6 +37,8 @@ impl ConnectionInfo { /// database. #[cfg(not(target_arch = "wasm32"))] pub fn from_url(url_str: &str) -> crate::Result { + use super::PostgresUrl; + let url_result: Result = url_str.parse(); // Non-URL database strings are interpreted as SQLite file paths. @@ -84,7 +86,8 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres( - PostgresUrl::new(url)?, + // TODO + PostgresUrl::new_native(url)?, ))), #[allow(unreachable_patterns)] _ => unreachable!(), @@ -98,8 +101,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.dbname()), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(url) => Some(url.dbname()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.dbname()), #[cfg(feature = "mssql")] @@ -122,8 +123,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => url.schema(), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(_) => "public", #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => url.dbname(), #[cfg(feature = "mssql")] @@ -144,8 +143,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => url.host(), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(url) => url.host(), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => url.host(), #[cfg(feature = "mssql")] @@ -165,8 +162,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.username()), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(_) => None, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.username()), #[cfg(feature = "mssql")] @@ -184,7 +179,7 @@ impl ConnectionInfo { #[cfg(not(target_arch = "wasm32"))] ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] - NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_) => None, + NativeConnectionInfo::Postgres(_) => None, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(_) => None, #[cfg(feature = "mssql")] @@ -217,7 +212,7 @@ impl ConnectionInfo { #[cfg(not(target_arch = "wasm32"))] ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] - NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_) => SqlFamily::Postgres, + NativeConnectionInfo::Postgres(_) => SqlFamily::Postgres, #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(_) => SqlFamily::Mysql, #[cfg(feature = "mssql")] @@ -236,8 +231,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => Some(url.port()), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(url) => Some(url.port()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => Some(url.port()), #[cfg(feature = "mssql")] @@ -253,7 +246,7 @@ impl ConnectionInfo { pub fn pg_bouncer(&self) -> bool { match self { #[cfg(all(not(target_arch = "wasm32"), feature = "postgresql"))] - ConnectionInfo::Native(NativeConnectionInfo::Postgres(url)) => url.pg_bouncer(), + ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(url))) => url.pg_bouncer(), _ => false, } } @@ -266,8 +259,6 @@ impl ConnectionInfo { ConnectionInfo::Native(info) => match info { #[cfg(feature = "postgresql")] NativeConnectionInfo::Postgres(url) => format!("{}:{}", url.host(), url.port()), - #[cfg(feature = "postgresql")] - NativeConnectionInfo::PostgresWs(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "mysql")] NativeConnectionInfo::Mysql(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "mssql")] diff --git a/quaint/src/connector/native.rs b/quaint/src/connector/native.rs index f4eba5e1d26b..d70f710da8d8 100644 --- a/quaint/src/connector/native.rs +++ b/quaint/src/connector/native.rs @@ -3,7 +3,7 @@ use crate::connector::MssqlUrl; #[cfg(feature = "mysql")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql")] -use crate::connector::{PostgresUrl, PostgresWebSocketUrl}; +use crate::connector::PostgresUrl; /// General information about a SQL connection, provided by native Rust drivers. #[cfg(not(target_arch = "wasm32"))] @@ -12,8 +12,6 @@ pub enum NativeConnectionInfo { /// A PostgreSQL connection URL. #[cfg(feature = "postgresql")] Postgres(PostgresUrl), - #[cfg(feature = "postgresql")] - PostgresWs(PostgresWebSocketUrl), /// A MySQL connection URL. #[cfg(feature = "mysql")] Mysql(MysqlUrl), diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index f6f821817a24..ad53908383fb 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -7,7 +7,7 @@ mod error; mod explain; mod websocket; -pub(crate) use crate::connector::postgres::url::PostgresUrl; +pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{ timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, @@ -164,7 +164,7 @@ impl SslParams { } } -impl PostgresUrl { +impl PostgresNativeUrl { pub(crate) fn cache(&self) -> StatementCache { if self.query_params.pg_bouncer { StatementCache::new(0) @@ -232,7 +232,7 @@ impl PostgresUrl { impl PostgreSql { /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { + pub async fn new(url: PostgresNativeUrl) -> crate::Result { let config = url.to_config(); let mut tls_builder = TlsConnector::builder(); @@ -941,7 +941,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -993,7 +993,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", schema_name); url.query_pairs_mut().append_pair("pbbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1044,7 +1044,7 @@ mod tests { let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1095,7 +1095,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1146,7 +1146,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); let client = PostgreSql::new(pg_url).await.unwrap(); diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 8ee12320df1f..1095bb59fca0 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -63,16 +63,74 @@ impl PostgresFlavour { } } +#[derive(Debug, Clone)] +pub enum PostgresUrl { + Native(Box), + WebSocket(PostgresWebSocketUrl), +} + +impl PostgresUrl { + pub fn new_native(url: Url) -> Result { + Ok(Self::Native(Box::new(PostgresNativeUrl::new(url)?))) + } + + pub fn new_websocket(url: Url, api_key: String) -> Result { + Ok(Self::WebSocket(PostgresWebSocketUrl::new(url, api_key))) + } + + pub fn dbname(&self) -> &str { + match self { + Self::Native(url) => url.dbname(), + Self::WebSocket(url) => url.dbname(), + } + } + + pub fn host(&self) -> &str { + match self { + Self::Native(native_url) => native_url.host(), + Self::WebSocket(ws_url) => ws_url.host(), + } + } + + pub fn port(&self) -> u16 { + match self { + Self::Native(native_url) => native_url.port(), + Self::WebSocket(ws_url) => ws_url.port(), + } + } + + pub fn username(&self) -> Cow<'_, str> { + match self { + Self::Native(native_url) => native_url.username(), + Self::WebSocket(_) => Cow::Borrowed(""), + } + } + + pub fn schema(&self) -> &str { + match self { + Self::Native(native_url) => native_url.schema(), + Self::WebSocket(_) => "public", + } + } + + pub fn socket_timeout(&self) -> Option { + match self { + Self::Native(native_url) => native_url.socket_timeout(), + Self::WebSocket(_) => None, + } + } +} + /// Wraps a connection url and exposes the parsing logic used by Quaint, /// including default values. #[derive(Debug, Clone)] -pub struct PostgresUrl { +pub struct PostgresNativeUrl { pub(crate) url: Url, pub(crate) query_params: PostgresUrlQueryParams, pub(crate) flavour: PostgresFlavour, } -impl PostgresUrl { +impl PostgresNativeUrl { /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection /// parameters. pub fn new(url: Url) -> Result { @@ -473,14 +531,15 @@ mod tests { #[test] fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/psql.sock", url.host()); } #[test] fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/postgresql", url.host()); } @@ -488,63 +547,69 @@ mod tests { #[test] fn should_allow_changing_of_cache_size() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()) + .unwrap(); assert_eq!(420, url.cache().capacity()); } #[test] fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(100, url.cache().capacity()); } #[test] fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()) + .unwrap(); assert_eq!(Some("test"), url.application_name()); } #[test] fn should_have_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Require, url.channel_binding()); } #[test] fn should_have_default_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); } #[test] fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); assert_eq!(0, url.cache().capacity()); } #[test] fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("localhost", url.host()); } #[test] fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); assert_eq!("2001:db8:1234::ffff", url.host()); } #[test] fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); assert_eq!("--cluster=my_cluster", url.options().unwrap()); } @@ -631,7 +696,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", "hello"); url.query_pairs_mut().append_pair("pgbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -647,7 +712,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -661,7 +726,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); @@ -675,7 +740,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "HeLLo"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 381f0c824149..2026679cd480 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -307,8 +307,10 @@ impl Builder { /// - Defaults to `PostgresFlavour::Unknown`. #[cfg(feature = "postgresql-native")] pub fn set_postgres_flavour(&mut self, flavour: crate::connector::PostgresFlavour) { - use crate::connector::NativeConnectionInfo; - if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(ref mut url)) = self.connection_info { + use crate::connector::{NativeConnectionInfo, PostgresUrl}; + if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(ref mut url))) = + self.connection_info + { url.set_flavour(flavour); } @@ -415,7 +417,7 @@ impl Quaint { } #[cfg(feature = "postgresql")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = crate::connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = crate::connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let connection_limit = url.connection_limit(); let pool_timeout = url.pool_timeout(); let max_connection_lifetime = url.max_connection_lifetime(); diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 7533dffcfcc5..0a2fc0adbd03 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -3,7 +3,7 @@ use crate::connector::MssqlUrl; #[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql-native")] -use crate::connector::PostgresUrl; +use crate::connector::PostgresNativeUrl; use crate::{ ast, connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, @@ -85,7 +85,7 @@ pub enum QuaintManager { Mysql { url: MysqlUrl }, #[cfg(feature = "postgresql")] - Postgres { url: PostgresUrl }, + Postgres { url: PostgresNativeUrl }, #[cfg(feature = "sqlite")] Sqlite { url: String, db_name: String }, diff --git a/quaint/src/single.rs b/quaint/src/single.rs index cbf460c41509..fd018925852d 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -148,7 +148,7 @@ impl Quaint { } #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs index 0d876d6b4dcf..c901b9ef3887 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn cockroach_setup(url: String, prisma_schema: &str) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Cockroach); let db_name = quaint_url.dbname(); diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs index 6bbba8564cae..536f51eb4834 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn postgres_setup(url: String, prisma_schema: &str, db_schemas: &[&str]) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Postgres); let (db_name, schema) = (quaint_url.dbname(), quaint_url.schema()); diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index 4762427d8247..9c3feef989b0 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -6,11 +6,7 @@ use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; use once_cell::sync::Lazy; -use quaint::{ - connector::{PostgresUrl, PostgresWebSocketUrl}, - prelude::NativeConnectionInfo, - Value, -}; +use quaint::{connector::PostgresUrl, prelude::NativeConnectionInfo, Value}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; @@ -33,11 +29,8 @@ SET enable_experimental_alter_column_type_general = true; type State = super::State, Connection)>; -#[derive(Clone)] -enum MigratePostgresUrl { - Native(PostgresUrl), - WebSocket(PostgresWebSocketUrl), -} +#[derive(Debug, Clone)] +struct MigratePostgresUrl(PostgresUrl); static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") @@ -50,7 +43,7 @@ impl MigratePostgresUrl { const API_KEY_PARAM: &'static str = "apiKey"; fn new(url: Url) -> ConnectorResult { - if url.scheme() == Self::WEBSOCKET_SCHEME { + let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { let mut ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; ws_url.set_path(url.path()); let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { @@ -58,48 +51,34 @@ impl MigratePostgresUrl { "Required `apiKey` query string parameter was not provided in a connection URL", )); }; - Ok(Self::WebSocket(PostgresWebSocketUrl::new(ws_url, api_key.into_owned()))) + PostgresUrl::new_websocket(ws_url, api_key.into_owned()).map_err(ConnectorError::url_parse_error)? } else { - let postgres_url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; - Ok(Self::Native(postgres_url)) - } + PostgresUrl::new_native(url).map_err(ConnectorError::url_parse_error)? + }; + + Ok(Self(postgres_url)) } pub(super) fn host(&self) -> &str { - match self { - MigratePostgresUrl::Native(native_url) => native_url.host(), - MigratePostgresUrl::WebSocket(ws_url) => ws_url.host(), - } + self.0.host() } pub(super) fn port(&self) -> u16 { - match self { - MigratePostgresUrl::Native(native_url) => native_url.port(), - MigratePostgresUrl::WebSocket(ws_url) => ws_url.port(), - } + self.0.port() } pub(super) fn dbname(&self) -> &str { - match self { - MigratePostgresUrl::Native(native_url) => native_url.dbname(), - MigratePostgresUrl::WebSocket(ws_url) => ws_url.dbname(), - } + self.0.dbname() } pub(super) fn schema(&self) -> &str { - match self { - MigratePostgresUrl::Native(native_url) => native_url.schema(), - MigratePostgresUrl::WebSocket(_) => "public", - } + self.0.schema() } } impl From for NativeConnectionInfo { fn from(value: MigratePostgresUrl) -> Self { - match value { - MigratePostgresUrl::Native(url) => NativeConnectionInfo::Postgres(url), - MigratePostgresUrl::WebSocket(url) => NativeConnectionInfo::PostgresWs(url), - } + NativeConnectionInfo::Postgres(value.0) } } @@ -178,13 +157,7 @@ impl PostgresFlavour { } pub(crate) fn schema_name(&self) -> &str { - self.state - .params() - .and_then(|p| match &p.url { - MigratePostgresUrl::Native(url) => Some(url.schema()), - MigratePostgresUrl::WebSocket(_) => None, - }) - .unwrap_or("public") + self.state.params().map(|p| p.url.schema()).unwrap_or("public") } } @@ -632,7 +605,7 @@ async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection for database_name in CANDIDATE_DEFAULT_DATABASES { url.set_path(&format!("/{database_name}")); - let postgres_url = MigratePostgresUrl::Native(PostgresUrl::new(url.clone()).unwrap()); + let postgres_url = MigratePostgresUrl(PostgresUrl::new_native(url.clone()).unwrap()); match Connection::new(url.clone()).await { // If the database does not exist, try the next one. Err(err) => match &err.error_code() { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index c36161546cb4..122d77eef257 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -4,7 +4,7 @@ use enumflags2::BitFlags; use indoc::indoc; use psl::PreviewFeature; use quaint::{ - connector::{self, tokio_postgres::error::ErrorPosition}, + connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; @@ -21,11 +21,11 @@ impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { let url = MigratePostgresUrl::new(url)?; - let quaint = match url { - MigratePostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.clone()) + let quaint = match url.0 { + PostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.as_ref().clone()) .await .map_err(quaint_err(&url))?, - MigratePostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()) + PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()) .await .map_err(quaint_err(&url))?, }; diff --git a/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs b/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs index 93ac1d511890..ca1e807a46b9 100644 --- a/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs +++ b/schema-engine/sql-migration-tests/src/multi_engine_test_api.rs @@ -196,7 +196,7 @@ impl TestApi { }; let mut connector = match &connection_info { - ConnectionInfo::Native(NativeConnectionInfo::Postgres(_) | NativeConnectionInfo::PostgresWs(_)) => { + ConnectionInfo::Native(NativeConnectionInfo::Postgres(_)) => { if self.args.provider() == "cockroachdb" { SqlSchemaConnector::new_cockroach() } else { From fd1033417d62c0a4353adf0d4d38511bda16ee9c Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 2 Oct 2024 17:50:23 +0200 Subject: [PATCH 03/15] Fix compilation --- quaint/src/connector/connection_info.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 2284aeb40902..80d63089e489 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -37,8 +37,6 @@ impl ConnectionInfo { /// database. #[cfg(not(target_arch = "wasm32"))] pub fn from_url(url_str: &str) -> crate::Result { - use super::PostgresUrl; - let url_result: Result = url_str.parse(); // Non-URL database strings are interpreted as SQLite file paths. @@ -86,8 +84,7 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres( - // TODO - PostgresUrl::new_native(url)?, + super::PostgresUrl::new_native(url)?, ))), #[allow(unreachable_patterns)] _ => unreachable!(), From 516094e4b52a28d97c8f68a408ace7f0e51bee8f Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 2 Oct 2024 18:20:53 +0200 Subject: [PATCH 04/15] Handle unauthorized error --- quaint/src/connector/postgres/native/websocket.rs | 15 +++++++++++++-- .../sql-schema-connector/src/flavour/postgres.rs | 6 +++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index fba6742deccf..7e572e3a1613 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -2,7 +2,12 @@ use std::str::FromStr; use async_tungstenite::{ tokio::connect_async, - tungstenite::{self, client::IntoClientRequest, http::HeaderValue, Error as TungsteniteError}, + tungstenite::{ + self, + client::IntoClientRequest, + http::{HeaderValue, StatusCode}, + Error as TungsteniteError, + }, }; use futures::FutureExt; use tokio_postgres::{Client, Config, NoTls}; @@ -10,7 +15,7 @@ use ws_stream_tungstenite::WsStream; use crate::{ connector::PostgresWebSocketUrl, - error::{self, Error, ErrorKind, NativeErrorKind}, + error::{self, Error, ErrorKind, Name, NativeErrorKind}, }; const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; @@ -60,6 +65,12 @@ impl From for error::Error { message: tls_error.to_string(), })), + TungsteniteError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => { + Error::builder(ErrorKind::DatabaseAccessDenied { + db_name: Name::Unavailable, + }) + } + _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))), }; diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index 9c3feef989b0..79b2b7d7f65d 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -51,10 +51,10 @@ impl MigratePostgresUrl { "Required `apiKey` query string parameter was not provided in a connection URL", )); }; - PostgresUrl::new_websocket(ws_url, api_key.into_owned()).map_err(ConnectorError::url_parse_error)? + PostgresUrl::new_websocket(ws_url, api_key.into_owned()) } else { - PostgresUrl::new_native(url).map_err(ConnectorError::url_parse_error)? - }; + PostgresUrl::new_native(url) + }.map_err(ConnectorError::url_parse_error)?; Ok(Self(postgres_url)) } From ae170354c9d41323622ac659c9a793a0f8e48da4 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 11:47:13 +0200 Subject: [PATCH 05/15] Fix wasm build --- quaint/Cargo.toml | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 079d9035b6eb..9ef66d1915aa 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -51,6 +51,8 @@ postgresql-native = [ "bit-vec", "lru-cache", "byteorder", + "dep:ws_stream_tungstenite", + "dep:async-tungstenite" ] postgresql = [] @@ -95,8 +97,6 @@ sqlformat = { version = "0.2.3", optional = true } uuid.workspace = true crosstarget-utils = { path = "../libs/crosstarget-utils" } concat-idents = "1.1.5" -ws_stream_tungstenite = { version = "0.14.0", features = ["tokio_io"] } -async-tungstenite = { version = "0.28.0", features = ["tokio-runtime"]} [dev-dependencies] once_cell = "1.3" @@ -113,6 +113,16 @@ expect-test = "1" version = "0.2" features = ["js"] +[dependencies.ws_stream_tungstenite] +version = "0.14.0" +features = ["tokio_io"] +optional = true + +[dependencies.async-tungstenite] +version = "0.28.0" +features = ["tokio-runtime"] +optional = true + [dependencies.byteorder] default-features = false optional = true @@ -183,7 +193,7 @@ optional = true [dependencies.tokio-util] version = "0.7" -features = ["compat", "io"] +features = ["compat"] optional = true [build-dependencies] From ad2a323f2eaff0ce607d5631739d533585a0c58b Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 11:50:21 +0200 Subject: [PATCH 06/15] Cargo fmt --- .../connectors/sql-schema-connector/src/flavour/postgres.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index 79b2b7d7f65d..b32b9c3a19cc 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -54,7 +54,8 @@ impl MigratePostgresUrl { PostgresUrl::new_websocket(ws_url, api_key.into_owned()) } else { PostgresUrl::new_native(url) - }.map_err(ConnectorError::url_parse_error)?; + } + .map_err(ConnectorError::url_parse_error)?; Ok(Self(postgres_url)) } From f8516375c229b75214ebcead26748b81659959d1 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 12:18:47 +0200 Subject: [PATCH 07/15] Use correct urls --- quaint/src/connector/postgres/url.rs | 17 +++++------------ .../src/flavour/postgres.rs | 5 ++--- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 1095bb59fca0..703aff33ebb5 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -81,7 +81,7 @@ impl PostgresUrl { pub fn dbname(&self) -> &str { match self { Self::Native(url) => url.dbname(), - Self::WebSocket(url) => url.dbname(), + Self::WebSocket(_) => "postgres", } } @@ -184,7 +184,10 @@ impl PostgresNativeUrl { /// Name of the database connected. Defaults to `postgres`. pub fn dbname(&self) -> &str { - dbname(&self.url) + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } } /// The percent-decoded database password. @@ -508,18 +511,8 @@ impl PostgresWebSocketUrl { pub fn port(&self) -> u16 { self.url.port().unwrap_or(80) } - - pub fn dbname(&self) -> &str { - dbname(&self.url) - } } -fn dbname(url: &Url) -> &str { - match url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } -} #[cfg(test)] mod tests { use super::*; diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index b32b9c3a19cc..bc99ec8c89b7 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -35,7 +35,7 @@ struct MigratePostgresUrl(PostgresUrl); static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") .map(Cow::Owned) - .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma.io")) + .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma-data.net/websocket")) }); impl MigratePostgresUrl { @@ -44,8 +44,7 @@ impl MigratePostgresUrl { fn new(url: Url) -> ConnectorResult { let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { - let mut ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; - ws_url.set_path(url.path()); + let ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { return Err(ConnectorError::url_parse_error( "Required `apiKey` query string parameter was not provided in a connection URL", From f9c05c34f952a78faeb669d73f2642b199ec4e18 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 12:19:41 +0200 Subject: [PATCH 08/15] Restore Cargo.toml --- Cargo.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 750429edee96..f2000ba619d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ tsify = { version = "0.4.5" } wasm-bindgen = { version = "0.2.92" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2", default-features = false, features = ["console-error"] } +wasm-bindgen-test = { version = "0.3.0" } url = { version = "2.5.0" } bson = { version = "2.11.0", features = ["chrono-0_4", "uuid-1"] } @@ -84,9 +85,15 @@ path = "quaint" [profile.dev.package.backtrace] opt-level = 3 +[profile.release.package.query-engine-node-api] +strip = "symbols" + [profile.release.package.query-engine] strip = "symbols" +[profile.release.package.query-engine-c-abi] +strip = "symbols" + [profile.release] lto = "fat" codegen-units = 1 From 449db7d7e31c2388a41f285a7d89c28c061a2ce1 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 17:04:48 +0200 Subject: [PATCH 09/15] Fix TLS and api_key --- Cargo.lock | 3 +++ quaint/Cargo.toml | 2 +- .../connectors/sql-schema-connector/src/flavour/postgres.rs | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 40cc4d01ba29..8908c0596350 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -173,8 +173,10 @@ dependencies = [ "futures-io", "futures-util", "log", + "native-tls", "pin-project-lite", "tokio", + "tokio-native-tls", "tungstenite", ] @@ -6177,6 +6179,7 @@ dependencies = [ "http 1.1.0", "httparse", "log", + "native-tls", "rand 0.8.5", "sha1", "thiserror", diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 9ef66d1915aa..42e40d537fd9 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -120,7 +120,7 @@ optional = true [dependencies.async-tungstenite] version = "0.28.0" -features = ["tokio-runtime"] +features = ["tokio-runtime", "tokio-native-tls"] optional = true [dependencies.byteorder] diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index bc99ec8c89b7..ac704459b5a9 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -40,14 +40,14 @@ static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { impl MigratePostgresUrl { const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; - const API_KEY_PARAM: &'static str = "apiKey"; + const API_KEY_PARAM: &'static str = "api_key"; fn new(url: Url) -> ConnectorResult { let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { let ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { return Err(ConnectorError::url_parse_error( - "Required `apiKey` query string parameter was not provided in a connection URL", + "Required `api_key` query string parameter was not provided in a connection URL", )); }; PostgresUrl::new_websocket(ws_url, api_key.into_owned()) From 052f389b5f74065ef29fd643f0b98bc1af6c16b9 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 17:38:33 +0200 Subject: [PATCH 10/15] TLS support --- quaint/src/connector/postgres/native/websocket.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 7e572e3a1613..831e5ce7e3fb 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -10,6 +10,7 @@ use async_tungstenite::{ }, }; use futures::FutureExt; +use postgres_native_tls::{MakeTlsConnector, TlsConnector}; use tokio_postgres::{Client, Config, NoTls}; use ws_stream_tungstenite::WsStream; @@ -21,6 +22,7 @@ use crate::{ const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let host = url.host().to_owned(); let (ws_stream, response) = connect_async(url).await.inspect_err(|e| { eprintln!("{}", e); })?; @@ -38,7 +40,9 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R let config = Config::from_str(connection_params)?; let ws_byte_stream = WsStream::new(ws_stream); - let (client, connection) = config.connect_raw(ws_byte_stream, NoTls).await?; + let native_tls = native_tls::TlsConnector::new()?; + let tls = TlsConnector::new(native_tls, &host); + let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; tokio::spawn(connection.map(|r| match r { Ok(_) => (), Err(e) => { From 6f42a5be081634a20cdb6ea860e6417ff8a8a3fb Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Fri, 4 Oct 2024 18:04:09 +0200 Subject: [PATCH 11/15] Fix fixed TLS --- quaint/src/connector/postgres/native/websocket.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 831e5ce7e3fb..ff72bb9970c8 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -10,8 +10,8 @@ use async_tungstenite::{ }, }; use futures::FutureExt; -use postgres_native_tls::{MakeTlsConnector, TlsConnector}; -use tokio_postgres::{Client, Config, NoTls}; +use postgres_native_tls::TlsConnector; +use tokio_postgres::{Client, Config}; use ws_stream_tungstenite::WsStream; use crate::{ @@ -22,7 +22,6 @@ use crate::{ const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { - let host = url.host().to_owned(); let (ws_stream, response) = connect_async(url).await.inspect_err(|e| { eprintln!("{}", e); })?; @@ -40,8 +39,7 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R let config = Config::from_str(connection_params)?; let ws_byte_stream = WsStream::new(ws_stream); - let native_tls = native_tls::TlsConnector::new()?; - let tls = TlsConnector::new(native_tls, &host); + let tls = TlsConnector::new(native_tls::TlsConnector::new()?, "TODO"); let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; tokio::spawn(connection.map(|r| match r { Ok(_) => (), From bc9fcf9be87ae60a3a3de776e61e162734fa4b69 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 9 Oct 2024 16:05:25 +0200 Subject: [PATCH 12/15] Implement TLS handshake with proper host name --- .../connector/postgres/native/websocket.rs | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index ff72bb9970c8..53412b8e2317 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -5,7 +5,7 @@ use async_tungstenite::{ tungstenite::{ self, client::IntoClientRequest, - http::{HeaderValue, StatusCode}, + http::{HeaderMap, HeaderValue, StatusCode}, Error as TungsteniteError, }, }; @@ -20,26 +20,25 @@ use crate::{ }; const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; +const HOST_HEADER: &str = "Prisma-Db-Host"; pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { let (ws_stream, response) = connect_async(url).await.inspect_err(|e| { - eprintln!("{}", e); + dbg!(&e); + if let TungsteniteError::Http(response) = e { + dbg!(String::from_utf8(response.body().clone().unwrap()).unwrap()); + } })?; - let Some(header) = response.headers().get(CONNECTION_PARAMS_HEADER) else { - let message = format!("Missing response header {CONNECTION_PARAMS_HEADER}"); - let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); - return Err(error); - }; - - let connection_params = header.to_str().map_err(|inner| { - Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() - })?; + let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; + dbg!(&connection_params); + let db_host = require_header_value(response.headers(), HOST_HEADER)?; + dbg!(&connection_params); let config = Config::from_str(connection_params)?; let ws_byte_stream = WsStream::new(ws_stream); - let tls = TlsConnector::new(native_tls::TlsConnector::new()?, "TODO"); + let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host); let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; tokio::spawn(connection.map(|r| match r { Ok(_) => (), @@ -50,6 +49,20 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R Ok(client) } +fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> { + let Some(header) = headers.get(name) else { + let message = format!("Missing response header {name}"); + let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); + return Err(error); + }; + + let value = header.to_str().map_err(|inner| { + Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() + })?; + + Ok(value) +} + impl IntoClientRequest for PostgresWebSocketUrl { fn into_client_request(self) -> tungstenite::Result { let mut request = self.url.to_string().into_client_request()?; From 2af0f60b0c86577bb010f01b23fd3d852290c88d Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 9 Oct 2024 16:09:59 +0200 Subject: [PATCH 13/15] Update quaint/src/connector/postgres/native/websocket.rs Co-authored-by: Alberto Schiabel --- quaint/src/connector/postgres/native/websocket.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 53412b8e2317..b38e91a973e2 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -43,7 +43,7 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R tokio::spawn(connection.map(|r| match r { Ok(_) => (), Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); + tracing::error!("Error in PostgreSQL WebSocket connection: {:?}", e); } })); Ok(client) From 98181068eb488580347cbc447d000612d949e00f Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 9 Oct 2024 16:13:33 +0200 Subject: [PATCH 14/15] Remove dbg --- quaint/src/connector/postgres/native/websocket.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index b38e91a973e2..8b6b936bfd2b 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -23,15 +23,9 @@ const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; const HOST_HEADER: &str = "Prisma-Db-Host"; pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { - let (ws_stream, response) = connect_async(url).await.inspect_err(|e| { - dbg!(&e); - if let TungsteniteError::Http(response) = e { - dbg!(String::from_utf8(response.body().clone().unwrap()).unwrap()); - } - })?; + let (ws_stream, response) = connect_async(url).await?; let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; - dbg!(&connection_params); let db_host = require_header_value(response.headers(), HOST_HEADER)?; dbg!(&connection_params); From 44fc31aa119a1850fd50d310484f52dd9ae4d6a9 Mon Sep 17 00:00:00 2001 From: Serhii Tatarintsev Date: Wed, 9 Oct 2024 16:48:06 +0200 Subject: [PATCH 15/15] Feedback & cleanup --- quaint/src/connector/postgres/native/websocket.rs | 1 - .../src/flavour/postgres/connection.rs | 11 ++++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs index 8b6b936bfd2b..7899e9a22ec0 100644 --- a/quaint/src/connector/postgres/native/websocket.rs +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -27,7 +27,6 @@ pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::R let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; let db_host = require_header_value(response.headers(), HOST_HEADER)?; - dbg!(&connection_params); let config = Config::from_str(connection_params)?; let ws_byte_stream = WsStream::new(ws_stream); diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 122d77eef257..cb31b4394d72 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -22,13 +22,10 @@ impl Connection { let url = MigratePostgresUrl::new(url)?; let quaint = match url.0 { - PostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.as_ref().clone()) - .await - .map_err(quaint_err(&url))?, - PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()) - .await - .map_err(quaint_err(&url))?, - }; + PostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.as_ref().clone()).await, + PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, + } + .map_err(quaint_err(&url))?; let version = quaint.version().await.map_err(quaint_err(&url))?;