Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQLite client and converter #424

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:

- name: Test with pytest
env:
RECAP_URLS: '["postgresql://postgres:password@localhost:5432/testdb"]'
RECAP_URLS: '["postgresql://postgres:password@localhost:5432/testdb", "sqlite:///file:mem1?mode=memory&cache=shared&uri=true"]'
run: |
pdm run integration

Expand Down
1 change: 1 addition & 0 deletions recap/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"mysql": "recap.clients.mysql.MysqlClient",
"postgresql": "recap.clients.postgresql.PostgresqlClient",
"snowflake": "recap.clients.snowflake.SnowflakeClient",
"sqlite": "recap.clients.sqlite.SQLiteClient",
"thrift+hms": "recap.clients.hive_metastore.HiveMetastoreClient",
}

Expand Down
148 changes: 148 additions & 0 deletions recap/clients/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import annotations

from contextlib import contextmanager
from re import compile as re_compile
from typing import Any, Generator

from recap.clients.dbapi import Connection
from recap.converters.sqlite import SQLiteAffinity, SQLiteConverter
from recap.types import StructType

SQLITE3_CONNECT_ARGS = {
"database",
"timeout",
"detect_types",
"isolation_level",
"check_same_thread",
"factory",
"cached_statements",
"uri",
}


class SQLiteClient:
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.converter = SQLiteConverter()

@staticmethod
@contextmanager
def create(url: str, **url_args) -> Generator[SQLiteClient, None, None]:
import sqlite3

# Strip sqlite:/// URL prefix
url_args["database"] = url[len("sqlite:///") :]

# Only include kwargs that are valid for PsycoPG2 parse_dsn()
url_args = {k: v for k, v in url_args.items() if k in SQLITE3_CONNECT_ARGS}

with sqlite3.connect(**url_args) as client:
yield SQLiteClient(client) # type: ignore

@staticmethod
def parse(method: str, **url_args) -> tuple[str, list[Any]]:
from urllib.parse import urlunparse

match method:
case "ls":
return (url_args["url"], [])
case "schema":
table = url_args["paths"].pop(-1)
connection_url = urlunparse(
[
url_args.get("dialect") or url_args.get("scheme"),
url_args.get("netloc"),
# Include / prefix for paths
"/".join(url_args.get("paths", [])),
url_args.get("params"),
url_args.get("query"),
url_args.get("fragment"),
]
)

# urlunsplit does not double slashes if netloc is empty. But most
# URLs with empty netloc should have a double slash (e.g.
# bigquery:// or sqlite:///some/file.db). Include an extra "/"
# because the root path is not included with an empty netloc
# and join().
if not url_args.get("netloc"):
connection_url = connection_url.replace(":", ":///", 1)

return (connection_url, [table])
case _:
raise ValueError("Invalid method")

def ls(self) -> list[str]:
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_schema WHERE type='table'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could imagine also wanting to capture schema of views. Doesn't seem necessary for the first go at this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, good point. I'll open a follow-on GH issue for that. I hope it's as easy as type in ('table', 'view') 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return [row[0] for row in cursor.fetchall()]

def schema(self, table: str) -> StructType:
cursor = self.connection.cursor()

# Validate that table exists since we want to prevent SQL injections in
# the PRAGMA call
if not self._table_exists(table):
raise ValueError(f"Table '{table}' does not exist in the database.")

cursor.execute(f"PRAGMA table_info({table});")
names = [name[0].upper() for name in cursor.description]
rows = []

for row_cells in cursor.fetchall():
row = dict(zip(names, row_cells))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3 python

row = self.add_information_schema(row)
rows.append(row)

return self.converter.to_recap(rows)

def add_information_schema(self, row: dict[str, Any]) -> dict[str, Any]:
"""
SQLite does not have an INFORMATION_SCHEMA, so we need to add these
columns.

:param row: A row from the PRAGMA table_info() query.
:return: The row with the INFORMATION_SCHEMA columns added.
"""

is_not_null = row["NOTNULL"] or row["PK"]

# Set defaults.
information_schema_cols = {
"COLUMN_NAME": row["NAME"],
"IS_NULLABLE": "NO" if is_not_null else "YES",
"COLUMN_DEFAULT": row["DFLT_VALUE"],
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"CHARACTER_OCTET_LENGTH": None,
}

# Extract precision, scale, and octet length.
# This regex matches the following patterns:
# - <type>(<param1>(, <param2>)?)
# param1 can be a precision or octet length, and param2 can be a scale.
numeric_pattern = re_compile(r"(\w+)\((\d+)(?:,\s*(\d+))?\)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT helped me understand this matches strings that look like function calls. Might be helpful to point to some documentation on what TYPE can look plike

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docs.

param_match = numeric_pattern.search(row["TYPE"])

if param_match:
# Extract matched values
base_type, precision, scale = param_match.groups()
base_type = base_type.upper()
precision = int(precision)
scale = int(scale) if scale else 0

match SQLiteConverter.get_affinity(base_type):
case SQLiteAffinity.INTEGER | SQLiteAffinity.REAL | SQLiteAffinity.NUMERIC:
information_schema_cols["NUMERIC_PRECISION"] = precision
information_schema_cols["NUMERIC_SCALE"] = scale
case SQLiteAffinity.TEXT | SQLiteAffinity.BLOB:
information_schema_cols["CHARACTER_OCTET_LENGTH"] = precision

return row | information_schema_cols

def _table_exists(self, table: str) -> bool:
cursor = self.connection.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
)
return bool(cursor.fetchone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This confuses me since above I saw row[0] for row in cursor.fetchall(). If there is only one element SELECTed, is a tuple returned for that row, or just the value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content isn't as important as whether the row exists or not. I'm doing the _table_exists check to prevent SQL injection attacks since you can't %s or ? parameterize pragma calls. The SELECT is getting all rows that exactly match the table string. If at least one row exists, then we assume the table exists.

The row[0] for row in cursor.fetchall() is for listing all tables in the database. But it's not guaranteed that a user will call schema() with a table from the ls() command, so I wanted to guard against injection.

90 changes: 90 additions & 0 deletions recap/converters/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from enum import Enum
from typing import Any

from recap.converters.dbapi import DbapiConverter
from recap.types import (
BytesType,
FloatType,
IntType,
NullType,
RecapType,
StringType,
UnionType,
)

# SQLite's maximum length is 2^31-1 bytes, or 2147483647 bytes.
SQLITE_MAX_LENGTH = 2147483647


class SQLiteAffinity(Enum):
"""
SQLite uses column affinity to map non-STRICT table columns to values. See
https://www.sqlite.org/datatype3.html#type_affinity for details.
"""

INTEGER = "integer"
REAL = "real"
TEXT = "text"
BLOB = "blob"
NUMERIC = "numeric"


class SQLiteConverter(DbapiConverter):
def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
column_name = column_props["COLUMN_NAME"]
column_type = column_props["TYPE"]
octet_length = column_props["CHARACTER_OCTET_LENGTH"]
precision = column_props["NUMERIC_PRECISION"]

match SQLiteConverter.get_affinity(column_type):
case SQLiteAffinity.INTEGER:
return IntType(bits=64)
case SQLiteAffinity.REAL:
if precision and precision <= 23:
return FloatType(bits=32)
return FloatType(bits=64)
case SQLiteAffinity.TEXT:
return StringType(bytes_=octet_length or SQLITE_MAX_LENGTH)
case SQLiteAffinity.BLOB:
return BytesType(bytes_=octet_length or SQLITE_MAX_LENGTH)
case SQLiteAffinity.NUMERIC:
# NUMERIC affinity may contain values using all five storage classes
return UnionType(
types=[
NullType(),
IntType(bits=64),
FloatType(bits=64),
StringType(bytes_=SQLITE_MAX_LENGTH),
BytesType(bytes_=SQLITE_MAX_LENGTH),
]
)
case _:
raise ValueError(
f"Unsupported `{column_type}` type for `{column_name}`"
)

@staticmethod
def get_affinity(column_type: str | None) -> SQLiteAffinity:
"""
Encode affinity rules as defined here:

https://www.sqlite.org/datatype3.html#determination_of_column_affinity

:param column_type: The column type to determine the affinity of.
:return: The affinity of the column type.
"""

column_type = (column_type or "").upper()

if not column_type:
return SQLiteAffinity.BLOB
elif "INT" in column_type:
return SQLiteAffinity.INTEGER
elif "CHAR" in column_type or "TEXT" in column_type or "CLOB" in column_type:
return SQLiteAffinity.TEXT
elif "BLOB" in column_type:
return SQLiteAffinity.BLOB
elif "REAL" in column_type or "FLOA" in column_type or "DOUB" in column_type:
return SQLiteAffinity.REAL
else:
return SQLiteAffinity.NUMERIC
20 changes: 13 additions & 7 deletions recap/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ def safe_urls(self) -> list[str]:
(
split_url.scheme,
netloc,
split_url.path.strip("/"),
split_url.path.rstrip("/"),
split_url.query,
split_url.fragment,
)
)

# urlunsplit does not double slashes if netloc is empty. But most
# URLs with empty netloc should have a double slash (e.g.
# bigquery:// or sqlite:///some/file.db).
if not netloc:
sanitized_url = sanitized_url.replace(":", "://", 1)

safe_urls_list.append(sanitized_url)

return safe_urls_list
Expand Down Expand Up @@ -107,17 +113,17 @@ def unsafe_url(self, url: str, strict: bool = True) -> str:
(
url_split.scheme,
netloc,
url_path.strip("/"),
url_path.rstrip("/"),
query,
url_split.fragment or unsafe_url_split.fragment,
)
)

# Unsplit returns a URL with a trailing colon if the URL only
# has a scheme. This looks weird, so include trailing double
# slash (e.g. bigquery: to bigquery://).
if merged_url == f"{url_split.scheme}:":
merged_url += "//"
# urlunsplit does not double slashes if netloc is empty. But most
# URLs with empty netloc should have a double slash (e.g.
# bigquery:// or sqlite:///some/file.db).
if not netloc:
merged_url = merged_url.replace(":", "://", 1)

return merged_url

Expand Down
14 changes: 13 additions & 1 deletion recap/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,20 @@ def make_nullable(self) -> UnionType:
# Move field name to the union type
if "name" in extra_attrs:
union_attrs["name"] = extra_attrs.pop("name")
# Create a copy of the type with doc, default, and name removed
type_copy = RecapTypeClass(**attrs, **extra_attrs)
return UnionType([NullType(), type_copy], **union_attrs)
if RecapTypeClass == UnionType:
# Avoid UnionType(types=[NullType(), UnionType(...)])
# Instead, just add NullType and default=None to the existing union
type_copy = UnionType(**attrs, **extra_attrs, **union_attrs)
else:
type_copy = UnionType([type_copy], **union_attrs)
# If a NullType isn't in the UnionType, add it. Can't do `NullType() in
# type_copy.types` because equality checks extra_attrs, which can vary.
# Instead, just look for any NullType instance.
if not list(filter(lambda t: isinstance(t, NullType), type_copy.types)):
type_copy.types.insert(0, NullType())
return type_copy

def validate(self) -> None:
# Default to valid type
Expand Down
2 changes: 1 addition & 1 deletion tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ services:
retries: 5

hive-metastore:
image: ghcr.io/criccomini/hive-metastore-standalone:latest
image: ghcr.io/recap-build/hive-metastore-standalone:latest
ports:
- "9083:9083"
5 changes: 4 additions & 1 deletion tests/integration/server/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def teardown_class(cls):
def test_ls_root(self):
response = client.get("/ls")
assert response.status_code == 200
assert response.json() == ["postgresql://localhost:5432/testdb"]
assert response.json() == [
"postgresql://localhost:5432/testdb",
"sqlite:///file:mem1?mode=memory&cache=shared&uri=true",
]

def test_ls_subpath(self):
response = client.get("/ls/postgresql://localhost:5432/testdb")
Expand Down
Loading
Loading