diff --git a/tagstudio/src/qt/flowlayout.py b/tagstudio/src/qt/flowlayout.py index 6334cf953..b37a45182 100644 --- a/tagstudio/src/qt/flowlayout.py +++ b/tagstudio/src/qt/flowlayout.py @@ -140,4 +140,4 @@ def _do_layout(self, rect: QRect, test_only: bool) -> float: x = next_x line_height = max(line_height, item.sizeHint().height()) - return y + line_height - rect.y() * ((len(self._item_list)) / len(self._item_list)) + return y + line_height - rect.y() diff --git a/tagstudio/src/qt/modals/tag_search.py b/tagstudio/src/qt/modals/tag_search.py index c44278fde..da80b31d2 100644 --- a/tagstudio/src/qt/modals/tag_search.py +++ b/tagstudio/src/qt/modals/tag_search.py @@ -4,6 +4,7 @@ import math +from typing import Optional import structlog from PySide6.QtCore import QSize, Qt, Signal @@ -20,7 +21,7 @@ from src.core.library.alchemy.enums import FilterState from src.core.palette import ColorType, get_tag_color from src.qt.widgets.panel import PanelWidget -from src.qt.widgets.tag import TagWidget +from src.qt.widgets.tag import Tag, TagWidget logger = structlog.get_logger(__name__) @@ -31,6 +32,7 @@ class TagSearchPanel(PanelWidget): def __init__(self, library: Library): super().__init__() self.lib = library + self.first_tag: Optional[Tag] = None self.first_tag_id = None self.tag_limit = 100 self.setMinimumSize(300, 400) @@ -63,14 +65,16 @@ def __init__(self, library: Library): self.update_tags() def on_return(self, text: str): - if text and self.first_tag_id is not None: - # callback(self.first_tag_id) - self.tag_chosen.emit(self.first_tag_id) + if text and self.first_tag is not None: + # callback(self.first_tag) + self.tag_chosen.emit(self.first_tag.id) self.search_field.setText("") self.update_tags() + return True else: self.search_field.setFocus() self.parentWidget().hide() + return False def update_tags(self, name: str | None = None): while self.scroll_layout.count(): @@ -78,11 +82,13 @@ def update_tags(self, name: str | None = None): found_tags = self.lib.search_tags( FilterState( - path=name, + tag=name, page_size=self.tag_limit, ) ) + self.first_tag = found_tags[0] if found_tags else None + for tag in found_tags: c = QWidget() layout = QHBoxLayout(c) diff --git a/tagstudio/src/qt/widgets/tag_box.py b/tagstudio/src/qt/widgets/tag_box.py index c24a3519e..d7df3abc9 100644 --- a/tagstudio/src/qt/widgets/tag_box.py +++ b/tagstudio/src/qt/widgets/tag_box.py @@ -7,10 +7,10 @@ import typing import structlog -from PySide6.QtCore import Qt, Signal -from PySide6.QtWidgets import QPushButton +from PySide6.QtCore import QObject, QStringListModel, Qt, Signal +from PySide6.QtWidgets import QCompleter, QHBoxLayout, QLineEdit, QPushButton, QVBoxLayout from src.core.constants import TAG_ARCHIVED, TAG_FAVORITE -from src.core.library import Entry, Tag +from src.core.library import Entry, Library, Tag from src.core.library.alchemy.enums import FilterState from src.core.library.alchemy.fields import TagBoxField from src.qt.flowlayout import FlowLayout @@ -26,6 +26,20 @@ logger = structlog.get_logger(__name__) +class TagCompleter(QCompleter): + def __init__(self, parent: QObject, lib: Library): + super().__init__(parent) + self.lib = lib + self.update(set()) + + def update(self, exclude: set[str]): + tags = {tag.name for tag in self.lib.tags} + tags -= exclude + model = QStringListModel(list(tags), self) + self.first_choice = model.stringList()[0] + self.setModel(model) + + class TagBoxWidget(FieldWidget): updated = Signal() error_occurred = Signal(Exception) @@ -45,11 +59,24 @@ def __init__( driver # Used for creating tag click callbacks that search entries for that tag. ) self.setObjectName("tagBox") - self.base_layout = FlowLayout() - self.base_layout.enable_grid_optimizations(value=False) - self.base_layout.setContentsMargins(0, 0, 0, 0) + self.base_layout = QVBoxLayout() self.setLayout(self.base_layout) + self.tags_layout = FlowLayout() + self.base_layout.addLayout(self.tags_layout) + self.tags_layout.enable_grid_optimizations(value=False) + self.tags_layout.setContentsMargins(0, 0, 0, 0) + + self.add_layout = QHBoxLayout() + self.base_layout.addLayout(self.add_layout) + + self.tag_entry = QLineEdit() + self.add_layout.addWidget(self.tag_entry) + + self.tag_completer = TagCompleter(self.tag_entry, self.driver.lib) + self.tag_completer.setCaseSensitivity(Qt.CaseSensitivity.CaseInsensitive) + self.tag_completer.setWidget(self.tag_entry) + self.add_button = QPushButton() self.add_button.setCursor(Qt.CursorShape.PointingHandCursor) self.add_button.setMinimumSize(23, 23) @@ -73,13 +100,40 @@ def __init__( f"background: #555555;" f"}}" ) + self.add_layout.addWidget(self.add_button) + tsp = TagSearchPanel(self.driver.lib) - tsp.tag_chosen.connect(lambda x: self.add_tag_callback(x)) + tsp.tag_chosen.connect( + lambda x: ( + self.add_tag_callback(x), + self.tag_entry.clear(), + ) + ) self.add_modal = PanelModal(tsp, title, "Add Tags") + self.add_button.clicked.connect( - lambda: ( - tsp.update_tags(), - self.add_modal.show(), + lambda: (self.add_modal.show(), tsp.update_tags(tsp.search_field.text())) + ) + self.tag_entry.textChanged.connect( + lambda text: ( + tsp.search_field.setText(text), + self.tag_completer.setCompletionPrefix(text), + self.tag_completer.complete(), + ) + ) + self.tag_entry.returnPressed.connect( + lambda: self.tag_completer.activated.emit( + self.tag_completer.first_choice + if (self.tag_completer.first_choice and self.tag_entry.text()) + else self.tag_entry.text() + ) + if not self.tag_completer.popup().selectedIndexes() + else () + ) + self.tag_completer.activated.connect( + lambda selected: ( + tsp.update_tags(selected), + self.tag_entry.clear() if tsp.on_return(selected) else (), ) ) @@ -89,10 +143,9 @@ def set_field(self, field: TagBoxField): self.field = field def set_tags(self, tags: typing.Iterable[Tag]): - is_recycled = False - while self.base_layout.itemAt(0) and self.base_layout.itemAt(1): - self.base_layout.takeAt(0).widget().deleteLater() - is_recycled = True + self.tag_completer.update({tag.name for tag in tags}) + while self.tags_layout.itemAt(0): + self.tags_layout.takeAt(0).widget().deleteLater() for tag in tags: tag_widget = TagWidget(tag, has_edit=True, has_remove=True) @@ -110,18 +163,7 @@ def set_tags(self, tags: typing.Iterable[Tag]): ) ) tag_widget.on_edit.connect(lambda t=tag: self.edit_tag(t)) - self.base_layout.addWidget(tag_widget) - - # Move or add the '+' button. - if is_recycled: - self.base_layout.addWidget(self.base_layout.takeAt(0).widget()) - else: - self.base_layout.addWidget(self.add_button) - - # Handles an edge case where there are no more tags and the '+' button - # doesn't move all the way to the left. - if self.base_layout.itemAt(0) and not self.base_layout.itemAt(1): - self.base_layout.update() + self.tags_layout.addWidget(tag_widget) def edit_tag(self, tag: Tag): assert isinstance(tag, Tag), f"tag is {type(tag)}" diff --git a/tagstudio/tests/qt/test_tag_widget.py b/tagstudio/tests/qt/test_tag_widget.py index 9d10691a9..78929cd58 100644 --- a/tagstudio/tests/qt/test_tag_widget.py +++ b/tagstudio/tests/qt/test_tag_widget.py @@ -73,7 +73,7 @@ def test_tag_widget_remove(qtbot, qt_driver, library, entry_full): qtbot.add_widget(tag_widget) - tag_widget = tag_widget.base_layout.itemAt(0).widget() + tag_widget = tag_widget.tags_layout.itemAt(0).widget() assert isinstance(tag_widget, TagWidget) tag_widget.remove_button.clicked.emit() @@ -95,7 +95,7 @@ def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): qtbot.add_widget(tag_box_widget) - tag_widget = tag_box_widget.base_layout.itemAt(0).widget() + tag_widget = tag_box_widget.tags_layout.itemAt(0).widget() assert isinstance(tag_widget, TagWidget) # When @@ -108,3 +108,29 @@ def test_tag_widget_edit(qtbot, qt_driver, library, entry_full): assert isinstance(panel, BuildTagPanel) assert panel.tag.name == tag.name assert panel.name_field.text() == tag.name + + +def test_tag_widget_autocomplete(qtbot, qt_driver, library): + # Given + entry = next(library.get_entries(with_joins=True)) + field = entry.tag_box_fields[0] + + tag_widget = TagBoxWidget(field, "title", qt_driver) + tag_widget.driver.selected = [0] + + qtbot.add_widget(tag_widget) + + assert len(entry.tags) == 1 + + # Test autocomplete + tag_widget.tag_entry.setText("arch") + tag_widget.tag_entry.returnPressed.emit() + + entry = next(library.get_entries(with_joins=True)) # Update entry + assert len(entry.tags) == 2 + + # Test unmatched autocomplete + tag_widget.tag_completer.activated.emit("missing") + + entry = next(library.get_entries(with_joins=True)) # Update entry + assert len(entry.tags) == 2