diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index 186cd797a..845ae3a84 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -543,10 +543,18 @@ def search_library( statement = select(Entry) if search.ast: + start_time = time.time() + statement = statement.outerjoin(Entry.tag_box_fields).where( SQLBoolExpressionBuilder(self).visit(search.ast) ) + end_time = time.time() + + logger.info( + f"SQL Expression Builder finished ({format_timespan(end_time - start_time)})" + ) + extensions = self.prefs(LibraryPrefs.EXTENSION_LIST) is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST) diff --git a/tagstudio/src/core/library/alchemy/visitors.py b/tagstudio/src/core/library/alchemy/visitors.py index 1756bb089..5eed4580f 100644 --- a/tagstudio/src/core/library/alchemy/visitors.py +++ b/tagstudio/src/core/library/alchemy/visitors.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING -from sqlalchemy import and_, distinct, func, or_, select +import structlog +from sqlalchemy import and_, distinct, func, or_, select, text from sqlalchemy.orm import Session from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories @@ -16,6 +17,20 @@ else: Library = None # don't import .library because of circular imports +logger = structlog.get_logger(__name__) + +CHILDREN_QUERY = text(""" +-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming +WITH RECURSIVE Subtags AS ( + SELECT :tag_id AS child_id + UNION ALL + SELECT ts.parent_id AS child_id + FROM tag_subtags ts + INNER JOIN Subtags s ON ts.child_id = s.child_id +) +SELECT * FROM Subtags; +""") # noqa: E501 + def get_filetype_equivalency_list(item: str) -> list[str] | set[str]: for s in FILETYPE_EQUIVALENTS: @@ -98,16 +113,28 @@ def visit_property(self, node: Property) -> None: def visit_not(self, node: Not) -> ColumnExpressionArgument: return ~self.__entry_satisfies_ast(node.child) - def __get_tag_ids(self, tag_name: str) -> list[int]: + def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]: """Given a tag name find the ids of all tags that this name could refer to.""" - with Session(self.lib.engine, expire_on_commit=False) as session: - return list( + with Session(self.lib.engine) as session: + tag_ids = list( session.scalars( select(Tag.id) .where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name))) .union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name))) ) ) + if len(tag_ids) > 1: + logger.debug( + f'Tag Constraint "{tag_name}" is ambiguous, {len(tag_ids)} matching tags found', + tag_ids=tag_ids, + include_children=include_children, + ) + if not include_children: + return tag_ids + outp = [] + for tag_id in tag_ids: + outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id}))) + return outp def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]: """Returns Binary Expression that is true if the Entry has all provided tag ids.""" diff --git a/tagstudio/tests/test_search.py b/tagstudio/tests/test_search.py index 7d2e69db1..9f3754b43 100644 --- a/tagstudio/tests/test_search.py +++ b/tagstudio/tests/test_search.py @@ -116,6 +116,20 @@ def test_parentheses(search_library: Library, query: str, count: int): verify_count(search_library, query, count) +@pytest.mark.parametrize( + ["query", "count"], + [ + ("ellipse", 17), + ("yellow", 15), + ("color", 24), + ("shape", 24), + ("yellow not green", 10), + ], +) +def test_parent_tags(search_library: Library, query: str, count: int): + verify_count(search_library, query, count) + + @pytest.mark.parametrize( "invalid_query", ["asd AND", "asd AND AND", "tag:(", "(asd", "asd[]", "asd]", ":", "tag: :"] )