diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index a1a6a6fac9f1..9a4889005782 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -511,7 +511,7 @@ def _silu_flops_compute(input: Tensor, inplace: bool = False): return input.numel(), 0 -def _gelu_flops_compute(input): +def _gelu_flops_compute(input, **kwargs): return input.numel(), 0 @@ -668,16 +668,14 @@ def _instance_norm_flops_compute( return input.numel() * (5 if has_affine else 4), 0 -def _upsample_flops_compute(input, - size=None, - scale_factor=None, - mode="nearest", - align_corners=None): +def _upsample_flops_compute(input, **kwargs): + size = kwargs.get('size', None) if size is not None: - if isinstance(size, tuple): + if isinstance(size, tuple) or isinstance(size, list): return int(_prod(size)), 0 else: return int(size), 0 + scale_factor = kwargs.get('scale_factor', None) assert scale_factor is not None, "either size or scale_factor should be defined" flops = input.numel() if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):