Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Quantize QuestionAnswering models #1581

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions scripts/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1,
self.answerable_scores.add(nn.Dense(2, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer))
self.quantized_backbone = None
self.quantized = False

def get_start_logits(self, contextual_embedding, p_mask):
"""
Expand Down Expand Up @@ -287,10 +289,14 @@ def forward(self, tokens, token_types, valid_length, p_mask, start_position):
Shape (batch_size, sequence_length)
answerable_logits
"""
backbone_net = self.backbone
if self.quantized:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.quantized:
if self.quantized_bacbone is not None:

end remove quantized ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but I remained this quantized flag as switch on/off of quantized model - not sure if it is really usable. What do you think?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be not usable for me.

backbone_net = self.quantized_backbone

if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
contextual_embeddings = backbone_net(tokens, valid_length)
start_logits = self.get_start_logits(contextual_embeddings, p_mask)
end_logits = self.get_end_logits(contextual_embeddings,
np.expand_dims(start_position, axis=1),
Expand Down Expand Up @@ -337,11 +343,16 @@ def inference(self, tokens, token_types, valid_length, p_mask,
The answerable logits. Here 0 --> answerable and 1 --> not answerable.
Shape (batch_size, sequence_length, 2)
"""
backbone_net = self.backbone
if self.quantized:
backbone_net = self.quantized_backbone

# Shape (batch_size, sequence_length, C)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
contextual_embeddings = backbone_net(tokens, valid_length)

start_logits = self.get_start_logits(contextual_embeddings, p_mask)
# The shape of start_top_index will be (..., start_top_n)
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
Expand Down
80 changes: 77 additions & 3 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def parse_args():
parser.add_argument('--max_saved_ckpt', type=int, default=5,
help='The maximum number of saved checkpoints')
parser.add_argument('--dtype', type=str, default='float32',
help='Data type used for evaluation. Either float32 or float16. When you '
help='Data type used for evaluation. Either float32, float16 or int8. When you '
'use --dtype float16, amp will be turned on in the training phase and '
'fp16 will be used in evaluation.')
'fp16 will be used in evaluation. For now int8 data type is supported on CPU only.')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -815,6 +815,71 @@ def predict_extended(original_feature,
assert len(nbest_json) >= 1
return not_answerable_score, nbest[0][0], nbest_json

def quantize_and_calibrate(net, dataloader):
class QuantizationDataLoader(mx.gluon.data.DataLoader):
def __init__(self, dataloader, use_segmentation):
self._dataloader = dataloader
self._iter = None
self._use_segmentation = use_segmentation

def __iter__(self):
self._iter = iter(self._dataloader)
return self

def __next__(self):
batch = next(self._iter)
if self._use_segmentation:
return [batch.data, batch.segment_ids, batch.valid_length]
else:
return [batch.data, batch.valid_length]

def __del__(self):
del(self._dataloader)

class BertLayerCollector(mx.contrib.quantization.CalibrationCollector):
"""Saves layer output min and max values in a dict with layer names as keys.
The collected min and max values will be directly used as thresholds for quantization.
"""
def __init__(self, clip_min, clip_max):
super(BertLayerCollector, self).__init__()
self.clip_min = clip_min
self.clip_max = clip_max

def collect(self, name, op_name, arr):
"""Callback function for collecting min and max values from an NDArray."""
if name not in self.include_layers:
return
arr = arr.copyto(mx.cpu()).asnumpy()
min_range = np.min(arr)
max_range = np.max(arr)

if (op_name.find("npi_copy") != -1 or op_name.find("LayerNorm") != -1) and max_range > self.clip_max:
max_range = self.clip_max

if op_name.find('Dropout') != -1 and min_range < self.clip_min:
print(name, op_name)
min_range = self.clip_min

if name in self.min_max_dict:
cur_min_max = self.min_max_dict[name]
self.min_max_dict[name] = (min(cur_min_max[0], min_range),
max(cur_min_max[1], max_range))
else:
self.min_max_dict[name] = (min_range, max_range)

calib_data = QuantizationDataLoader(dataloader, net.use_segmentation)
net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto',
exclude_layers=None,
exclude_layers_match=None,
calib_data=calib_data,
calib_mode='custom',
LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10),
num_calib_batches=10,
ctx=mx.current_context(),
logger=logging.getLogger())
net.quantized = True
return net


def evaluate(args, last=True):
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
Expand All @@ -828,9 +893,10 @@ def evaluate(args, last=True):
logging.info(
'Srarting inference without horovod on the first node on device {}'.format(
str(ctx_l)))
network_dtype = args.dtype if args.dtype != 'int8' else 'float32'

cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout, dtype=args.dtype)
args.model_name, ctx_l, args.classifier_dropout, dtype=network_dtype)
if args.dtype == 'float16':
qa_net.cast('float16')
qa_net.hybridize()
Expand Down Expand Up @@ -860,6 +926,9 @@ def eval_validation(ckpt_name, best_eval):
num_workers=0,
shuffle=False)

if args.dtype == 'int8':
quantize_and_calibrate(qa_net, dev_dataloader)

log_interval = args.eval_log_interval
all_results = []
epoch_tic = time.time()
Expand Down Expand Up @@ -999,6 +1068,11 @@ def eval_validation(ckpt_name, best_eval):
if __name__ == '__main__':
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
args = parse_args()
if args.dtype == 'int8':
ctx_l = parse_ctx(args.gpus)
if ctx_l[0] != mx.cpu() or len(ctx_l) != 1:
raise ValueError("Evaluation on int8 data type is supported only for CPU for now")

if args.do_train:
if args.dtype == 'float16':
# Initialize amp if it's fp16 training
Expand Down