diff --git a/sqeleton/abcs/compiler.py b/sqeleton/abcs/compiler.py index dee9f69..9cdbacf 100644 --- a/sqeleton/abcs/compiler.py +++ b/sqeleton/abcs/compiler.py @@ -4,8 +4,7 @@ class AbstractCompiler(ABC): @abstractmethod - def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: - ... + def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: ... class Compilable(ABC): diff --git a/sqeleton/abcs/database_types.py b/sqeleton/abcs/database_types.py index 59843c5..4b0c77d 100644 --- a/sqeleton/abcs/database_types.py +++ b/sqeleton/abcs/database_types.py @@ -126,9 +126,7 @@ class String_FixedAlphanum(String_Alphanum): def make_value(self, value): if len(value) != self.length: - raise ValueError( - f"Expected alphanumeric value of length {self.length}, but got '{value}'." - ) + raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.") return self.python_type(value, max_len=self.length) diff --git a/sqeleton/bound_exprs.py b/sqeleton/bound_exprs.py index b193de2..aa85de5 100644 --- a/sqeleton/bound_exprs.py +++ b/sqeleton/bound_exprs.py @@ -56,7 +56,7 @@ def with_schema(self, schema): table_path = self.node.replace(schema=schema) return self.replace(node=table_path) - def query_schema(self, *, refine: bool = True, refine_where = None, case_sensitive=True): + def query_schema(self, *, refine: bool = True, refine_where=None, case_sensitive=True): table_path = self.node if table_path.schema: @@ -77,5 +77,6 @@ def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tu if TYPE_CHECKING: + class BoundTable(BoundTable, TablePath): pass diff --git a/sqeleton/databases/base.py b/sqeleton/databases/base.py index 100c8ea..ec2c1be 100644 --- a/sqeleton/databases/base.py +++ b/sqeleton/databases/base.py @@ -331,16 +331,13 @@ def compile(self, sql_ast): # logger.setLevel(level) @overload - def query(self, query_input: QueryInput) -> Any: - ... + def query(self, query_input: QueryInput) -> Any: ... @overload - def query(self, query_input: QueryInput, res_type: None) -> Any: - ... + def query(self, query_input: QueryInput, res_type: None) -> Any: ... @overload - def query(self, query_input: QueryInput, res_type: Type[TRes]) -> TRes: - ... + def query(self, query_input: QueryInput, res_type: Type[TRes]) -> TRes: ... def query(self, query_input, res_type=None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' diff --git a/sqeleton/databases/clickhouse.py b/sqeleton/databases/clickhouse.py index f5cf620..76b4aab 100644 --- a/sqeleton/databases/clickhouse.py +++ b/sqeleton/databases/clickhouse.py @@ -102,7 +102,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: class Dialect(BaseDialect): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False - ARG_SYMBOL = None # TODO Clickhouse only supports named parameters, not positional + ARG_SYMBOL = None # TODO Clickhouse only supports named parameters, not positional TYPE_CLASSES = { "Int8": Integer, "Int16": Integer, diff --git a/sqeleton/databases/databricks.py b/sqeleton/databases/databricks.py index 9d11dd6..ac6d097 100644 --- a/sqeleton/databases/databricks.py +++ b/sqeleton/databases/databricks.py @@ -178,7 +178,9 @@ def _process_table_schema( self._refine_coltypes(path, col_dict, where) return col_dict - def process_query_table_schema(self, path: DbPath, raw_schema: Dict[str, Tuple], refine: bool = True, refine_where: Optional[str] = None) -> Tuple[Dict[str, ColType], Optional[list]]: + def process_query_table_schema( + self, path: DbPath, raw_schema: Dict[str, Tuple], refine: bool = True, refine_where: Optional[str] = None + ) -> Tuple[Dict[str, ColType], Optional[list]]: if not refine: raise NotImplementedError() return self._process_table_schema(path, raw_schema, list(raw_schema), refine_where), None diff --git a/sqeleton/databases/presto.py b/sqeleton/databases/presto.py index 606169d..4f4416b 100644 --- a/sqeleton/databases/presto.py +++ b/sqeleton/databases/presto.py @@ -23,7 +23,17 @@ Native_UUID, ) from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue -from .base import BaseDialect, Database, QueryResult, import_helper, ThreadLocalInterpreter, Mixin_Schema, Mixin_RandomSample, SqlCode, logger +from .base import ( + BaseDialect, + Database, + QueryResult, + import_helper, + ThreadLocalInterpreter, + Mixin_Schema, + Mixin_RandomSample, + SqlCode, + logger, +) from .base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -76,7 +86,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: class Dialect(BaseDialect, Mixin_Schema): name = "Presto" ROUNDS_ON_PREC_LOSS = True - ARG_SYMBOL = None # Not implemented by Presto + ARG_SYMBOL = None # Not implemented by Presto TYPE_CLASSES = { # Timestamps "timestamp with time zone": TimestampTZ, @@ -186,7 +196,7 @@ def _query(self, sql_code: SqlCode) -> Optional[QueryResult]: if isinstance(sql_code, ThreadLocalInterpreter): return sql_code.apply_queries(partial(query_cursor, c)) elif isinstance(sql_code, str): - sql_code = CompiledCode(sql_code, [], None) # Unknown type. #TODO: Should we guess? + sql_code = CompiledCode(sql_code, [], None) # Unknown type. #TODO: Should we guess? return query_cursor(c, sql_code) diff --git a/sqeleton/databases/trino.py b/sqeleton/databases/trino.py index f07b149..08acf00 100644 --- a/sqeleton/databases/trino.py +++ b/sqeleton/databases/trino.py @@ -23,7 +23,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: else: s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" + return ( + f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')" + ) def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: if isinstance(coltype, String_UUID): @@ -55,9 +57,7 @@ def __init__(self, **kw): self.default_schema = kw.get("schema") if kw.get("password"): - kw["auth"] = trino.auth.BasicAuthentication( - kw.pop("user"), kw.pop("password") - ) + kw["auth"] = trino.auth.BasicAuthentication(kw.pop("user"), kw.pop("password")) kw["http_scheme"] = "https" cert = kw.pop("cert", None) @@ -65,7 +65,6 @@ def __init__(self, **kw): if cert is not None: self._conn._http_session.verify = cert - @property def is_autocommit(self) -> bool: - return True \ No newline at end of file + return True diff --git a/sqeleton/queries/extras.py b/sqeleton/queries/extras.py index 382407c..f2e8c90 100644 --- a/sqeleton/queries/extras.py +++ b/sqeleton/queries/extras.py @@ -9,8 +9,6 @@ from .ast_classes import Expr, ExprNode, Concat, Code - - @dataclass class NormalizeAsString(ExprNode): expr: ExprNode @@ -35,7 +33,6 @@ def compile_node(c: Compiler, n: NormalizeAsString) -> str: expr = c.compile(n.expr) return c.dialect.normalize_value_by_type(expr, n.expr_type or n.expr.type) - @md def compile_node(c: Compiler, n: ApplyFuncAndNormalizeAsString) -> str: expr = n.expr @@ -56,7 +53,6 @@ def compile_node(c: Compiler, n: ApplyFuncAndNormalizeAsString) -> str: return c.compile(expr) - @md def compile_node(c: Compiler, n: Checksum) -> str: if len(n.exprs) > 1: diff --git a/sqeleton/schema.py b/sqeleton/schema.py index a557ba8..64fd4f6 100644 --- a/sqeleton/schema.py +++ b/sqeleton/schema.py @@ -10,6 +10,7 @@ Schema = CaseAwareMapping + class TableType: pass # TODO: This should replace the current Schema type @@ -21,6 +22,7 @@ def is_superclass(cls, t): SchemaInput = Union[Type[TableType], Schema, dict] + @dataclass class Options: default: Any = None @@ -29,11 +31,13 @@ class Options: # TODO: foreign_key, unique # TODO: index? + @dataclass class _Field: type: type options: Options + class _Schema(CaseAwareMapping[Union[type, _Field]]): pass @@ -41,6 +45,7 @@ class _Schema(CaseAwareMapping[Union[type, _Field]]): def make(cls, schema: SchemaInput): assert schema if TableType.is_superclass(schema): + def _make_field(k: str, v: type): field = getattr(schema, k) if field: @@ -49,7 +54,7 @@ def _make_field(k: str, v: type): return _Field(v, field) return v - schema = CaseSensitiveDict({k:_make_field(k, v) for k,v in schema.__annotations__.items()}) + schema = CaseSensitiveDict({k: _make_field(k, v) for k, v in schema.__annotations__.items()}) elif isinstance(schema, CaseAwareMapping): pass @@ -59,7 +64,8 @@ def _make_field(k: str, v: type): return schema -def options(**kw) -> Any: # Any, so that type-checking doesn't complain + +def options(**kw) -> Any: # Any, so that type-checking doesn't complain return Options(**kw) diff --git a/sqeleton/utils.py b/sqeleton/utils.py index 1ceee4a..ac1e6e5 100644 --- a/sqeleton/utils.py +++ b/sqeleton/utils.py @@ -92,12 +92,10 @@ def match_regexps(regexps: Dict[str, Any], s: str) -> Generator[tuple, None, Non class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod - def get_key(self, key: str) -> str: - ... + def get_key(self, key: str) -> str: ... @abstractmethod - def __init__(self, initial): - ... + def __init__(self, initial): ... def new(self, initial=()): return type(self)(initial)