From f6266f0c2ee2b62318cca3083708219f2cf97255 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Mon, 21 Feb 2022 10:59:22 +0100 Subject: [PATCH] Reduce after quantization memory usage (#20894) --- python/mxnet/contrib/quantization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 4ad354a7d2e2..10d2455cb9ae 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -921,6 +921,9 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize if calib_mode in ['naive', 'entropy', 'custom']: inputs = [mx.sym.var(desc.name) for desc in data_descs] calib_net = SymbolBlock(symnet, inputs) + for k, v in calib_net.collect_params().items(): + v.grad_req = 'null' + calib_net.load_dict(params, cast_dtype=True, dtype_source='saved') calib_net.hybridize(static_alloc=False, static_shape=False) num_batches = _collect_layer_statistics(calib_net, calib_data, collector, num_inputs, @@ -939,6 +942,9 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize inputs = [mx.sym.var(desc.name) for desc in data_descs] net = SymbolBlock(qsym, inputs) + for k, v in net.collect_params().items(): + v.grad_req = 'null' + all_params = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in qarg_params.items()} all_params.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) net.load_dict(all_params, cast_dtype=True, dtype_source='saved')