Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from .query import (
AndQuery,
FieldQuery,
FieldQueryType,
FieldSort,
MatchQuery,
NullSort,
Query,
Expand All @@ -47,6 +49,15 @@
if TYPE_CHECKING:
from types import TracebackType

from .query import SQLiteType

D = TypeVar("D", bound="Database", default=Any)
else:
D = TypeVar("D", bound="Database")


FlexAttrs = dict[str, str]


class DBAccessError(Exception):
"""The SQLite database became inaccessible.
Expand Down Expand Up @@ -236,7 +247,7 @@ def __len__(self) -> int:
# Abstract base for model classes.


class Model(ABC):
class Model(ABC, Generic[D]):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., they allow subscript access like
``obj['field']``). The same field set is available via attribute
Expand Down Expand Up @@ -284,12 +295,12 @@ class Model(ABC):
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""

_sorts: dict[str, type[Sort]] = {}
_sorts: dict[str, type[FieldSort]] = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""

_queries: dict[str, type[FieldQuery]] = {}
_queries: dict[str, FieldQueryType] = {}
"""Named queries that use a field-like `name:value` syntax but which
do not relate to any specific field.
"""
Expand All @@ -306,7 +317,7 @@ class Model(ABC):
"""

@cached_classproperty
def _relation(cls) -> type[Model]:
def _relation(cls):
"""The model that this model is closely related to."""
return cls

Expand Down Expand Up @@ -347,7 +358,7 @@ def _template_funcs(self) -> Mapping[str, Callable[[str], str]]:

# Basic operation.

def __init__(self, db: Database | None = None, **values):
def __init__(self, db: D | None = None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
Expand All @@ -363,7 +374,7 @@ def __init__(self, db: Database | None = None, **values):
@classmethod
def _awaken(
cls: type[AnyModel],
db: Database | None = None,
db: D | None = None,
fixed_values: dict[str, Any] = {},
flex_values: dict[str, Any] = {},
) -> AnyModel:
Expand Down Expand Up @@ -393,7 +404,7 @@ def clear_dirty(self):
if self._db:
self._revision = self._db.revision

def _check_db(self, need_id: bool = True) -> Database:
def _check_db(self, need_id: bool = True) -> D:
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
Expand Down Expand Up @@ -574,7 +585,7 @@ def store(self, fields: Iterable[str] | None = None):

# Build assignments for query.
assignments = []
subvars = []
subvars: list[SQLiteType] = []
for key in fields:
if key != "id" and key in self._dirty:
self._dirty.remove(key)
Expand Down Expand Up @@ -637,7 +648,7 @@ def remove(self):
f"DELETE FROM {self._flex_table} WHERE entity_id=?", (self.id,)
)

def add(self, db: Database | None = None):
def add(self, db: D | None = None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
Expand Down Expand Up @@ -714,16 +725,16 @@ def field_query(
cls,
field,
pattern,
query_cls: type[FieldQuery] = MatchQuery,
query_cls: FieldQueryType = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)

@classmethod
def all_fields_query(
cls: type[Model],
pats: Mapping,
query_cls: type[FieldQuery] = MatchQuery,
pats: Mapping[str, str],
query_cls: FieldQueryType = MatchQuery,
):
"""Get a query that matches many fields with different patterns.

Expand All @@ -749,8 +760,8 @@ class Results(Generic[AnyModel]):
def __init__(
self,
model_class: type[AnyModel],
rows: list[Mapping],
db: Database,
rows: list[sqlite3.Row],
db: D,
flex_rows,
query: Query | None = None,
sort=None,
Expand Down Expand Up @@ -834,9 +845,9 @@ def __iter__(self) -> Iterator[AnyModel]:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()

def _get_indexed_flex_attrs(self) -> Mapping:
def _get_indexed_flex_attrs(self) -> dict[int, FlexAttrs]:
"""Index flexible attributes by the entity id they belong to"""
flex_values: dict[int, dict[str, Any]] = {}
flex_values: dict[int, FlexAttrs] = {}
for row in self.flex_rows:
if row["entity_id"] not in flex_values:
flex_values[row["entity_id"]] = {}
Expand All @@ -845,7 +856,9 @@ def _get_indexed_flex_attrs(self) -> Mapping:

return flex_values

def _make_model(self, row, flex_values: dict = {}) -> AnyModel:
def _make_model(
self, row: sqlite3.Row, flex_values: FlexAttrs = {}
) -> AnyModel:
"""Create a Model object for the given row"""
cols = dict(row)
values = {k: v for (k, v) in cols.items() if not k[:4] == "flex"}
Expand Down Expand Up @@ -954,14 +967,16 @@ def __exit__(
self._mutated = False
self.db._db_lock.release()

def query(self, statement: str, subvals: Sequence = ()) -> list:
def query(
self, statement: str, subvals: Sequence[SQLiteType] = ()
) -> list[sqlite3.Row]:
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()

def mutate(self, statement: str, subvals: Sequence = ()) -> Any:
def mutate(self, statement: str, subvals: Sequence[SQLiteType] = ()) -> Any:
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
Expand Down Expand Up @@ -1122,7 +1137,7 @@ def _close(self):
conn.close()

@contextlib.contextmanager
def _tx_stack(self) -> Generator[list]:
def _tx_stack(self) -> Generator[list[Transaction]]:
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
Expand Down
49 changes: 37 additions & 12 deletions beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import unicodedata
from abc import ABC, abstractmethod
from collections.abc import Collection, Iterator, MutableSequence, Sequence
from collections.abc import Iterator, MutableSequence, Sequence
from datetime import datetime, timedelta
from functools import reduce
from operator import mul, or_
Expand All @@ -30,6 +30,11 @@

if TYPE_CHECKING:
from beets.dbcore import Model
from beets.dbcore.db import AnyModel

P = TypeVar("P", default=Any)
else:
P = TypeVar("P")


class ParsingError(ValueError):
Expand Down Expand Up @@ -107,9 +112,9 @@ def __hash__(self) -> int:
return hash(type(self))


P = TypeVar("P")
SQLiteType = Union[str, bytes, float, int, memoryview]
SQLiteType = Union[str, bytes, float, int, memoryview, None]
AnySQLiteType = TypeVar("AnySQLiteType", bound=SQLiteType)
FieldQueryType = type["FieldQuery"]


class FieldQuery(Query, Generic[P]):
Expand Down Expand Up @@ -289,7 +294,7 @@ def _normalize(s: str) -> str:
return unicodedata.normalize("NFC", s)

@classmethod
def string_match(cls, pattern: Pattern, value: str) -> bool:
def string_match(cls, pattern: Pattern[str], value: str) -> bool:
return pattern.search(cls._normalize(value)) is not None


Expand Down Expand Up @@ -451,7 +456,7 @@ def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return reduce(or_, (sq.field_names for sq in self.subqueries))

def __init__(self, subqueries: Sequence = ()):
def __init__(self, subqueries: Sequence[Query] = ()):
self.subqueries = subqueries

# Act like a sequence.
Expand All @@ -462,7 +467,7 @@ def __len__(self) -> int:
def __getitem__(self, key):
return self.subqueries[key]

def __iter__(self) -> Iterator:
def __iter__(self) -> Iterator[Query]:
return iter(self.subqueries)

def __contains__(self, subq) -> bool:
Expand All @@ -476,7 +481,7 @@ def clause_with_joiner(
all subqueries with the string joiner (padded by spaces).
"""
clause_parts = []
subvals = []
subvals: list[SQLiteType] = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
Expand Down Expand Up @@ -511,7 +516,7 @@ def field_names(self) -> set[str]:
"""Return a set with field names that this query operates on."""
return set(self.fields)

def __init__(self, pattern, fields, cls: type[FieldQuery]):
def __init__(self, pattern, fields, cls: FieldQueryType):
self.pattern = pattern
self.fields = fields
self.query_class = cls
Expand Down Expand Up @@ -549,7 +554,7 @@ class MutableCollectionQuery(CollectionQuery):
query is initialized.
"""

subqueries: MutableSequence
subqueries: MutableSequence[Query]

def __setitem__(self, key, value):
self.subqueries[key] = value
Expand Down Expand Up @@ -894,7 +899,7 @@ def order_clause(self) -> str | None:
"""
return None

def sort(self, items: list) -> list:
def sort(self, items: list[AnyModel]) -> list[AnyModel]:
"""Sort the list of objects and return a list."""
return sorted(items)

Expand Down Expand Up @@ -988,7 +993,7 @@ def __init__(
self.ascending = ascending
self.case_insensitive = case_insensitive

def sort(self, objs: Collection):
def sort(self, objs: list[AnyModel]) -> list[AnyModel]:
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
Expand Down Expand Up @@ -1047,7 +1052,7 @@ def is_slow(self) -> bool:
class NullSort(Sort):
"""No sorting. Leave results unsorted."""

def sort(self, items: list) -> list:
def sort(self, items: list[AnyModel]) -> list[AnyModel]:
return items

def __nonzero__(self) -> bool:
Expand All @@ -1061,3 +1066,23 @@ def __eq__(self, other) -> bool:

def __hash__(self) -> int:
return 0


class SmartArtistSort(FieldSort):
"""Sort by artist (either album artist or track artist),
prioritizing the sort field over the raw field.
"""

def order_clause(self):
order = "ASC" if self.ascending else "DESC"
collate = "COLLATE NOCASE" if self.case_insensitive else ""
field = self.field

return f"COALESCE(NULLIF({field}_sort, ''), {field}) {collate} {order}"

def sort(self, objs: list[AnyModel]) -> list[AnyModel]:
def key(o):
val = o[f"{self.field}_sort"] or o[self.field]
return val.lower() if self.case_insensitive else val

return sorted(objs, key=key, reverse=not self.ascending)
Loading