diff --git a/neural_compressor/experimental/common/criterion.py b/neural_compressor/experimental/common/criterion.py index 5a6a99f2fda..4382e827225 100644 --- a/neural_compressor/experimental/common/criterion.py +++ b/neural_compressor/experimental/common/criterion.py @@ -82,6 +82,7 @@ def __getitem__(self, criterion_type): Args: criterion_type (string): criterion type. + Returns: cls: criterion class. """ @@ -1249,8 +1250,21 @@ def __call__(self, **kwargs): class SelfKnowledgeDistillationLoss(KnowledgeDistillationFramework): - def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0, - add_origin_loss=False, student_model=None, teacher_model=None): + """SelfKnowledge Distillation Loss.""" + + def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,add_origin_loss=False, student_model=None, teacher_model=None): + """Initialize SelfKnowledge Distillation Loss class. + + Args: + layer_mappings (list): layers of distillation.Format like + [[[student1_layer_name1, teacher_layer_name1],[student2_layer_name1, teacher_layer_name1]],[[student1_layer_name2, teacher_layer_name2],[student2_layer_name2, teacher_layer_name2]]] + loss_types (list, optional): loss types. Defaults to ['CE'] * len(layer_mappings). + loss_weights (list, optional): loss weights. Defaults to [1.0 / len(layer_mappings)] * len(layer_mappings).temperature (float, optional): use to calculate the soft label CE. + temperature (optional): temperature. Defaults to 1.0. + add_origin_loss (bool, optional): whether to add origin loss for hard label loss. + student_model (optional): student model. Defaults to None. + teacher_model (optional): teacher model. Defaults to None. + """ super(SelfKnowledgeDistillationLoss, self).__init__(student_model=student_model, teacher_model=teacher_model) self.temperature = temperature @@ -1277,32 +1291,70 @@ def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temper ) def init_loss_funcs(self): + """Init loss funcs. + + Raises: + NotImplementedError: NotImplementedError + """ raise NotImplementedError('Function init_loss_funcs ' 'should be framework related.') def teacher_model_forward(self, input, teacher_model=None): + """Teacher model forward. + + Raises: + NotImplementedError: NotImplementedError + """ raise NotImplementedError('Function teacher_model_forward ' 'should be framework related.') def loss_cal(self, student_outputs): + """Calculate loss. + + Raises: + NotImplementedError: NotImplementedError + """ raise NotImplementedError( 'Function loss_cal should be framework related.') def loss_cal_sloss(self, student_outputs, teacher_outputs, student_loss): + """Calculate all losses between student model and teacher model. + + Args: + student_outputs (dict): student outputs + teacher_outputs (dict): teacher outputs + student_loss (tensor): student loss + + Returns: + tensor: loss + """ loss = self.loss_cal(student_outputs) if self.add_origin_loss: loss += student_loss return loss def __call__(self, student_outputs, targets): + """Return 0.""" return 0 class PyTorchSelfKnowledgeDistillationLoss( SelfKnowledgeDistillationLoss ): - def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0, - add_origin_loss=False, student_model=None, teacher_model=None): + """PyTorch SelfKnowledge Distillation Loss.""" + def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,add_origin_loss=False, student_model=None, teacher_model=None): + """Initialize PyTorch SelfKnowledge Distillation Loss class. + + Args: + layer_mappings (list): layers of distillation.Format like + [[[student1_layer_name1, teacher_layer_name1],[student2_layer_name1, teacher_layer_name1]],[[student1_layer_name2, teacher_layer_name2],[student2_layer_name2, teacher_layer_name2]]] + loss_types (list, optional): loss types. Defaults to ['CE'] * len(layer_mappings). + loss_weights (list, optional): loss weights. Defaults to [1.0 / len(layer_mappings)] * len(layer_mappings).temperature (float, optional): use to calculate the soft label CE. + temperature (optional): temperature. Defaults to 1.0. + add_origin_loss (bool, optional): whether to add origin loss for hard label loss. + student_model (optional): student model. Defaults to None. + teacher_model (optional): teacher model. Defaults to None. + """ super(PyTorchSelfKnowledgeDistillationLoss, self).__init__( layer_mappings=layer_mappings, loss_types=loss_types, @@ -1313,19 +1365,47 @@ def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temper teacher_model=teacher_model) def SoftCrossEntropy(self, logits, targets): + """Return SoftCrossEntropy. + + Args: + logits (tensor): output logits + targets (tensor): ground truth label + + Returns: + tensor: SoftCrossEntropy + """ log_prob = torch.nn.functional.log_softmax(logits, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * log_prob).sum(dim=-1).mean() def KullbackLeiblerDivergence(self, logits, targets): + """Return KullbackLeiblerDivergence. + + Args: + logits (tensor): output logits + targets (tensor): ground truth label + + Returns: + tensor: KullbackLeiblerDivergence + """ log_prob = torch.nn.functional.log_softmax(logits, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return torch.nn.functional.kl_div(log_prob, targets_prob) def L2Divergence(self, feature1, feature2): + """Return L2Divergence. + + Args: + feature1 (tensor): feature1 value + feature2 (tensor): feature2 value + + Returns: + tensor: L2Divergence between feature1 and feature2 + """ return torch.dist(feature1, feature2) def init_loss_funcs(self): + """Init loss funcs.""" for loss_type in self.loss_types: if loss_type == 'CE': loss_func = self.SoftCrossEntropy @@ -1340,6 +1420,14 @@ def init_loss_funcs(self): self.loss_funcs.append(loss_func) def loss_cal(self, student_outputs): + """Calculate loss of student model. + + Args: + student_outputs (dict): student outputs + + Returns: + tensor: loss + """ self.loss = torch.FloatTensor([0.]) tmp_loss = 0 temperature = self.temperature @@ -1363,7 +1451,13 @@ def loss_cal(self, student_outputs): @criterion_registry('SelfKnowledgeDistillationLoss', 'pytorch') class PyTorchSelfKnowledgeDistillationLossWrapper(object): + """PyTorch SelfKnowledge Distillation Loss Wrapper.""" def __init__(self, param_dict): + """Initialize PyTorchSelfKnowledgeDistillationLossWrapper class. + + Args: + param_dict (dict): param dict + """ self.param_dict = param_dict def _param_check(self): @@ -1412,4 +1506,10 @@ def _param_check(self): return new_dict def __call__(self, **kwargs): + """Return PyTorchSelfKnowledgeDistillationLoss, param dict. + + Returns: + class: PyTorchSelfKnowledgeDistillationLoss + param dict (dict): param dict + """ return PyTorchSelfKnowledgeDistillationLoss, self._param_check()