diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 3017c4dea3..4c58a678cf 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -43,6 +43,8 @@ def initialize_observer( The name is then used to load the observer from the registry and attached to the module. The name of the observer uses the base_name provided. + This function always initializes memoryless observers for weights + :param module: torch.nn.Module that the observer is being attached to :param base_name: str used to name the observer attribute @@ -57,9 +59,25 @@ def initialize_observer( args: QuantizationArgs = getattr_chain( module, f"quantization_scheme.{arg_name}", None ) + observer = args.observer + + # training is no longer supported: always use memoryless for weights + if base_name == "weight" and args.observer in ("static_minmax", "minmax"): + observer = "memoryless_minmax" + logger.warning( + "Overriding weight observer for lower memory usage " + f"({args.observer} -> {observer})" + ) + if base_name == "weight" and args.observer in ("mse",): + observer = "memoryless_mse" + logger.warning( + "Overriding weight observer for lower memory usage " + f"({args.observer} -> {observer})" + ) + if args is not None and args.dynamic is not True: observer = Observer.load_from_registry( - args.observer, base_name=base_name, args=args, module=module + observer, base_name=base_name, args=args, module=module ) module.register_module(f"{base_name}_observer", observer)