https://github.com/Sharpiless/Yolov5-distillation-train-inference
请移步:https://github.com/Sharpiless/yolov5-distillation-5.0
链接:https://pan.baidu.com/s/13gq5QwCrRNdRXWzSYUeJIw
提取码:4ppv
python train_distill.py --weights yolov5s.pt \
--teacher weights/yolov5l_voc.pt --distill_ratio 0.001 \
--teacher-cfg model/yolov5l.yaml --data data/voc.yaml \
--epochs 30 --batch-size 16
--weights:预训练模型
--teacher:教师模型权重
--distill-ratio:蒸馏损失权重
--with-gt-loss:是否同时使用ground truth
--soft-loss:是否使用KL散度作为蒸馏的类别损失(缺省使用L2-logits损失)
--full-output-loss:是否使用《Object detection at 200 Frames Per Second》中的损失
这篇文章分别对这几个损失函数做出改进,具体思路为只有当teacher network的objectness value高时,才学习bounding box坐标和class probabilities。
默认会启用 data/voc.yaml 自动下载VOC数据集进行训练
或者手动运行 data/scripts/get_voc2007.sh 下载
如需修改成自己的数据集,则只需要修改yaml路径即可
数据集:
VOC2007(补充的无标签数据使用VOC2012)
GPU:2080Ti*1
Batch Size:16
Epoches:30
Baseline:Yolov5s
Teacher model:Yolov5l(mAP 0.5:0.95 = 0.541)
这里假设VOC2012中新增加的数据为无标签数据(2k张)。
教师模型 | 训练方法 | 蒸馏损失 | P | R | mAP50 |
---|---|---|---|---|---|
无 | 正常训练 | 不使用 | 0.7756 | 0.7115 | 0.7609 |
Yolov5l | output based | l2 | 0.7585 | 0.7198 | 0.7644 |
Yolov5l | output based | KL | 0.7417 | 0.7207 | 0.7536 |
Yolov5m | output based | l2 | 0.7682 | 0.7436 | 0.7976 |
Yolov5m | output based | KL | 0.7731 | 0.7313 | 0.7931 |
参数和细节正在完善,支持KL散度、L2 logits损失和Sigmoid蒸馏损失等
- [√] 修改logist输出作为蒸馏损失输入
- [√] 完善代码结构和相关参数设定
- [×] 查找为何蒸馏损失不起作用(或者收敛慢)的原因
- [×] 完善相关实验并测试精度
- [√] 修改dataloader加快训练速度
- [√] 修改teacher model的批量推理加快训练速度
- 1.训练轮数太少没收敛,可能蒸馏训练收敛满最终结果高
- 2.教师模型是Yolov5l在VOC训练30轮得到的(mAP 0.5:0.95 = 0.541),质量比标注较差影响蒸馏训练的结果
- 3.可调整的参数还有很多(教师模型的检测、IOU阈值,蒸馏损失种类,蒸馏损失比率等)