diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 72bd9f2b47..64e77f8140 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -37,6 +37,8 @@ from .query import ( AndQuery, FieldQuery, + FieldQueryType, + FieldSort, MatchQuery, NullSort, Query, @@ -47,6 +49,15 @@ if TYPE_CHECKING: from types import TracebackType + from .query import SQLiteType + + D = TypeVar("D", bound="Database", default=Any) +else: + D = TypeVar("D", bound="Database") + + +FlexAttrs = dict[str, str] + class DBAccessError(Exception): """The SQLite database became inaccessible. @@ -236,7 +247,7 @@ def __len__(self) -> int: # Abstract base for model classes. -class Model(ABC): +class Model(ABC, Generic[D]): """An abstract object representing an object in the database. Model objects act like dictionaries (i.e., they allow subscript access like ``obj['field']``). The same field set is available via attribute @@ -284,12 +295,12 @@ class Model(ABC): """Optional Types for non-fixed (i.e., flexible and computed) fields. """ - _sorts: dict[str, type[Sort]] = {} + _sorts: dict[str, type[FieldSort]] = {} """Optional named sort criteria. The keys are strings and the values are subclasses of `Sort`. """ - _queries: dict[str, type[FieldQuery]] = {} + _queries: dict[str, FieldQueryType] = {} """Named queries that use a field-like `name:value` syntax but which do not relate to any specific field. """ @@ -306,7 +317,7 @@ class Model(ABC): """ @cached_classproperty - def _relation(cls) -> type[Model]: + def _relation(cls): """The model that this model is closely related to.""" return cls @@ -347,7 +358,7 @@ def _template_funcs(self) -> Mapping[str, Callable[[str], str]]: # Basic operation. - def __init__(self, db: Database | None = None, **values): + def __init__(self, db: D | None = None, **values): """Create a new object with an optional Database association and initial field values. """ @@ -363,7 +374,7 @@ def __init__(self, db: Database | None = None, **values): @classmethod def _awaken( cls: type[AnyModel], - db: Database | None = None, + db: D | None = None, fixed_values: dict[str, Any] = {}, flex_values: dict[str, Any] = {}, ) -> AnyModel: @@ -393,7 +404,7 @@ def clear_dirty(self): if self._db: self._revision = self._db.revision - def _check_db(self, need_id: bool = True) -> Database: + def _check_db(self, need_id: bool = True) -> D: """Ensure that this object is associated with a database row: it has a reference to a database (`_db`) and an id. A ValueError exception is raised otherwise. @@ -574,7 +585,7 @@ def store(self, fields: Iterable[str] | None = None): # Build assignments for query. assignments = [] - subvars = [] + subvars: list[SQLiteType] = [] for key in fields: if key != "id" and key in self._dirty: self._dirty.remove(key) @@ -637,7 +648,7 @@ def remove(self): f"DELETE FROM {self._flex_table} WHERE entity_id=?", (self.id,) ) - def add(self, db: Database | None = None): + def add(self, db: D | None = None): """Add the object to the library database. This object must be associated with a database; you can provide one via the `db` parameter or use the currently associated database. @@ -714,7 +725,7 @@ def field_query( cls, field, pattern, - query_cls: type[FieldQuery] = MatchQuery, + query_cls: FieldQueryType = MatchQuery, ) -> FieldQuery: """Get a `FieldQuery` for this model.""" return query_cls(field, pattern, field in cls._fields) @@ -722,8 +733,8 @@ def field_query( @classmethod def all_fields_query( cls: type[Model], - pats: Mapping, - query_cls: type[FieldQuery] = MatchQuery, + pats: Mapping[str, str], + query_cls: FieldQueryType = MatchQuery, ): """Get a query that matches many fields with different patterns. @@ -749,8 +760,8 @@ class Results(Generic[AnyModel]): def __init__( self, model_class: type[AnyModel], - rows: list[Mapping], - db: Database, + rows: list[sqlite3.Row], + db: D, flex_rows, query: Query | None = None, sort=None, @@ -834,9 +845,9 @@ def __iter__(self) -> Iterator[AnyModel]: # Objects are pre-sorted (i.e., by the database). return self._get_objects() - def _get_indexed_flex_attrs(self) -> Mapping: + def _get_indexed_flex_attrs(self) -> dict[int, FlexAttrs]: """Index flexible attributes by the entity id they belong to""" - flex_values: dict[int, dict[str, Any]] = {} + flex_values: dict[int, FlexAttrs] = {} for row in self.flex_rows: if row["entity_id"] not in flex_values: flex_values[row["entity_id"]] = {} @@ -845,7 +856,9 @@ def _get_indexed_flex_attrs(self) -> Mapping: return flex_values - def _make_model(self, row, flex_values: dict = {}) -> AnyModel: + def _make_model( + self, row: sqlite3.Row, flex_values: FlexAttrs = {} + ) -> AnyModel: """Create a Model object for the given row""" cols = dict(row) values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"} @@ -954,14 +967,16 @@ def __exit__( self._mutated = False self.db._db_lock.release() - def query(self, statement: str, subvals: Sequence = ()) -> list: + def query( + self, statement: str, subvals: Sequence[SQLiteType] = () + ) -> list[sqlite3.Row]: """Execute an SQL statement with substitution values and return a list of rows from the database. """ cursor = self.db._connection().execute(statement, subvals) return cursor.fetchall() - def mutate(self, statement: str, subvals: Sequence = ()) -> Any: + def mutate(self, statement: str, subvals: Sequence[SQLiteType] = ()) -> Any: """Execute an SQL statement with substitution values and return the row ID of the last affected row. """ @@ -1122,7 +1137,7 @@ def _close(self): conn.close() @contextlib.contextmanager - def _tx_stack(self) -> Generator[list]: + def _tx_stack(self) -> Generator[list[Transaction]]: """A context manager providing access to the current thread's transaction stack. The context manager synchronizes access to the stack map. Transactions should never migrate across threads. diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 0b52b0f222..da621a767a 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -19,7 +19,7 @@ import re import unicodedata from abc import ABC, abstractmethod -from collections.abc import Collection, Iterator, MutableSequence, Sequence +from collections.abc import Iterator, MutableSequence, Sequence from datetime import datetime, timedelta from functools import reduce from operator import mul, or_ @@ -30,6 +30,11 @@ if TYPE_CHECKING: from beets.dbcore import Model + from beets.dbcore.db import AnyModel + + P = TypeVar("P", default=Any) +else: + P = TypeVar("P") class ParsingError(ValueError): @@ -107,9 +112,9 @@ def __hash__(self) -> int: return hash(type(self)) -P = TypeVar("P") -SQLiteType = Union[str, bytes, float, int, memoryview] +SQLiteType = Union[str, bytes, float, int, memoryview, None] AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType) +FieldQueryType = type["FieldQuery"] class FieldQuery(Query, Generic[P]): @@ -289,7 +294,7 @@ def _normalize(s: str) -> str: return unicodedata.normalize("NFC", s) @classmethod - def string_match(cls, pattern: Pattern, value: str) -> bool: + def string_match(cls, pattern: Pattern[str], value: str) -> bool: return pattern.search(cls._normalize(value)) is not None @@ -451,7 +456,7 @@ def field_names(self) -> set[str]: """Return a set with field names that this query operates on.""" return reduce(or_, (sq.field_names for sq in self.subqueries)) - def __init__(self, subqueries: Sequence = ()): + def __init__(self, subqueries: Sequence[Query] = ()): self.subqueries = subqueries # Act like a sequence. @@ -462,7 +467,7 @@ def __len__(self) -> int: def __getitem__(self, key): return self.subqueries[key] - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[Query]: return iter(self.subqueries) def __contains__(self, subq) -> bool: @@ -476,7 +481,7 @@ def clause_with_joiner( all subqueries with the string joiner (padded by spaces). """ clause_parts = [] - subvals = [] + subvals: list[SQLiteType] = [] for subq in self.subqueries: subq_clause, subq_subvals = subq.clause() if not subq_clause: @@ -511,7 +516,7 @@ def field_names(self) -> set[str]: """Return a set with field names that this query operates on.""" return set(self.fields) - def __init__(self, pattern, fields, cls: type[FieldQuery]): + def __init__(self, pattern, fields, cls: FieldQueryType): self.pattern = pattern self.fields = fields self.query_class = cls @@ -549,7 +554,7 @@ class MutableCollectionQuery(CollectionQuery): query is initialized. """ - subqueries: MutableSequence + subqueries: MutableSequence[Query] def __setitem__(self, key, value): self.subqueries[key] = value @@ -894,7 +899,7 @@ def order_clause(self) -> str | None: """ return None - def sort(self, items: list) -> list: + def sort(self, items: list[AnyModel]) -> list[AnyModel]: """Sort the list of objects and return a list.""" return sorted(items) @@ -988,7 +993,7 @@ def __init__( self.ascending = ascending self.case_insensitive = case_insensitive - def sort(self, objs: Collection): + def sort(self, objs: list[AnyModel]) -> list[AnyModel]: # TODO: Conversion and null-detection here. In Python 3, # comparisons with None fail. We should also support flexible # attributes with different types without falling over. @@ -1047,7 +1052,7 @@ def is_slow(self) -> bool: class NullSort(Sort): """No sorting. Leave results unsorted.""" - def sort(self, items: list) -> list: + def sort(self, items: list[AnyModel]) -> list[AnyModel]: return items def __nonzero__(self) -> bool: @@ -1061,3 +1066,23 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return 0 + + +class SmartArtistSort(FieldSort): + """Sort by artist (either album artist or track artist), + prioritizing the sort field over the raw field. + """ + + def order_clause(self): + order = "ASC" if self.ascending else "DESC" + collate = "COLLATE NOCASE" if self.case_insensitive else "" + field = self.field + + return f"COALESCE(NULLIF({field}_sort, ''), {field}) {collate} {order}" + + def sort(self, objs: list[AnyModel]) -> list[AnyModel]: + def key(o): + val = o[f"{self.field}_sort"] or o[self.field] + return val.lower() if self.case_insensitive else val + + return sorted(objs, key=key, reverse=not self.ascending) diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index f71f9c25c8..2896326680 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -25,7 +25,9 @@ if TYPE_CHECKING: from collections.abc import Collection, Sequence - from .query import Sort + from .query import FieldQueryType, Sort + + Prefixes = dict[str, FieldQueryType] PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. @@ -41,10 +43,10 @@ def parse_query_part( part: str, - query_classes: dict[str, type[query.FieldQuery]] = {}, - prefixes: dict = {}, + query_classes: dict[str, FieldQueryType] = {}, + prefixes: Prefixes = {}, default_class: type[query.SubstringQuery] = query.SubstringQuery, -) -> tuple[str | None, str, type[query.FieldQuery], bool]: +) -> tuple[str | None, str, FieldQueryType, bool]: """Parse a single *query part*, which is a chunk of a complete query string representing a single criterion. @@ -111,7 +113,7 @@ def parse_query_part( def construct_query_part( model_cls: type[Model], - prefixes: dict, + prefixes: Prefixes, query_part: str, ) -> query.Query: """Parse a *query part* string and return a :class:`Query` object. @@ -133,7 +135,7 @@ def construct_query_part( # Use `model_cls` to build up a map from field (or query) names to # `Query` classes. - query_classes: dict[str, type[query.FieldQuery]] = {} + query_classes: dict[str, FieldQueryType] = {} for k, t in itertools.chain( model_cls._fields.items(), model_cls._types.items() ): @@ -179,7 +181,7 @@ def construct_query_part( def query_from_strings( query_cls: type[query.CollectionQuery], model_cls: type[Model], - prefixes: dict, + prefixes: Prefixes, query_parts: Collection[str], ) -> query.Query: """Creates a collection query of type `query_cls` from a list of @@ -213,16 +215,16 @@ def construct_sort_part( assert direction in ("+", "-"), "part must end with + or -" is_ascending = direction == "+" - if field in model_cls._sorts: - sort = model_cls._sorts[field]( - model_cls, is_ascending, case_insensitive - ) + if sort_cls := model_cls._sorts.get(field): + if isinstance(sort_cls, query.SmartArtistSort): + field = "albumartist" if model_cls.__name__ == "Album" else "artist" elif field in model_cls._fields: - sort = query.FixedFieldSort(field, is_ascending, case_insensitive) + sort_cls = query.FixedFieldSort else: # Flexible or computed. - sort = query.SlowFieldSort(field, is_ascending, case_insensitive) - return sort + sort_cls = query.SlowFieldSort + + return sort_cls(field, is_ascending, case_insensitive) def sort_from_strings( @@ -247,7 +249,7 @@ def sort_from_strings( def parse_sorted_query( model_cls: type[Model], parts: list[str], - prefixes: dict = {}, + prefixes: Prefixes = {}, case_insensitive: bool = True, ) -> tuple[query.Query, Sort]: """Given a list of strings, create the `Query` and `Sort` that they diff --git a/beets/dbcore/types.py b/beets/dbcore/types.py index 7c546eb92c..2a64b2ed94 100644 --- a/beets/dbcore/types.py +++ b/beets/dbcore/types.py @@ -18,11 +18,17 @@ import typing from abc import ABC -from typing import Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from beets.util import str2bool -from .query import BooleanQuery, FieldQuery, NumericQuery, SubstringQuery +from .query import ( + BooleanQuery, + FieldQueryType, + NumericQuery, + SQLiteType, + SubstringQuery, +) class ModelType(typing.Protocol): @@ -37,8 +43,12 @@ def __init__(self, value: Any = None): ... # Generic type variables, used for the value type T and null type N (if # nullable, else T and N are set to the same type for the concrete subclasses # of Type). -N = TypeVar("N") -T = TypeVar("T", bound=ModelType) +if TYPE_CHECKING: + N = TypeVar("N", default=Any) + T = TypeVar("T", bound=ModelType, default=Any) +else: + N = TypeVar("N") + T = TypeVar("T", bound=ModelType) class Type(ABC, Generic[T, N]): @@ -51,7 +61,7 @@ class Type(ABC, Generic[T, N]): """The SQLite column type for the value. """ - query: type[FieldQuery] = SubstringQuery + query: FieldQueryType = SubstringQuery """The `Query` subclass to be used when querying the field. """ @@ -107,10 +117,7 @@ def normalize(self, value: Any) -> T | N: # `self.model_type(value)` return cast(T, value) - def from_sql( - self, - sql_value: None | int | float | str | bytes, - ) -> T | N: + def from_sql(self, sql_value: SQLiteType) -> T | N: """Receives the value stored in the SQL backend and return the value to be stored in the model. @@ -131,7 +138,7 @@ def from_sql( else: return self.normalize(sql_value) - def to_sql(self, model_value: Any) -> None | int | float | str | bytes: + def to_sql(self, model_value: Any) -> SQLiteType: """Convert a value as stored in the model object to a value used by the database adapter. """ @@ -234,7 +241,7 @@ class BaseFloat(Type[float, N]): """ sql = "REAL" - query: type[FieldQuery[Any]] = NumericQuery + query: FieldQueryType = NumericQuery model_type = float def __init__(self, digits: int = 1): diff --git a/beets/library.py b/beets/library.py index e5f26e422e..11ef7c416b 100644 --- a/beets/library.py +++ b/beets/library.py @@ -295,47 +295,6 @@ def parse(self, string): return self.null -# Library-specific sort types. - - -class SmartArtistSort(dbcore.query.Sort): - """Sort by artist (either album artist or track artist), - prioritizing the sort field over the raw field. - """ - - def __init__(self, model_cls, ascending=True, case_insensitive=True): - self.album = model_cls is Album - self.ascending = ascending - self.case_insensitive = case_insensitive - - def order_clause(self): - order = "ASC" if self.ascending else "DESC" - field = "albumartist" if self.album else "artist" - collate = "COLLATE NOCASE" if self.case_insensitive else "" - - return f"COALESCE(NULLIF({field}_sort, ''), {field}) {collate} {order}" - - def sort(self, objs): - if self.album: - - def field(a): - return a.albumartist_sort or a.albumartist - - else: - - def field(i): - return i.artist_sort or i.artist - - if self.case_insensitive: - - def key(x): - return field(x).lower() - - else: - key = field - return sorted(objs, key=key, reverse=not self.ascending) - - # Special path format key. PF_KEY_DEFAULT = "default" @@ -381,7 +340,7 @@ def __str__(self): # Item and Album model classes. -class LibModel(dbcore.Model): +class LibModel(dbcore.Model["Library"]): """Shared concrete functionality for Items and Albums.""" # Config key that specifies how an instance should be formatted. @@ -632,7 +591,7 @@ class Item(LibModel): _formatter = FormattedItemMapping - _sorts = {"artist": SmartArtistSort} + _sorts = {"artist": dbcore.query.SmartArtistSort} _queries = {"singleton": SingletonQuery} @@ -1074,9 +1033,9 @@ def destination( The path is returned as a bytestring. ``basedir`` can override the library's base directory for the destination. """ - self._check_db() - basedir = basedir or self._db.directory - path_formats = path_formats or self._db.path_formats + db = self._check_db() + basedir = basedir or db.directory + path_formats = path_formats or db.path_formats # Use a path format based on a query, falling back on the # default. @@ -1117,11 +1076,11 @@ def destination( maxlen = beets.config["max_filename_length"].get(int) if not maxlen: # When zero, try to determine from filesystem. - maxlen = util.max_filename_length(self._db.directory) + maxlen = util.max_filename_length(db.directory) lib_path_str, fallback = util.legalize_path( subpath, - self._db.replacements, + db.replacements, maxlen, os.path.splitext(self.path)[1], ) @@ -1205,8 +1164,8 @@ class Album(LibModel): } _sorts = { - "albumartist": SmartArtistSort, - "artist": SmartArtistSort, + "albumartist": dbcore.query.SmartArtistSort, + "artist": dbcore.query.SmartArtistSort, } # List of keys that are set on an album's items. @@ -1604,7 +1563,8 @@ def __init__( self.path_formats = path_formats self.replacements = replacements - self._memotable = {} # Used for template substitution performance. + # Used for template substitution performance. + self._memotable: dict[tuple[str, ...], str] = {} # Adding objects to the database. diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 820f77029f..94d7650bb6 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -163,7 +163,7 @@ class MoveOperation(Enum): REFLINK_AUTO = 5 -def normpath(path: bytes) -> bytes: +def normpath(path: PathLike) -> bytes: """Provide the canonical form of the path suitable for storing in the database. """