Skip to content

Commit

Permalink
Merge pull request #74 from hyl-g/fix_open_images
Browse files Browse the repository at this point in the history
1. decouple boxes, labels; 2: cast lables to 64 bits
  • Loading branch information
qfgaohao authored Sep 14, 2019
2 parents bc7bbba + 4d54172 commit 7174f33
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions vision/datasets/open_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pathlib
import cv2
import pandas as pd

import copy

class OpenImagesDataset:

Expand All @@ -26,12 +26,14 @@ def __init__(self, root,
def _getitem(self, index):
image_info = self.data[index]
image = self._read_image(image_info['image_id'])
boxes = image_info['boxes']
# duplicate boxes to prevent corruption of dataset
boxes = copy.copy(image_info['boxes'])
boxes[:, 0] *= image.shape[1]
boxes[:, 1] *= image.shape[0]
boxes[:, 2] *= image.shape[1]
boxes[:, 3] *= image.shape[0]
labels = image_info['labels']
# duplicate labels to prevent corruption of dataset
labels = copy.copy(image_info['labels'])
if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)
if self.target_transform:
Expand Down Expand Up @@ -63,7 +65,8 @@ def _read_data(self):
data = []
for image_id, group in annotations.groupby("ImageID"):
boxes = group.loc[:, ["XMin", "YMin", "XMax", "YMax"]].values.astype(np.float32)
labels = np.array([class_dict[name] for name in group["ClassName"]])
# make labels 64 bits to satisfy the cross_entropy function
labels = np.array([class_dict[name] for name in group["ClassName"]], dtype='int64')
data.append({
'image_id': image_id,
'boxes': boxes,
Expand Down

0 comments on commit 7174f33

Please sign in to comment.