Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def initialize_observer(
The name is then used to load the observer from the registry and attached
to the module. The name of the observer uses the base_name provided.

This function always initializes memoryless observers for weights

:param module: torch.nn.Module that the observer is being attached to
:param base_name: str used to name the observer attribute

Expand All @@ -57,9 +59,25 @@ def initialize_observer(
args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
observer = args.observer

# training is no longer supported: always use memoryless for weights
if base_name == "weight" and args.observer in ("static_minmax", "minmax"):
observer = "memoryless_minmax"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})"
)
if base_name == "weight" and args.observer in ("mse",):
observer = "memoryless_mse"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})"
)

if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
args.observer, base_name=base_name, args=args, module=module
observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)

Expand Down