diff --git a/diesel/src/mysql/connection/raw.rs b/diesel/src/mysql/connection/raw.rs index 27efc5ddfe75..be771cef27cf 100644 --- a/diesel/src/mysql/connection/raw.rs +++ b/diesel/src/mysql/connection/raw.rs @@ -45,6 +45,7 @@ impl RawConnection { let password = connection_options.password(); let database = connection_options.database(); let port = connection_options.port(); + let unix_socket = connection_options.unix_socket(); unsafe { // Make sure you don't use the fake one! @@ -59,7 +60,9 @@ impl RawConnection { .map(CStr::as_ptr) .unwrap_or_else(|| ptr::null_mut()), u32::from(port.unwrap_or(0)), - ptr::null_mut(), + unix_socket + .map(CStr::as_ptr) + .unwrap_or_else(|| ptr::null_mut()), 0, ) }; diff --git a/diesel/src/mysql/connection/url.rs b/diesel/src/mysql/connection/url.rs index 8dbbf76a94b2..5fc84622f056 100644 --- a/diesel/src/mysql/connection/url.rs +++ b/diesel/src/mysql/connection/url.rs @@ -3,6 +3,7 @@ extern crate url; use self::percent_encoding::percent_decode; use self::url::{Host, Url}; +use std::collections::HashMap; use std::ffi::{CStr, CString}; use crate::result::{ConnectionError, ConnectionResult}; @@ -13,6 +14,7 @@ pub struct ConnectionOptions { password: Option, database: Option, port: Option, + unix_socket: Option, } impl ConnectionOptions { @@ -30,8 +32,19 @@ impl ConnectionOptions { return Err(connection_url_error()); } + let query_pairs = url.query_pairs().into_owned().collect::>(); + if query_pairs.get("database").is_some() { + return Err(connection_url_error()); + } + + let unix_socket = match query_pairs.get("unix_socket") { + Some(v) => Some(CString::new(v.as_bytes())?), + _ => None, + }; + let host = match url.host() { Some(Host::Ipv6(host)) => Some(CString::new(host.to_string())?), + Some(host) if host.to_string() == "localhost" && unix_socket != None => None, Some(host) => Some(CString::new(host.to_string())?), None => None, }; @@ -40,6 +53,7 @@ impl ConnectionOptions { Some(password) => Some(decode_into_cstring(password)?), None => None, }; + let database = match url.path_segments().and_then(|mut iter| iter.next()) { Some("") | None => None, Some(segment) => Some(CString::new(segment.as_bytes())?), @@ -51,6 +65,7 @@ impl ConnectionOptions { password: password, database: database, port: url.port(), + unix_socket: unix_socket, }) } @@ -73,6 +88,10 @@ impl ConnectionOptions { pub fn port(&self) -> Option { self.port } + + pub fn unix_socket(&self) -> Option<&CStr> { + self.unix_socket.as_ref().map(|x| &**x) + } } fn decode_into_cstring(s: &str) -> ConnectionResult { @@ -84,7 +103,7 @@ fn decode_into_cstring(s: &str) -> ConnectionResult { fn connection_url_error() -> ConnectionError { let msg = "MySQL connection URLs must be in the form \ - `mysql://[[user]:[password]@]host[:port][/database]`"; + `mysql://[[user]:[password]@]host[:port][/database][?unix_socket=socket-path]`"; ConnectionError::InvalidConnectionUrl(msg.into()) } @@ -94,6 +113,7 @@ fn urls_with_schemes_other_than_mysql_are_errors() { assert!(ConnectionOptions::parse("http://localhost").is_err()); assert!(ConnectionOptions::parse("file:///tmp/mysql.sock").is_err()); assert!(ConnectionOptions::parse("socket:///tmp/mysql.sock").is_err()); + assert!(ConnectionOptions::parse("mysql://localhost?database=somedb").is_err()); assert!(ConnectionOptions::parse("mysql://localhost").is_ok()); } @@ -186,3 +206,24 @@ fn ipv6_host_not_wrapped_in_brackets() { .host() ); } + +#[test] +fn unix_socket_tests() { + let unix_socket = "/var/run/mysqld.sock"; + let username = "foo"; + let password = "bar"; + let db_url = format!( + "mysql://{}:{}@localhost?unix_socket={}", + username, password, unix_socket + ); + let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap(); + let cstring = |s| CString::new(s).unwrap(); + assert_eq!(None, conn_opts.host); + assert_eq!(None, conn_opts.port); + assert_eq!(cstring(username), conn_opts.user); + assert_eq!(cstring(password), conn_opts.password.unwrap()); + assert_eq!( + CString::new(unix_socket).unwrap(), + conn_opts.unix_socket.unwrap() + ); +}