diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 50aa5fb817..a2c4072adc 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -70,6 +70,7 @@ def train(model, keep_checkpoint_max=5, test_config=None, precision='fp32', + amp_level='O1', profiler_options=None, to_static_training=False): """ @@ -93,6 +94,9 @@ def train(model, keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5. test_config(dict, optional): Evaluation config. precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal. + amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, + the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators + parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp) profiler_options (str, optional): The option of train profiler. to_static_training (bool, optional): Whether to use @to_static for training. """ @@ -109,6 +113,17 @@ def train(model, os.remove(save_dir) os.makedirs(save_dir) + # use amp + if precision == 'fp16': + logger.info('use AMP to train. AMP level = {}'.format(amp_level)) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + if amp_level == 'O2': + model, optimizer = paddle.amp.decorate( + models=model, + optimizers=optimizer, + level='O2', + save_dtype='float32') + if nranks > 1: paddle.distributed.fleet.init(is_collective=True) optimizer = paddle.distributed.fleet.distributed_optimizer( @@ -125,11 +140,6 @@ def train(model, return_list=True, worker_init_fn=worker_init_fn, ) - # use amp - if precision == 'fp16': - logger.info('use amp to train') - scaler = paddle.amp.GradScaler(init_loss_scaling=1024) - if use_vdl: from visualdl import LogWriter log_writer = LogWriter(save_dir) @@ -169,15 +179,14 @@ def train(model, if precision == 'fp16': with paddle.amp.auto_cast( + level=amp_level, enable=True, custom_white_list={ "elementwise_add", "batch_norm", "sync_batch_norm" }, custom_black_list={'bilinear_interp_v2'}): - if nranks > 1: - logits_list = ddp_model(images) - else: - logits_list = model(images) + logits_list = ddp_model(images) if nranks > 1 else model( + images) loss_list = loss_computation( logits_list=logits_list, labels=labels, @@ -192,10 +201,7 @@ def train(model, else: scaler.minimize(optimizer, scaled) # update parameters else: - if nranks > 1: - logits_list = ddp_model(images) - else: - logits_list = model(images) + logits_list = ddp_model(images) if nranks > 1 else model(images) loss_list = loss_computation( logits_list=logits_list, labels=labels, @@ -273,7 +279,12 @@ def train(model, test_config = {} mean_iou, acc, _, _, _ = evaluate( - model, val_dataset, num_workers=num_workers, **test_config) + model, + val_dataset, + num_workers=num_workers, + precision=precision, + amp_level=amp_level, + **test_config) model.train() @@ -309,7 +320,7 @@ def train(model, batch_start = time.time() # Calculate flops. - if local_rank == 0: + if local_rank == 0 and not (precision == 'fp16' and amp_level == 'O2'): _, c, h, w = images.shape _ = paddle.flops( model, [1, c, h, w], diff --git a/paddleseg/core/val.py b/paddleseg/core/val.py index 2923ae469f..28579912fe 100644 --- a/paddleseg/core/val.py +++ b/paddleseg/core/val.py @@ -34,6 +34,8 @@ def evaluate(model, is_slide=False, stride=None, crop_size=None, + precision='fp32', + amp_level='O1', num_workers=0, print_detail=True, auc_roc=False): @@ -52,6 +54,8 @@ def evaluate(model, It should be provided when `is_slide` is True. crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height. It should be provided when `is_slide` is True. + precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the evaluation is normal. + amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp) num_workers (int, optional): Num workers for data loader. Default: 0. print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True. auc_roc(bool, optional): whether add auc_roc metric @@ -99,26 +103,65 @@ def evaluate(model, ori_shape = label.shape[-2:] if aug_eval: - pred, logits = infer.aug_inference( - model, - im, - ori_shape=ori_shape, - transforms=eval_dataset.transforms.transforms, - scales=scales, - flip_horizontal=flip_horizontal, - flip_vertical=flip_vertical, - is_slide=is_slide, - stride=stride, - crop_size=crop_size) + if precision == 'fp16': + with paddle.amp.auto_cast( + level=amp_level, + enable=True, + custom_white_list={ + "elementwise_add", "batch_norm", + "sync_batch_norm" + }, + custom_black_list={'bilinear_interp_v2'}): + pred, logits = infer.aug_inference( + model, + im, + ori_shape=ori_shape, + transforms=eval_dataset.transforms.transforms, + scales=scales, + flip_horizontal=flip_horizontal, + flip_vertical=flip_vertical, + is_slide=is_slide, + stride=stride, + crop_size=crop_size) + else: + pred, logits = infer.aug_inference( + model, + im, + ori_shape=ori_shape, + transforms=eval_dataset.transforms.transforms, + scales=scales, + flip_horizontal=flip_horizontal, + flip_vertical=flip_vertical, + is_slide=is_slide, + stride=stride, + crop_size=crop_size) else: - pred, logits = infer.inference( - model, - im, - ori_shape=ori_shape, - transforms=eval_dataset.transforms.transforms, - is_slide=is_slide, - stride=stride, - crop_size=crop_size) + if precision == 'fp16': + with paddle.amp.auto_cast( + level=amp_level, + enable=True, + custom_white_list={ + "elementwise_add", "batch_norm", + "sync_batch_norm" + }, + custom_black_list={'bilinear_interp_v2'}): + pred, logits = infer.inference( + model, + im, + ori_shape=ori_shape, + transforms=eval_dataset.transforms.transforms, + is_slide=is_slide, + stride=stride, + crop_size=crop_size) + else: + pred, logits = infer.inference( + model, + im, + ori_shape=ori_shape, + transforms=eval_dataset.transforms.transforms, + is_slide=is_slide, + stride=stride, + crop_size=crop_size) intersect_area, pred_area, label_area = metrics.calculate_area( pred, diff --git a/test_tipc/configs/pphumanseg_lite/train_amp_infer_python.txt b/test_tipc/configs/pphumanseg_lite/train_amp_infer_python.txt new file mode 100644 index 0000000000..e724c07301 --- /dev/null +++ b/test_tipc/configs/pphumanseg_lite/train_amp_infer_python.txt @@ -0,0 +1,52 @@ +===========================train_params=========================== +model_name:pphumanseg_lite +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:null|null +Global.auto_cast:null +--iters:lite_train_lite_infer=50|lite_train_whole_infer=50|whole_train_whole_infer=1000 +--save_dir: +--batch_size:lite_train_lite_infer=2|lite_train_whole_infer=2|whole_train_whole_infer=8 +--model_path:null +train_model_name:best_model/model.pdparams +train_infer_img_dir:test_tipc/data/mini_supervisely/test.txt +null:null +## +trainer:amp_train +amp_train:train.py --config test_tipc/configs/pphumanseg_lite/pphumanseg_lite_mini_supervisely.yml --precision fp16 --amp_level O1 --do_eval --save_interval 500 --seed 100 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir: +--model_path: +norm_export:export.py --config test_tipc/configs/pphumanseg_lite/pphumanseg_lite_mini_supervisely.yml +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:./test_tipc/output/pphumanseg_lite/pphumanseg_lite_generic_192x192/model.pdparams +infer_export:export.py --config test_tipc/configs/pphumanseg_lite/pphumanseg_lite_mini_supervisely.yml +infer_quant:False +inference:deploy/python/infer.py +--device:cpu|gpu +--enable_mkldnn:True|False +--cpu_threads:1|6 +--batch_size:1 +--use_trt:False|True +--precision:fp32|int8|fp16 +--config: +--image_path:./test_tipc/data/mini_supervisely/test.txt +--save_log_path:null +--benchmark:True +--save_dir: +--model_name:pphumanseg_lite diff --git a/test_tipc/docs/install.md b/test_tipc/docs/install.md index 2bfde8c7f7..29778982f6 100644 --- a/test_tipc/docs/install.md +++ b/test_tipc/docs/install.md @@ -96,7 +96,7 @@ git clone https://github.com/PaddlePaddle/PaddleSeg cd PaddleSeg pip3.7 install -r test_tipc/requirements.txt -pip install -e . +pip3.7 install -e . cd .. ``` @@ -105,15 +105,15 @@ cd .. # 安装AutoLog(规范化日志输出工具) git clone https://github.com/LDOUBLEV/AutoLog cd AutoLog -pip install -r requirements.txt +pip3.7 install -r requirements.txt python3.7 setup.py bdist_wheel -pip install ./dist/auto_log-[xxx]-py3-none-any.whl +pip3.7 install ./dist/auto_log-[xxx]-py3-none-any.whl cd ../ ``` - 安装PaddleSlim (可选) ``` # 如果要测试量化、裁剪等功能,需要安装PaddleSlim -pip3 install paddleslim +pip3.7 install paddleslim ``` ## FAQ : diff --git a/test_tipc/docs/test_train_amp_inference_python.md b/test_tipc/docs/test_train_amp_inference_python.md new file mode 100644 index 0000000000..b35c073afd --- /dev/null +++ b/test_tipc/docs/test_train_amp_inference_python.md @@ -0,0 +1,108 @@ +# Linux GPU/CPU 混合精度训练推理测试 + +Linux GPU/CPU 混合精度训练推理测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能。 + +## 1. 测试结论汇总 + +- 训练相关: + +| 算法名称 | 模型名称 | 单机单卡 | 单机多卡 | +| :----: | :----: | :----: | :----: | +| ConnectNet | PP-HumanSeg-Lite | 混合精度训练 | 混合精度训练 | + + +- 推理相关: + +| 算法名称 | 模型名称 | device_CPU | device_GPU | batchsize | +| :----: | :----: | :----: | :----: | :----: | +| ConnectNet | PP-HumanSeg-Lite | 支持 | 支持 | 1 | + + +## 2. 测试流程 + +### 2.1 准备环境 + + +- 安装PaddlePaddle:如果您已经安装了2.2或者以上版本的paddlepaddle,那么无需运行下面的命令安装paddlepaddle。 + ``` + # 需要安装2.2及以上版本的Paddle + # 安装GPU版本的Paddle + pip install paddlepaddle-gpu==2.2.2 + # 安装CPU版本的Paddle + pip install paddlepaddle==2.2.2 + ``` + +- 安装依赖 + ``` + pip install -r requirements.txt + ``` +- 安装AutoLog(规范化日志输出工具) + ``` + pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl + ``` + +### 2.2 功能测试 + + +测试方法如下所示,希望测试不同的模型文件,只需更换为自己的参数配置文件,即可完成对应模型的测试。 + +```bash +bash test_tipc/test_train_inference_python.sh ${your_params_file} lite_train_lite_infer +``` + +以`pphumanseg_lite`的`Linux GPU/CPU 混合精度训练推理测试`为例,命令如下所示。 + +```bash +bash test_tipc/prepare.sh test_tipc/configs/pphumanseg_lite/train_amp_infer_python.txt lite_train_lite_infer +``` + +```bash +bash test_tipc/test_train_inference_python.sh test_tipc/configs/pphumanseg_lite/train_amp_infer_python.txt lite_train_lite_infer +``` + +输出结果如下,表示命令运行成功。 + +```bash +[33m Run successfully with command - python3.7 train.py --config test_tipc/configs/pphumanseg_lite/pphumanseg_lite_mini_supervisely.yml --precision fp16 --amp_level O2 --do_eval --save_interval 500 --seed 100 --save_dir=./test_tipc/output/pphumanseg_lite/amp_train_gpus_0_autocast_null --iters=50 --batch_size=2 +...... +[33m Run successfully with command - python3.7 deploy/python/infer.py --device=cpu --enable_mkldnn=True --cpu_threads=1 --config=./test_tipc/output/pphumanseg_lite/amp_train_gpus_0_autocast_null//deploy.yaml --batch_size=1 --image_path=test_tipc/data/mini_supervisely/test.txt --benchmark=True --precision=fp32 --save_dir=./test_tipc/output/pphumanseg_lite/python_infer_cpu_usemkldnn_True_threads_1_precision_fp32_batchsize_1_results > ./test_tipc/output/pphumanseg_lite/python_infer_cpu_usemkldnn_True_threads_1_precision_fp32_batchsize_1.log 2>&1 ! +``` + +在开启benchmark选项时,可以得到测试的详细数据,包含运行环境信息(系统版本、CUDA版本、CUDNN版本、驱动版本),Paddle版本信息,参数设置信息(运行设备、线程数、是否开启内存优化等),模型信息(模型名称、精度),数据信息(batchsize、是否为动态shape等),性能信息(CPU/GPU的占用、运行耗时、预处理耗时、推理耗时、后处理耗时),内容如下所示: + +``` +2022-04-13 11:43:09 [INFO] ---------------------- Env info ---------------------- +2022-04-13 11:43:09 [INFO] OS_version: CentOS Linux 7 +2022-04-13 11:43:09 [INFO] CUDA_version: 11.2.67 +Build cuda_11.2.r11.2/compiler.29373293_0 +2022-04-13 11:43:09 [INFO] CUDNN_version: None.None.None +2022-04-13 11:43:09 [INFO] drivier_version: 460.27.04 +2022-04-13 11:43:09 [INFO] ---------------------- Paddle info ---------------------- +2022-04-13 11:43:09 [INFO] paddle_version: 2.2.2 +2022-04-13 11:43:09 [INFO] paddle_commit: b031c389938bfa15e15bb20494c76f86289d77b0 +2022-04-13 11:43:09 [INFO] log_api_version: 1.0 +2022-04-13 11:43:09 [INFO] ----------------------- Conf info ----------------------- +2022-04-13 11:43:09 [INFO] runtime_device: cpu +2022-04-13 11:43:09 [INFO] ir_optim: True +2022-04-13 11:43:09 [INFO] enable_memory_optim: True +2022-04-13 11:43:09 [INFO] enable_tensorrt: False +2022-04-13 11:43:09 [INFO] enable_mkldnn: False +2022-04-13 11:43:09 [INFO] cpu_math_library_num_threads: 1 +2022-04-13 11:43:09 [INFO] ----------------------- Model info ---------------------- +2022-04-13 11:43:09 [INFO] model_name: +2022-04-13 11:43:09 [INFO] precision: fp32 +2022-04-13 11:43:09 [INFO] ----------------------- Data info ----------------------- +2022-04-13 11:43:09 [INFO] batch_size: 1 +2022-04-13 11:43:09 [INFO] input_shape: dynamic +2022-04-13 11:43:09 [INFO] data_num: 50 +2022-04-13 11:43:09 [INFO] ----------------------- Perf info ----------------------- +2022-04-13 11:43:09 [INFO] cpu_rss(MB): 315.2656, gpu_rss(MB): 8842.0, gpu_util: 0.0% +2022-04-13 11:43:09 [INFO] total time spent(s): 4.2386 +2022-04-13 11:43:09 [INFO] preprocess_time(ms): 33.7887, inference_time(ms): 50.7443, postprocess_time(ms): 0.2391 +``` + +该信息可以在运行log中查看,以`pphumanseg_lite`为例,log位置在`./output/pphumanseg_lite/lite_train_lite_infer/python_infer_cpu_usemkldnn_False_threads_1_precision_fp32_batchsize_1.log`。 + +如果运行失败,也会在终端中输出运行失败的日志信息以及对应的运行命令。可以基于该命令,分析运行失败的原因。 + +`注意`: 混合精度参数配置文件中,默认使用O1模式;O2模式存在部分问题,需要安装PaddlePaddle develop版本才可使用。 diff --git a/train.py b/train.py index 063ee1cbf6..b7e4939712 100644 --- a/train.py +++ b/train.py @@ -103,8 +103,17 @@ def parse_args(): default="fp32", type=str, choices=["fp32", "fp16"], - help="Use AMP if precision='fp16'. If precision='fp32', the training is normal." + help="Use AMP (Auto mixed precision) if precision='fp16'. If precision='fp32', the training is normal." ) + parser.add_argument( + "--amp_level", + default="O1", + type=str, + choices=["O1", "O2"], + help="Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input \ + data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators \ + parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel \ + and batchnorm. Default is O1(amp)") parser.add_argument( '--data_format', dest='data_format', @@ -211,6 +220,7 @@ def main(args): keep_checkpoint_max=args.keep_checkpoint_max, test_config=cfg.test_config, precision=args.precision, + amp_level=args.amp_level, profiler_options=args.profiler_options, to_static_training=cfg.to_static_training)