Skip to content

Commit

Permalink
first
Browse files Browse the repository at this point in the history
  • Loading branch information
JialeCao001 committed Aug 27, 2020
1 parent 70fb59e commit ca26d3f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We propose a novel two-stage detection method, D2Det, that collectively addresse
## Train and Inference
Please use the following commands for training and testing by single GPU or multiple GPUs.


##### Train with a single GPU
```shell
python tools/train.py ${CONFIG_FILE}
Expand All @@ -42,10 +43,20 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
```


- CONFIG_FILE about D2Det is in [configs/D2Det](configs/D2Det), please refer to [GETTING_STARTED.md](docs/GETTING_STARTED.md) for more details.


## Demo


With our trained model, detection results of an image can be visualized using the following command.
```shell
python ./demo/D2Det_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMAGE_FILE} [--out ${OUT_PATH}]
e.g.,
python ./demo/D2Det_demo.py ./configs/D2Det/D2Det_instance_r101_fpn_2x.py ./D2Det-instance-res101.pth ./demo/demo.jpg --out ./demo/aa.jpg
```


## Results

We provide some models with different backbones and results of object detection and instance segmentation on MS COCO benchmark.
Expand Down
30 changes: 30 additions & 0 deletions demo/D2Det_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import os.path as osp
import sys
sys.path.insert(0, osp.join(osp.dirname(osp.abspath(__file__)), '../'))
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import mmcv
import argparse

def parse_args():
parser = argparse.ArgumentParser(
description='D2Det inference demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('img_file', help='img path')
parser.add_argument('--out', help='output result path')
args = parser.parse_args()
return args

def main():
args = parse_args()
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device='cuda:0')
# test a single image
result = inference_detector(model, args.img_file)
# show the results
show_result_pyplot(args.img_file, result, model.CLASSES, out_file=args.out)


if __name__ == '__main__':
main()
16 changes: 11 additions & 5 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def show_result(img,
labels = np.concatenate(labels)
# draw segmentation masks
if segm_result is not None:
segms = mmcv.concat_list(segm_result)
if len(segm_result) > 1:
segms = mmcv.concat_list(segm_result[0])
bboxes[:, -1] = np.concatenate(segm_result[1])/1.3#rescale the mask scores to the range of [0,1]
else:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
np.random.seed(42)
color_masks = [
Expand Down Expand Up @@ -185,7 +189,8 @@ def show_result_pyplot(img,
result,
class_names,
score_thr=0.3,
fig_size=(15, 10)):
fig_size=(15, 10),
out_file=None):
"""Visualize the detection results on the image.
Args:
Expand All @@ -199,6 +204,7 @@ def show_result_pyplot(img,
be written to the out file instead of shown in a window.
"""
img = show_result(
img, result, class_names, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
img, result, class_names, score_thr=score_thr, show=False, out_file=out_file)
if out_file is None:
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
2 changes: 2 additions & 0 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def show_result(self, data, result, dataset=None, score_thr=0.3):
if segm_result is not None:
if len(segm_result) > 1:
segms = mmcv.concat_list(segm_result[0])
bboxes[:, -1] = np.concatenate(
segm_result[1]) / 1.3 # rescale the mask scores to the range of [0,1]
else:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
Expand Down

0 comments on commit ca26d3f

Please sign in to comment.