Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc committed Apr 17, 2024
1 parent 9e8de91 commit 66632de
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 61 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ databend-client = { path = "core", version = "0.16.3" }
databend-driver = { path = "driver", version = "0.16.3" }
databend-driver-macros = { path = "macros", version = "0.16.3" }
databend-sql = { path = "sql", version = "0.16.3" }

[patch.crates-io]
pyo3-asyncio = { git = "https://github.com/everpcpc/pyo3-asyncio", rev = "42af887" }
7 changes: 4 additions & 3 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ license = { workspace = true }
authors = { workspace = true }

[lib]
crate-type = ["cdylib", "rlib"]
crate-type = ["cdylib"]
name = "databend_driver"
doc = false

[dependencies]
chrono = { version = "0.4.35", default-features = false }
ctor = "0.2.5"
databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
once_cell = "1.18"
pyo3 = { version = "0.20", features = ["abi3-py37"] }
pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] }
pyo3 = { version = "0.21", features = ["abi3-py37", "chrono"] }
pyo3-asyncio = { version = "0.21", features = ["tokio-runtime"] }
tokio = "1.34"
tokio-stream = "0.1"
18 changes: 9 additions & 9 deletions bindings/python/src/asyncio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl AsyncDatabendClient {
Ok(Self(client))
}

pub fn get_conn<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn get_conn<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let conn = this.get_conn().await.map_err(DriverError::new)?;
Expand All @@ -44,39 +44,39 @@ pub struct AsyncDatabendConnection(Box<dyn databend_driver::Connection>);

#[pymethods]
impl AsyncDatabendConnection {
pub fn info<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn info<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let info = this.info().await;
Ok(ConnectionInfo::new(info))
})
}

pub fn version<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn version<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let version = this.version().await.map_err(DriverError::new)?;
Ok(version)
})
}

pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let res = this.exec(&sql).await.map_err(DriverError::new)?;
Ok(res)
})
}

pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let row = this.query_row(&sql).await.map_err(DriverError::new)?;
Ok(row.map(Row::new))
})
}

pub fn query_all<'p>(&self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
pub fn query_all<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let rows: Vec<Row> = this
Expand All @@ -90,7 +90,7 @@ impl AsyncDatabendConnection {
})
}

pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let streamer = this.query_iter(&sql).await.map_err(DriverError::new)?;
Expand All @@ -99,11 +99,11 @@ impl AsyncDatabendConnection {
}

pub fn stream_load<'p>(
&self,
&'p self,
py: Python<'p>,
sql: String,
data: Vec<Vec<String>>,
) -> PyResult<&'p PyAny> {
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let data = data
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection};
use crate::types::{ConnectionInfo, Field, Row, RowIterator, Schema, ServerStats};

#[pymodule]
fn _databend_driver(_py: Python, m: &PyModule) -> PyResult<()> {
fn _databend_driver(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<AsyncDatabendClient>()?;
m.add_class::<AsyncDatabendConnection>()?;
m.add_class::<BlockingDatabendClient>()?;
Expand Down
64 changes: 32 additions & 32 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

use std::sync::Arc;

use chrono::{NaiveDate, NaiveDateTime};
use once_cell::sync::Lazy;
use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::{PyDict, PyList, PyTuple, PyType};
use pyo3::{intern, prelude::*};
use pyo3::types::{PyBytes, PyDict, PyList, PyTuple, PyType};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
Expand All @@ -32,14 +34,14 @@ pub static VERSION: Lazy<String> = Lazy::new(|| {

pub static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<PyType>> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
py.import_bound(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
.map(|ty| ty.as_ref(py))
.map(|ty| ty.bind(py))
}

pub struct Value(databend_driver::Value);
Expand All @@ -49,42 +51,45 @@ impl IntoPy<PyObject> for Value {
match self.0 {
databend_driver::Value::Null => py.None(),
databend_driver::Value::EmptyArray => {
let list = PyList::empty(py);
let list = PyList::empty_bound(py);
list.into_py(py)
}
databend_driver::Value::EmptyMap => {
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);
dict.into_py(py)
}
databend_driver::Value::Boolean(b) => b.into_py(py),
databend_driver::Value::Binary(b) => b.into_py(py),
databend_driver::Value::Binary(b) => {
let buf = PyBytes::new_bound(py, &b);
buf.into_py(py)
}
databend_driver::Value::String(s) => s.into_py(py),
databend_driver::Value::Number(n) => {
let v = NumberValue(n);
v.into_py(py)
}
databend_driver::Value::Timestamp(_) => {
let s = self.0.to_string();
s.into_py(py)
let t = NaiveDateTime::try_from(self.0).unwrap();
t.into_py(py)
}
databend_driver::Value::Date(_) => {
let s = self.0.to_string();
s.into_py(py)
let d = NaiveDate::try_from(self.0).unwrap();
d.into_py(py)
}
databend_driver::Value::Array(inner) => {
let list = PyList::new(py, inner.into_iter().map(|v| Value(v).into_py(py)));
let list = PyList::new_bound(py, inner.into_iter().map(|v| Value(v).into_py(py)));
list.into_py(py)
}
databend_driver::Value::Map(inner) => {
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);
for (k, v) in inner {
dict.set_item(Value(k).into_py(py), Value(v).into_py(py))
.unwrap();
}
dict.into_py(py)
}
databend_driver::Value::Tuple(inner) => {
let tuple = PyTuple::new(py, inner.into_iter().map(|v| Value(v).into_py(py)));
let tuple = PyTuple::new_bound(py, inner.into_iter().map(|v| Value(v).into_py(py)));
tuple.into_py(py)
}
databend_driver::Value::Bitmap(s) => s.into_py(py),
Expand Down Expand Up @@ -138,12 +143,9 @@ impl Row {

#[pymethods]
impl Row {
pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult<PyObject> {
let res = PyTuple::new(
py,
self.0.values().iter().map(|v| Value(v.clone()).into_py(py)), // FIXME: do not clone
);
Ok(res.into_py(py))
pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyTuple>> {
let vals = self.0.values().iter().map(|v| Value(v.clone()).into_py(py));
Ok(PyTuple::new_bound(py, vals))
}
}

Expand All @@ -158,7 +160,7 @@ impl RowIterator {

#[pymethods]
impl RowIterator {
fn schema<'p>(&self, py: Python) -> PyResult<Schema> {
pub fn schema(&self, py: Python) -> PyResult<Schema> {
let streamer = self.0.clone();
let ret = wait_for_future(py, async move {
let schema = streamer.lock().await.schema();
Expand All @@ -170,35 +172,33 @@ impl RowIterator {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(&self, py: Python) -> PyResult<Option<Row>> {
fn __next__(&self, py: Python) -> PyResult<Row> {
let streamer = self.0.clone();
let ret = wait_for_future(py, async move {
wait_for_future(py, async move {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
},
None => Err(PyStopIteration::new_err("The iterator is exhausted")),
}
});
ret.map(Some)
})
}

fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __anext__(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
fn __anext__<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let streamer = self.0.clone();
let future = future_into_py(py, async move {
future_into_py(py, async move {
match streamer.lock().await.next().await {
Some(val) => match val {
Err(e) => Err(PyException::new_err(format!("{}", e))),
Ok(ret) => Ok(Row(ret)),
},
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
});
Ok(Some(future?.into()))
})
}
}

Expand All @@ -207,13 +207,13 @@ pub struct Schema(databend_driver::SchemaRef);

#[pymethods]
impl Schema {
pub fn fields<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
pub fn fields<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyList>> {
let fields = self
.0
.fields()
.into_iter()
.map(|f| Field(f.clone()).into_py(py));
Ok(PyList::new(py, fields))
Ok(PyList::new_bound(py, fields))
}
}

Expand Down
17 changes: 9 additions & 8 deletions bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from datetime import datetime, date
from decimal import Decimal

from behave import given, when, then
Expand Down Expand Up @@ -70,13 +71,13 @@ async def _(context):

# Map
row = await context.conn.query_row("select {'xx':to_date('2020-01-01')}")
assert row.values() == ({"xx": "2020-01-01"},)
assert row.values() == ({"xx": date(2020, 1, 1)},)

# Tuple
row = await context.conn.query_row(
"select (10, '20', to_datetime('2024-04-16 12:34:56.789'))"
)
assert row.values() == ((10, "20", "2024-04-16 12:34:56.789"),)
assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789)),)


@then("Select numbers should iterate all rows")
Expand Down Expand Up @@ -106,9 +107,9 @@ async def _(context):
async for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected

Expand All @@ -130,8 +131,8 @@ async def _(context):
async for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected
17 changes: 9 additions & 8 deletions bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from datetime import datetime, date
from decimal import Decimal

from behave import given, when, then
Expand Down Expand Up @@ -65,13 +66,13 @@ async def _(context):

# Map
row = context.conn.query_row("select {'xx':to_date('2020-01-01')}")
assert row.values() == ({"xx": "2020-01-01"},)
assert row.values() == ({"xx": date(2020, 1, 1)},)

# Tuple
row = context.conn.query_row(
"select (10, '20', to_datetime('2024-04-16 12:34:56.789'))"
)
assert row.values() == ((10, "20", "2024-04-16 12:34:56.789"),)
assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789)),)


@then("Select numbers should iterate all rows")
Expand Down Expand Up @@ -99,9 +100,9 @@ def _(context):
for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected

Expand All @@ -122,8 +123,8 @@ def _(context):
for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected

0 comments on commit 66632de

Please sign in to comment.