File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change 33import torch
44from torch import nn , Tensor
55
6+ from .misc import FrozenBatchNorm2d
7+
68
79def _cat (tensors : List [Tensor ], dim : int = 0 ) -> Tensor :
810 """
@@ -43,7 +45,13 @@ def split_normalization_params(
4345) -> Tuple [List [Tensor ], List [Tensor ]]:
4446 # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
4547 if not norm_classes :
46- norm_classes = [nn .modules .batchnorm ._BatchNorm , nn .LayerNorm , nn .GroupNorm ]
48+ norm_classes = [
49+ nn .modules .batchnorm ._BatchNorm ,
50+ nn .LayerNorm ,
51+ nn .GroupNorm ,
52+ nn .modules .instancenorm ._InstanceNorm ,
53+ nn .LocalResponseNorm ,
54+ ]
4755
4856 for t in norm_classes :
4957 if not issubclass (t , nn .Module ):
You can’t perform that action at this time.
0 commit comments