diff --git a/README.md b/README.md
index 98d2f9640..7bd64032e 100644
--- a/README.md
+++ b/README.md
@@ -204,6 +204,7 @@ For different parts from MMDetection, we have also prepared user guides and adva
- [Cross-library calls using mim](docs/en/common_usage/mim_usage.md)
- [Apply multiple Necks](docs/en/common_usage/multi_necks.md)
- [Specify specific device training or inference](docs/en/common_usage/specify_device.md)
+- [Single and multi-channel application examples](docs/en/common_usage/single_multi_channel_applications.md)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 4c9c52017..37c83ef77 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -225,6 +225,7 @@ MMYOLO 用法和 MMDetection 几乎一致,所有教程都是通用的,你也
- [使用 mim 跨库调用其他 OpenMMLab 仓库的脚本](docs/zh_cn/common_usage/mim_usage.md)
- [应用多个 Neck](docs/zh_cn/common_usage/multi_necks.md)
- [指定特定设备训练或推理](docs/zh_cn/common_usage/specify_device.md)
+- [单通道和多通道应用案例](docs/zh_cn/common_usage/single_multi_channel_applications.md)
diff --git a/docs/en/common_usage/single_multi_channel_applications.md b/docs/en/common_usage/single_multi_channel_applications.md
new file mode 100644
index 000000000..30932708b
--- /dev/null
+++ b/docs/en/common_usage/single_multi_channel_applications.md
@@ -0,0 +1,188 @@
+# Single and multi-channel application examples
+
+## Training example on a single-channel image dataset
+
+The default training images in MMYOLO are all color three-channel data. If you want to use a single-channel dataset for training and testing, it is expected that the following modifications are needed.
+
+1. All image processing pipelines have to support single channel operations
+2. The input channel of the first convolutional layer of the backbone network of the model needs to be changed from 3 to 1
+3. If you wish to load COCO pre-training weights, you need to handle the first convolutional layer weight size mismatch
+
+The following uses the `cat` dataset as an example to describe the entire modification process, if you are using a custom grayscale image dataset, you can skip the dataset preprocessing step.
+
+### 1 Dataset pre-processing
+
+The processing training of the custom dataset can be found in [Annotation-to-deployment workflow for custom dataset](../recommended_topics/labeling_to_deployment_tutorials.md)。
+
+`cat` is a three-channel color image dataset. For demonstration purpose, you can run the following code and commands to replace the dataset images with single-channel images for subsequent validation.
+
+**1. Download the `cat` dataset for decompression**
+
+```shell
+python tools/misc/download_dataset.py --dataset-name cat --save-dir ./data/cat --unzip --delete
+```
+
+**2. Convert datasets to grayscale maps**
+
+```python
+import argparse
+import imghdr
+import os
+from typing import List
+import cv2
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='data_path')
+ parser.add_argument('path', type=str, help='Original dataset path')
+ return parser.parse_args()
+
+def main():
+ args = parse_args()
+
+ path = args.path + '/images/'
+ save_path = path
+ file_list: List[str] = os.listdir(path)
+ # Grayscale conversion of each imager
+ for file in file_list:
+ if imghdr.what(path + '/' + file) != 'jpeg':
+ continue
+ img = cv2.imread(path + '/' + file)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ cv2.imwrite(save_path + '/' + file, img)
+
+if __name__ == '__main__':
+ main()
+```
+
+Name the above script as `cvt_single_channel.py`, and run the command as:
+
+```shell
+python cvt_single_channel.py data/cat
+```
+
+### 2 Modify the base configuration file
+
+**At present, some image processing functions of MMYOLO, such as color space transformation, are not compatible with single-channel images, so if we use single-channel data for training directly, we need to modify part of the pipeline, which is a large amount of work**. In order to solve the incompatibility problem, the recommended approach is to load the single-channel image as a three-channel image as a three-channel data, but convert it to single-channel format before input to the network. This approach will slightly increase the arithmetic burden, but the user basically does not need to modify the code to use.
+
+Take `projects/misc/custom_dataset/yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py` as the `base` configuration, copy it to the `configs/yolov5` directory, and add `yolov5_s- v61_syncbn_fast_1xb32-100e_cat_single_channel.py` file. We can inherit `YOLOv5DetDataPreprocessor` from the `mmyolo/models/data_preprocessors/data_preprocessor.py` file and name the new class `YOLOv5SCDetDataPreprocessor`, in which convert the image to a single channel, add the dependency library and register the new class in `mmyolo/models/data_preprocessors/__init__.py`. The `YOLOv5SCDetDataPreprocessor` sample code is:
+
+```python
+@MODELS.register_module()
+class YOLOv5SCDetDataPreprocessor(YOLOv5DetDataPreprocessor):
+ """Rewrite collate_fn to get faster training speed.
+
+ Note: It must be used together with `mmyolo.datasets.utils.yolov5_collate`
+ """
+
+ def forward(self, data: dict, training: bool = False) -> dict:
+ """Perform normalization, padding, bgr2rgb conversion and convert to single channel image based on ``DetDataPreprocessor``.
+
+ Args:
+ data (dict): Data sampled from dataloader.
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict: Data in the same format as the model input.
+ """
+ if not training:
+ return super().forward(data, training)
+
+ data = self.cast_data(data)
+ inputs, data_samples = data['inputs'], data['data_samples']
+ assert isinstance(data['data_samples'], dict)
+
+ # TODO: Supports multi-scale training
+ if self._channel_conversion and inputs.shape[1] == 3:
+ inputs = inputs[:, [2, 1, 0], ...]
+
+ if self._enable_normalize:
+ inputs = (inputs - self.mean) / self.std
+
+ if self.batch_augments is not None:
+ for batch_aug in self.batch_augments:
+ inputs, data_samples = batch_aug(inputs, data_samples)
+
+ img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
+ data_samples = {
+ 'bboxes_labels': data_samples['bboxes_labels'],
+ 'img_metas': img_metas
+ }
+
+ # Convert to single channel image
+ inputs = inputs.mean(dim=1, keepdim=True)
+
+ return {'inputs': inputs, 'data_samples': data_samples}
+```
+
+At this point, the `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py` configuration file reads as follows.
+
+```python
+_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'
+
+_base_.model.data_preprocessor.type = 'YOLOv5SCDetDataPreprocessor'
+```
+
+### 3 Pre-training model loading problem
+
+When using a pre-trained 3-channel model directly, it's theoretically possible to experience a decrease in accuracy, though this has not been experimentally verified. To mitigate this potential issue, there are several solutions, including adjusting the weight of each channel in the input layer. One approach is to set the weight of each channel in the input layer to the average of the weights of the original 3 channels. Alternatively, the weight of each channel could be set to one of the weights of the original 3 channels, or the input layer could be trained directly without modifying the weights, depending on the specific circumstances. In this work, we chose to adjust the weights of the 3 channels in the input layer to the average of the weights of the pre-trained 3 channels.
+
+```python
+import torch
+
+def main():
+ # Load weights file
+ state_dict = torch.load(
+ 'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'
+ )
+
+ # Modify input layer weights
+ weights = state_dict['state_dict']['backbone.stem.conv.weight']
+ avg_weight = weights.mean(dim=1, keepdim=True)
+ state_dict['state_dict']['backbone.stem.conv.weight'] = avg_weight
+
+ # Save the modified weights to a new file
+ torch.save(
+ state_dict,
+ 'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
+ )
+
+if __name__ == '__main__':
+ main()
+```
+
+At this point, the `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py` configuration file reads as follows:
+
+```python
+_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'
+
+_base_.model.data_preprocessor.type = 'YOLOv5SCDetDataPreprocessor'
+
+load_from = './checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
+```
+
+### 4 Model training effect
+
+
+
+The left figure shows the actual label and the right figure shows the target detection result.
+
+```shell
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.958
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 1.000
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.958
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.881
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.969
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.969
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.969
+bbox_mAP_copypaste: 0.958 1.000 1.000 -1.000 -1.000 0.958
+Epoch(val) [100][116/116] coco/bbox_mAP: 0.9580 coco/bbox_mAP_50: 1.0000 coco/bbox_mAP_75: 1.0000 coco/bbox_mAP_s: -1.0000 coco/bbox_mAP_m: -1.0000 coco/bbox_mAP_l: 0.9580
+```
+
+## Training example on a multi-channel image dataset
+
+TODO
diff --git a/docs/en/index.rst b/docs/en/index.rst
index 5516b6197..23f9a0e76 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -45,6 +45,7 @@ You can switch between Chinese and English documents in the top-right corner of
common_usage/mim_usage.md
common_usage/multi_necks.md
common_usage/specify_device.md
+ common_usage/single_multi_channel_applications.md
.. toctree::
diff --git a/docs/zh_cn/common_usage/single_multi_channel_applications.md b/docs/zh_cn/common_usage/single_multi_channel_applications.md
new file mode 100644
index 000000000..a20ef9044
--- /dev/null
+++ b/docs/zh_cn/common_usage/single_multi_channel_applications.md
@@ -0,0 +1,188 @@
+# 单通道和多通道应用案例
+
+## 在单通道图像数据集上训练示例
+
+MMYOLO 中默认的训练图片均为彩色三通道数据,如果希望采用单通道数据集进行训练和测试,预计需要修改的地方包括:
+
+1. 所有的图片处理 pipeline 都要支持单通道运算
+2. 模型的骨干网络的第一个卷积层输入通道需要从 3 改成 1
+3. 如果希望加载 COCO 预训练权重,则需要处理第一个卷积层权重尺寸不匹配问题
+
+下面以 `cat` 数据集为例,描述整个修改过程,如果你使用的是自定义灰度图像数据集,你可以跳过数据集预处理这一步。
+
+### 1 数据集预处理
+
+自定义数据集的处理训练可参照[自定义数据集 标注+训练+测试+部署 全流程](../recommended_topics/labeling_to_deployment_tutorials.md)。
+
+`cat` 是一个三通道彩色图片数据集,为了方便演示,你可以运行下面的代码和命令,将数据集图片替换为单通道图片,方便后续验证。
+
+**1. 下载 `cat` 数据集进行解压**
+
+```shell
+python tools/misc/download_dataset.py --dataset-name cat --save-dir ./data/cat --unzip --delete
+```
+
+**2. 将数据集转换为灰度图**
+
+```python
+import argparse
+import imghdr
+import os
+from typing import List
+import cv2
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='data_path')
+ parser.add_argument('path', type=str, help='Original dataset path')
+ return parser.parse_args()
+
+def main():
+ args = parse_args()
+
+ path = args.path + '/images/'
+ save_path = path
+ file_list: List[str] = os.listdir(path)
+ # Grayscale conversion of each imager
+ for file in file_list:
+ if imghdr.what(path + '/' + file) != 'jpeg':
+ continue
+ img = cv2.imread(path + '/' + file)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ cv2.imwrite(save_path + '/' + file, img)
+
+if __name__ == '__main__':
+ main()
+```
+
+将上述脚本命名为 `cvt_single_channel.py`, 运行命令为:
+
+```shell
+python cvt_single_channel.py data/cat
+```
+
+### 2 修改 base 配置文件
+
+**目前 MMYOLO 的一些图像处理函数例如颜色空间变换还不兼容单通道图片,如果直接采用单通道数据训练需要修改部分 pipeline,工作量较大**。为了解决不兼容问题,推荐的做法是将单通道图片作为采用三通道图片方式读取将其加载为三通道数据,但是在输入到网络前将其转换为单通道格式。这种做法会稍微增加一些运算负担,但是用户基本不需要修改代码即可使用。
+
+以 `projects/misc/custom_dataset/yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py`为 `base` 配置,将其复制到 `configs/yolov5` 目录下,在同级配置路径下新增 `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py` 文件。 我们可以 `mmyolo/models/data_preprocessors/data_preprocessor.py` 文件中继承 `YOLOv5DetDataPreprocessor` 并命名新类为 `YOLOv5SCDetDataPreprocessor`, 在其中将图片转成单通道,添加依赖库并在`mmyolo/models/data_preprocessors/__init__.py`中注册新类。 `YOLOv5SCDetDataPreprocessor` 示例代码为:
+
+```python
+@MODELS.register_module()
+class YOLOv5SCDetDataPreprocessor(YOLOv5DetDataPreprocessor):
+ """Rewrite collate_fn to get faster training speed.
+
+ Note: It must be used together with `mmyolo.datasets.utils.yolov5_collate`
+ """
+
+ def forward(self, data: dict, training: bool = False) -> dict:
+ """Perform normalization, padding, bgr2rgb conversion and convert to single channel image based on ``DetDataPreprocessor``.
+
+ Args:
+ data (dict): Data sampled from dataloader.
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict: Data in the same format as the model input.
+ """
+ if not training:
+ return super().forward(data, training)
+
+ data = self.cast_data(data)
+ inputs, data_samples = data['inputs'], data['data_samples']
+ assert isinstance(data['data_samples'], dict)
+
+ # TODO: Supports multi-scale training
+ if self._channel_conversion and inputs.shape[1] == 3:
+ inputs = inputs[:, [2, 1, 0], ...]
+
+ if self._enable_normalize:
+ inputs = (inputs - self.mean) / self.std
+
+ if self.batch_augments is not None:
+ for batch_aug in self.batch_augments:
+ inputs, data_samples = batch_aug(inputs, data_samples)
+
+ img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
+ data_samples = {
+ 'bboxes_labels': data_samples['bboxes_labels'],
+ 'img_metas': img_metas
+ }
+
+ # Convert to single channel image
+ inputs = inputs.mean(dim=1, keepdim=True)
+
+ return {'inputs': inputs, 'data_samples': data_samples}
+```
+
+此时 `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py`配置文件内容为如下所示:
+
+```python
+_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'
+
+_base_.model.data_preprocessor.type = 'YOLOv5SCDetDataPreprocessor'
+```
+
+### 3 预训练模型加载问题
+
+直接使用原三通道的预训练模型,理论上会导致精度有所降低(未实验验证)。可采用的解决思路:将输入层 3 通道每个通道的权重调整为原 3 通道权重的平均值, 或将输入层每个通道的权重调整为原3通道某一通道权重,也可以对输入层权重不做修改直接训练,具体效果根据实际情况有所不同。这里采用将输入层 3 个通道权重调整为预训练 3 通道权重平均值的方式。
+
+```python
+import torch
+
+def main():
+ # 加载权重文件
+ state_dict = torch.load(
+ 'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'
+ )
+
+ # 修改输入层权重
+ weights = state_dict['state_dict']['backbone.stem.conv.weight']
+ avg_weight = weights.mean(dim=1, keepdim=True)
+ state_dict['state_dict']['backbone.stem.conv.weight'] = avg_weight
+
+ # 保存修改后的权重到新文件
+ torch.save(
+ state_dict,
+ 'checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
+ )
+
+if __name__ == '__main__':
+ main()
+```
+
+此时 `yolov5_s-v61_syncbn_fast_1xb32-100e_cat_single_channel.py`配置文件内容为如下所示:
+
+```python
+_base_ = 'yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py'
+
+_base_.model.data_preprocessor.type = 'YOLOv5SCDetDataPreprocessor'
+
+load_from = './checkpoints/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187_single_channel.pth'
+```
+
+### 4 模型训练效果
+
+
+
+左图是实际标签,右图是目标检测结果。
+
+```shell
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.958
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 1.000
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.958
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.881
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.969
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.969
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.969
+bbox_mAP_copypaste: 0.958 1.000 1.000 -1.000 -1.000 0.958
+Epoch(val) [100][116/116] coco/bbox_mAP: 0.9580 coco/bbox_mAP_50: 1.0000 coco/bbox_mAP_75: 1.0000 coco/bbox_mAP_s: -1.0000 coco/bbox_mAP_m: -1.0000 coco/bbox_mAP_l: 0.9580
+```
+
+## 在多通道图像数据集上训练示例
+
+TODO
diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst
index 1138e9c34..2976849bb 100644
--- a/docs/zh_cn/index.rst
+++ b/docs/zh_cn/index.rst
@@ -45,6 +45,7 @@
common_usage/mim_usage.md
common_usage/multi_necks.md
common_usage/specify_device.md
+ common_usage/single_multi_channel_applications.md
.. toctree::