Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api-reference/sql.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# `narwhals.sql`

::: narwhals.sql
handler: python
options:
members:
- table
60 changes: 34 additions & 26 deletions docs/generating_sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,68 @@ For example, what's the SQL equivalent to:

```python exec="1" source="above" session="generating-sql"
import narwhals as nw
from narwhals.typing import IntoFrameT
from narwhals.typing import FrameT


def avg_monthly_price(df_native: IntoFrameT) -> IntoFrameT:
def avg_monthly_price(df: FrameT) -> FrameT:
return (
nw.from_native(df_native)
.group_by(nw.col("date").dt.truncate("1mo"))
df.group_by(nw.col("date").dt.truncate("1mo"))
.agg(nw.col("price").mean())
.sort("date")
.to_native()
)
```

?

There are several ways to find out.
Narwhals provides you with a `narwhals.sql` module to do just that!

## Via DuckDB
!!! info
`narwhals.sql` currently requires DuckDB to be installed.

## `narwhals.sql`

You can generate SQL directly from DuckDB.

```python exec="1" source="above" session="generating-sql" result="sql"
import duckdb

conn = duckdb.connect()
conn.sql("""CREATE TABLE prices (date DATE, price DOUBLE);""")

df = nw.from_native(conn.table("prices"))
print(avg_monthly_price(df).sql_query())
import narwhals as nw
from narwhals.sql import table

prices = table("prices", {"date": nw.Date, "price": nw.Float64})

sql_query = (
prices.group_by(nw.col("date").dt.truncate("1mo"))
.agg(nw.col("price").mean())
.sort("date")
.to_native()
.sql_query()
)
print(sql_query)
```

To make it look a bit prettier, or to then transpile it to other SQL dialects, we can pass it to [SQLGlot](https://github.com/tobymao/sqlglot):
To make it look a bit prettier, or to then transpile it to other SQL dialects, you can pass it to [SQLGlot](https://github.com/tobymao/sqlglot):

```python exec="1" source="above" session="generating-sql" result="sql"
import sqlglot

print(sqlglot.transpile(avg_monthly_price(df).sql_query(), pretty=True)[0])
print(sqlglot.transpile(sql_query, pretty=True)[0])
```

You can even pass a [different dialect](https://github.com/tobymao/sqlglot?tab=readme-ov-file#supported-dialects):

```python exec="1" source="above" session="generating-sql" result="sql"
print(sqlglot.transpile(sql_query, pretty=True, dialect="databricks")[0])
```


## Via Ibis

We can also use Ibis to generate SQL:
You can also use Ibis or SQLFrame to generate SQL:

```python exec="1" source="above" session="generating-sql" result="sql"
import ibis

t = ibis.table({"date": "date", "price": "double"}, name="prices")
print(ibis.to_sql(avg_monthly_price(t)))
df = nw.from_native(ibis.table({"date": "date", "price": "double"}, name="prices"))
print(ibis.to_sql(avg_monthly_price(df).to_native()))
```

## Via SQLFrame
Expand All @@ -66,11 +80,5 @@ session = StandaloneSession.builder.getOrCreate()
session.catalog.add_table("prices", column_mapping={"date": "date", "price": "float"})
df = nw.from_native(session.read.table("prices"))

print(avg_monthly_price(df).sql(dialect="duckdb"))
```

Or, to print the SQL code in a different dialect (say, databricks):

```python exec="1" source="above" session="generating-sql" result="sql"
print(avg_monthly_price(df).sql(dialect="databricks"))
print(avg_monthly_price(df).to_native().sql(dialect="duckdb"))
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ nav:
- api-reference/dtypes.md
- api-reference/exceptions.md
- api-reference/selectors.md
- api-reference/sql.md
- api-reference/testing.md
- api-reference/typing.md
- api-reference/utils.md
Expand Down
63 changes: 63 additions & 0 deletions narwhals/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from narwhals._duckdb.utils import DeferredTimeZone, narwhals_to_native_dtype
from narwhals.translate import from_native
from narwhals.utils import Version

if TYPE_CHECKING:
from duckdb import DuckDBPyRelation

from narwhals.dataframe import LazyFrame
from narwhals.typing import IntoSchema

try:
import duckdb # ignore-banned-import
except ImportError as _exc: # pragma: no cover
msg = (
"`narwhals.sql` requires DuckDB to be installed.\n\n"
"Hint: run `pip install -U narwhals[sql]`"
)
raise ModuleNotFoundError(msg) from _exc

CONN = duckdb.connect()
TZ = DeferredTimeZone(
CONN.sql("select value from duckdb_settings() where name = 'TimeZone'")
)
Comment on lines +14 to +26
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (hope?) that in the future we could make a narwhals-sqlglot or narwhals-substrait plugin and use that here. But for now, I think using DuckDB for this is quite nice



def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]:
"""Generate standalone LazyFrame which you can use to generate SQL.

Note that this requires DuckDB to be installed.

Parameters:
name: Table name.
schema: Table schema.

Returns:
A LazyFrame.

Examples:
>>> import narwhals as nw
>>> from narwhals.sql import table
>>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String}
>>> assets = table("assets", schema)
>>> result = assets.filter(nw.col("price") > 100)
>>> print(result.to_native().sql_query())
SELECT * FROM main.assets WHERE (price > 100)
"""
column_mapping = {
col: narwhals_to_native_dtype(dtype, Version.MAIN, TZ)
for col, dtype in schema.items()
}
dtypes = ", ".join(f"{col} {dtype}" for col, dtype in column_mapping.items())
CONN.sql(f"""
CREATE TABLE "{name}"
({dtypes});
""")
return from_native(CONN.table(name))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ€” maybe this could return a subclass of LazyFrame which has a to_sql method, with an optional dependency on sqlglot if people want to pass dialect != 'duckdb' or pretty=True



__all__ = ["table"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dask = ["dask[dataframe]>=2024.8"]
duckdb = ["duckdb>=1.1"]
ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"]
sqlframe = ["sqlframe>=3.22.0,!=3.39.3"]
sql = ["duckdb>=1.1"]

[dependency-groups]
core = [
Expand Down
27 changes: 27 additions & 0 deletions tests/sql_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import pytest

import narwhals as nw
from tests.utils import DUCKDB_VERSION


def test_sql() -> None:
pytest.importorskip("duckdb")
if DUCKDB_VERSION < (1, 3):
pytest.skip()
from narwhals.sql import table

schema = {"date": nw.Date(), "price": nw.Int64(), "symbol": nw.String()}
assets = table("assets", schema)
result = (
assets.with_columns(
returns=(nw.col("price") / nw.col("price").shift(1)).over(
"symbol", order_by="date"
)
)
.to_native()
.sql_query()
)
expected = """SELECT date, price, symbol, (price / lag(price, 1) OVER (PARTITION BY symbol ORDER BY date ASC NULLS FIRST)) AS "returns" FROM main.assets"""
assert result == expected
Loading