-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SQLite client and converter #424
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import contextmanager | ||
from re import compile as re_compile | ||
from typing import Any, Generator | ||
|
||
from recap.clients.dbapi import Connection | ||
from recap.converters.sqlite import SQLiteAffinity, SQLiteConverter | ||
from recap.types import StructType | ||
|
||
SQLITE3_CONNECT_ARGS = { | ||
"database", | ||
"timeout", | ||
"detect_types", | ||
"isolation_level", | ||
"check_same_thread", | ||
"factory", | ||
"cached_statements", | ||
"uri", | ||
} | ||
|
||
|
||
class SQLiteClient: | ||
def __init__(self, connection: Connection) -> None: | ||
self.connection = connection | ||
self.converter = SQLiteConverter() | ||
|
||
@staticmethod | ||
@contextmanager | ||
def create(url: str, **url_args) -> Generator[SQLiteClient, None, None]: | ||
import sqlite3 | ||
|
||
# Strip sqlite:/// URL prefix | ||
url_args["database"] = url[len("sqlite:///") :] | ||
|
||
# Only include kwargs that are valid for PsycoPG2 parse_dsn() | ||
url_args = {k: v for k, v in url_args.items() if k in SQLITE3_CONNECT_ARGS} | ||
|
||
with sqlite3.connect(**url_args) as client: | ||
yield SQLiteClient(client) # type: ignore | ||
|
||
@staticmethod | ||
def parse(method: str, **url_args) -> tuple[str, list[Any]]: | ||
from urllib.parse import urlunparse | ||
|
||
match method: | ||
case "ls": | ||
return (url_args["url"], []) | ||
case "schema": | ||
table = url_args["paths"].pop(-1) | ||
connection_url = urlunparse( | ||
[ | ||
url_args.get("dialect") or url_args.get("scheme"), | ||
url_args.get("netloc"), | ||
# Include / prefix for paths | ||
"/".join(url_args.get("paths", [])), | ||
url_args.get("params"), | ||
url_args.get("query"), | ||
url_args.get("fragment"), | ||
] | ||
) | ||
|
||
# urlunsplit does not double slashes if netloc is empty. But most | ||
# URLs with empty netloc should have a double slash (e.g. | ||
# bigquery:// or sqlite:///some/file.db). Include an extra "/" | ||
# because the root path is not included with an empty netloc | ||
# and join(). | ||
if not url_args.get("netloc"): | ||
connection_url = connection_url.replace(":", ":///", 1) | ||
|
||
return (connection_url, [table]) | ||
case _: | ||
raise ValueError("Invalid method") | ||
|
||
def ls(self) -> list[str]: | ||
cursor = self.connection.cursor() | ||
cursor.execute("SELECT name FROM sqlite_schema WHERE type='table'") | ||
return [row[0] for row in cursor.fetchall()] | ||
|
||
def schema(self, table: str) -> StructType: | ||
cursor = self.connection.cursor() | ||
|
||
# Validate that table exists since we want to prevent SQL injections in | ||
# the PRAGMA call | ||
if not self._table_exists(table): | ||
raise ValueError(f"Table '{table}' does not exist in the database.") | ||
|
||
cursor.execute(f"PRAGMA table_info({table});") | ||
names = [name[0].upper() for name in cursor.description] | ||
rows = [] | ||
|
||
for row_cells in cursor.fetchall(): | ||
row = dict(zip(names, row_cells)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. <3 python |
||
row = self.add_information_schema(row) | ||
rows.append(row) | ||
|
||
return self.converter.to_recap(rows) | ||
|
||
def add_information_schema(self, row: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
SQLite does not have an INFORMATION_SCHEMA, so we need to add these | ||
columns. | ||
|
||
:param row: A row from the PRAGMA table_info() query. | ||
:return: The row with the INFORMATION_SCHEMA columns added. | ||
""" | ||
|
||
is_not_null = row["NOTNULL"] or row["PK"] | ||
|
||
# Set defaults. | ||
information_schema_cols = { | ||
"COLUMN_NAME": row["NAME"], | ||
"IS_NULLABLE": "NO" if is_not_null else "YES", | ||
"COLUMN_DEFAULT": row["DFLT_VALUE"], | ||
"NUMERIC_PRECISION": None, | ||
"NUMERIC_SCALE": None, | ||
"CHARACTER_OCTET_LENGTH": None, | ||
} | ||
|
||
# Extract precision, scale, and octet length. | ||
# This regex matches the following patterns: | ||
# - <type>(<param1>(, <param2>)?) | ||
# param1 can be a precision or octet length, and param2 can be a scale. | ||
numeric_pattern = re_compile(r"(\w+)\((\d+)(?:,\s*(\d+))?\)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GPT helped me understand this matches strings that look like function calls. Might be helpful to point to some documentation on what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docs. |
||
param_match = numeric_pattern.search(row["TYPE"]) | ||
|
||
if param_match: | ||
# Extract matched values | ||
base_type, precision, scale = param_match.groups() | ||
base_type = base_type.upper() | ||
precision = int(precision) | ||
scale = int(scale) if scale else 0 | ||
|
||
match SQLiteConverter.get_affinity(base_type): | ||
case SQLiteAffinity.INTEGER | SQLiteAffinity.REAL | SQLiteAffinity.NUMERIC: | ||
information_schema_cols["NUMERIC_PRECISION"] = precision | ||
information_schema_cols["NUMERIC_SCALE"] = scale | ||
case SQLiteAffinity.TEXT | SQLiteAffinity.BLOB: | ||
information_schema_cols["CHARACTER_OCTET_LENGTH"] = precision | ||
|
||
return row | information_schema_cols | ||
|
||
def _table_exists(self, table: str) -> bool: | ||
cursor = self.connection.cursor() | ||
cursor.execute( | ||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) | ||
) | ||
return bool(cursor.fetchone()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This confuses me since above I saw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The content isn't as important as whether the row exists or not. I'm doing the The |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from enum import Enum | ||
from typing import Any | ||
|
||
from recap.converters.dbapi import DbapiConverter | ||
from recap.types import ( | ||
BytesType, | ||
FloatType, | ||
IntType, | ||
NullType, | ||
RecapType, | ||
StringType, | ||
UnionType, | ||
) | ||
|
||
# SQLite's maximum length is 2^31-1 bytes, or 2147483647 bytes. | ||
SQLITE_MAX_LENGTH = 2147483647 | ||
|
||
|
||
class SQLiteAffinity(Enum): | ||
""" | ||
SQLite uses column affinity to map non-STRICT table columns to values. See | ||
https://www.sqlite.org/datatype3.html#type_affinity for details. | ||
""" | ||
|
||
INTEGER = "integer" | ||
REAL = "real" | ||
TEXT = "text" | ||
BLOB = "blob" | ||
NUMERIC = "numeric" | ||
|
||
|
||
class SQLiteConverter(DbapiConverter): | ||
def _parse_type(self, column_props: dict[str, Any]) -> RecapType: | ||
column_name = column_props["COLUMN_NAME"] | ||
column_type = column_props["TYPE"] | ||
octet_length = column_props["CHARACTER_OCTET_LENGTH"] | ||
precision = column_props["NUMERIC_PRECISION"] | ||
|
||
match SQLiteConverter.get_affinity(column_type): | ||
case SQLiteAffinity.INTEGER: | ||
return IntType(bits=64) | ||
case SQLiteAffinity.REAL: | ||
if precision and precision <= 23: | ||
return FloatType(bits=32) | ||
return FloatType(bits=64) | ||
case SQLiteAffinity.TEXT: | ||
return StringType(bytes_=octet_length or SQLITE_MAX_LENGTH) | ||
case SQLiteAffinity.BLOB: | ||
return BytesType(bytes_=octet_length or SQLITE_MAX_LENGTH) | ||
case SQLiteAffinity.NUMERIC: | ||
# NUMERIC affinity may contain values using all five storage classes | ||
return UnionType( | ||
types=[ | ||
NullType(), | ||
IntType(bits=64), | ||
FloatType(bits=64), | ||
StringType(bytes_=SQLITE_MAX_LENGTH), | ||
BytesType(bytes_=SQLITE_MAX_LENGTH), | ||
] | ||
) | ||
case _: | ||
raise ValueError( | ||
f"Unsupported `{column_type}` type for `{column_name}`" | ||
) | ||
|
||
@staticmethod | ||
def get_affinity(column_type: str | None) -> SQLiteAffinity: | ||
""" | ||
Encode affinity rules as defined here: | ||
|
||
https://www.sqlite.org/datatype3.html#determination_of_column_affinity | ||
|
||
:param column_type: The column type to determine the affinity of. | ||
:return: The affinity of the column type. | ||
""" | ||
|
||
column_type = (column_type or "").upper() | ||
|
||
if not column_type: | ||
return SQLiteAffinity.BLOB | ||
elif "INT" in column_type: | ||
return SQLiteAffinity.INTEGER | ||
elif "CHAR" in column_type or "TEXT" in column_type or "CLOB" in column_type: | ||
return SQLiteAffinity.TEXT | ||
elif "BLOB" in column_type: | ||
return SQLiteAffinity.BLOB | ||
elif "REAL" in column_type or "FLOA" in column_type or "DOUB" in column_type: | ||
return SQLiteAffinity.REAL | ||
else: | ||
return SQLiteAffinity.NUMERIC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could imagine also wanting to capture schema of views. Doesn't seem necessary for the first go at this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, good point. I'll open a follow-on GH issue for that. I hope it's as easy as
type in ('table', 'view')
😅There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#425