diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 9ddaceeaf..64a9e045c 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -32,6 +32,7 @@ desc, exists, func, + inspect, or_, select, text, @@ -43,6 +44,7 @@ contains_eager, joinedload, make_transient, + noload, selectinload, ) @@ -312,13 +314,12 @@ def get_field_name_from_id(self, field_id: int) -> _FieldID: return f return None - def tag_display_name(self, tag_id: int) -> str: - with Session(self.engine) as session: - tag = session.scalar(select(Tag).where(Tag.id == tag_id)) - if not tag: - return "" + def tag_display_name(self, tag: Tag | None) -> str: + if not tag: + return "" - if tag.disambiguation_id: + if tag.disambiguation_id: + with Session(self.engine) as session: disam_tag = session.scalar(select(Tag).where(Tag.id == tag.disambiguation_id)) if not disam_tag: return "" @@ -326,8 +327,8 @@ def tag_display_name(self, tag_id: int) -> str: if not disam_name: disam_name = disam_tag.name return f"{tag.name} ({disam_name})" - else: - return tag.name + else: + return tag.name def open_library( self, library_dir: Path, storage_path: Path | str | None = None @@ -1544,6 +1545,45 @@ def get_tag_color(self, slug: str, namespace: str) -> TagColorGroup | None: return session.scalar(statement) + def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]: + """Get a dictionary containing tags in `tag_ids` and all of their ancestor tags.""" + current_tag_ids: set[int] = set(tag_ids) + all_tag_ids: set[int] = set() + all_tags: dict[int, Tag] = {} + all_tag_parents: dict[int, list[int]] = {} + + with Session(self.engine) as session: + while len(current_tag_ids) > 0: + all_tag_ids.update(current_tag_ids) + statement = select(TagParent).where(TagParent.child_id.in_(current_tag_ids)) + tag_parents = session.scalars(statement).fetchall() + current_tag_ids.clear() + for tag_parent in tag_parents: + all_tag_parents.setdefault(tag_parent.child_id, []).append(tag_parent.parent_id) + current_tag_ids.add(tag_parent.parent_id) + current_tag_ids = current_tag_ids.difference(all_tag_ids) + + statement = select(Tag).where(Tag.id.in_(all_tag_ids)) + statement = statement.options( + noload(Tag.parent_tags), selectinload(Tag.aliases), joinedload(Tag.color) + ) + tags = session.scalars(statement).fetchall() + for tag in tags: + all_tags[tag.id] = tag + for tag in all_tags.values(): + # Sqlalchemy tracks this as a change to the parent_tags field + tag.parent_tags = {all_tags[p] for p in all_tag_parents.get(tag.id, [])} + # When calling session.add with this tag instance sqlalchemy will + # attempt to create TagParents that already exist. + + state = inspect(tag) + # Prevent sqlalchemy from thinking any fields are different from what's commited + # commited_state contains original values for fields that have changed. + # empty when no fields have changed + state.committed_state.clear() + + return all_tags + def add_parent_tag(self, parent_id: int, child_id: int) -> bool: if parent_id == child_id: return False diff --git a/src/tagstudio/core/library/alchemy/models.py b/src/tagstudio/core/library/alchemy/models.py index ea60e6f06..4d103be20 100644 --- a/src/tagstudio/core/library/alchemy/models.py +++ b/src/tagstudio/core/library/alchemy/models.py @@ -156,6 +156,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() + def __hash__(self) -> int: + return hash(self.id) + def __lt__(self, other) -> bool: return self.name < other.name diff --git a/src/tagstudio/qt/controller/components/tag_box_controller.py b/src/tagstudio/qt/controller/components/tag_box_controller.py index d07a08a45..1af5422bb 100644 --- a/src/tagstudio/qt/controller/components/tag_box_controller.py +++ b/src/tagstudio/qt/controller/components/tag_box_controller.py @@ -75,7 +75,7 @@ def _on_edit(self, tag: Tag) -> None: # type: ignore[misc] edit_modal = PanelModal( build_tag_panel, - self.__driver.lib.tag_display_name(tag.id), + self.__driver.lib.tag_display_name(tag), "Edit Tag", done_callback=self.on_update.emit, has_save=True, diff --git a/src/tagstudio/qt/modals/tag_database.py b/src/tagstudio/qt/modals/tag_database.py index c233baf08..9c8a29240 100644 --- a/src/tagstudio/qt/modals/tag_database.py +++ b/src/tagstudio/qt/modals/tag_database.py @@ -62,7 +62,7 @@ def delete_tag(self, tag: Tag): message_box = QMessageBox( QMessageBox.Question, # type: ignore Translations["tag.remove"], - Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag.id)), + Translations.format("tag.confirm_delete", tag_name=self.lib.tag_display_name(tag)), QMessageBox.Ok | QMessageBox.Cancel, # type: ignore ) diff --git a/src/tagstudio/qt/modals/tag_search.py b/src/tagstudio/qt/modals/tag_search.py index 732c78b86..4cc2b2f45 100644 --- a/src/tagstudio/qt/modals/tag_search.py +++ b/src/tagstudio/qt/modals/tag_search.py @@ -387,7 +387,7 @@ def callback(btp: BuildTagPanel): self.edit_modal = PanelModal( build_tag_panel, - self.lib.tag_display_name(tag.id), + self.lib.tag_display_name(tag), Translations["tag.edit"], done_callback=(self.update_tags(self.search_field.text())), has_save=True, diff --git a/src/tagstudio/qt/view/components/tag_box_view.py b/src/tagstudio/qt/view/components/tag_box_view.py index 5c44cafe4..f6183615c 100644 --- a/src/tagstudio/qt/view/components/tag_box_view.py +++ b/src/tagstudio/qt/view/components/tag_box_view.py @@ -32,7 +32,7 @@ def __init__(self, title: str, driver: "QtDriver") -> None: self.setLayout(self.__root_layout) def set_tags(self, tags: Iterable[Tag]) -> None: - tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag.id)) + tags_ = sorted(list(tags), key=lambda tag: self.__lib.tag_display_name(tag)) logger.info("[TagBoxWidget] Tags:", tags=tags) while self.__root_layout.itemAt(0): self.__root_layout.takeAt(0).widget().deleteLater() # pyright: ignore[reportOptionalMemberAccess] diff --git a/src/tagstudio/qt/widgets/preview/field_containers.py b/src/tagstudio/qt/widgets/preview/field_containers.py index 6ed7b1b02..a23714cca 100644 --- a/src/tagstudio/qt/widgets/preview/field_containers.py +++ b/src/tagstudio/qt/widgets/preview/field_containers.py @@ -160,96 +160,38 @@ def hide_containers(self): c.setHidden(True) def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: - """Get a dictionary of category tags mapped to their respective tags.""" - cats: dict[Tag | None, set[Tag]] = {} - cats[None] = set() - - base_tag_ids: set[int] = {x.id for x in tags} - exhausted: set[int] = set() - cluster_map: dict[int, set[int]] = {} - - def add_to_cluster(tag_id: int, p_ids: list[int] | None = None): - """Maps a Tag's child tags' IDs back to it's parent tag's ID. - - Example: - Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: - "Cartoon Network" -> Johnny Bravo, - "Character" -> "Johnny Bravo", - "TV" -> Johnny Bravo" - """ - tag_obj = unwrap(self.lib.get_tag(tag_id)) # Get full object - if p_ids is None: - p_ids = tag_obj.parent_ids - - for p_id in p_ids: - if cluster_map.get(p_id) is None: - cluster_map[p_id] = set() - # If the p_tag has p_tags of its own, recursively link those to the original Tag. - if tag_id not in cluster_map[p_id]: - cluster_map[p_id].add(tag_id) - p_tag = unwrap(self.lib.get_tag(p_id)) # Get full object - if p_tag.parent_ids: - add_to_cluster( - tag_id, - [sub_id for sub_id in p_tag.parent_ids if sub_id != tag_id], - ) - exhausted.add(p_id) - exhausted.add(tag_id) - - for tag in tags: - add_to_cluster(tag.id) + """Get a dictionary of category tags mapped to their respective tags. - logger.info("[FieldContainers] Entry Cluster", entry_cluster=exhausted) - logger.info("[FieldContainers] Cluster Map", cluster_map=cluster_map) + Example: + Tag: ["Johnny Bravo", Parent Tags: "Cartoon Network (TV)", "Character"] maps to: + "Cartoon Network" -> Johnny Bravo, + "Character" -> "Johnny Bravo", + "TV" -> Johnny Bravo" + """ + hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) - # Initialize all categories from parents. - tags_ = {t for tid in exhausted if (t := self.lib.get_tag(tid)) is not None} - for tag in tags_: + categories: dict[Tag | None, set[Tag]] = {None: set()} + for tag in hierarchy_tags.values(): if tag.is_category: - cats[tag] = set() - logger.info("[FieldContainers] Blank Tag Categories", cats=cats) - - # Add tags to any applicable categories. - added_ids: set[int] = set() - for key in cats: - logger.info("[FieldContainers] Checking category tag key", key=key) - - if key: - logger.info( - "[FieldContainers] Key cluster:", key=key, cluster=cluster_map.get(key.id) - ) - - if final_tags := cluster_map.get(key.id, set()).union([key.id]): - cats[key] = { - t - for tid in final_tags - if tid in base_tag_ids and (t := self.lib.get_tag(tid)) is not None - } - added_ids = added_ids.union({tid for tid in final_tags if tid in base_tag_ids}) - - # Add remaining tags to None key (general case). - cats[None] = { - t - for tid in base_tag_ids - if tid not in added_ids and (t := self.lib.get_tag(tid)) is not None - } - logger.info( - "[FieldContainers] Key cluster: None, general case!", - general_tags=cats[None], - added=added_ids, - base_tag_ids=base_tag_ids, - ) - - # Remove unused categories - empty: list[Tag | None] = [] - for k, v in list(cats.items()): - if not v: - empty.append(k) - for key in empty: - cats.pop(key, None) + categories[tag] = set() + for tag in tags: + tag = hierarchy_tags[tag.id] + has_category_parent = False + parent_tags = tag.parent_tags + while len(parent_tags) > 0: + grandparent_tags: set[Tag] = set() + for parent_tag in parent_tags: + if parent_tag in categories: + categories[parent_tag].add(tag) + has_category_parent = True + grandparent_tags.update(parent_tag.parent_tags) + parent_tags = grandparent_tags + if tag.is_category: + categories[tag].add(tag) + elif not has_category_parent: + categories[None].add(tag) - logger.info("[FieldContainers] Tag Categories", categories=cats) - return cats + return dict((c, d) for c, d in categories.items() if len(d) > 0) def remove_field_prompt(self, name: str) -> str: return Translations.format("library.field.confirm_remove", name=name) diff --git a/src/tagstudio/qt/widgets/tag.py b/src/tagstudio/qt/widgets/tag.py index 4ab55c88d..cdce68a67 100644 --- a/src/tagstudio/qt/widgets/tag.py +++ b/src/tagstudio/qt/widgets/tag.py @@ -266,7 +266,7 @@ def set_tag(self, tag: Tag | None) -> None: ) if self.lib: - self.bg_button.setText(escape_text(self.lib.tag_display_name(tag.id))) + self.bg_button.setText(escape_text(self.lib.tag_display_name(tag))) else: self.bg_button.setText(escape_text(tag.name))