Skip to content

Commit

Permalink
[Feature] Add support to annotate KIE linking field (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
CVHub520 committed Jul 26, 2024
1 parent dfd4794 commit 4441704
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 13 deletions.
2 changes: 2 additions & 0 deletions anylabeling/views/labeling/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def load(self, filename):
"flags",
"description",
"attributes",
"kie_linking",
]
try:
with io_open(filename, "r") as f:
Expand Down Expand Up @@ -144,6 +145,7 @@ def load(self, filename):
"description": s.get("description"),
"difficult": s.get("difficult", False),
"attributes": s.get("attributes", {}),
"kie_linking": s.get("kie_linking", []),
"other_data": {
k: v for k, v in s.items() if k not in shape_keys
},
Expand Down
13 changes: 13 additions & 0 deletions anylabeling/views/labeling/label_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,12 +2356,14 @@ def edit_label(self, item=None):
group_id,
description,
difficult,
kie_linking,
) = self.label_dialog.pop_up(
text=shape.label,
flags=shape.flags,
group_id=shape.group_id,
description=shape.description,
difficult=shape.difficult,
kie_linking=shape.kie_linking
)
if text is None:
return
Expand All @@ -2380,6 +2382,7 @@ def edit_label(self, item=None):
shape.group_id = group_id
shape.description = description
shape.difficult = difficult
shape.kie_linking = kie_linking

# Add to label history
self.label_dialog.add_label_history(shape.label)
Expand Down Expand Up @@ -2523,6 +2526,7 @@ def format_shape(s):
"shape_type": s.shape_type,
"flags": s.flags,
"attributes": s.attributes,
"kie_linking": s.kie_linking,
}
if s.shape_type == "rotation":
info["direction"] = s.direction
Expand Down Expand Up @@ -2713,6 +2717,7 @@ def load_labels(self, shapes):
difficult = shape.get("difficult", False)
attributes = shape.get("attributes", {})
direction = shape.get("direction", 0)
kie_linking = shape.get("kie_linking", [])
other_data = shape["other_data"]

if label in self.hidden_cls or not points:
Expand All @@ -2728,6 +2733,7 @@ def load_labels(self, shapes):
difficult=difficult,
direction=direction,
attributes=attributes,
kie_linking=kie_linking,
)
for x, y in points:
shape.add_point(QtCore.QPointF(x, y))
Expand Down Expand Up @@ -2784,6 +2790,7 @@ def format_shape(s):
"shape_type": s.shape_type,
"flags": s.flags,
"attributes": s.attributes,
"kie_linking": s.kie_linking,
}
if s.shape_type == "rotation":
info["direction"] = s.direction
Expand Down Expand Up @@ -2900,6 +2907,7 @@ def new_shape(self):
group_id = None
description = ""
difficult = False
kie_linking = []

if self.canvas.shapes[-1].label in [
AutoLabelingMode.ADD,
Expand All @@ -2922,6 +2930,7 @@ def new_shape(self):
group_id,
description,
difficult,
kie_linking,
) = self.label_dialog.pop_up(text)
if not text:
self.label_dialog.edit.setText(previous_text)
Expand All @@ -2946,6 +2955,7 @@ def new_shape(self):
shape.description = description
shape.label = text
shape.difficult = difficult
shape.kie_linking = kie_linking
self.add_label(shape)
self.actions.edit_mode.setEnabled(True)
self.actions.undo_last_point.setEnabled(False)
Expand Down Expand Up @@ -5563,12 +5573,14 @@ def finish_auto_labeling_object(self):
group_id,
description,
difficult,
kie_linking,
) = self.label_dialog.pop_up(
text=self.find_last_label(),
flags={},
group_id=None,
description=None,
difficult=False,
kie_linking=[],
)
if not text:
self.label_dialog.edit.setText(previous_text)
Expand Down Expand Up @@ -5599,6 +5611,7 @@ def finish_auto_labeling_object(self):
shape.group_id = group_id
shape.description = description
shape.difficult = difficult
shape.kie_linking = kie_linking
# Update unique label list
if not self.unique_label_list.find_items_by_label(shape.label):
unique_label_item = (
Expand Down
2 changes: 2 additions & 0 deletions anylabeling/views/labeling/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def __init__(
difficult=False,
direction=0,
attributes={},
kie_linking=[],
):
self.label = label
self.score = score
self.group_id = group_id
self.description = description
self.difficult = difficult
self.kie_linking = kie_linking
self.points = []
self.fill = False
self.selected = False
Expand Down
108 changes: 95 additions & 13 deletions anylabeling/views/labeling/widgets/label_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,40 @@ def __init__(
QtCore.QRegularExpression(r"\d*"), None
)
)
self.edit_group_id.setAlignment(QtCore.Qt.AlignCenter)

# Add difficult checkbox
self.edit_difficult = QtWidgets.QCheckBox(self.tr("useDifficult"))
self.edit_difficult.setChecked(difficult)

# Add linking input
self.linking_input = QtWidgets.QLineEdit()
self.linking_input.setPlaceholderText(self.tr("Enter linking, e.g., [0,1]"))
linking_font = self.linking_input.font() # Adjust placeholder font size
linking_font.setPointSize(8)
self.linking_input.setFont(linking_font)
self.linking_list = QtWidgets.QListWidget()
self.linking_list.setHidden(True) # Initially hide the list
row_height = self.linking_list.fontMetrics().height()
self.linking_list.setFixedHeight(row_height * 4 + 2 * self.linking_list.frameWidth())
self.add_linking_button = QtWidgets.QPushButton(self.tr("Add"))
self.add_linking_button.clicked.connect(self.add_linking_pair)

layout = QtWidgets.QVBoxLayout()
layout.setContentsMargins(10, 10, 10, 10)
if show_text_field:
layout_edit = QtWidgets.QHBoxLayout()
layout_edit.addWidget(self.edit, 6)
layout_edit.addWidget(self.edit_group_id, 2)
layout_edit.addWidget(self.edit, 2.5)
layout_edit.addWidget(self.edit_group_id, 1.5)
layout.addLayout(layout_edit)

# Add linking layout
layout_linking = QtWidgets.QHBoxLayout()
layout_linking.addWidget(self.linking_input, 2.5)
layout_linking.addWidget(self.add_linking_button, 1.5)
layout.addLayout(layout_linking)
layout.addWidget(self.linking_list)

# buttons
self.button_box = bb = QtWidgets.QDialogButtonBox(
QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel,
Expand All @@ -409,6 +433,13 @@ def __init__(
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)

# text edit
self.edit_description = QtWidgets.QTextEdit()
self.edit_description.setPlaceholderText(self.tr("Label description"))
self.edit_description.setFixedHeight(50)
layout.addWidget(self.edit_description)

# difficult & confirm button
layout_button = QtWidgets.QHBoxLayout()
layout_button.addWidget(self.edit_difficult)
layout_button.addWidget(self.button_box)
Expand Down Expand Up @@ -445,11 +476,6 @@ def __init__(
self.reset_flags()
layout.addItem(self.flags_layout)
self.edit.textChanged.connect(self.update_flags)
# text edit
self.edit_description = QtWidgets.QTextEdit()
self.edit_description.setPlaceholderText(self.tr("Label description"))
self.edit_description.setFixedHeight(50)
layout.addWidget(self.edit_description)
self.setLayout(layout)
# completion
completer = QtWidgets.QCompleter()
Expand All @@ -467,6 +493,53 @@ def __init__(
# Save last label
self._last_label = ""

def add_linking_pair(self):
linking_text = self.linking_input.text()
try:
linking_pairs = eval(linking_text)
if (isinstance(linking_pairs, list)
and len(linking_pairs) == 2
and all(isinstance(item, int) for item in linking_pairs)):
if linking_pairs in self.get_kie_linking():
QtWidgets.QMessageBox.warning(
self,
self.tr("Duplicate Entry"),
self.tr("This linking pair already exists."),
)
self.linking_list.addItem(str(linking_pairs))
self.linking_input.clear()
self.linking_list.setHidden(False) # Show the list when an item is added
else:
raise ValueError
except:
QtWidgets.QMessageBox.warning(
self,
self.tr("Invalid Input"),
self.tr("Please enter a valid list of linking pairs like [1,2]."),
)

def keyPressEvent(self, event):
if event.key() == QtCore.Qt.Key_Delete:
selected_items = self.linking_list.selectedItems()
if selected_items:
for item in selected_items:
self.linking_list.takeItem(self.linking_list.row(item))
if len(self.get_kie_linking) == 0:
self.linking_list.setHidden(True)
else:
super(LabelDialog, self).keyPressEvent(event)

def remove_linking_item(self, item_widget):
list_item = self.linking_list.itemWidget(item_widget)
self.linking_list.takeItem(self.linking_list.row(list_item))
item_widget.deleteLater()

def reset_linking(self, kie_linking=[]):
self.linking_list.clear()
for linking_pair in kie_linking:
self.linking_list.addItem(str(linking_pair))
self.linking_list.setHidden(False if kie_linking else True)

def get_last_label(self):
return self._last_label

Expand Down Expand Up @@ -553,6 +626,13 @@ def get_description(self):
def get_difficult_state(self):
return self.edit_difficult.isChecked()

def get_kie_linking(self):
kie_linking = []
for index in range(self.linking_list.count()):
item = self.linking_list.item(index)
kie_linking.append(eval(item.text()))
return kie_linking

def pop_up(
self,
text=None,
Expand All @@ -561,6 +641,7 @@ def pop_up(
group_id=None,
description=None,
difficult=False,
kie_linking=[],
):
if self._fit_to_content["row"]:
self.label_list.setMinimumHeight(
Expand All @@ -577,6 +658,8 @@ def pop_up(
if description is None:
description = ""
self.edit_description.setPlainText(description)
# Set initial values for kie_linking
self.reset_linking(kie_linking)
if flags:
self.set_flags(flags)
else:
Expand All @@ -592,37 +675,36 @@ def pop_up(
else:
self.edit_group_id.setText(str(group_id))
items = self.label_list.findItems(text, QtCore.Qt.MatchFixedString)

if items:
if len(items) != 1:
logger.warning("Label list has duplicate '%s'", text)
self.label_list.setCurrentItem(items[0])
row = self.label_list.row(items[0])
self.edit.completer().setCurrentRow(row)
self.edit.setFocus(QtCore.Qt.PopupFocusReason)

if move:
cursor_pos = QtGui.QCursor.pos()
screen = QtWidgets.QApplication.desktop().screenGeometry(cursor_pos)
dialog_frame_size = self.frameGeometry()

# Calculate the ideal top-left corner position for the dialog based on the mouse click
ideal_pos = cursor_pos

# Adjust to prevent the dialog from exceeding the right screen boundary
if (ideal_pos.x() + dialog_frame_size.width()) > screen.right():
ideal_pos.setX(screen.right() - dialog_frame_size.width())

# Adjust to prevent the dialog's bottom from going off-screen
if (ideal_pos.y() + dialog_frame_size.height()) > screen.bottom():
ideal_pos.setY(screen.bottom() - dialog_frame_size.height())

self.move(ideal_pos)

if self.exec_():
return (
self.edit.text(),
self.get_flags(),
self.get_group_id(),
self.get_description(),
self.get_difficult_state(),
self.get_kie_linking(),
)

return None, None, None, None, False
return None, None, None, None, False, []

0 comments on commit 4441704

Please sign in to comment.