Skip to content

Latest commit

 

History

History
334 lines (255 loc) · 15.5 KB

README_zh-CN.md

File metadata and controls

334 lines (255 loc) · 15.5 KB
 
OpenMMLab 官网 HOT      OpenMMLab 开放平台 TRY IT OUT
 

PyPI - Python Version PyPI license open issues issue resolution

📘使用文档 | 🛠️安装教程 | 🤔报告问题

English | 简体中文

简介

MMEngine 是一个基于 PyTorch 用于深度学习模型训练的基础库,支持在 Linux、Windows、macOS 上运行。它具有如下三个亮点:

  1. 通用:MMEngine 实现了一个高级的通用训练器,它能够:

    • 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 imagenet(原始pytorch example 400 行)
    • 轻松兼容流行的算法库 (如 TIMM、TorchVision 和 Detectron2 ) 中的模型
  2. 统一:MMEngine 设计了一个接口统一的开放架构,使得:

    • 用户可以仅依赖一份代码实现所有任务的轻量化,例如 MMRazor 1.x 相比 MMRazor 0.x 优化了 40% 的代码量
    • 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
  3. 灵活:MMEngine 实现了“乐高”式的训练流程,支持了:

    • 根据迭代数、 loss 和评测结果等动态调整的训练流程、优化策略和数据增强策略,例如早停(early stopping)机制等
    • 任意形式的模型权重平均,如 Exponential Momentum Average (EMA) 和 Stochastic Weight Averaging (SWA)
    • 训练过程中针对任意数据和任意节点的灵活可视化和日志控制
    • 对神经网络模型中各个层的优化配置进行细粒度调整
    • 混合精度训练的灵活控制

最近进展

最新版本 v0.2.0 在 2022.10.11 发布。

  1. 重构 FileIO 以提供更加易用的接口并保持向下兼容。
  2. 新增 Test time augmentation 模型基类。
  3. 分布式训练时,支持将 BN 自动转化为 SyncBN。
  4. 新增了 SMDDP 后端并支持在 AWS 进行分布式训练。

如果想了解更多版本更新细节和历史信息,请阅读更新日志

安装

在安装 MMengine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 PyTorch 官方安装文档

安装 MMEngine

pip install -U openmim
mim install mmengine

验证是否安装成功

python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'

更多安装方式请阅读安装文档

快速上手

以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、可配置的训练和验证流程。

构建模型

首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 "loss",并返回一个包含 "loss" 字段的字典;对于验证,我们需要 mode 接受字符串 "predict",并返回同时包含预测信息和真实信息的结果。

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
构建数据集

其次,我们需要构建训练和验证所需要的数据集 (Dataset)和数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))
构建评测指标

为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 processcompute_metrics 方法。

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })
    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回保存有评测指标结果的字典,其中键为指标名称
        return dict(accuracy=100 * total_correct / total_size)
构建执行器

最后,我们利用构建好的模型数据加载器评测指标构建一个执行器 (Runner),同时在其中配置 优化器工作路径训练与验证配置等选项

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
)
开始训练
runner.train()

了解更多

入门教程
进阶教程
示例
架构设计
迁移指南

贡献指南

我们感谢所有的贡献者为改进和提升 MMEngine 所作出的努力。请参考贡献指南来了解参与项目贡献的相关指引。

开源许可证

该项目采用 Apache 2.0 license 开源许可证。

OpenMMLab 的其他项目

  • MIM: MIM 是 OpenMMLab 项目、算法、模型的统一入口
  • MMCV: OpenMMLab 计算机视觉基础库
  • MMClassification: OpenMMLab 图像分类工具箱
  • MMDetection: OpenMMLab 目标检测工具箱
  • MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台
  • MMRotate: OpenMMLab 旋转框检测工具箱与测试基准
  • MMSegmentation: OpenMMLab 语义分割工具箱
  • MMOCR: OpenMMLab 全流程文字检测识别理解工具包
  • MMPose: OpenMMLab 姿态估计工具箱
  • MMHuman3D: OpenMMLab 人体参数化模型工具箱与测试基准
  • MMSelfSup: OpenMMLab 自监督学习工具箱与测试基准
  • MMRazor: OpenMMLab 模型压缩工具箱与测试基准
  • MMFewShot: OpenMMLab 少样本学习工具箱与测试基准
  • MMAction2: OpenMMLab 新一代视频理解工具箱
  • MMTracking: OpenMMLab 一体化视频目标感知平台
  • MMFlow: OpenMMLab 光流估计工具箱与测试基准
  • MMEditing: OpenMMLab 图像视频编辑工具箱
  • MMGeneration: OpenMMLab 图片视频生成模型工具箱
  • MMDeploy: OpenMMLab 模型部署框架

欢迎加入 OpenMMLab 社区

扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,加入 OpenMMLab 团队的 官方交流 QQ 群,或通过添加微信“Open小喵Lab”加入官方交流微信群。

我们会在 OpenMMLab 社区为大家

  • 📢 分享 AI 框架的前沿核心技术
  • 💻 解读 PyTorch 常用模块源码
  • 📰 发布 OpenMMLab 的相关新闻
  • 🚀 介绍 OpenMMLab 开发的前沿算法
  • 🏃 获取更高效的问题答疑和意见反馈
  • 🔥 提供与各行各业开发者充分交流的平台

干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬