Skip to content

[Feature] Support training on a single channel image dataset. #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 111 additions & 10 deletions docs/en/advanced_guides/how_to.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,26 +547,127 @@ python ./tools/train.py \
- `randomness.diff_rank_seed=True`, set different seeds according to global rank. Defaults to False.
- `randomness.deterministic=True`, set the deterministic option for cuDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Defaults to False. See https://pytorch.org/docs/stable/notes/randomness.html for more details.

## Specify specific GPUs during training or inference
## Training example on a single-channel image dataset

If you have multiple GPUs, such as 8 GPUs, numbered `0, 1, 2, 3, 4, 5, 6, 7`, GPU 0 will be used by default for training or inference. If you want to specify other GPUs for training or inference, you can use the following commands:
### Data set preparation

This section uses the `cat` dataset as an example. If you are using a custom grayscale image dataset, you can skip this step.

The processing training of the custom dataset can be found in [Annotation-to-deployment workflow for custom dataset](../user_guides/custom_dataset.md)。

```shell
CUDA_VISIBLE_DEVICES=5 python ./tools/train.py ${CONFIG} #train
CUDA_VISIBLE_DEVICES=5 python ./tools/test.py ${CONFIG} ${CHECKPOINT_FILE} #test
python tools/misc/download_dataset.py --dataset-name balloon --save-dir projects/single_channel/data --unzip
python projects/single_channel/balloon2coco_single_channel.py
#--save-dir Example dataset storage path
```

If you set `CUDA_VISIBLE_DEVICES` to -1 or a number greater than the maximum GPU number, such as 8, the CPU will be used for training or inference.
`cat` is a 3-channel color image dataset. For demonstration purposes, you can run the following code and commands to replace the dataset images with single-channel images for subsequent verification.

Image single channel conversion sample code:

```python
import argparse
import imghdr
import os
from typing import List

from PIL import Image


def parse_args():
parser = argparse.ArgumentParser(description='data_path')
parser.add_argument('--path', type=str, help='Original dataset path')
return parser.parse_args()

If you want to use several of these GPUs to train in parallel, you can use the following command:

def main():
args = parse_args()

path = args.path + '/images/'
save_path = path
file_list: List[str] = os.listdir(path)
# Grayscale conversion of each imager
for file in file_list:
if imghdr.what(path + '/' + file) != 'jpeg':
continue
o_img = Image.open(path + '/' + file)
L_img = o_img.convert('L')
L_img.save(save_path + '/' + file)
args = parse_args()


if (__name__ == '__main__'):
main()
```

Name the above script as `cvt_single_channel.py`, and run the command as:

```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh ${CONFIG} ${GPU_NUM}
python cvt_single_channel.py --path data/cat
```

Here the `GPU_NUM` is 4. In addition, if multiple tasks are trained in parallel on one machine and each task requires multiple GPUs, the PORT of each task need to be set differently to avoid communication conflict, like the following commands:
### Create a new profile and train

**Currently, some image processing functions in MMYOLO are not compatible with single-channel images, and the chances of incompatibility are relatively high**. The recommended approach is to read single-channel images as three-channel images for training. Although this will reduce some computational performance, it is generally not necessary to modify the configuration to use it.

#### Modify the configuration file

Take `projects/misc/custom_dataset/yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py` as the `base` configuration, copy it to the `configs/yolov5` directory, and add `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py` file.

```python
_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'

load_from = './checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
```

#### Pre-training model loading problem

Using the original three-channel pre-training model directly may lead to a decrease in accuracy (although this has not been experimentally verified). There are several solutions: adjust the weight of each channel in the input layer to the average of the original three-channel weights, adjust the weight of each channel in the input layer to one of the original three-channel weights, or train directly without modifying the input layer weights. The specific effect of each solution varies depending on the actual situation. In our implementation, we adjust the weights of the three channels of the input layer to the average of the weights of the pre-trained three channels.

```python
import torch

def main():
# Load weights file
state_dict = torch.load(
'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'
)

# Modify input layer weights
weights = state_dict['state_dict']['backbone.stem.conv.weight']
avg_weight = weights.mean(dim=1, keepdim=True)
new_weights = torch.cat([avg_weight] * 3, dim=1)
state_dict['state_dict']['backbone.stem.conv.weight'] = new_weights

# Save the modified weights to a new file
torch.save(
state_dict,
'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
)

if __name__ == '__main__':
main()
```

#### Results of model training

<img src="https://raw.githubusercontent.com/landhill/mmyolo/main/resources/cat_single_channel_test.jpeg"/>

The left figure shows the actual label and the right figure shows the target detection result.

```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG} 4
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.798
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.953
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.864
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.798
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.777
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.847
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.850
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.850
bbox_mAP_copypaste: 0.798 0.953 0.864 -1.000 -1.000 0.798
Epoch(val) [100][116/116] coco/bbox_mAP: 0.7980 coco/bbox_mAP_50: 0.9530 coco/bbox_mAP_75: 0.8640 coco/bbox_mAP_s: -1.0000 coco/bbox_mAP_m: -1.0000 coco/bbox_mAP_l: 0.7980
```
119 changes: 109 additions & 10 deletions docs/zh_cn/advanced_guides/how_to.md
Original file line number Diff line number Diff line change
Expand Up @@ -553,26 +553,125 @@ python ./tools/train.py \

- `randomness.deterministic=True`,把 cuDNN 后端确定性选项设置为 True,即把`torch.backends.cudnn.deterministic` 设为 True,把 `torch.backends.cudnn.benchmark` 设为False。`deterministic` 默认为 False。更多细节见 https://pytorch.org/docs/stable/notes/randomness.html。

## 指定特定 GPU 训练或推理
## 在单通道图像数据集上训练示例

如果你有多张 GPU,比如 8 张,其编号分别为 `0, 1, 2, 3, 4, 5, 6, 7`,使用单卡训练或推理时会默认使用卡 0。如果想指定其他卡进行训练或推理,可以使用以下命令:
### 数据集预处理

本节以 `cat` 数据集为例,如果你使用的是自定义灰度图像数据集,你可以跳过这一步。

自定义数据集的处理训练可参照[自定义数据集 标注+训练+测试+部署 全流程](../user_guides/custom_dataset.md)。

```shell
CUDA_VISIBLE_DEVICES=5 python ./tools/train.py ${CONFIG} #train
CUDA_VISIBLE_DEVICES=5 python ./tools/test.py ${CONFIG} ${CHECKPOINT_FILE} #test
python tools/misc/download_dataset.py --dataset-name cat --save-dir ./data/cat --unzip --delete
```

如果设置`CUDA_VISIBLE_DEVICES`为 -1 或者一个大于 GPU 最大编号的数,比如 8,将会使用 CPU 进行训练或者推理。
`cat` 是一个 3 通道彩色图片数据集,为了方便演示,你可以运行下面的代码和命令,将数据集图片替换为单通道图片,方便后续验证。

图像单通道转换示例代码:

```python
import argparse
import imghdr
import os
from typing import List

from PIL import Image


def parse_args():
parser = argparse.ArgumentParser(description='data_path')
parser.add_argument('--path', type=str, help='Original dataset path')
return parser.parse_args()

如果你想使用其中几张卡并行训练,可以使用如下命令:

def main():
args = parse_args()

path = args.path + '/images/'
save_path = path
file_list: List[str] = os.listdir(path)
# Grayscale conversion of each imager
for file in file_list:
if imghdr.what(path + '/' + file) != 'jpeg':
continue
o_img = Image.open(path + '/' + file)
L_img = o_img.convert('L')
L_img.save(save_path + '/' + file)
args = parse_args()


if (__name__ == '__main__'):
main()
```

将上述脚本命名为 `cvt_single_channel.py`, 运行命令为:

```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh ${CONFIG} ${GPU_NUM}
python cvt_single_channel.py --path data/cat
```

这里 `GPU_NUM` 为 4。另外如果在一台机器上多个任务同时多卡训练,需要设置不同的端口,比如以下命令:
### 新建配置文件并训练

**目前MMYOLO的一些图像处理函数还不兼容单通道图片,出现不兼容问题概率较高**。为了避免这种不兼容的问题,推荐的做法是将单通道图片作为三通道图片读取后进行训练。这样会降低一些运算性能,但是基本不需要修改配置即可使用。

#### 修改配置文件

以 `projects/misc/custom_dataset/yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py`为 `base` 配置,将其复制到`configs/yolov5`目录下,在同级配置路径下新增 `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py` 文件。

```python
_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'

load_from = './checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
```

#### 预训练模型加载问题

直接使用原三通道的预训练模型,理论上会导致精度有所降低(未实验验证)。可采用的解决思路:将输入层3通道每个通道的权重调整为原3通道权重的平均值, 或将输入层每个通道的权重调整为原3通道某一通道权重,也可以对输入层权重不做修改直接训练,具体效果根据实际情况有所不同。这里采用将输入层3个通道权重调整为预训练3通道权重平均值的方式。

```python
import torch

def main():
# 加载权重文件
state_dict = torch.load(
'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'
)

# 修改输入层权重
weights = state_dict['state_dict']['backbone.stem.conv.weight']
avg_weight = weights.mean(dim=1, keepdim=True)
new_weights = torch.cat([avg_weight] * 3, dim=1)
state_dict['state_dict']['backbone.stem.conv.weight'] = new_weights

# 保存修改后的权重到新文件
torch.save(
state_dict,
'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
)

if __name__ == '__main__':
main()
```

#### 模型训练效果

<img src="https://raw.githubusercontent.com/landhill/mmyolo/main/resources/cat_single_channel_test.jpeg"/>

左图是实际标签,右图是目标检测结果。

```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG} 4
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.798
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.953
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.864
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.798
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.777
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.847
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.850
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.850
bbox_mAP_copypaste: 0.798 0.953 0.864 -1.000 -1.000 0.798
Epoch(val) [100][116/116] coco/bbox_mAP: 0.7980 coco/bbox_mAP_50: 0.9530 coco/bbox_mAP_75: 0.8640 coco/bbox_mAP_s: -1.0000 coco/bbox_mAP_m: -1.0000 coco/bbox_mAP_l: 0.7980
```