diff --git a/Cargo.toml b/Cargo.toml index 44c247a5..3c808fb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,6 @@ databend-client = { path = "core", version = "0.16.3" } databend-driver = { path = "driver", version = "0.16.3" } databend-driver-macros = { path = "macros", version = "0.16.3" } databend-sql = { path = "sql", version = "0.16.3" } + +[patch.crates-io] +pyo3-asyncio = { git = "https://github.com/everpcpc/pyo3-asyncio", rev = "42af887" } diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 61cfcf39..ac5f2cf8 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -8,15 +8,16 @@ license = { workspace = true } authors = { workspace = true } [lib] -crate-type = ["cdylib", "rlib"] +crate-type = ["cdylib"] name = "databend_driver" doc = false [dependencies] +chrono = { version = "0.4.35", default-features = false } ctor = "0.2.5" databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } once_cell = "1.18" -pyo3 = { version = "0.20", features = ["abi3-py37"] } -pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } +pyo3 = { version = "0.21", features = ["abi3-py37", "chrono"] } +pyo3-asyncio = { version = "0.21", features = ["tokio-runtime"] } tokio = "1.34" tokio-stream = "0.1" diff --git a/bindings/python/src/asyncio.rs b/bindings/python/src/asyncio.rs index 28725d21..d95c034e 100644 --- a/bindings/python/src/asyncio.rs +++ b/bindings/python/src/asyncio.rs @@ -30,7 +30,7 @@ impl AsyncDatabendClient { Ok(Self(client)) } - pub fn get_conn<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn get_conn<'p>(&'p self, py: Python<'p>) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let conn = this.get_conn().await.map_err(DriverError::new)?; @@ -44,7 +44,7 @@ pub struct AsyncDatabendConnection(Box); #[pymethods] impl AsyncDatabendConnection { - pub fn info<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn info<'p>(&'p self, py: Python<'p>) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let info = this.info().await; @@ -52,7 +52,7 @@ impl AsyncDatabendConnection { }) } - pub fn version<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn version<'p>(&'p self, py: Python<'p>) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let version = this.version().await.map_err(DriverError::new)?; @@ -60,7 +60,7 @@ impl AsyncDatabendConnection { }) } - pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> { + pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let res = this.exec(&sql).await.map_err(DriverError::new)?; @@ -68,7 +68,7 @@ impl AsyncDatabendConnection { }) } - pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> { + pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let row = this.query_row(&sql).await.map_err(DriverError::new)?; @@ -76,7 +76,7 @@ impl AsyncDatabendConnection { }) } - pub fn query_all<'p>(&self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> { + pub fn query_all<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let rows: Vec = this @@ -90,7 +90,7 @@ impl AsyncDatabendConnection { }) } - pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> { + pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let streamer = this.query_iter(&sql).await.map_err(DriverError::new)?; @@ -99,11 +99,11 @@ impl AsyncDatabendConnection { } pub fn stream_load<'p>( - &self, + &'p self, py: Python<'p>, sql: String, data: Vec>, - ) -> PyResult<&'p PyAny> { + ) -> PyResult> { let this = self.0.clone(); future_into_py(py, async move { let data = data diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index d8f3b8e3..20111b94 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -24,7 +24,7 @@ use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection}; use crate::types::{ConnectionInfo, Field, Row, RowIterator, Schema, ServerStats}; #[pymodule] -fn _databend_driver(_py: Python, m: &PyModule) -> PyResult<()> { +fn _databend_driver(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 2884f4a4..dd71ad18 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -14,11 +14,13 @@ use std::sync::Arc; +use chrono::{NaiveDate, NaiveDateTime}; use once_cell::sync::Lazy; use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration}; +use pyo3::intern; +use pyo3::prelude::*; use pyo3::sync::GILOnceCell; -use pyo3::types::{PyDict, PyList, PyTuple, PyType}; -use pyo3::{intern, prelude::*}; +use pyo3::types::{PyBytes, PyDict, PyList, PyTuple, PyType}; use pyo3_asyncio::tokio::future_into_py; use tokio::sync::Mutex; use tokio_stream::StreamExt; @@ -32,14 +34,14 @@ pub static VERSION: Lazy = Lazy::new(|| { pub static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); -fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> { +fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound> { DECIMAL_CLS .get_or_try_init(py, || { - py.import(intern!(py, "decimal"))? + py.import_bound(intern!(py, "decimal"))? .getattr(intern!(py, "Decimal"))? .extract() }) - .map(|ty| ty.as_ref(py)) + .map(|ty| ty.bind(py)) } pub struct Value(databend_driver::Value); @@ -49,34 +51,37 @@ impl IntoPy for Value { match self.0 { databend_driver::Value::Null => py.None(), databend_driver::Value::EmptyArray => { - let list = PyList::empty(py); + let list = PyList::empty_bound(py); list.into_py(py) } databend_driver::Value::EmptyMap => { - let dict = PyDict::new(py); + let dict = PyDict::new_bound(py); dict.into_py(py) } databend_driver::Value::Boolean(b) => b.into_py(py), - databend_driver::Value::Binary(b) => b.into_py(py), + databend_driver::Value::Binary(b) => { + let buf = PyBytes::new_bound(py, &b); + buf.into_py(py) + } databend_driver::Value::String(s) => s.into_py(py), databend_driver::Value::Number(n) => { let v = NumberValue(n); v.into_py(py) } databend_driver::Value::Timestamp(_) => { - let s = self.0.to_string(); - s.into_py(py) + let t = NaiveDateTime::try_from(self.0).unwrap(); + t.into_py(py) } databend_driver::Value::Date(_) => { - let s = self.0.to_string(); - s.into_py(py) + let d = NaiveDate::try_from(self.0).unwrap(); + d.into_py(py) } databend_driver::Value::Array(inner) => { - let list = PyList::new(py, inner.into_iter().map(|v| Value(v).into_py(py))); + let list = PyList::new_bound(py, inner.into_iter().map(|v| Value(v).into_py(py))); list.into_py(py) } databend_driver::Value::Map(inner) => { - let dict = PyDict::new(py); + let dict = PyDict::new_bound(py); for (k, v) in inner { dict.set_item(Value(k).into_py(py), Value(v).into_py(py)) .unwrap(); @@ -84,7 +89,7 @@ impl IntoPy for Value { dict.into_py(py) } databend_driver::Value::Tuple(inner) => { - let tuple = PyTuple::new(py, inner.into_iter().map(|v| Value(v).into_py(py))); + let tuple = PyTuple::new_bound(py, inner.into_iter().map(|v| Value(v).into_py(py))); tuple.into_py(py) } databend_driver::Value::Bitmap(s) => s.into_py(py), @@ -138,12 +143,9 @@ impl Row { #[pymethods] impl Row { - pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult { - let res = PyTuple::new( - py, - self.0.values().iter().map(|v| Value(v.clone()).into_py(py)), // FIXME: do not clone - ); - Ok(res.into_py(py)) + pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult> { + let vals = self.0.values().iter().map(|v| Value(v.clone()).into_py(py)); + Ok(PyTuple::new_bound(py, vals)) } } @@ -158,7 +160,7 @@ impl RowIterator { #[pymethods] impl RowIterator { - fn schema<'p>(&self, py: Python) -> PyResult { + pub fn schema(&self, py: Python) -> PyResult { let streamer = self.0.clone(); let ret = wait_for_future(py, async move { let schema = streamer.lock().await.schema(); @@ -170,9 +172,9 @@ impl RowIterator { fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } - fn __next__(&self, py: Python) -> PyResult> { + fn __next__(&self, py: Python) -> PyResult { let streamer = self.0.clone(); - let ret = wait_for_future(py, async move { + wait_for_future(py, async move { match streamer.lock().await.next().await { Some(val) => match val { Err(e) => Err(PyException::new_err(format!("{}", e))), @@ -180,16 +182,15 @@ impl RowIterator { }, None => Err(PyStopIteration::new_err("The iterator is exhausted")), } - }); - ret.map(Some) + }) } fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } - fn __anext__(&self, py: Python<'_>) -> PyResult> { + fn __anext__<'p>(&'p self, py: Python<'p>) -> PyResult> { let streamer = self.0.clone(); - let future = future_into_py(py, async move { + future_into_py(py, async move { match streamer.lock().await.next().await { Some(val) => match val { Err(e) => Err(PyException::new_err(format!("{}", e))), @@ -197,8 +198,7 @@ impl RowIterator { }, None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), } - }); - Ok(Some(future?.into())) + }) } } @@ -207,13 +207,13 @@ pub struct Schema(databend_driver::SchemaRef); #[pymethods] impl Schema { - pub fn fields<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn fields<'p>(&'p self, py: Python<'p>) -> PyResult> { let fields = self .0 .fields() .into_iter() .map(|f| Field(f.clone()).into_py(py)); - Ok(PyList::new(py, fields)) + Ok(PyList::new_bound(py, fields)) } } diff --git a/bindings/python/tests/asyncio/steps/binding.py b/bindings/python/tests/asyncio/steps/binding.py index 62cb1ddc..6a45ab0b 100644 --- a/bindings/python/tests/asyncio/steps/binding.py +++ b/bindings/python/tests/asyncio/steps/binding.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from datetime import datetime, date from decimal import Decimal from behave import given, when, then @@ -70,13 +71,13 @@ async def _(context): # Map row = await context.conn.query_row("select {'xx':to_date('2020-01-01')}") - assert row.values() == ({"xx": "2020-01-01"},) + assert row.values() == ({"xx": date(2020, 1, 1)},) # Tuple row = await context.conn.query_row( "select (10, '20', to_datetime('2024-04-16 12:34:56.789'))" ) - assert row.values() == ((10, "20", "2024-04-16 12:34:56.789"),) + assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789)),) @then("Select numbers should iterate all rows") @@ -106,9 +107,9 @@ async def _(context): async for row in rows: ret.append(row.values()) expected = [ - (-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"), - (-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"), - (-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"), + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), ] assert ret == expected @@ -130,8 +131,8 @@ async def _(context): async for row in rows: ret.append(row.values()) expected = [ - (-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"), - (-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"), - (-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"), + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), ] assert ret == expected diff --git a/bindings/python/tests/blocking/steps/binding.py b/bindings/python/tests/blocking/steps/binding.py index 61ee9063..040d29b6 100644 --- a/bindings/python/tests/blocking/steps/binding.py +++ b/bindings/python/tests/blocking/steps/binding.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from datetime import datetime, date from decimal import Decimal from behave import given, when, then @@ -65,13 +66,13 @@ async def _(context): # Map row = context.conn.query_row("select {'xx':to_date('2020-01-01')}") - assert row.values() == ({"xx": "2020-01-01"},) + assert row.values() == ({"xx": date(2020, 1, 1)},) # Tuple row = context.conn.query_row( "select (10, '20', to_datetime('2024-04-16 12:34:56.789'))" ) - assert row.values() == ((10, "20", "2024-04-16 12:34:56.789"),) + assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789)),) @then("Select numbers should iterate all rows") @@ -99,9 +100,9 @@ def _(context): for row in rows: ret.append(row.values()) expected = [ - (-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"), - (-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"), - (-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"), + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), ] assert ret == expected @@ -122,8 +123,8 @@ def _(context): for row in rows: ret.append(row.values()) expected = [ - (-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"), - (-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"), - (-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"), + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), ] assert ret == expected