From b781f4f27a86189ddeb1d30417794a303cf30983 Mon Sep 17 00:00:00 2001 From: PeiyuLau Date: Wed, 10 Jul 2024 14:05:00 +0800 Subject: [PATCH] add data_format selection support to ocr --- configs/cls/cls_mv3.yml | 2 + .../det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml | 9 ++ configs/det/det_mv3_db.yml | 7 +- configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml | 2 + .../PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml | 4 + ppocr/modeling/backbones/det_mobilenet_v3.py | 24 ++++- ppocr/modeling/backbones/det_resnet_vd.py | 33 +++++-- ppocr/modeling/backbones/rec_mobilenet_v3.py | 9 +- ppocr/modeling/backbones/rec_mv1_enhance.py | 30 +++++- ppocr/modeling/backbones/rec_svtrnet.py | 4 +- ppocr/modeling/heads/cls_head.py | 5 +- ppocr/modeling/heads/det_db_head.py | 20 +++- ppocr/modeling/heads/rec_multi_head.py | 12 ++- ppocr/modeling/heads/rec_sar_head.py | 4 +- ppocr/modeling/necks/db_fpn.py | 99 ++++++++++++------- ppocr/modeling/necks/intracl.py | 16 ++- ppocr/modeling/necks/rnn.py | 52 ++++++---- 17 files changed, 240 insertions(+), 92 deletions(-) mode change 100644 => 100755 configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml mode change 100644 => 100755 ppocr/modeling/heads/rec_sar_head.py diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml index 0c46ff56027..8016f91c6a6 100644 --- a/configs/cls/cls_mv3.yml +++ b/configs/cls/cls_mv3.yml @@ -23,10 +23,12 @@ Architecture: name: MobileNetV3 scale: 0.35 model_name: small + data_format: NHWC Neck: Head: name: ClsHead class_dim: 2 + data_format: NHWC Loss: name: ClsLoss diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml index 252d1599776..7123dd1c109 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml @@ -35,13 +35,16 @@ Architecture: scale: 0.5 model_name: large disable_se: true + data_format: NHWC Neck: name: RSEFPN out_channels: 96 shortcut: True + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Student2: pretrained: model_type: det @@ -52,13 +55,16 @@ Architecture: scale: 0.5 model_name: large disable_se: true + data_format: NHWC Neck: name: RSEFPN out_channels: 96 shortcut: True + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Teacher: freeze_params: true return_all_feats: false @@ -68,13 +74,16 @@ Architecture: name: ResNet_vd in_channels: 3 layers: 50 + data_format: NHWC Neck: name: LKPAN out_channels: 256 + data_format: NHWC Head: name: DBHead kernel_list: [7,2,2] k: 50 + data_format: NHWC Loss: name: CombinedLoss diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 8f5685ec2a3..4455730728c 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -25,12 +25,15 @@ Architecture: name: MobileNetV3 scale: 0.5 model_name: large + data_format: NHWC Neck: name: DBFPN out_channels: 256 + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Loss: name: DBLoss @@ -64,7 +67,7 @@ Metric: Train: dataset: name: SimpleDataSet - data_dir: ./train_data/icdar2015/text_localization/ + data_dir: ./ label_file_list: - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt ratio_list: [1.0] @@ -107,7 +110,7 @@ Train: Eval: dataset: name: SimpleDataSet - data_dir: ./train_data/icdar2015/text_localization/ + data_dir: ./ label_file_list: - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt transforms: diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml index fd15873fbf8..39ca94f06b3 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -44,6 +44,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -59,6 +60,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Loss: name: MultiLoss diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml old mode 100644 new mode 100755 index 3b82ef857f0..31fac2e585b --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -54,6 +54,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -69,6 +70,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Student: pretrained: freeze_params: false @@ -82,6 +84,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -97,6 +100,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Loss: name: CombinedLoss loss_config_list: diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index 98db44b6911..f5f84f14f15 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -35,7 +35,7 @@ def make_divisible(v, divisor=8, min_value=None): class MobileNetV3(nn.Layer): def __init__( - self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs + self, in_channels=3, model_name="large", scale=0.5, disable_se=False, data_format='NCHW', **kwargs ): """ the MobilenetV3 backbone network for detection module. @@ -46,6 +46,7 @@ def __init__( self.disable_se = disable_se + self.nchw = data_format=='NCHW' if model_name == "large": cfg = [ # k, exp, c, se, nl, s, @@ -102,6 +103,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format ) self.stages = [] @@ -125,6 +127,7 @@ def __init__( stride=s, use_se=se, act=nl, + data_format=data_format ) ) inplanes = make_divisible(scale * c) @@ -139,6 +142,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format ) ) self.stages.append(nn.Sequential(*block_list)) @@ -147,6 +151,8 @@ def __init__( self.add_sublayer(sublayer=stage, name="stage{}".format(i)) def forward(self, x): + if not self.nchw: + x = x.transpose([0,2,3,1]) x = self.conv(x) out_list = [] for stage in self.stages: @@ -166,6 +172,7 @@ def __init__( groups=1, if_act=True, act=None, + data_format='NCHW' ): super(ConvBNLayer, self).__init__() self.if_act = if_act @@ -178,9 +185,10 @@ def __init__( padding=padding, groups=groups, bias_attr=False, + data_format=data_format ) - self.bn = nn.BatchNorm(num_channels=out_channels, act=None) + self.bn = nn.BatchNorm(num_channels=out_channels, act=None, data_layout=data_format) def forward(self, x): x = self.conv(x) @@ -210,6 +218,7 @@ def __init__( stride, use_se, act=None, + data_format='NCHW' ): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels @@ -223,6 +232,7 @@ def __init__( padding=0, if_act=True, act=act, + data_format=data_format ) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, @@ -233,9 +243,10 @@ def __init__( groups=mid_channels, if_act=True, act=act, + data_format=data_format ) if self.if_se: - self.mid_se = SEModule(mid_channels) + self.mid_se = SEModule(mid_channels, data_format=data_format) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -244,6 +255,7 @@ def __init__( padding=0, if_act=False, act=None, + data_format=data_format ) def forward(self, inputs): @@ -258,15 +270,16 @@ def forward(self, inputs): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4): + def __init__(self, in_channels, reduction=4, data_format='NCHW'): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.avg_pool = nn.AdaptiveAvgPool2D(1, data_format=data_format) self.conv1 = nn.Conv2D( in_channels=in_channels, out_channels=in_channels // reduction, kernel_size=1, stride=1, padding=0, + data_format=data_format ) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, @@ -274,6 +287,7 @@ def __init__(self, in_channels, reduction=4): kernel_size=1, stride=1, padding=0, + data_format=data_format ) def forward(self, inputs): diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py index 070ba3c9787..e2d3038680f 100644 --- a/ppocr/modeling/backbones/det_resnet_vd.py +++ b/ppocr/modeling/backbones/det_resnet_vd.py @@ -45,6 +45,7 @@ def __init__( skip_quant=False, dcn_bias_regularizer=L2Decay(0.0), dcn_bias_lr_scale=2.0, + data_format='NCHW' ): super(DeformableConvV2, self).__init__() self.offset_channel = 2 * kernel_size**2 * groups @@ -70,6 +71,7 @@ def __init__( deformable_groups=groups, weight_attr=weight_attr, bias_attr=dcn_bias_attr, + data_format=data_format ) if lr_scale == 1 and regularizer is None: @@ -88,6 +90,7 @@ def __init__( padding=(kernel_size - 1) // 2, weight_attr=ParamAttr(initializer=Constant(0.0)), bias_attr=offset_bias_attr, + data_format=data_format ) if skip_quant: self.conv_offset.skip_quant = True @@ -116,12 +119,13 @@ def __init__( is_vd_mode=False, act=None, is_dcn=False, + data_format='NCHW' ): super(ConvBNLayer, self).__init__() self.is_vd_mode = is_vd_mode self._pool2d_avg = nn.AvgPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True + kernel_size=2, stride=2, padding=0, ceil_mode=True, data_format=data_format ) if not is_dcn: self._conv = nn.Conv2D( @@ -132,6 +136,7 @@ def __init__( padding=(kernel_size - 1) // 2, groups=groups, bias_attr=False, + data_format=data_format ) else: self._conv = DeformableConvV2( @@ -142,8 +147,9 @@ def __init__( padding=(kernel_size - 1) // 2, groups=dcn_groups, # groups, bias_attr=False, + data_format=data_format ) - self._batch_norm = nn.BatchNorm(out_channels, act=act) + self._batch_norm = nn.BatchNorm(out_channels, act=act, data_layout=data_format) def forward(self, inputs): if self.is_vd_mode: @@ -162,6 +168,7 @@ def __init__( shortcut=True, if_first=False, is_dcn=False, + data_format='NCHW' ): super(BottleneckBlock, self).__init__() @@ -170,6 +177,7 @@ def __init__( out_channels=out_channels, kernel_size=1, act="relu", + data_format=data_format ) self.conv1 = ConvBNLayer( in_channels=out_channels, @@ -179,12 +187,14 @@ def __init__( act="relu", is_dcn=is_dcn, dcn_groups=2, + data_format=data_format ) self.conv2 = ConvBNLayer( in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, act=None, + data_format=data_format ) if not shortcut: @@ -194,6 +204,7 @@ def __init__( kernel_size=1, stride=1, is_vd_mode=False if if_first else True, + data_format=data_format ) self.shortcut = shortcut @@ -220,6 +231,7 @@ def __init__( stride, shortcut=True, if_first=False, + data_format='NCHW' ): super(BasicBlock, self).__init__() self.stride = stride @@ -229,9 +241,10 @@ def __init__( kernel_size=3, stride=stride, act="relu", + data_format=data_format ) self.conv1 = ConvBNLayer( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None + in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None, data_format=data_format ) if not shortcut: @@ -241,6 +254,7 @@ def __init__( kernel_size=1, stride=1, is_vd_mode=False if if_first else True, + data_format=data_format ) self.shortcut = shortcut @@ -260,7 +274,7 @@ def forward(self, inputs): class ResNet_vd(nn.Layer): def __init__( - self, in_channels=3, layers=50, dcn_stage=None, out_indices=None, **kwargs + self, in_channels=3, layers=50, dcn_stage=None, out_indices=None, data_format='NCHW', **kwargs ): super(ResNet_vd, self).__init__() @@ -296,14 +310,15 @@ def __init__( kernel_size=3, stride=2, act="relu", + data_format=data_format ) self.conv1_2 = ConvBNLayer( - in_channels=32, out_channels=32, kernel_size=3, stride=1, act="relu" + in_channels=32, out_channels=32, kernel_size=3, stride=1, act="relu", data_format=data_format ) self.conv1_3 = ConvBNLayer( - in_channels=32, out_channels=64, kernel_size=3, stride=1, act="relu" + in_channels=32, out_channels=64, kernel_size=3, stride=1, act="relu", data_format=data_format ) - self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1, data_format=data_format) self.stages = [] self.out_channels = [] @@ -326,6 +341,7 @@ def __init__( shortcut=shortcut, if_first=block == i == 0, is_dcn=is_dcn, + data_format=data_format ), ) shortcut = True @@ -348,6 +364,7 @@ def __init__( stride=2 if i == 0 and block != 0 else 1, shortcut=shortcut, if_first=block == i == 0, + data_format=data_format ), ) shortcut = True @@ -357,6 +374,8 @@ def __init__( self.stages.append(nn.Sequential(*block_list)) def forward(self, inputs): + if not self.nchw: + inputs = inputs.transpose([0,2,3,1]) # NCHW -> NHWC y = self.conv1_1(inputs) y = self.conv1_2(y) y = self.conv1_3(y) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index 00ee5a3da0b..7d17759ad0f 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -32,10 +32,12 @@ def __init__( large_stride=None, small_stride=None, disable_se=False, + data_format='NCHW', **kwargs, ): super(MobileNetV3, self).__init__() self.disable_se = disable_se + self.nchw = data_format=='NCHW' if small_stride is None: small_stride = [2, 2, 2, 2] if large_stride is None: @@ -113,6 +115,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format ) i = 0 block_list = [] @@ -128,6 +131,7 @@ def __init__( stride=s, use_se=se, act=nl, + data_format=data_format ) ) inplanes = make_divisible(scale * c) @@ -143,12 +147,15 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format ) - self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0,data_format=data_format) self.out_channels = make_divisible(scale * cls_ch_squeeze) def forward(self, x): + if not self.nchw: + x = x.transpose([0,2,3,1]) # NCHW -> NHWC x = self.conv1(x) x = self.blocks(x) x = self.conv2(x) diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py index f20fa4cb1d2..e7c26cac0ba 100644 --- a/ppocr/modeling/backbones/rec_mv1_enhance.py +++ b/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -42,6 +42,7 @@ def __init__( channels=None, num_groups=1, act="hard_swish", + data_format='NCHW' ): super(ConvBNLayer, self).__init__() @@ -54,6 +55,7 @@ def __init__( groups=num_groups, weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False, + data_format=data_format ) self._batch_norm = BatchNorm( @@ -61,6 +63,7 @@ def __init__( act=act, param_attr=ParamAttr(regularizer=L2Decay(0.0)), bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + data_layout=data_format ) def forward(self, inputs): @@ -81,6 +84,7 @@ def __init__( dw_size=3, padding=1, use_se=False, + data_format='NCHW' ): super(DepthwiseSeparable, self).__init__() self.use_se = use_se @@ -91,15 +95,17 @@ def __init__( stride=stride, padding=padding, num_groups=int(num_groups * scale), + data_format=data_format ) if use_se: - self._se = SEModule(int(num_filters1 * scale)) + self._se = SEModule(int(num_filters1 * scale), data_format=data_format) self._pointwise_conv = ConvBNLayer( num_channels=int(num_filters1 * scale), filter_size=1, num_filters=int(num_filters2 * scale), stride=1, padding=0, + data_format=data_format ) def forward(self, inputs): @@ -118,6 +124,7 @@ def __init__( last_conv_stride=1, last_pool_type="max", last_pool_kernel_size=[3, 2], + data_format='NCHW', **kwargs, ): super().__init__() @@ -131,6 +138,7 @@ def __init__( num_filters=int(32 * scale), stride=2, padding=1, + data_format=data_format ) conv2_1 = DepthwiseSeparable( @@ -140,6 +148,7 @@ def __init__( num_groups=32, stride=1, scale=scale, + data_format=data_format ) self.block_list.append(conv2_1) @@ -150,6 +159,7 @@ def __init__( num_groups=64, stride=1, scale=scale, + data_format=data_format ) self.block_list.append(conv2_2) @@ -160,6 +170,7 @@ def __init__( num_groups=128, stride=1, scale=scale, + data_format=data_format ) self.block_list.append(conv3_1) @@ -170,6 +181,7 @@ def __init__( num_groups=128, stride=(2, 1), scale=scale, + data_format=data_format ) self.block_list.append(conv3_2) @@ -180,6 +192,7 @@ def __init__( num_groups=256, stride=1, scale=scale, + data_format=data_format ) self.block_list.append(conv4_1) @@ -190,6 +203,7 @@ def __init__( num_groups=256, stride=(2, 1), scale=scale, + data_format=data_format ) self.block_list.append(conv4_2) @@ -204,6 +218,7 @@ def __init__( padding=2, scale=scale, use_se=False, + data_format=data_format ) self.block_list.append(conv5) @@ -217,6 +232,7 @@ def __init__( padding=2, scale=scale, use_se=True, + data_format=data_format ) self.block_list.append(conv5_6) @@ -230,6 +246,7 @@ def __init__( padding=2, use_se=True, scale=scale, + data_format=data_format ) self.block_list.append(conv6) @@ -239,12 +256,15 @@ def __init__( kernel_size=last_pool_kernel_size, stride=last_pool_kernel_size, padding=0, + data_format=data_format ) else: - self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, data_format=data_format) self.out_channels = int(1024 * scale) def forward(self, inputs): + if self.data_format == 'NHWC': + inputs = paddle.tensor.transpose(inputs, [0,2,3,1]) y = self.conv1(inputs) y = self.block_list(y) y = self.pool(y) @@ -252,9 +272,9 @@ def forward(self, inputs): class SEModule(nn.Layer): - def __init__(self, channel, reduction=4): + def __init__(self, channel, reduction=4, data_format='NCHW'): super(SEModule, self).__init__() - self.avg_pool = AdaptiveAvgPool2D(1) + self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format) self.conv1 = Conv2D( in_channels=channel, out_channels=channel // reduction, @@ -263,6 +283,7 @@ def __init__(self, channel, reduction=4): padding=0, weight_attr=ParamAttr(), bias_attr=ParamAttr(), + data_format=data_format ) self.conv2 = Conv2D( in_channels=channel // reduction, @@ -272,6 +293,7 @@ def __init__(self, channel, reduction=4): padding=0, weight_attr=ParamAttr(), bias_attr=ParamAttr(), + data_format=data_format ) def forward(self, inputs): diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py index 427c87b324a..724b0256242 100644 --- a/ppocr/modeling/backbones/rec_svtrnet.py +++ b/ppocr/modeling/backbones/rec_svtrnet.py @@ -51,6 +51,7 @@ def __init__( bias_attr=False, groups=1, act=nn.GELU, + data_format='NCHW' ): super().__init__() self.conv = nn.Conv2D( @@ -62,8 +63,9 @@ def __init__( groups=groups, weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), bias_attr=bias_attr, + data_format=data_format ) - self.norm = nn.BatchNorm2D(out_channels) + self.norm = nn.BatchNorm2D(out_channels, data_format=data_format) self.act = act() def forward(self, inputs): diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py index 867e9601827..b5204b934f8 100644 --- a/ppocr/modeling/heads/cls_head.py +++ b/ppocr/modeling/heads/cls_head.py @@ -31,9 +31,10 @@ class ClsHead(nn.Layer): params(dict): super parameters for build Class network """ - def __init__(self, in_channels, class_dim, **kwargs): + def __init__(self, in_channels, class_dim, data_format='NCHW', **kwargs): super(ClsHead, self).__init__() self.pool = nn.AdaptiveAvgPool2D(1) + self.nchw = data_format=='NCHW' stdv = 1.0 / math.sqrt(in_channels * 1.0) self.fc = nn.Linear( in_channels, @@ -46,7 +47,7 @@ def __init__(self, in_channels, class_dim, **kwargs): def forward(self, x, targets=None): x = self.pool(x) - x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) + x = paddle.reshape(x, shape=[x.shape[0], x.shape[1 if self.nchw else 3]]) x = self.fc(x) if not self.training: x = F.softmax(x, axis=1) diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index 8f41a25b01b..a02bbe58040 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -32,7 +32,7 @@ def get_bias_attr(k): class Head(nn.Layer): - def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): + def __init__(self, in_channels, kernel_list=[3, 2, 2], data_format='HCNW', **kwargs): super(Head, self).__init__() self.conv1 = nn.Conv2D( @@ -42,12 +42,14 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): padding=int(kernel_list[0] // 2), weight_attr=ParamAttr(), bias_attr=False, + data_format=data_format ) self.conv_bn1 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)), act="relu", + data_layout=data_format ) self.conv2 = nn.Conv2DTranspose( @@ -57,12 +59,14 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): stride=2, weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()), bias_attr=get_bias_attr(in_channels // 4), + data_format=data_format ) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)), act="relu", + data_layout=data_format ) self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -71,6 +75,7 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): stride=2, weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()), bias_attr=get_bias_attr(in_channels // 4), + data_format=data_format ) def forward(self, x, return_f=False): @@ -95,11 +100,12 @@ class DBHead(nn.Layer): params(dict): super parameters for build DB network """ - def __init__(self, in_channels, k=50, **kwargs): + def __init__(self, in_channels, k=50, data_format='NCHW', **kwargs): super(DBHead, self).__init__() self.k = k - self.binarize = Head(in_channels, **kwargs) - self.thresh = Head(in_channels, **kwargs) + self.binarize = Head(in_channels, data_format=data_format, **kwargs) + self.thresh = Head(in_channels, data_format=data_format, **kwargs) + self.data_format = data_format def step_function(self, x, y): return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) @@ -107,11 +113,15 @@ def step_function(self, x, y): def forward(self, x, targets=None): shrink_maps = self.binarize(x) if not self.training: + if 'NHWC' == self.data_format: + shrink_maps = paddle.tensor.transpose(shrink_maps, [0, 3, 1, 2]) return {"maps": shrink_maps} threshold_maps = self.thresh(x) binary_maps = self.step_function(shrink_maps, threshold_maps) - y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) + y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1 if 'NCHW' == self.data_format else 3) + if 'NHWC' == self.data_format: + y = paddle.tensor.transpose(y, [0, 3, 1, 2]) return {"maps": y} diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py index be4461567bd..e80590cc9c8 100644 --- a/ppocr/modeling/heads/rec_multi_head.py +++ b/ppocr/modeling/heads/rec_multi_head.py @@ -65,14 +65,15 @@ def forward(self, x): class MultiHead(nn.Layer): - def __init__(self, in_channels, out_channels_list, **kwargs): + def __init__(self, in_channels, out_channels_list, data_format='NCHW', **kwargs): super().__init__() self.head_list = kwargs.pop("head_list") self.use_pool = kwargs.get("use_pool", False) self.use_pos = kwargs.get("use_pos", False) self.in_channels = in_channels + self.nchw=data_format=='NCHW' if self.use_pool: - self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0) + self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0, data_format=data_format) self.gtc_head = "sar" assert len(self.head_list) >= 2 for idx, head_name in enumerate(self.head_list): @@ -113,17 +114,18 @@ def __init__(self, in_channels, out_channels_list, **kwargs): ) elif name == "CTCHead": # ctc neck - self.encoder_reshape = Im2Seq(in_channels) + self.encoder_reshape = Im2Seq(in_channels, data_format=data_format) neck_args = self.head_list[idx][name]["Neck"] encoder_type = neck_args.pop("name") self.ctc_encoder = SequenceEncoder( - in_channels=in_channels, encoder_type=encoder_type, **neck_args + in_channels=in_channels, encoder_type=encoder_type, data_format=data_format, **neck_args ) # ctc head head_args = self.head_list[idx][name]["Head"] self.ctc_head = eval(name)( in_channels=self.ctc_encoder.out_channels, out_channels=out_channels_list["CTCLabelDecode"], + data_format=data_format **head_args, ) else: @@ -144,6 +146,8 @@ def forward(self, x, targets=None): # eval mode if not self.training: return ctc_out + if not self.nchw: + x = x.transpose([0,3,1,2]) if self.gtc_head == "sar": sar_out = self.sar_head(x, targets[1:]) head_out["sar"] = sar_out diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py old mode 100644 new mode 100755 index 9c646a1d672..c5d44282b57 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -70,7 +70,7 @@ def __init__( kwargs = dict( input_size=d_model, hidden_size=d_enc, - num_layers=2, + num_layers=1, time_major=False, dropout=enc_drop_rnn, direction=direction, @@ -197,7 +197,7 @@ def __init__( kwargs = dict( input_size=encoder_rnn_out_size, hidden_size=encoder_rnn_out_size, - num_layers=2, + num_layers=1, time_major=False, dropout=dec_drop_rnn, direction=direction, diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 5c1674b434a..5f00e722941 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -42,6 +42,7 @@ def __init__( groups=None, if_act=True, act="relu", + data_format='NCHW', **kwargs, ): super(DSConv, self).__init__() @@ -57,9 +58,10 @@ def __init__( padding=padding, groups=groups, bias_attr=False, + data_format=data_format ) - self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None) + self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None, data_layout=data_format) self.conv2 = nn.Conv2D( in_channels=in_channels, @@ -67,9 +69,10 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format ) - self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None) + self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None, data_layout=data_format) self.conv3 = nn.Conv2D( in_channels=int(in_channels * 4), @@ -77,6 +80,7 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format ) self._c = [in_channels, out_channels] if in_channels != out_channels: @@ -86,6 +90,7 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format ) def forward(self, inputs): @@ -114,11 +119,12 @@ def forward(self, inputs): class DBFPN(nn.Layer): - def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): + def __init__(self, in_channels, out_channels, use_asf=False, data_format='NCHW', **kwargs): super(DBFPN, self).__init__() self.out_channels = out_channels self.use_asf = use_asf weight_attr = paddle.nn.initializer.KaimingUniform() + self.data_format = data_format self.in2_conv = nn.Conv2D( in_channels=in_channels[0], @@ -126,6 +132,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], @@ -133,6 +140,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], @@ -140,6 +148,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], @@ -147,6 +156,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, @@ -155,6 +165,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, @@ -163,6 +174,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, @@ -171,6 +183,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, @@ -179,6 +192,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) if self.use_asf is True: @@ -193,24 +207,24 @@ def forward(self, x): in2 = self.in2_conv(c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, scale_factor=2, mode="nearest", align_mode=1, data_format=self.data_format ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, scale_factor=2, mode="nearest", align_mode=1, data_format=self.data_format ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, scale_factor=2, mode="nearest", align_mode=1, data_format=self.data_format ) # 1/4 p5 = self.p5_conv(in5) p4 = self.p4_conv(out4) p3 = self.p3_conv(out3) p2 = self.p2_conv(out2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1, data_format=self.data_format) + p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1, data_format=self.data_format) + p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1, data_format=self.data_format) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat([p5, p4, p3, p2], axis=1 if 'NCHW' == self.data_format else 3) if self.use_asf is True: fuse = self.asf(fuse, [p5, p4, p3, p2]) @@ -219,7 +233,7 @@ def forward(self, x): class RSELayer(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): + def __init__(self, in_channels, out_channels, kernel_size, shortcut=True, data_format='NCHW'): super(RSELayer, self).__init__() weight_attr = paddle.nn.initializer.KaimingUniform() self.out_channels = out_channels @@ -230,8 +244,9 @@ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): padding=int(kernel_size // 2), weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) - self.se_block = SEModule(self.out_channels) + self.se_block = SEModule(self.out_channels, data_format=data_format) self.shortcut = shortcut def forward(self, ins): @@ -244,26 +259,28 @@ def forward(self, ins): class RSEFPN(nn.Layer): - def __init__(self, in_channels, out_channels, shortcut=True, **kwargs): + def __init__(self, in_channels, out_channels, shortcut=True, data_format='NCHW', **kwargs): super(RSEFPN, self).__init__() self.out_channels = out_channels + self.nchw = data_format=='NCHW' + self.data_format = data_format self.ins_conv = nn.LayerList() self.inp_conv = nn.LayerList() self.intracl = False if "intracl" in kwargs.keys() and kwargs["intracl"] is True: self.intracl = kwargs["intracl"] - self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) for i in range(len(in_channels)): self.ins_conv.append( - RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut) + RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut, data_format=data_format) ) self.inp_conv.append( RSELayer( - out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut + out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut, data_format=data_format ) ) @@ -276,13 +293,13 @@ def forward(self, x): in2 = self.ins_conv[0](c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/4 p5 = self.inp_conv[3](in5) @@ -296,18 +313,20 @@ def forward(self, x): p3 = self.incl2(p3) p2 = self.incl1(p2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1, data_format=data_format) + p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1, data_format=data_format) + p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat([p5, p4, p3, p2], axis=1 if self.nchw else 3) return fuse class LKPAN(nn.Layer): - def __init__(self, in_channels, out_channels, mode="large", **kwargs): + def __init__(self, in_channels, out_channels, mode="large", data_format='NCHW', **kwargs): super(LKPAN, self).__init__() self.out_channels = out_channels + self.nchw = data_format=='NCHW' + self.data_format = data_format weight_attr = paddle.nn.initializer.KaimingUniform() self.ins_conv = nn.LayerList() @@ -335,6 +354,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) ) @@ -346,6 +366,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): padding=4, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) ) @@ -359,6 +380,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): stride=2, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) ) self.pan_lat_conv.append( @@ -369,16 +391,17 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): padding=4, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format ) ) self.intracl = False if "intracl" in kwargs.keys() and kwargs["intracl"] is True: self.intracl = kwargs["intracl"] - self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) + self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2, data_format=data_format) def forward(self, x): c2, c3, c4, c5 = x @@ -389,13 +412,13 @@ def forward(self, x): in2 = self.ins_conv[0](c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format ) # 1/4 f5 = self.inp_conv[3](in5) @@ -418,11 +441,11 @@ def forward(self, x): p3 = self.incl2(p3) p2 = self.incl1(p2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1, data_format=data_format) + p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1, data_format=data_format) + p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1, data_format=data_format) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat([p5, p4, p3, p2], axis=1 if self.nchw else 3) return fuse diff --git a/ppocr/modeling/necks/intracl.py b/ppocr/modeling/necks/intracl.py index 2c4809cb122..3b7d80b0f66 100644 --- a/ppocr/modeling/necks/intracl.py +++ b/ppocr/modeling/necks/intracl.py @@ -5,16 +5,16 @@ class IntraCLBlock(nn.Layer): - def __init__(self, in_channels=96, reduce_factor=4): + def __init__(self, in_channels=96, reduce_factor=4, data_format='NCHW'): super(IntraCLBlock, self).__init__() self.channels = in_channels self.rf = reduce_factor weight_attr = paddle.nn.initializer.KaimingUniform() self.conv1x1_reduce_channel = nn.Conv2D( - self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0 + self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0, data_format=data_format ) self.conv1x1_return_channel = nn.Conv2D( - self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0 + self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0, data_format=data_format ) self.v_layer_7x1 = nn.Conv2D( @@ -23,6 +23,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), + data_format=data_format ) self.v_layer_5x1 = nn.Conv2D( self.channels // self.rf, @@ -30,6 +31,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(5, 1), stride=(1, 1), padding=(2, 0), + data_format=data_format ) self.v_layer_3x1 = nn.Conv2D( self.channels // self.rf, @@ -37,6 +39,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), + data_format=data_format ) self.q_layer_1x7 = nn.Conv2D( @@ -45,6 +48,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), + data_format=data_format ) self.q_layer_1x5 = nn.Conv2D( self.channels // self.rf, @@ -52,6 +56,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 5), stride=(1, 1), padding=(0, 2), + data_format=data_format ) self.q_layer_1x3 = nn.Conv2D( self.channels // self.rf, @@ -59,6 +64,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), + data_format=data_format ) # base @@ -68,6 +74,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), + data_format=data_format ) self.c_layer_5x5 = nn.Conv2D( self.channels // self.rf, @@ -82,9 +89,10 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), + data_format=data_format ) - self.bn = nn.BatchNorm2D(self.channels) + self.bn = nn.BatchNorm2D(self.channels, data_format=data_format) self.relu = nn.ReLU() def forward(self, x): diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index fa7b8a1f1af..d38e15b55f7 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -30,15 +30,21 @@ class Im2Seq(nn.Layer): - def __init__(self, in_channels, **kwargs): + def __init__(self, in_channels, *data_format='NCHW', *kwargs): super().__init__() self.out_channels = in_channels + self.nchw=data_format=='NCHW' def forward(self, x): - B, C, H, W = x.shape - assert H == 1 - x = x.squeeze(axis=2) - x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + if self.nchw: + B, C, H, W = x.shape + assert H == 1 + x = x.squeeze(axis=2) + x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + else: + B, H, W, C = x.shape + assert H == 1 + x = x.squeeze(axis=1) return x @@ -152,19 +158,22 @@ def __init__( drop_path=0.0, kernel_size=[3, 3], qk_scale=None, + data_format='NCHW' ): super(EncoderWithSVTR, self).__init__() self.depth = depth self.use_guide = use_guide + self.nchw=data_format=='NCHW' self.conv1 = ConvBNLayer( in_channels, in_channels // 8, kernel_size=kernel_size, padding=[kernel_size[0] // 2, kernel_size[1] // 2], act=nn.Swish, + data_format=data_format ) self.conv2 = ConvBNLayer( - in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish + in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish, data_format=data_format ) self.svtr_block = nn.LayerList( @@ -189,7 +198,7 @@ def __init__( ] ) self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6) - self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=nn.Swish) + self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=nn.Swish, data_format=data_format) # last conv-nxn, the input is concat of input tensor and conv3 output tensor self.conv4 = ConvBNLayer( 2 * in_channels, @@ -197,9 +206,10 @@ def __init__( kernel_size=kernel_size, padding=[kernel_size[0] // 2, kernel_size[1] // 2], act=nn.Swish, + data_format=data_format ) - self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=nn.Swish) + self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=nn.Swish, data_format=data_format) self.out_channels = dims self.apply(self._init_weights) @@ -225,23 +235,31 @@ def forward(self, x): z = self.conv1(z) z = self.conv2(z) # SVTR global block - B, C, H, W = z.shape - z = z.flatten(2).transpose([0, 2, 1]) + if self.nchw: + B, C, H, W = z.shape + z = z.flatten(2).transpose([0, 2, 1]) + else: + B, H, W, C = z.shape + z = z.flatten(start_axis=1, stop_axis=2) + for blk in self.svtr_block: z = blk(z) z = self.norm(z) # last stage - z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + if self.nchw: + z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + else: + z = z.reshape([0, H, W, C]) z = self.conv3(z) - z = paddle.concat((h, z), axis=1) + z = paddle.concat((h, z), axis=1 if self.nchw else 3) z = self.conv1x1(self.conv4(z)) return z class SequenceEncoder(nn.Layer): - def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): + def __init__(self, in_channels, encoder_type, hidden_size=48, data_format='NCHW', **kwargs): super(SequenceEncoder, self).__init__() - self.encoder_reshape = Im2Seq(in_channels) + self.encoder_reshape = Im2Seq(in_channels, data_format=data_format) self.out_channels = self.encoder_reshape.out_channels self.encoder_type = encoder_type if encoder_type == "reshape": @@ -259,15 +277,15 @@ def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): ) if encoder_type == "svtr": self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, **kwargs + self.encoder_reshape.out_channels, data_format=data_format, **kwargs ) elif encoder_type == "cascadernn": self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size, **kwargs + self.encoder_reshape.out_channels, hidden_size, data_format=data_format, **kwargs ) else: self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size + self.encoder_reshape.out_channels, hidden_size, data_format=data_format ) self.out_channels = self.encoder.out_channels self.only_reshape = False