Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 104 additions & 4 deletions neural_compressor/experimental/common/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __getitem__(self, criterion_type):

Args:
criterion_type (string): criterion type.

Returns:
cls: criterion class.
"""
Expand Down Expand Up @@ -1249,8 +1250,21 @@ def __call__(self, **kwargs):


class SelfKnowledgeDistillationLoss(KnowledgeDistillationFramework):
def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,
add_origin_loss=False, student_model=None, teacher_model=None):
"""SelfKnowledge Distillation Loss."""

def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,add_origin_loss=False, student_model=None, teacher_model=None):
"""Initialize SelfKnowledge Distillation Loss class.

Args:
layer_mappings (list): layers of distillation.Format like
[[[student1_layer_name1, teacher_layer_name1],[student2_layer_name1, teacher_layer_name1]],[[student1_layer_name2, teacher_layer_name2],[student2_layer_name2, teacher_layer_name2]]]
loss_types (list, optional): loss types. Defaults to ['CE'] * len(layer_mappings).
loss_weights (list, optional): loss weights. Defaults to [1.0 / len(layer_mappings)] * len(layer_mappings).temperature (float, optional): use to calculate the soft label CE.
temperature (optional): temperature. Defaults to 1.0.
add_origin_loss (bool, optional): whether to add origin loss for hard label loss.
student_model (optional): student model. Defaults to None.
teacher_model (optional): teacher model. Defaults to None.
"""
super(SelfKnowledgeDistillationLoss, self).__init__(student_model=student_model,
teacher_model=teacher_model)
self.temperature = temperature
Expand All @@ -1277,32 +1291,70 @@ def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temper
)

def init_loss_funcs(self):
"""Init loss funcs.

Raises:
NotImplementedError: NotImplementedError
"""
raise NotImplementedError('Function init_loss_funcs '
'should be framework related.')

def teacher_model_forward(self, input, teacher_model=None):
"""Teacher model forward.

Raises:
NotImplementedError: NotImplementedError
"""
raise NotImplementedError('Function teacher_model_forward '
'should be framework related.')

def loss_cal(self, student_outputs):
"""Calculate loss.

Raises:
NotImplementedError: NotImplementedError
"""
raise NotImplementedError(
'Function loss_cal should be framework related.')

def loss_cal_sloss(self, student_outputs, teacher_outputs, student_loss):
"""Calculate all losses between student model and teacher model.

Args:
student_outputs (dict): student outputs
teacher_outputs (dict): teacher outputs
student_loss (tensor): student loss

Returns:
tensor: loss
"""
loss = self.loss_cal(student_outputs)
if self.add_origin_loss:
loss += student_loss
return loss

def __call__(self, student_outputs, targets):
"""Return 0."""
return 0


class PyTorchSelfKnowledgeDistillationLoss(
SelfKnowledgeDistillationLoss
):
def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,
add_origin_loss=False, student_model=None, teacher_model=None):
"""PyTorch SelfKnowledge Distillation Loss."""
def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temperature=1.0,add_origin_loss=False, student_model=None, teacher_model=None):
"""Initialize PyTorch SelfKnowledge Distillation Loss class.

Args:
layer_mappings (list): layers of distillation.Format like
[[[student1_layer_name1, teacher_layer_name1],[student2_layer_name1, teacher_layer_name1]],[[student1_layer_name2, teacher_layer_name2],[student2_layer_name2, teacher_layer_name2]]]
loss_types (list, optional): loss types. Defaults to ['CE'] * len(layer_mappings).
loss_weights (list, optional): loss weights. Defaults to [1.0 / len(layer_mappings)] * len(layer_mappings).temperature (float, optional): use to calculate the soft label CE.
temperature (optional): temperature. Defaults to 1.0.
add_origin_loss (bool, optional): whether to add origin loss for hard label loss.
student_model (optional): student model. Defaults to None.
teacher_model (optional): teacher model. Defaults to None.
"""
super(PyTorchSelfKnowledgeDistillationLoss, self).__init__(
layer_mappings=layer_mappings,
loss_types=loss_types,
Expand All @@ -1313,19 +1365,47 @@ def __init__(self, layer_mappings=[], loss_types=None, loss_weights=None, temper
teacher_model=teacher_model)

def SoftCrossEntropy(self, logits, targets):
"""Return SoftCrossEntropy.

Args:
logits (tensor): output logits
targets (tensor): ground truth label

Returns:
tensor: SoftCrossEntropy
"""
log_prob = torch.nn.functional.log_softmax(logits, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return (-targets_prob * log_prob).sum(dim=-1).mean()

def KullbackLeiblerDivergence(self, logits, targets):
"""Return KullbackLeiblerDivergence.

Args:
logits (tensor): output logits
targets (tensor): ground truth label

Returns:
tensor: KullbackLeiblerDivergence
"""
log_prob = torch.nn.functional.log_softmax(logits, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return torch.nn.functional.kl_div(log_prob, targets_prob)

def L2Divergence(self, feature1, feature2):
"""Return L2Divergence.

Args:
feature1 (tensor): feature1 value
feature2 (tensor): feature2 value

Returns:
tensor: L2Divergence between feature1 and feature2
"""
return torch.dist(feature1, feature2)

def init_loss_funcs(self):
"""Init loss funcs."""
for loss_type in self.loss_types:
if loss_type == 'CE':
loss_func = self.SoftCrossEntropy
Expand All @@ -1340,6 +1420,14 @@ def init_loss_funcs(self):
self.loss_funcs.append(loss_func)

def loss_cal(self, student_outputs):
"""Calculate loss of student model.

Args:
student_outputs (dict): student outputs

Returns:
tensor: loss
"""
self.loss = torch.FloatTensor([0.])
tmp_loss = 0
temperature = self.temperature
Expand All @@ -1363,7 +1451,13 @@ def loss_cal(self, student_outputs):

@criterion_registry('SelfKnowledgeDistillationLoss', 'pytorch')
class PyTorchSelfKnowledgeDistillationLossWrapper(object):
"""PyTorch SelfKnowledge Distillation Loss Wrapper."""
def __init__(self, param_dict):
"""Initialize PyTorchSelfKnowledgeDistillationLossWrapper class.

Args:
param_dict (dict): param dict
"""
self.param_dict = param_dict

def _param_check(self):
Expand Down Expand Up @@ -1412,4 +1506,10 @@ def _param_check(self):
return new_dict

def __call__(self, **kwargs):
"""Return PyTorchSelfKnowledgeDistillationLoss, param dict.

Returns:
class: PyTorchSelfKnowledgeDistillationLoss
param dict (dict): param dict
"""
return PyTorchSelfKnowledgeDistillationLoss, self._param_check()