diff --git a/Cargo.lock b/Cargo.lock index 62a21959bf0..8908c059635 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,6 +164,34 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "async-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90e661b6cb0a6eb34d02c520b052daa3aa9ac0cc02495c9d066bbce13ead132b" +dependencies = [ + "futures-io", + "futures-util", + "log", + "native-tls", + "pin-project-lite", + "tokio", + "tokio-native-tls", + "tungstenite", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", +] + [[package]] name = "asynchronous-codec" version = "0.6.2" @@ -1736,7 +1764,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.2.2", "slab", "tokio", @@ -1908,6 +1936,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1915,7 +1954,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -1942,7 +1981,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -3150,7 +3189,7 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry", "prost", "thiserror", @@ -3379,6 +3418,16 @@ dependencies = [ "indexmap 1.9.3", ] +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "phf" version = "0.11.2" @@ -3770,6 +3819,7 @@ name = "quaint" version = "0.2.0-alpha.13" dependencies = [ "async-trait", + "async-tungstenite", "base64 0.12.3", "bigdecimal", "bit-vec", @@ -3810,11 +3860,12 @@ dependencies = [ "tiberius", "tokio", "tokio-postgres", - "tokio-util 0.6.10", + "tokio-util 0.7.8", "tracing", "tracing-core", "url", "uuid", + "ws_stream_tungstenite", ] [[package]] @@ -4484,7 +4535,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-tls", @@ -5867,7 +5918,6 @@ checksum = "36943ee01a6d67977dd3f84a5a1d2efeb4ada3a1ae771cadfaa535d9d9fc6507" dependencies = [ "bytes", "futures-core", - "futures-io", "futures-sink", "log", "pin-project-lite", @@ -5911,7 +5961,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-timeout", @@ -6117,6 +6167,25 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "native-tls", + "rand 0.8.5", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6258,6 +6327,12 @@ dependencies = [ "user-facing-error-macros", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8-width" version = "0.1.6" @@ -6778,6 +6853,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ws_stream_tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed39ff9f8b2eda91bf6390f9f49eee93d655489e15708e3bb638c1c4f07cecb4" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags 2.4.0", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "pharos", + "rustc_version", + "tokio", + "tracing", + "tungstenite", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index f14a6b9bf1b..65a0d929995 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -463,7 +463,10 @@ impl Connector for PostgresDatamodelConnector { } fn validate_url(&self, url: &str) -> Result<(), String> { - if !url.starts_with("postgres://") && !url.starts_with("postgresql://") { + if !url.starts_with("postgres://") + && !url.starts_with("postgresql://") + && !url.starts_with("prisma+postgres://") + { return Err("must start with the protocol `postgresql://` or `postgres://`.".to_owned()); } diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b63fe18b494..42e40d537fd 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -51,6 +51,8 @@ postgresql-native = [ "bit-vec", "lru-cache", "byteorder", + "dep:ws_stream_tungstenite", + "dep:async-tungstenite" ] postgresql = [] @@ -111,6 +113,16 @@ expect-test = "1" version = "0.2" features = ["js"] +[dependencies.ws_stream_tungstenite] +version = "0.14.0" +features = ["tokio_io"] +optional = true + +[dependencies.async-tungstenite] +version = "0.28.0" +features = ["tokio-runtime", "tokio-native-tls"] +optional = true + [dependencies.byteorder] default-features = false optional = true @@ -180,7 +192,7 @@ features = ["rt-multi-thread", "macros", "sync"] optional = true [dependencies.tokio-util] -version = "0.6" +version = "0.7" features = ["compat"] optional = true diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 7dd8a5b5825..80d63089e48 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -84,7 +84,7 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres( - PostgresUrl::new(url)?, + super::PostgresUrl::new_native(url)?, ))), #[allow(unreachable_patterns)] _ => unreachable!(), @@ -243,7 +243,7 @@ impl ConnectionInfo { pub fn pg_bouncer(&self) -> bool { match self { #[cfg(all(not(target_arch = "wasm32"), feature = "postgresql"))] - ConnectionInfo::Native(NativeConnectionInfo::Postgres(url)) => url.pg_bouncer(), + ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(url))) => url.pg_bouncer(), _ => false, } } diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 805ba13a602..ad53908383f 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -5,8 +5,9 @@ pub(crate) mod column_type; mod conversion; mod error; mod explain; +mod websocket; -pub(crate) use crate::connector::postgres::url::PostgresUrl; +pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{ timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, @@ -37,12 +38,15 @@ use std::{ time::Duration, }; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; +use websocket::connect_via_websocket; /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] pub use tokio_postgres; +use super::PostgresWebSocketUrl; + struct PostgresClient(Client); impl Debug for PostgresClient { @@ -160,7 +164,7 @@ impl SslParams { } } -impl PostgresUrl { +impl PostgresNativeUrl { pub(crate) fn cache(&self) -> StatementCache { if self.query_params.pg_bouncer { StatementCache::new(0) @@ -228,7 +232,7 @@ impl PostgresUrl { impl PostgreSql { /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { + pub async fn new(url: PostgresNativeUrl) -> crate::Result { let config = url.to_config(); let mut tls_builder = TlsConnector::builder(); @@ -292,6 +296,21 @@ impl PostgreSql { }) } + /// Create a new websocket connection to managed database + pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let client = connect_via_websocket(url).await?; + + Ok(Self { + client: PostgresClient(client), + socket_timeout: None, + pg_bouncer: false, + statement_cache: Mutex::new(StatementCache::new(0)), + is_healthy: AtomicBool::new(true), + is_cockroachdb: false, + is_materialize: false, + }) + } + /// The underlying tokio_postgres::Client. Only available with the /// `expose-drivers` Cargo feature. This is a lower level API when you need /// to get into database specific features. @@ -922,7 +941,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -974,7 +993,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", schema_name); url.query_pairs_mut().append_pair("pbbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1025,7 +1044,7 @@ mod tests { let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1076,7 +1095,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); let client = PostgreSql::new(pg_url).await.unwrap(); @@ -1127,7 +1146,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); let client = PostgreSql::new(pg_url).await.unwrap(); diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs new file mode 100644 index 00000000000..7899e9a22ec --- /dev/null +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -0,0 +1,87 @@ +use std::str::FromStr; + +use async_tungstenite::{ + tokio::connect_async, + tungstenite::{ + self, + client::IntoClientRequest, + http::{HeaderMap, HeaderValue, StatusCode}, + Error as TungsteniteError, + }, +}; +use futures::FutureExt; +use postgres_native_tls::TlsConnector; +use tokio_postgres::{Client, Config}; +use ws_stream_tungstenite::WsStream; + +use crate::{ + connector::PostgresWebSocketUrl, + error::{self, Error, ErrorKind, Name, NativeErrorKind}, +}; + +const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; +const HOST_HEADER: &str = "Prisma-Db-Host"; + +pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let (ws_stream, response) = connect_async(url).await?; + + let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; + let db_host = require_header_value(response.headers(), HOST_HEADER)?; + + let config = Config::from_str(connection_params)?; + let ws_byte_stream = WsStream::new(ws_stream); + + let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host); + let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; + tokio::spawn(connection.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL WebSocket connection: {:?}", e); + } + })); + Ok(client) +} + +fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> { + let Some(header) = headers.get(name) else { + let message = format!("Missing response header {name}"); + let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); + return Err(error); + }; + + let value = header.to_str().map_err(|inner| { + Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() + })?; + + Ok(value) +} + +impl IntoClientRequest for PostgresWebSocketUrl { + fn into_client_request(self) -> tungstenite::Result { + let mut request = self.url.to_string().into_client_request()?; + let bearer = format!("Bearer {}", self.api_key()); + let auth_header = HeaderValue::from_str(&bearer)?; + request.headers_mut().insert("Authorization", auth_header); + Ok(request) + } +} + +impl From for error::Error { + fn from(value: TungsteniteError) -> Self { + let builder = match value { + TungsteniteError::Tls(tls_error) => Error::builder(ErrorKind::Native(NativeErrorKind::TlsError { + message: tls_error.to_string(), + })), + + TungsteniteError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => { + Error::builder(ErrorKind::DatabaseAccessDenied { + db_name: Name::Unavailable, + }) + } + + _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))), + }; + + builder.build() + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 844da48c8d6..703aff33ebb 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -63,16 +63,74 @@ impl PostgresFlavour { } } +#[derive(Debug, Clone)] +pub enum PostgresUrl { + Native(Box), + WebSocket(PostgresWebSocketUrl), +} + +impl PostgresUrl { + pub fn new_native(url: Url) -> Result { + Ok(Self::Native(Box::new(PostgresNativeUrl::new(url)?))) + } + + pub fn new_websocket(url: Url, api_key: String) -> Result { + Ok(Self::WebSocket(PostgresWebSocketUrl::new(url, api_key))) + } + + pub fn dbname(&self) -> &str { + match self { + Self::Native(url) => url.dbname(), + Self::WebSocket(_) => "postgres", + } + } + + pub fn host(&self) -> &str { + match self { + Self::Native(native_url) => native_url.host(), + Self::WebSocket(ws_url) => ws_url.host(), + } + } + + pub fn port(&self) -> u16 { + match self { + Self::Native(native_url) => native_url.port(), + Self::WebSocket(ws_url) => ws_url.port(), + } + } + + pub fn username(&self) -> Cow<'_, str> { + match self { + Self::Native(native_url) => native_url.username(), + Self::WebSocket(_) => Cow::Borrowed(""), + } + } + + pub fn schema(&self) -> &str { + match self { + Self::Native(native_url) => native_url.schema(), + Self::WebSocket(_) => "public", + } + } + + pub fn socket_timeout(&self) -> Option { + match self { + Self::Native(native_url) => native_url.socket_timeout(), + Self::WebSocket(_) => None, + } + } +} + /// Wraps a connection url and exposes the parsing logic used by Quaint, /// including default values. #[derive(Debug, Clone)] -pub struct PostgresUrl { +pub struct PostgresNativeUrl { pub(crate) url: Url, pub(crate) query_params: PostgresUrlQueryParams, pub(crate) flavour: PostgresFlavour, } -impl PostgresUrl { +impl PostgresNativeUrl { /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection /// parameters. pub fn new(url: Url) -> Result { @@ -431,6 +489,30 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) ssl_mode: SslMode, } +#[derive(Debug, Clone)] +pub struct PostgresWebSocketUrl { + pub(crate) url: Url, + pub(crate) api_key: String, +} + +impl PostgresWebSocketUrl { + pub fn new(url: Url, api_key: String) -> Self { + Self { url, api_key } + } + + pub fn api_key(&self) -> &str { + &self.api_key + } + + pub fn host(&self) -> &str { + self.url.host_str().unwrap_or("localhost") + } + + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(80) + } +} + #[cfg(test)] mod tests { use super::*; @@ -442,14 +524,15 @@ mod tests { #[test] fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/psql.sock", url.host()); } #[test] fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/postgresql", url.host()); } @@ -457,63 +540,69 @@ mod tests { #[test] fn should_allow_changing_of_cache_size() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()) + .unwrap(); assert_eq!(420, url.cache().capacity()); } #[test] fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(100, url.cache().capacity()); } #[test] fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()) + .unwrap(); assert_eq!(Some("test"), url.application_name()); } #[test] fn should_have_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Require, url.channel_binding()); } #[test] fn should_have_default_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); } #[test] fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); assert_eq!(0, url.cache().capacity()); } #[test] fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("localhost", url.host()); } #[test] fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); assert_eq!("2001:db8:1234::ffff", url.host()); } #[test] fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); assert_eq!("--cluster=my_cluster", url.options().unwrap()); } @@ -600,7 +689,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", "hello"); url.query_pairs_mut().append_pair("pgbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -616,7 +705,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -630,7 +719,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); @@ -644,7 +733,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "HeLLo"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 381f0c82414..2026679cd48 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -307,8 +307,10 @@ impl Builder { /// - Defaults to `PostgresFlavour::Unknown`. #[cfg(feature = "postgresql-native")] pub fn set_postgres_flavour(&mut self, flavour: crate::connector::PostgresFlavour) { - use crate::connector::NativeConnectionInfo; - if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(ref mut url)) = self.connection_info { + use crate::connector::{NativeConnectionInfo, PostgresUrl}; + if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(ref mut url))) = + self.connection_info + { url.set_flavour(flavour); } @@ -415,7 +417,7 @@ impl Quaint { } #[cfg(feature = "postgresql")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = crate::connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = crate::connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let connection_limit = url.connection_limit(); let pool_timeout = url.pool_timeout(); let max_connection_lifetime = url.max_connection_lifetime(); diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 7533dffcfcc..0a2fc0adbd0 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -3,7 +3,7 @@ use crate::connector::MssqlUrl; #[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql-native")] -use crate::connector::PostgresUrl; +use crate::connector::PostgresNativeUrl; use crate::{ ast, connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, @@ -85,7 +85,7 @@ pub enum QuaintManager { Mysql { url: MysqlUrl }, #[cfg(feature = "postgresql")] - Postgres { url: PostgresUrl }, + Postgres { url: PostgresNativeUrl }, #[cfg(feature = "sqlite")] Sqlite { url: String, db_name: String }, diff --git a/quaint/src/single.rs b/quaint/src/single.rs index cbf460c4150..fd018925852 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -148,7 +148,7 @@ impl Quaint { } #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs index 0d876d6b4dc..c901b9ef388 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn cockroach_setup(url: String, prisma_schema: &str) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Cockroach); let db_name = quaint_url.dbname(); diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs index 6bbba8564ca..536f51eb483 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn postgres_setup(url: String, prisma_schema: &str, db_schemas: &[&str]) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Postgres); let (db_name, schema) = (quaint_url.dbname(), quaint_url.schema()); diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index dca3b89f6f2..ac704459b5a 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -5,12 +5,13 @@ use self::connection::*; use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; -use quaint::{connector::PostgresUrl, Value}; +use once_cell::sync::Lazy; +use quaint::{connector::PostgresUrl, prelude::NativeConnectionInfo, Value}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; use sql_schema_describer::SqlSchema; -use std::{borrow::Cow, collections::HashMap, future, time}; +use std::{borrow::Cow, collections::HashMap, future, str::FromStr, time}; use url::Url; use user_facing_errors::{ common::{DatabaseAccessDenied, DatabaseDoesNotExist}, @@ -28,9 +29,62 @@ SET enable_experimental_alter_column_type_general = true; type State = super::State, Connection)>; +#[derive(Debug, Clone)] +struct MigratePostgresUrl(PostgresUrl); + +static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { + std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") + .map(Cow::Owned) + .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma-data.net/websocket")) +}); + +impl MigratePostgresUrl { + const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; + const API_KEY_PARAM: &'static str = "api_key"; + + fn new(url: Url) -> ConnectorResult { + let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { + let ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; + let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { + return Err(ConnectorError::url_parse_error( + "Required `api_key` query string parameter was not provided in a connection URL", + )); + }; + PostgresUrl::new_websocket(ws_url, api_key.into_owned()) + } else { + PostgresUrl::new_native(url) + } + .map_err(ConnectorError::url_parse_error)?; + + Ok(Self(postgres_url)) + } + + pub(super) fn host(&self) -> &str { + self.0.host() + } + + pub(super) fn port(&self) -> u16 { + self.0.port() + } + + pub(super) fn dbname(&self) -> &str { + self.0.dbname() + } + + pub(super) fn schema(&self) -> &str { + self.0.schema() + } +} + +impl From for NativeConnectionInfo { + fn from(value: MigratePostgresUrl) -> Self { + NativeConnectionInfo::Postgres(value.0) + } +} + struct Params { connector_params: ConnectorParams, - url: PostgresUrl, + url: MigratePostgresUrl, } /// The specific provider that was requested by the user. @@ -378,7 +432,7 @@ impl SqlFlavour for PostgresFlavour { .map_err(ConnectorError::url_parse_error)?; disable_postgres_statement_cache(&mut url)?; let connection_string = url.to_string(); - let url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; + let url = MigratePostgresUrl::new(url)?; connector_params.connection_string = connection_string; let params = Params { connector_params, url }; self.state.set_params(params); @@ -510,7 +564,11 @@ impl SqlFlavour for PostgresFlavour { /// TL;DR, /// 1. pg >= 13 -> it works. /// 2. pg < 13 -> syntax error on WITH (FORCE), and then fail with db in use if pgbouncer is used. -async fn drop_db_try_force(conn: &mut Connection, url: &PostgresUrl, database_name: &str) -> ConnectorResult<()> { +async fn drop_db_try_force( + conn: &mut Connection, + url: &MigratePostgresUrl, + database_name: &str, +) -> ConnectorResult<()> { let drop_database = format!("DROP DATABASE IF EXISTS \"{database_name}\" WITH (FORCE)"); if let Err(err) = conn.raw_cmd(&drop_database, url).await { if let Some(msg) = err.message() { @@ -537,7 +595,7 @@ fn strip_schema_param_from_url(url: &mut Url) { /// Try to connect as an admin to a postgres database. We try to pick a default database from which /// we can create another database. -async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, PostgresUrl)> { +async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, MigratePostgresUrl)> { // "postgres" is the default database on most postgres installations, // "template1" is guaranteed to exist, and "defaultdb" is the only working // option on DigitalOcean managed postgres databases. @@ -547,7 +605,7 @@ async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection for database_name in CANDIDATE_DEFAULT_DATABASES { url.set_path(&format!("/{database_name}")); - let postgres_url = PostgresUrl::new(url.clone()).unwrap(); + let postgres_url = MigratePostgresUrl(PostgresUrl::new_native(url.clone()).unwrap()); match Connection::new(url.clone()).await { // If the database does not exist, try the next one. Err(err) => match &err.error_code() { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 3ca9b673b0a..cb31b4394d7 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -5,7 +5,7 @@ use indoc::indoc; use psl::PreviewFeature; use quaint::{ connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, - prelude::{ConnectionInfo, NativeConnectionInfo, Queryable}, + prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; use sql_schema_describer::{postgres::PostgresSchemaExt, SqlSchema}; @@ -13,19 +13,19 @@ use user_facing_errors::{schema_engine::ApplyMigrationError, schema_engine::Data use crate::sql_renderer::IteratorJoin; +use super::MigratePostgresUrl; + pub(super) struct Connection(connector::PostgreSql); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { - let url = PostgresUrl::new(url).map_err(|err| { - ConnectorError::user_facing(user_facing_errors::common::InvalidConnectionString { - details: err.to_string(), - }) - })?; + let url = MigratePostgresUrl::new(url)?; - let quaint = connector::PostgreSql::new(url.clone()) - .await - .map_err(quaint_err(&url))?; + let quaint = match url.0 { + PostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.as_ref().clone()).await, + PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, + } + .map_err(quaint_err(&url))?; let version = quaint.version().await.map_err(quaint_err(&url))?; @@ -116,12 +116,12 @@ impl Connection { Ok(schema) } - pub(super) async fn raw_cmd(&mut self, sql: &str, url: &PostgresUrl) -> ConnectorResult<()> { + pub(super) async fn raw_cmd(&mut self, sql: &str, url: &MigratePostgresUrl) -> ConnectorResult<()> { tracing::debug!(query_type = "raw_cmd", sql); self.0.raw_cmd(sql).await.map_err(quaint_err(url)) } - pub(super) async fn version(&mut self, url: &PostgresUrl) -> ConnectorResult> { + pub(super) async fn version(&mut self, url: &MigratePostgresUrl) -> ConnectorResult> { tracing::debug!(query_type = "version"); self.0.version().await.map_err(quaint_err(url)) } @@ -129,7 +129,7 @@ impl Connection { pub(super) async fn query( &mut self, query: quaint::ast::Query<'_>, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { use quaint::visitor::Visitor; let (sql, params) = quaint::visitor::Postgres::build(query).unwrap(); @@ -140,7 +140,7 @@ impl Connection { &self, sql: &str, params: &[quaint::prelude::Value<'_>], - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "query_raw", sql, ?params); self.0.query_raw(sql, params).await.map_err(quaint_err(url)) @@ -149,7 +149,7 @@ impl Connection { pub(super) async fn describe_query( &self, sql: &str, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "describe_query", sql); self.0.describe_query(sql).await.map_err(quaint_err(url)) @@ -237,11 +237,6 @@ fn normalize_sql_schema(schema: &mut SqlSchema, preview_features: BitFlags impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { - |err| { - crate::flavour::quaint_error_to_connector_error( - err, - &ConnectionInfo::Native(NativeConnectionInfo::Postgres(url.clone())), - ) - } +fn quaint_err(url: &MigratePostgresUrl) -> impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { + |err| crate::flavour::quaint_error_to_connector_error(err, &ConnectionInfo::Native(url.clone().into())) } diff --git a/schema-engine/core/src/lib.rs b/schema-engine/core/src/lib.rs index b367ab0bfff..3c0a2bf6d6a 100644 --- a/schema-engine/core/src/lib.rs +++ b/schema-engine/core/src/lib.rs @@ -41,7 +41,7 @@ fn connector_for_connection_string( preview_features: BitFlags, ) -> CoreResult> { match connection_string.split(':').next() { - Some("postgres") | Some("postgresql") => { + Some("postgres") | Some("postgresql") | Some("prisma+postgres") => { let params = ConnectorParams { connection_string, preview_features,