From d061b60caedc86d7d440f8b425ac7d60273a18c1 Mon Sep 17 00:00:00 2001 From: Ccc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:06:31 +0800 Subject: [PATCH] [Enhancement] Simplify Config and Builder (#2897) * Simplify Config and Builder --- paddleseg/cvlibs/__init__.py | 1 + paddleseg/cvlibs/builder.py | 256 ++++++++++++++++++++++++++--- paddleseg/cvlibs/config.py | 217 ++++++------------------ paddleseg/cvlibs/config_checker.py | 165 ++++++++++--------- paddleseg/utils/utils.py | 16 ++ tools/export.py | 11 +- tools/predict.py | 8 +- tools/train.py | 19 +-- tools/val.py | 12 +- 9 files changed, 415 insertions(+), 290 deletions(-) diff --git a/paddleseg/cvlibs/__init__.py b/paddleseg/cvlibs/__init__.py index 5fcb1d6c10..8e1f8b0828 100644 --- a/paddleseg/cvlibs/__init__.py +++ b/paddleseg/cvlibs/__init__.py @@ -15,3 +15,4 @@ from . import manager from . import param_init from .config import Config +from .builder import Builder, SegBuilder diff --git a/paddleseg/cvlibs/builder.py b/paddleseg/cvlibs/builder.py index 72d9b6718b..b2f77bf0f1 100644 --- a/paddleseg/cvlibs/builder.py +++ b/paddleseg/cvlibs/builder.py @@ -13,22 +13,31 @@ # limitations under the License. import copy +from typing import Any, Optional +import yaml +import paddle -class ComponentBuilder(object): +from paddleseg.cvlibs import manager, Config +from paddleseg.utils import utils, logger +from paddleseg.utils.utils import CachedProperty as cached_property + + +class Builder(object): """ - This class is responsible for building components. All component classes must be available - in the list of maintained components. + The base class for building components. Args: - com_list (list): A list of component classes. + config (Config): A Config class. + comp_list (list, optional): A list of component classes. Default: None """ - def __init__(self, com_list): + def __init__(self, config: Config, comp_list: Optional[list]=None): super().__init__() - self.com_list = com_list + self.config = config + self.comp_list = comp_list - def create_object(self, cfg): + def build_component(self, cfg): """ Create Python object, such as model, loss, dataset, etc. """ @@ -44,17 +53,17 @@ def create_object(self, cfg): params = {} for key, val in cfg.items(): if self.is_meta_type(val): - params[key] = self.create_object(val) + params[key] = self.build_component(val) elif isinstance(val, list): params[key] = [ - self.create_object(item) + self.build_component(item) if self.is_meta_type(item) else item for item in val ] else: params[key] = val try: - obj = self.create_object_impl(com_class, **params) + obj = self.build_component_impl(com_class, **params) except Exception as e: if hasattr(com_class, '__name__'): com_name = com_class.__name__ @@ -64,28 +73,16 @@ def create_object(self, cfg): f"Tried to create a {com_name} object, but the operation has failed. " "Please double check the arguments used to create the object.\n" f"The error message is: \n{str(e)}") - return obj - - def create_object_impl(self, component_class, *args, **kwargs): - raise NotImplementedError - - def load_component_class(self, cfg): - raise NotImplementedError - - @classmethod - def is_meta_type(cls, obj): - raise NotImplementedError + return obj -class DefaultComponentBuilder(ComponentBuilder): - def create_object_impl(self, component_class, *args, **kwargs): + def build_component_impl(self, component_class, *args, **kwargs): return component_class(*args, **kwargs) def load_component_class(self, class_type): - for com in self.com_list: + for com in self.comp_list: if class_type in com.components_dict: return com[class_type] - raise RuntimeError("The specified component ({}) was not found.".format( class_type)) @@ -94,3 +91,212 @@ def is_meta_type(cls, obj): # TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol) # to make it more pythonic? return isinstance(obj, dict) and 'type' in obj + + @classmethod + def show_msg(cls, name, cfg): + msg = 'Use the following config to build {}\n'.format(name) + msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper)) + logger.info(msg[0:-1]) + + +class SegBuilder(Builder): + """ + This class is responsible for building components for semantic segmentation. + """ + + def __init__(self, config, comp_list=None): + if comp_list is None: + comp_list = [ + manager.MODELS, manager.BACKBONES, manager.DATASETS, + manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS + ] + super().__init__(config, comp_list) + + @cached_property + def model(self) -> paddle.nn.Layer: + model_cfg = self.config.model_cfg + assert model_cfg != {}, \ + 'No model specified in the configuration file.' + + if self.config.train_dataset_cfg['type'] != 'Dataset': + # check and synchronize the num_classes in model config and dataset class + assert hasattr(self.train_dataset_class, 'NUM_CLASSES'), \ + 'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.' + num_classes = getattr(self.train_dataset_class, 'NUM_CLASSES') + if 'num_classes' in model_cfg: + assert model_cfg['num_classes'] == num_classes, \ + 'The num_classes is not consistent for model config ({}) ' \ + 'and train_dataset class ({}) '.format(model_cfg['num_classes'], num_classes) + else: + logger.warning( + 'Add the `num_classes` in train_dataset class to ' + 'model config. We suggest you manually set `num_classes` in model config.' + ) + model_cfg['num_classes'] = num_classes + # check and synchronize the in_channels in model config and dataset class + assert hasattr(self.train_dataset_class, 'IMG_CHANNELS'), \ + 'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.' + in_channels = getattr(self.train_dataset_class, 'IMG_CHANNELS') + x = utils.get_in_channels(model_cfg) + if x is not None: + assert x == in_channels, \ + 'The in_channels in model config ({}) and the img_channels in train_dataset ' \ + 'class ({}) is not consistent'.format(x, in_channels) + else: + model_cfg = utils.set_in_channels(model_cfg, in_channels) + logger.warning( + 'Add the `in_channels` in train_dataset class to ' + 'model config. We suggest you manually set `in_channels` in model config.' + ) + + self.show_msg('model', model_cfg) + return self.build_component(model_cfg) + + @cached_property + def optimizer(self) -> paddle.optimizer.Optimizer: + opt_cfg = self.config.optimizer_cfg + assert opt_cfg != {}, \ + 'No optimizer specified in the configuration file.' + # For compatibility + if opt_cfg['type'] == 'adam': + opt_cfg['type'] = 'Adam' + if opt_cfg['type'] == 'sgd': + opt_cfg['type'] = 'SGD' + if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg: + opt_cfg['type'] = 'Momentum' + logger.info('If the type is SGD and momentum in optimizer config, ' + 'the type is changed to Momentum.') + self.show_msg('optimizer', opt_cfg) + opt = self.build_component(opt_cfg) + opt = opt(self.model, self.lr_scheduler) + return opt + + @cached_property + def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler: + lr_cfg = self.config.lr_scheduler_cfg + assert lr_cfg != {}, \ + 'No lr_scheduler specified in the configuration file.' + + use_warmup = False + if 'warmup_iters' in lr_cfg: + use_warmup = True + warmup_iters = lr_cfg.pop('warmup_iters') + assert 'warmup_start_lr' in lr_cfg, \ + "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler" + warmup_start_lr = lr_cfg.pop('warmup_start_lr') + end_lr = lr_cfg['learning_rate'] + + lr_type = lr_cfg.pop('type') + if lr_type == 'PolynomialDecay': + iters = self.config.iters - warmup_iters if use_warmup else self.config.iters + iters = max(iters, 1) + lr_cfg.setdefault('decay_steps', iters) + + try: + lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg) + except Exception as e: + raise RuntimeError( + "Create {} has failed. Please check lr_scheduler in config. " + "The error message: {}".format(lr_type, e)) + + if use_warmup: + lr_sche = paddle.optimizer.lr.LinearWarmup( + learning_rate=lr_sche, + warmup_steps=warmup_iters, + start_lr=warmup_start_lr, + end_lr=end_lr) + + return lr_sche + + @cached_property + def loss(self) -> dict: + loss_cfg = self.config.loss_cfg + assert loss_cfg != {}, \ + 'No loss specified in the configuration file.' + return self._build_loss('loss', loss_cfg) + + @cached_property + def distill_loss(self) -> dict: + loss_cfg = self.config.distill_loss_cfg + assert loss_cfg != {}, \ + 'No distill_loss specified in the configuration file.' + return self._build_loss('distill_loss', loss_cfg) + + def _build_loss(self, loss_name, loss_cfg: dict): + def _check_helper(loss_cfg, ignore_index): + if 'ignore_index' not in loss_cfg: + loss_cfg['ignore_index'] = ignore_index + logger.warning('Add the `ignore_index` in train_dataset ' \ + 'class to {} config. We suggest you manually set ' \ + '`ignore_index` in {} config.'.format(loss_name, loss_name) + ) + else: + assert loss_cfg['ignore_index'] == ignore_index, \ + 'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\ + 'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], ignore_index) + + # check and synchronize the ignore_index in model config and dataset class + if self.config.train_dataset_cfg['type'] != 'Dataset': + assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \ + 'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.' + ignore_index = getattr(self.train_dataset_class, 'IGNORE_INDEX') + for loss_cfg_i in loss_cfg['types']: + if loss_cfg_i['type'] == 'MixedLoss': + for loss_cfg_j in loss_cfg_i['losses']: + _check_helper(loss_cfg_j, ignore_index) + else: + _check_helper(loss_cfg_i, ignore_index) + + self.show_msg(loss_name, loss_cfg) + loss_dict = {'coef': loss_cfg['coef'], "types": []} + for item in loss_cfg['types']: + loss_dict['types'].append(self.build_component(item)) + return loss_dict + + @cached_property + def train_dataset(self) -> paddle.io.Dataset: + dataset_cfg = self.config.train_dataset_cfg + assert dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + self.show_msg('train_dataset', dataset_cfg) + dataset = self.build_component(dataset_cfg) + assert len(dataset) != 0, \ + 'The number of samples in train_dataset is 0. Please check whether the dataset is valid.' + return dataset + + @cached_property + def val_dataset(self) -> paddle.io.Dataset: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + self.show_msg('val_dataset', dataset_cfg) + dataset = self.build_component(dataset_cfg) + assert len(dataset) != 0, \ + 'The number of samples in val_dataset is 0. Please check whether the dataset is valid.' + return dataset + + @cached_property + def train_dataset_class(self) -> Any: + dataset_cfg = self.config.train_dataset_cfg + assert dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + dataset_type = dataset_cfg.get('type') + return self.load_component_class(dataset_type) + + @cached_property + def val_dataset_class(self) -> Any: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + dataset_type = dataset_cfg.get('type') + return self.load_component_class(dataset_type) + + @cached_property + def val_transforms(self) -> list: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + transforms = [] + for item in dataset_cfg.get('transforms', []): + transforms.append(self.build_component(item)) + return transforms diff --git a/paddleseg/cvlibs/config.py b/paddleseg/cvlibs/config.py index e6fc0c263f..34f3ed037b 100644 --- a/paddleseg/cvlibs/config.py +++ b/paddleseg/cvlibs/config.py @@ -22,10 +22,8 @@ import paddle from paddleseg.cvlibs import config_checker as checker -from paddleseg.cvlibs import builder from paddleseg.cvlibs import manager from paddleseg.utils import logger, utils -from paddleseg.utils.utils import CachedProperty as cached_property _INHERIT_KEY = '_inherited_' _BASE_KEY = '_base_' @@ -33,7 +31,7 @@ class Config(object): """ - Training configuration parsing. The only yaml/yml file is supported. + Configuration parsing. The following hyper-parameters are available in the config file: batch_size: The number of samples per gpu. @@ -42,13 +40,12 @@ class Config(object): For data type, please refer to paddleseg.datasets. For specific transforms, please refer to paddleseg.transforms.transforms. val_dataset: A validation data config including type/data_root/transforms/mode. - optimizer: A optimizer config, but currently PaddleSeg only supports sgd with momentum in config file. - In addition, weight_decay could be set as a regularization. - learning_rate: A learning rate config. If decay is configured, learning _rate value is the starting learning rate, - where only poly decay is supported using the config file. In addition, decay power and end_lr are tuned experimentally. - loss: A loss config. Multi-loss config is available. The loss type order is consistent with the seg model outputs, - where the coef term indicates the weight of corresponding loss. Note that the number of coef must be the same as the number of - model outputs, and there could be only one loss type if using the same loss type among the outputs, otherwise the number of + optimizer: A optimizer config. Please refer to paddleseg.optimizers. + loss: A loss config. Multi-loss config is available. The loss type order is + consistent with the seg model outputs, where the coef term indicates the + weight of corresponding loss. Note that the number of coef must be the + same as the number of model outputs, and there could be only one loss type + if using the same loss type among the outputs, otherwise the number of loss type must be consistent with coef. model: A model config including type/backbone and model-dependent arguments. For model type, please refer to paddleseg.models. @@ -56,37 +53,24 @@ class Config(object): Args: path (str) : The path of config file, supports yaml format only. + opts (list, optional): Use opts to update the key-value pairs of all options. - Examples: - - from paddleseg.cvlibs.config import Config - - # Create a cfg object with yaml file path. - cfg = Config(yaml_cfg_path) - - # Parsing the argument when its property is used. - train_dataset = cfg.train_dataset - - # the argument of model should be parsed after dataset, - # since the model builder uses some properties in dataset. - model = cfg.model - ... """ - def __init__(self, - path: str, - learning_rate: Optional[float]=None, - batch_size: Optional[int]=None, - iters: Optional[int]=None, - opts: Optional[list]=None, - checker: Optional[checker.ConfigChecker]=None, - component_builder: Optional[builder.ComponentBuilder]=None): + def __init__( + self, + path: str, + learning_rate: Optional[float]=None, + batch_size: Optional[int]=None, + iters: Optional[int]=None, + opts: Optional[list]=None, + checker: Optional[checker.ConfigChecker]=None, ): assert os.path.exists(path), \ 'Config path ({}) does not exist'.format(path) assert path.endswith('yml') or path.endswith('yaml'), \ 'Config file ({}) should be yaml format'.format(path) - self.dic = self.parse_from_yaml(path) + self.dic = self._parse_from_yaml(path) self.dic = self.update_config_dict( self.dic, learning_rate=learning_rate, @@ -94,151 +78,61 @@ def __init__(self, iters=iters, opts=opts) - # We have to build the component builder before doing any sanity checks - # This is because during a sanity check, some component objects are (possibly) - # required to be constructed. - if component_builder is None: - component_builder = self._build_default_component_builder() - self.builder = component_builder - if checker is None: checker = self._build_default_checker() checker.apply_all_rules(self) - def __str__(self) -> str: - # Use NoAliasDumper to avoid yml anchor - return yaml.dump(self.dic, Dumper=utils.NoAliasDumper) - - #################### hyper parameters - @cached_property + @property def batch_size(self) -> int: return self.dic.get('batch_size') - @cached_property + @property def iters(self) -> int: return self.dic.get('iters') - @cached_property + @property def to_static_training(self) -> bool: return self.dic.get('to_static_training', False) - #################### lr_scheduler and optimizer - @cached_property - def optimizer(self) -> paddle.optimizer.Optimizer: - opt_cfg = self.dic.get('optimizer').copy() - # For compatibility - if opt_cfg['type'] == 'adam': - opt_cfg['type'] = 'Adam' - if opt_cfg['type'] == 'sgd': - opt_cfg['type'] = 'SGD' - if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg: - opt_cfg['type'] = 'Momentum' - opt = self.builder.create_object(opt_cfg) - opt = opt(self.model, self.lr_scheduler) - return opt - - @cached_property - def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler: - assert 'lr_scheduler' in self.dic, 'No `lr_scheduler` specified in the configuration file.' - params = self.dic.get('lr_scheduler').copy() - - use_warmup = False - if 'warmup_iters' in params: - use_warmup = True - warmup_iters = params.pop('warmup_iters') - assert 'warmup_start_lr' in params, \ - "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler" - warmup_start_lr = params.pop('warmup_start_lr') - end_lr = params['learning_rate'] - - lr_type = params.pop('type') - if lr_type == 'PolynomialDecay': - iters = self.iters - warmup_iters if use_warmup else self.iters - iters = max(iters, 1) - params.setdefault('decay_steps', iters) - params.setdefault('end_lr', 0) - params.setdefault('power', 0.9) - lr_sche = getattr(paddle.optimizer.lr, lr_type)(**params) - - if use_warmup: - lr_sche = paddle.optimizer.lr.LinearWarmup( - learning_rate=lr_sche, - warmup_steps=warmup_iters, - start_lr=warmup_start_lr, - end_lr=end_lr) - - return lr_sche - - #################### loss - @cached_property - def loss(self) -> dict: - return self._prepare_loss('loss') - - @cached_property - def distill_loss(self) -> dict: - return self._prepare_loss('distill_loss') - - def _prepare_loss(self, loss_name): - args = self.dic.get(loss_name, {}).copy() - losses = {'coef': args['coef'], "types": []} - for loss_cfg in args['types']: - losses['types'].append(self.builder.create_object(loss_cfg)) - return losses - - #################### model - @cached_property - def model(self) -> paddle.nn.Layer: - model_cfg = self.dic.get('model').copy() - return self.builder.create_object(model_cfg) - - #################### dataset and transforms - @cached_property - def train_dataset(self) -> paddle.io.Dataset: - dataset_cfg = self.dic.get('train_dataset').copy() - return self.builder.create_object(dataset_cfg) - - @cached_property - def val_dataset(self) -> paddle.io.Dataset: - assert 'val_dataset' in self.dic, \ - 'No val_dataset specified in the configuration file.' - dataset_cfg = self.dic.get('val_dataset').copy() - return self.builder.create_object(dataset_cfg) - - @cached_property - def train_dataset_class(self) -> Any: - dataset_type = self.dic['train_dataset']['type'] - return self.builder.load_component_class(dataset_type) - - @cached_property - def val_dataset_class(self) -> Any: - assert 'val_dataset' in self.dic, \ - 'No val_dataset specified in the configuration file.' - dataset_type = self.dic['val_dataset']['type'] - return self.builder.load_component_class(dataset_type) - - @cached_property - def val_transforms(self) -> list: - transforms = [] - if 'val_dataset' in self.dic: - for tf in self.dic.get('val_dataset').get('transforms', []): - transforms.append(self.builder.create_object(tf)) - return transforms - - @cached_property - def val_dataset_config(self) -> Dict: + @property + def model_cfg(self) -> Dict: + return self.dic.get('model', {}).copy() + + @property + def loss_cfg(self) -> Dict: + return self.dic.get('loss', {}).copy() + + @property + def distill_loss_cfg(self) -> Dict: + return self.dic.get('distill_loss', {}).copy() + + @property + def lr_scheduler_cfg(self) -> Dict: + return self.dic.get('lr_scheduler', {}).copy() + + @property + def optimizer_cfg(self) -> Dict: + return self.dic.get('optimizer', {}).copy() + + @property + def train_dataset_cfg(self) -> Dict: + return self.dic.get('train_dataset', {}).copy() + + @property + def val_dataset_cfg(self) -> Dict: return self.dic.get('val_dataset', {}).copy() - @cached_property + # TODO merge test_config into val_dataset + @property def test_config(self) -> Dict: return self.dic.get('test_config', {}).copy() - #################### checker and builder @classmethod def update_config_dict(cls, dic: dict, *args, **kwargs) -> dict: return update_config_dict(dic, *args, **kwargs) @classmethod - def parse_from_yaml(cls, path: str, *args, **kwargs) -> dict: + def _parse_from_yaml(cls, path: str, *args, **kwargs) -> dict: return parse_from_yaml(path, *args, **kwargs) @classmethod @@ -256,14 +150,9 @@ def _build_default_checker(cls): return checker.ConfigChecker(rules, allow_update=True) - @classmethod - def _build_default_component_builder(cls): - com_list = [ - manager.MODELS, manager.BACKBONES, manager.DATASETS, - manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS - ] - component_builder = builder.DefaultComponentBuilder(com_list=com_list) - return component_builder + def __str__(self) -> str: + # Use NoAliasDumper to avoid yml anchor + return yaml.dump(self.dic, Dumper=utils.NoAliasDumper) def parse_from_yaml(path: str): @@ -321,7 +210,7 @@ def update_config_dict(dic: dict, if opts is not None: for item in opts: assert ('=' in item) and (len(item.split('=')) == 2), "--opts params should be key=value," \ - " such as `--opts train.batch_size=1 test_config.scales=0.75,1.0,1.25`, " \ + " such as `--opts batch_size=1 test_config.scales=0.75,1.0,1.25`, " \ "but got ({})".format(opts) key, value = item.split('=') diff --git a/paddleseg/cvlibs/config_checker.py b/paddleseg/cvlibs/config_checker.py index 694bc503fb..694a08297d 100644 --- a/paddleseg/cvlibs/config_checker.py +++ b/paddleseg/cvlibs/config_checker.py @@ -15,6 +15,7 @@ import copy from paddleseg.utils import logger +from paddleseg.utils import utils class ConfigChecker(object): @@ -99,69 +100,81 @@ def check_and_correct(self, cfg): class DefaultSyncNumClassesRule(Rule): def check_and_correct(self, cfg): - model_config = cfg.dic['model'] - train_dataset_config = cfg.dic['train_dataset'] - val_dataset_config = cfg.dic['val_dataset'] + # check the num_classes in model, train_dataset and val_dataset config + model_cfg = cfg.model_cfg + train_dataset_cfg = cfg.train_dataset_cfg + val_dataset_cfg = cfg.val_dataset_cfg + assert train_dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + if train_dataset_cfg['type'] != 'Dataset': + return + if val_dataset_cfg != {}: + assert val_dataset_cfg['type'] == 'Dataset', \ + 'The type of train_dataset and val_dataset must be the same' + assert 'num_classes' in val_dataset_cfg, \ + 'No num_classes specified in train_dataset config.' + assert 'num_classes' in train_dataset_cfg, \ + 'No num_classes specified in train_dataset config.' + value_set = set() value_name = 'num_classes' + if value_name in model_cfg: + value_set.add(model_cfg[value_name]) + if value_name in train_dataset_cfg: + value_set.add(train_dataset_cfg[value_name]) + if value_name in val_dataset_cfg: + value_set.add(val_dataset_cfg[value_name]) - if value_name in model_config: - value_set.add(model_config[value_name]) - if value_name in train_dataset_config: - value_set.add(train_dataset_config[value_name]) - if value_name in val_dataset_config: - value_set.add(val_dataset_config[value_name]) - if hasattr(cfg.train_dataset_class, 'NUM_CLASSES'): - value_set.add(cfg.train_dataset_class.NUM_CLASSES) - if hasattr(cfg.val_dataset_class, 'NUM_CLASSES'): - value_set.add(cfg.val_dataset_class.NUM_CLASSES) - - if len(value_set) == 0: + if len(value_set) > 1: raise ValueError( - '`num_classes` is not found. Please set it in model, train_dataset or val_dataset' + '`num_classes` is not consistent: {}. Please set it ' + 'consistently in model, train_dataset and val_dataset config'. + format(value_set)) + if len(value_set) == 1 and value_name not in model_cfg: + logger.warning( + 'Add the `num_classes` in train_dataset and val_dataset ' + 'config to model config. We suggest you manually set `num_classes` in model config.' ) - elif len(value_set) > 1: - raise ValueError( - '`num_classes` is not consistent: {}. Please set it consistently in model or train_dataset or val_dataset' - .format(value_set)) - - model_config[value_name] = value_set.pop() + model_cfg[value_name] = value_set.pop() + cfg.dic['model'] = model_cfg class DefaultSyncImgChannelsRule(Rule): def check_and_correct(self, cfg): - model_config = cfg.dic['model'] - train_dataset_config = cfg.dic['train_dataset'] - val_dataset_config = cfg.dic['val_dataset'] - value_set = set() + model_cfg = cfg.model_cfg + train_dataset_cfg = cfg.train_dataset_cfg + val_dataset_cfg = cfg.val_dataset_cfg + + assert train_dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + if train_dataset_cfg['type'] != 'Dataset': + return + if val_dataset_cfg != {}: + assert val_dataset_cfg['type'] == 'Dataset', \ + 'The type of train_dataset and val_dataset must be the same' # If the model has backbone, in_channels is the input params of backbone. # Otherwise, in_channels is the input params of the model. - if 'backbone' in model_config: - x = model_config['backbone'].get('in_channels', None) - if x is not None: - value_set.add(x) - if 'in_channels' in model_config: - value_set.add(model_config['in_channels']) - if 'img_channels' in train_dataset_config: - value_set.add(train_dataset_config['img_channels']) - if 'img_channels' in val_dataset_config: - value_set.add(val_dataset_config['img_channels']) - if hasattr(cfg.train_dataset_class, 'IMG_CHANNELS'): - value_set.add(cfg.train_dataset_class.IMG_CHANNELS) - if hasattr(cfg.val_dataset_class, 'IMG_CHANNELS'): - value_set.add(cfg.val_dataset_class.IMG_CHANNELS) + value_set = set() + x = utils.get_in_channels(model_cfg) + if x is not None: + value_set.add(x) + if 'img_channels' in train_dataset_cfg: + value_set.add(train_dataset_cfg['img_channels']) + if 'img_channels' in val_dataset_cfg: + value_set.add(val_dataset_cfg['img_channels']) if len(value_set) > 1: - raise ValueError( - '`in_channels` is not consistent: {}. Please set it consistently in model or train_dataset or val_dataset' - .format(value_set)) - channels = 3 if len(value_set) == 0 else value_set.pop() - - if 'backbone' in model_config: - model_config['backbone']['in_channels'] = channels - else: - model_config['in_channels'] = channels + raise ValueError('`in_channels` is not consistent: {}. Please set it ' \ + 'consistently in model or train_dataset or val_dataset'.format(value_set)) + if len(value_set) == 1 and utils.get_in_channels(model_cfg) is None: + logger.warning( + 'Add the `in_channels` in train_dataset and val_dataset ' + 'config to model config. We suggest you manually set `in_channels` in model config.' + ) + model_cfg = utils.set_in_channels(model_cfg, value_set.pop()) + cfg.dic['model'] = model_cfg + # if len(value_set) == 0, model and dataset use the default in_channels (3) class DefaultSyncIgnoreIndexRule(Rule): @@ -170,41 +183,43 @@ def __init__(self, loss_name): self.loss_name = loss_name def check_and_correct(self, cfg): - def _check_ignore_index(loss_cfg, dataset_ignore_index): - if 'ignore_index' in loss_cfg: + def _check_helper(loss_cfg, dataset_ignore_index): + if 'ignore_index' not in loss_cfg: + loss_cfg['ignore_index'] = dataset_ignore_index + else: assert loss_cfg['ignore_index'] == dataset_ignore_index, \ 'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\ 'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], dataset_ignore_index) - else: - loss_cfg['ignore_index'] = dataset_ignore_index loss_cfg = cfg.dic.get(self.loss_name, None) if loss_cfg is None: return - train_dataset_config = cfg.dic['train_dataset'] - val_dataset_config = cfg.dic['val_dataset'] + train_dataset_cfg = cfg.train_dataset_cfg + val_dataset_cfg = cfg.val_dataset_cfg + assert train_dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + if train_dataset_cfg['type'] != 'Dataset': + return + if val_dataset_cfg != {}: + assert val_dataset_cfg['type'] == 'Dataset', \ + 'The type of train_dataset and val_dataset must be the same' + value_set = set() value_name = 'ignore_index' - - if value_name in train_dataset_config: - value_set.add(train_dataset_config[value_name]) - if value_name in val_dataset_config: - value_set.add(val_dataset_config[value_name]) - if hasattr(cfg.train_dataset_class, 'IGNORE_INDEX'): - value_set.add(cfg.train_dataset_class.IGNORE_INDEX) - if hasattr(cfg.val_dataset_class, 'IGNORE_INDEX'): - value_set.add(cfg.val_dataset_class.IGNORE_INDEX) + if value_name in train_dataset_cfg: + value_set.add(train_dataset_cfg[value_name]) + if value_name in val_dataset_cfg: + value_set.add(val_dataset_cfg[value_name]) if len(value_set) > 1: - raise ValueError( - '`ignore_index` is not consistent: {}. Please set it consistently in train_dataset and val_dataset' - .format(value_set)) - ignore_index = 255 if len(value_set) == 0 else value_set.pop() - - for loss_cfg_i in loss_cfg['types']: - if loss_cfg_i['type'] == 'MixedLoss': - for loss_cfg_j in loss_cfg_i['losses']: - _check_ignore_index(loss_cfg_j, ignore_index) - else: - _check_ignore_index(loss_cfg_i, ignore_index) + raise ValueError('`ignore_index` is not consistent: {}. Please set ' \ + 'it consistently in train_dataset and val_dataset'.format(value_set)) + if len(value_set) == 1: + ignore_index = value_set.pop() + for loss_cfg_i in loss_cfg['types']: + if loss_cfg_i['type'] == 'MixedLoss': + for loss_cfg_j in loss_cfg_i['losses']: + _check_helper(loss_cfg_j, ignore_index) + else: + _check_helper(loss_cfg_i, ignore_index) diff --git a/paddleseg/utils/utils.py b/paddleseg/utils/utils.py index bb524c3aaf..55e7036784 100644 --- a/paddleseg/utils/utils.py +++ b/paddleseg/utils/utils.py @@ -277,3 +277,19 @@ def __get__(self, obj, cls): # Note that this is only executed once obj.__dict__[self.func.__name__] = val return val + + +def get_in_channels(model_cfg): + if 'backbone' in model_cfg: + return model_cfg['backbone'].get('in_channels', None) + else: + return model_cfg.get('in_channels', None) + + +def set_in_channels(model_cfg, in_channels): + model_cfg = model_cfg.copy() + if 'backbone' in model_cfg: + model_cfg['backbone']['in_channels'] = in_channels + else: + model_cfg['in_channels'] = in_channels + return model_cfg diff --git a/tools/export.py b/tools/export.py index b23032fc6d..6c09b797d2 100644 --- a/tools/export.py +++ b/tools/export.py @@ -18,7 +18,7 @@ import paddle import yaml -from paddleseg.cvlibs import Config +from paddleseg.cvlibs import Config, SegBuilder from paddleseg.utils import logger, utils from paddleseg.deploy.export import WrappedModel @@ -56,13 +56,14 @@ def main(args): assert args.config is not None, \ 'No configuration file specified, please set --config' cfg = Config(args.config) + builder = SegBuilder(cfg) utils.show_env_info() utils.show_cfg_info(cfg) os.environ['PADDLESEG_EXPORT_STAGE'] = 'True' # save model - model = cfg.model + model = builder.model if args.model_path is not None: state_dict = paddle.load(args.model_path) model.set_dict(state_dict) @@ -78,9 +79,9 @@ def main(args): paddle.jit.save(model, os.path.join(args.save_dir, 'model')) # save deploy.yaml - val_dataset_config = cfg.val_dataset_config - assert val_dataset_config != {}, 'No val_dataset specified in the configuration file.' - transforms = val_dataset_config.get('transforms', None) + val_dataset_cfg = cfg.val_dataset_cfg + assert val_dataset_cfg != {}, 'No val_dataset specified in the configuration file.' + transforms = val_dataset_cfg.get('transforms', None) output_dtype = 'int32' if args.output_op == 'argmax' else 'float32' # TODO add test config diff --git a/tools/predict.py b/tools/predict.py index a3cdfaeaf6..ab19c9681b 100644 --- a/tools/predict.py +++ b/tools/predict.py @@ -17,7 +17,7 @@ import paddle -from paddleseg.cvlibs import manager, Config +from paddleseg.cvlibs import manager, Config, SegBuilder from paddleseg.utils import get_sys_env, logger, get_image_list, utils from paddleseg.core import predict from paddleseg.transforms import Compose @@ -118,18 +118,20 @@ def main(args): assert args.config is not None, \ 'No configuration file specified, please set --config' cfg = Config(args.config) + builder = SegBuilder(cfg) test_config = merge_test_config(cfg, args) utils.show_env_info() utils.show_cfg_info(cfg) utils.set_device(args.device) - transforms = Compose(cfg.val_transforms) + model = builder.model + transforms = Compose(builder.val_transforms) image_list, image_dir = get_image_list(args.image_path) logger.info('The number of images: {}'.format(len(image_list))) predict( - cfg.model, + model, model_path=args.model_path, transforms=transforms, image_list=image_list, diff --git a/tools/train.py b/tools/train.py index b99a047aea..d12d6878ec 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,7 +19,7 @@ import numpy as np import cv2 -from paddleseg.cvlibs import manager, Config +from paddleseg.cvlibs import Config, SegBuilder from paddleseg.utils import get_sys_env, logger, utils from paddleseg.core import train @@ -133,6 +133,7 @@ def main(args): iters=args.iters, batch_size=args.batch_size, opts=args.opts) + builder = SegBuilder(cfg) utils.show_env_info() utils.show_cfg_info(cfg) @@ -152,23 +153,21 @@ def main(args): for i in range(loss_len): cfg.dic['loss']['types'][i]['data_format'] = args.data_format - model = utils.convert_sync_batchnorm(cfg.model, args.device) + model = utils.convert_sync_batchnorm(builder.model, args.device) - train_dataset = cfg.train_dataset - assert train_dataset is not None, \ - 'The training dataset is not specified in the configuration file.' - assert len(train_dataset) != 0, \ - 'The length of train_dataset is 0. Please check whether the dataset is valid.' + train_dataset = builder.train_dataset # TODO refactor if args.repeats > 1: train_dataset.file_list *= args.repeats - val_dataset = cfg.val_dataset if args.do_eval else None + val_dataset = builder.val_dataset if args.do_eval else None + optimizer = builder.optimizer + loss = builder.loss train( model, train_dataset, val_dataset=val_dataset, - optimizer=cfg.optimizer, + optimizer=optimizer, save_dir=args.save_dir, iters=cfg.iters, batch_size=cfg.batch_size, @@ -177,7 +176,7 @@ def main(args): log_iters=args.log_iters, num_workers=args.num_workers, use_vdl=args.use_vdl, - losses=cfg.loss, + losses=loss, keep_checkpoint_max=args.keep_checkpoint_max, test_config=cfg.test_config, precision=args.precision, diff --git a/tools/val.py b/tools/val.py index b57e432795..454737608b 100644 --- a/tools/val.py +++ b/tools/val.py @@ -17,7 +17,7 @@ import paddle -from paddleseg.cvlibs import manager, Config +from paddleseg.cvlibs import manager, Config, SegBuilder from paddleseg.core import evaluate from paddleseg.utils import get_sys_env, logger, utils @@ -119,6 +119,7 @@ def main(args): assert args.config is not None, \ 'No configuration file specified, please set --config' cfg = Config(args.config, opts=args.opts) + builder = SegBuilder(cfg) test_config = merge_test_config(cfg, args) utils.show_env_info() @@ -137,16 +138,11 @@ def main(args): for i in range(loss_len): cfg.dic['loss']['types'][i]['data_format'] = args.data_format - model = cfg.model + model = builder.model if args.model_path: utils.load_entire_model(model, args.model_path) logger.info('Loaded trained weights successfully.') - - val_dataset = cfg.val_dataset - assert val_dataset is not None, \ - 'The val_dataset is not specified in the configuration file.' - assert len(val_dataset) != 0, \ - 'The length of val_dataset is 0. Please check whether the dataset is valid.' + val_dataset = builder.val_dataset evaluate(model, val_dataset, num_workers=args.num_workers, **test_config)