This page provides basic tutorials about the usage of MMDetection. For installation instructions, please see install.md.
It is recommended to symlink the dataset root to $MMDETECTION3D/data
.
If your folder structure is different from the following, you may need to change the corresponding paths in config files.
mmdetection3d
├── mmdet3d
├── tools
├── configs
├── data
│ ├── nuscenes
│ │ ├── maps
│ │ ├── samples
│ │ ├── sweeps
│ │ ├── v1.0-test
| | ├── v1.0-trainval
│ ├── kitti
│ │ ├── ImageSets
│ │ ├── testing
│ │ │ ├── calib
│ │ │ ├── image_2
│ │ │ ├── velodyne
│ │ ├── training
│ │ │ ├── calib
│ │ │ ├── image_2
│ │ │ ├── label_2
│ │ │ ├── velodyne
│ ├── lyft
│ │ ├── v1.01-train
│ │ │ ├── v1.01-train (train_data)
│ │ │ ├── lidar (train_lidar)
│ │ │ ├── images (train_images)
│ │ │ ├── maps (train_maps)
│ │ ├── v1.01-test
│ │ │ ├── v1.01-test (test_data)
│ │ │ ├── lidar (test_lidar)
│ │ │ ├── images (test_images)
│ │ │ ├── maps (test_maps)
│ │ ├── train.txt
│ │ ├── val.txt
│ │ ├── test.txt
│ │ ├── sample_submission.csv
│ ├── scannet
│ │ ├── meta_data
│ │ ├── scans
│ │ ├── batch_load_scannet_data.py
│ │ ├── load_scannet_data.py
│ │ ├── scannet_utils.py
│ │ ├── README.md
│ ├── sunrgbd
│ │ ├── OFFICIAL_SUNRGBD
│ │ ├── matlab
│ │ ├── sunrgbd_data.py
│ │ ├── sunrgbd_utils.py
│ │ ├── README.md
Download nuScenes V1.0 full dataset data HERE. Prepare nuscenes data by running
python tools/create_data.py nuscenes --root-path ./data/nuscenes --out-dir ./data/nuscenes --extra-tag nuscenes
Download KITTI 3D detection data HERE. Prepare kitti data by running
mkdir ./data/kitti/ && mkdir ./data/kitti/ImageSets
# Download data split
wget -c https://raw.githubusercontent.com/traveller59/second.pytorch/master/second/data/ImageSets/test.txt --no-check-certificate --content-disposition -O ./data/kitti/ImageSets/test.txt
wget -c https://raw.githubusercontent.com/traveller59/second.pytorch/master/second/data/ImageSets/train.txt --no-check-certificate --content-disposition -O ./data/kitti/ImageSets/train.txt
wget -c https://raw.githubusercontent.com/traveller59/second.pytorch/master/second/data/ImageSets/val.txt --no-check-certificate --content-disposition -O ./data/kitti/ImageSets/val.txt
wget -c https://raw.githubusercontent.com/traveller59/second.pytorch/master/second/data/ImageSets/trainval.txt --no-check-certificate --content-disposition -O ./data/kitti/ImageSets/trainval.txt
python tools/create_data.py kitti --root-path ./data/kitti --out-dir ./data/kitti --extra-tag kitti
Download Lyft 3D detection data HERE. Prepare Lyft data by running
python tools/create_data.py lyft --root-path ./data/lyft --out-dir ./data/lyft --extra-tag lyft --version v1.01
Note that we follow the original folder names for clear organization. Please rename the raw folders as shown above.
To prepare scannet data, please see scannet.
To prepare sunrgbd data, please see sunrgbd.
For using custom datasets, please refer to Tutorials 2: Adding New Dataset.
We provide testing scripts to evaluate a whole dataset (SUNRGBD, ScanNet, KITTI, etc.), and also some high-level apis for easier integration to other projects.
- single GPU
- single node multiple GPU
- multiple node
You can use the following commands to test a dataset.
# single-gpu testing
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
# multi-gpu testing
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
Optional arguments:
RESULT_FILE
: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.EVAL_METRICS
: Items to be evaluated on the results. Allowed values depend on the dataset, e.g.,proposal_fast
,proposal
,bbox
,segm
are available for COCO,mAP
,recall
for PASCAL VOC. Cityscapes could be evaluated bycityscapes
as well as all COCO metrics.--show
: If specified, detection results will be plotted in the silient mode. It is only applicable to single GPU testing and used for debugging and visualization. This should be used with--show-dir
.--show-dir
: If specified, detection results will be plotted on the***_points.obj
and***_pred.ply
files in the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
Examples:
Assume that you have already downloaded the checkpoints to the directory checkpoints/
.
-
Test votenet on ScanNet and save the points and prediction visualization results.
python tools/test.py configs/votenet/votenet_8x8_scannet-3d-18class.py \ checkpoints/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth \ --show --show-dir ./data/scannet/show_results
-
Test votenet on ScanNet, save the points, prediction, groundtruth visualization results, and evaluate the mAP.
python tools/test.py configs/votenet/votenet_8x8_scannet-3d-18class.py \ checkpoints/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth \ --eval mAP --options 'show=True' 'out_dir=./data/scannet/show_results'
-
Test votenet on ScanNet (without saving the test results) and evaluate the mAP.
python tools/test.py configs/votenet/votenet_8x8_scannet-3d-18class.py \ checkpoints/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth \ --eval mAP
-
Test SECOND with 8 GPUs, and evaluate the mAP.
./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py \ checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-3class_20200620_230238-9208083a.pth \ --out results.pkl --eval mAP
-
Test PointPillars on nuscenes with 8 GPUs, and generate the json file to be submit to the official evaluation server.
./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py \ checkpoints/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d_20200620_230405-2fa62f3d.pth \ --format-only --options 'jsonfile_prefix=./pointpillars_nuscenes_results'
The generated results be under
./pointpillars_nuscenes_results
directory. -
Test SECOND on KITTI with 8 GPUs, and generate the pkl files and submission datas to be submit to the official evaluation server.
./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py \ checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-3class_20200620_230238-9208083a.pth \ --format-only --options 'pklfile_prefix=./second_kitti_results' 'submission_prefix=./second_kitti_results'
The generated results be under
./second_kitti_results
directory.
To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run the following command
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --show --show-dir ${SHOW_DIR}
Aftering running this command, plotted results ***_points.obj and ***_pred.ply files in ${SHOW_DIR}
.
To see the points, detection results and ground truth of SUNRGBD, ScanNet or KITTI during evaluation time, you can run the following command
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --eval 'mAP' --options 'show=True' 'out_dir=${SHOW_DIR}'
After running this command, you will obtain ***_points.ob, ***_pred.ply files and ***_gt.ply in ${SHOW_DIR}
.
You can use 3D visualization software such as the MeshLab to open the these files under ${SHOW_DIR}
to see the 3D detection output. Specifically, open ***_points.obj
to see the input point cloud and open ***_pred.ply
to see the predicted 3D bounding boxes. This allows the inference and results generation be done in remote server and the users can open them on their host with GUI.
Notice: The visualization API is a little unstable since we plan to refactor these parts together with MMDetection in the future.
We provide a demo script to test a single sample.
python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
Examples:
python demo/pcd_demo.py demo/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth
If you want to input a ply
file, you can use the following function and convert it to bin
format. Then you can use the converted bin
file to generate demo.
Note that you need to install pandas and plyfile before using this script. This function can also be used for data preprocessing for training ply data
.
import numpy as np
import pandas as pd
from plyfile import PlyData
def conver_ply(input_path, output_path):
plydata = PlyData.read(input_path) # read file
data = plydata.elements[0].data # read data
data_pd = pd.DataFrame(data) # convert to DataFrame
data_np = np.zeros(data_pd.shape, dtype=np.float) # initialize array to store data
property_names = data[0].dtype.names # read names of properties
for i, name in enumerate(
property_names): # read data by property
data_np[:, i] = data_pd[name]
data_np.astype(np.float32).tofile(output_path)
Examples:
convert_ply('./test.ply', './test.bin')
Here is an example of building the model and test given point clouds.
from mmdet3d.apis import init_detector, inference_detector
config_file = 'configs/votenet/votenet_8x8_scannet-3d-18class.py'
checkpoint_file = 'checkpoints/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth'
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# test a single image and show the results
point_cloud = 'test.bin'
result, data = inference_detector(model, point_cloud)
# visualize the results and save the results in 'results' folder
model.show_results(data, result, out_dir='results')
A notebook demo can be found in demo/inference_demo.ipynb.
MMDetection implements distributed training and non-distributed training,
which uses MMDistributedDataParallel
and MMDataParallel
respectively.
All outputs (log files and checkpoints) will be saved to the working directory,
which is specified by work_dir
in the config file.
By default we evaluate the model on the validation set after each epoch, you can change the evaluation interval by adding the interval argument in the training config.
evaluation = dict(interval=12) # This evaluate the model per 12 epoch.
Important: The default learning rate in config files is for 8 GPUs and the exact batch size is marked by the config's file name, e.g. '2x8' means 2 samples per GPU using 8 GPUs. According to the Linear Scaling Rule, you need to set the learning rate proportional to the batch size if you use different GPUs or images per GPU, e.g., lr=0.01 for 4 GPUs * 2 img/gpu and lr=0.08 for 16 GPUs * 4 img/gpu. However, since most of the models in this repo use ADAM rather than SGD for optimization, the rule may not hold and users need to tune the learning rate by themselves.
python tools/train.py ${CONFIG_FILE} [optional arguments]
If you want to specify the working directory in the command, you can add an argument --work_dir ${YOUR_WORK_DIR}
.
./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
Optional arguments are:
--no-validate
(not suggested): By default, the codebase will perform evaluation at every k (default value is 1, which can be modified like this) epochs during the training. To disable this behavior, use--no-validate
.--work-dir ${WORK_DIR}
: Override the working directory specified in the config file.--resume-from ${CHECKPOINT_FILE}
: Resume from a previous checkpoint file.--options 'Key=value'
: Overide some settings in the used config.
Difference between resume-from
and load-from
:
resume-from
loads both the model weights and optimizer status, and the epoch is also inherited from the specified checkpoint. It is usually used for resuming the training process that is interrupted accidentally.
load-from
only loads the model weights and the training epoch starts from 0. It is usually used for finetuning.
If you run MMDetection on a cluster managed with slurm, you can use the script slurm_train.sh
. (This script also supports single machine training.)
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
Here is an example of using 16 GPUs to train Mask R-CNN on the dev partition.
GPUS=16 ./tools/slurm_train.sh dev mask_r50_1x configs/mask_rcnn_r50_fpn_1x_coco.py /nfs/xxxx/mask_rcnn_r50_fpn_1x
You can check slurm_train.sh for full arguments and environment variables.
If you have just multiple machines connected with ethernet, you can refer to PyTorch launch utility. Usually it is slow if you do not have high speed networking like InfiniBand.
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs, you need to specify different ports (29500 by default) for each job to avoid communication conflict.
If you use dist_train.sh
to launch training jobs, you can set the port in commands.
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
If you use launch training jobs with Slurm, there are two ways to specify the ports.
-
Set the port through
--options
. This is more recommended since it does not change the original configs.CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} --options 'dist_params.port=29500' CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} --options 'dist_params.port=29501'
-
Modify the config files (usually the 6th line from the bottom in config files) to set different communication ports.
In
config1.py
,dist_params = dict(backend='nccl', port=29500)
In
config2.py
,dist_params = dict(backend='nccl', port=29501)
Then you can launch two jobs with
config1.py
angconfig2.py
.CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
We provide lots of useful tools under tools/
directory.
You can plot loss/mAP curves given a training log file. Run pip install seaborn
first to install the dependency.
python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}]
Examples:
-
Plot the classification loss of some run.
python tools/analyze_logs.py plot_curve log.json --keys loss_cls --legend loss_cls
-
Plot the classification and regression loss of some run, and save the figure to a pdf.
python tools/analyze_logs.py plot_curve log.json --keys loss_cls loss_bbox --out losses.pdf
-
Compare the bbox mAP of two runs in the same figure.
python tools/analyze_logs.py plot_curve log1.json log2.json --keys bbox_mAP --legend run1 run2
You can also compute the average training speed.
python tools/analyze_logs.py cal_train_time log.json [--include-outliers]
The output is expected to be like the following.
-----Analyze train time of work_dirs/some_exp/20190611_192040.log.json-----
slowest epoch 11, average time is 1.2024
fastest epoch 1, average time is 1.1909
time std over epochs is 0.0028
average iter time: 1.1959 s/iter
Before you upload a model to AWS, you may want to (1) convert model weights to CPU tensors, (2) delete the optimizer states and (3) compute the hash of the checkpoint file and append the hash id to the filename.
python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME}
E.g.,
python tools/publish_model.py work_dirs/faster_rcnn/latest.pth faster_rcnn_r50_fpn_1x_20190801.pth
The final output filename will be faster_rcnn_r50_fpn_1x_20190801-{hash id}.pth
.
Currently, we provide four tutorials for users to finetune models, add new dataset, design data pipeline and add new modules. We also provide a full description about the config system.