Skip to content
77 changes: 48 additions & 29 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
DB_VERSION_LEGACY_KEY,
JSON_FILENAME,
SQL_FILENAME,
TAG_CHILDREN_QUERY,
)
from tagstudio.core.library.alchemy.db import make_tables
from tagstudio.core.library.alchemy.enums import (
Expand Down Expand Up @@ -555,6 +554,20 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
# Convert file extension list to ts_ignore file, if a .ts_ignore file does not exist
self.migrate_sql_to_ts_ignore(library_dir)

session.execute(
text("CREATE INDEX IF NOT EXISTS idx_tags_name_shorthand ON tags (name, shorthand)")
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_parents_child_id ON tag_parents (child_id)"
)
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_entries_entry_id ON tag_entries (entry_id)"
)
)

# Update DB_VERSION
if loaded_db_version < DB_VERSION:
self.set_version(DB_VERSION_CURRENT_KEY, DB_VERSION)
Expand Down Expand Up @@ -1054,55 +1067,61 @@ def search_library(

return res

def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]:
def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]:
"""Return a list of Tag records matching the query."""
name = name or ""
name = name.lower()

def sort_key(text: str):
return (not text.startswith(name), len(text), text)

with Session(self.engine) as session:
query = select(Tag).outerjoin(TagAlias).order_by(func.lower(Tag.name))
query = query.options(
selectinload(Tag.parent_tags),
selectinload(Tag.aliases),
)
if limit > 0:
query = query.limit(limit)
query = select(Tag.id, Tag.name)

if limit > 0 and not name:
query = query.limit(limit).order_by(func.lower(Tag.name))
Comment thread
Computerdores marked this conversation as resolved.
Outdated

if name:
query = query.where(
or_(
Tag.name.icontains(name),
Tag.shorthand.icontains(name),
TagAlias.name.icontains(name),
)
)

direct_tags = set(session.scalars(query))
ancestor_tag_ids: list[Tag] = []
for tag in direct_tags:
ancestor_tag_ids.extend(
list(session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag.id}))
)

ancestor_tags = session.scalars(
select(Tag)
.where(Tag.id.in_(ancestor_tag_ids))
.options(selectinload(Tag.parent_tags), selectinload(Tag.aliases))
)
tags = list(session.execute(query))

res = [
direct_tags,
{at for at in ancestor_tags if at not in direct_tags},
]
if name:
query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name))
tags.extend(session.execute(query))

tags.sort(key=lambda t: sort_key(t[1]))
seen_ids = set()
tag_ids = []
for row in tags:
id = row[0]
if id in seen_ids:
continue
tag_ids.append(id)
seen_ids.add(id)
Comment thread
Computerdores marked this conversation as resolved.
Outdated

logger.info(
"searching tags",
search=name,
limit=limit,
statement=str(query),
results=len(res),
results=len(tag_ids),
)

session.expunge_all()
if limit <= 0:
limit = len(tag_ids)
tag_ids = tag_ids[:limit]

return res
hierarchy = self.get_tag_hierarchy(tag_ids)
direct_tags = [hierarchy.pop(id) for id in tag_ids]
ancestor_tags = list(hierarchy.values())
ancestor_tags.sort(key=lambda t: sort_key(t.name))
return direct_tags, ancestor_tags

def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool:
"""Set the path field of an entry.
Expand Down
27 changes: 5 additions & 22 deletions src/tagstudio/qt/mixed/tag_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,32 +218,15 @@ def update_tags(self, query: str | None = None):
self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater()
self.create_button_in_layout = False

# Get results for the search query
query_lower = "" if not query else query.lower()
# Only use the tag limit if it's an actual number (aka not "All Tags")
tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit)
if self.exclude:
tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude}
tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude}

# Sort and prioritize the results
results_0 = list(tag_results[0])
results_0.sort(key=lambda tag: tag.name.lower())
results_1 = list(tag_results[1])
results_1.sort(key=lambda tag: tag.name.lower())
raw_results = list(results_0 + results_1)
priority_results: set[Tag] = set()
all_results: list[Tag] = []
direct_tags, ancestor_tags = self.lib.search_tags(name=query, limit=tag_limit)

if query and query.strip():
for tag in raw_results:
if tag.name.lower().startswith(query_lower):
priority_results.add(tag)
all_results = [t for t in direct_tags if t.id not in self.exclude]
for tag in ancestor_tags:
if tag.id not in self.exclude:
all_results.append(tag)
Comment thread
Computerdores marked this conversation as resolved.
Outdated
Comment thread
Computerdores marked this conversation as resolved.
Outdated

all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
r for r in raw_results if r not in priority_results
]
if tag_limit > 0:
all_results = all_results[:tag_limit]

Expand Down
8 changes: 4 additions & 4 deletions tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_library_search(library: Library, entry_full: Entry):
def test_tag_search(library: Library):
tag = library.tags[0]

assert library.search_tags(tag.name.lower())
assert library.search_tags(tag.name.upper())
assert library.search_tags(tag.name[2:-2])
assert library.search_tags(tag.name * 2) == [set(), set()]
assert library.search_tags(tag.name.lower())[0]
assert library.search_tags(tag.name.upper())[0]
assert library.search_tags(tag.name[2:-2])[0]
assert library.search_tags(tag.name * 2) == ([], [])


def test_get_entry(library: Library, entry_min: Entry):
Expand Down