Skip to content

Commit

Permalink
[Fix] Fix SemanticConnectivityLoss bug on cpu (PaddlePaddle#1940)
Browse files Browse the repository at this point in the history
  • Loading branch information
LutaoChu authored Mar 30, 2022
1 parent bd46c4e commit fbb26ce
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions paddleseg/models/losses/semantic_connectivity_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def forward(self, logits, labels):
label_num_conn, label_conn = cv2.connectedComponents(
labels_np_class.astype(np.uint8))

origin_pred_num_conn = pred_num_conn
if pred_num_conn > 2 * label_num_conn:
pred_num_conn = min(pred_num_conn, self.max_pred_num_conn)
real_pred_num = pred_num_conn - 1
Expand All @@ -100,8 +101,9 @@ def forward(self, logits, labels):
# Connected Components Matching and SC Loss Calculation
if real_label_num > 0 and real_pred_num > 0:
img_connectivity = compute_class_connectiveity(
pred_conn, label_conn, pred_num_conn, label_num_conn,
pred_i, real_label_num, real_pred_num, zero)
pred_conn, label_conn, pred_num_conn,
origin_pred_num_conn, label_num_conn, pred_i,
real_label_num, real_pred_num, zero)
sc_loss += 1 - img_connectivity
elif real_label_num == 0 and real_pred_num == 0:
# if no connected component, SC Loss = 0, so pass
Expand All @@ -122,12 +124,12 @@ def forward(self, logits, labels):


def compute_class_connectiveity(pred_conn, label_conn, pred_num_conn,
label_num_conn, pred, real_label_num,
real_pred_num, zero):
origin_pred_num_conn, label_num_conn, pred,
real_label_num, real_pred_num, zero):

pred_conn = paddle.to_tensor(pred_conn)
label_conn = paddle.to_tensor(label_conn)
pred_conn = F.one_hot(pred_conn, pred_num_conn)
pred_conn = F.one_hot(pred_conn, origin_pred_num_conn)
label_conn = F.one_hot(label_conn, label_num_conn)

ious = paddle.zeros((real_label_num, real_pred_num))
Expand Down

0 comments on commit fbb26ce

Please sign in to comment.