Skip to content

Commit ec39199

Browse files
author
Yuanjing Shi
authored
[PyTorch] [Relay] Add l1 and mse loss function for pytorch frontend (#11978)
* add l1 and mse loss function for pytorch frontend * fix CI
1 parent beea0d2 commit ec39199

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,35 @@ def cross_entropy_loss_with_logits(self, inputs, input_types):
932932
assert weights is None, "weight not supported in cross_entropy_loss"
933933
return _op.nn.cross_entropy_with_logits(_op.nn.log_softmax(input), target)
934934

935+
def l1_loss(self, inputs, input_types):
936+
assert len(inputs) == 3
937+
[predictions, targets, reduction] = inputs
938+
delta = _op.abs(_op.subtract(predictions, targets))
939+
if reduction == 0:
940+
# reduction = "none"
941+
return delta
942+
elif reduction == 1:
943+
# reduction = "mean"
944+
return _op.mean(delta)
945+
else:
946+
# reduction = "sum"
947+
return _op.sum(delta)
948+
949+
def mse_loss(self, inputs, input_types):
950+
assert len(inputs) == 3
951+
[predictions, targets, reduction] = inputs
952+
delta = _op.subtract(predictions, targets)
953+
delta = _op.power(delta, _expr.const(2, input_types[0]))
954+
if reduction == 0:
955+
# reduction = "none"
956+
return delta
957+
elif reduction == 1:
958+
# reduction = "mean"
959+
return _op.mean(delta)
960+
else:
961+
# reduction = "sum"
962+
return _op.sum(delta)
963+
935964
def hard_sigmoid(self, inputs, input_types):
936965
def _relu6(x):
937966
return _op.tensor.clip(x, 0.0, 6.0)
@@ -3200,7 +3229,6 @@ def create_convert_map(self):
32003229
"aten::silu": self.silu,
32013230
"aten::glu": self.glu,
32023231
"aten::log_sigmoid": self.log_sigmoid,
3203-
"aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
32043232
"aten::adaptive_avg_pool1d": functools.partial(
32053233
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d
32063234
),
@@ -3374,6 +3402,9 @@ def create_convert_map(self):
33743402
"aten::nll_loss": self.nll_loss,
33753403
"aten::nll_loss2d": self.nll_loss,
33763404
"aten::nll_loss_nd": self.nll_loss,
3405+
"aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
3406+
"aten::l1_loss": self.l1_loss,
3407+
"aten::mse_loss": self.mse_loss,
33773408
"aten::flip": self.flip,
33783409
"aten::gru": self.gru,
33793410
"aten::lstm": self.lstm,

python/tvm/topi/nn/softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def log_softmax(x, axis=-1):
129129
Parameters
130130
----------
131131
data : tvm.te.Tensor
132-
2-D input data
132+
N-D input data
133133
134134
Returns
135135
-------
136136
output : tvm.te.Tensor
137-
2-D output with same shape
137+
N-D output with same shape
138138
"""
139139
shape = x.shape
140140
if axis < 0:

tests/python/frontend/pytorch/test_forward.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4177,6 +4177,42 @@ def test_cross_entropy_loss():
41774177
verify_model(torch.nn.CrossEntropyLoss().eval(), input_data=[predictions, targets])
41784178

41794179

4180+
def test_forward_l1_loss():
4181+
torch.set_grad_enabled(False)
4182+
N, C = 10, 3
4183+
predictions = torch.rand((N, C)).float()
4184+
targets = torch.rand((N, C)).float()
4185+
verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
4186+
verify_model(torch.nn.L1Loss(reduction="sum").eval(), input_data=[predictions, targets])
4187+
verify_model(torch.nn.L1Loss(reduction="none").eval(), input_data=[predictions, targets])
4188+
4189+
# multidimension l1 loss
4190+
d1, d2 = 2, 3
4191+
predictions = torch.rand((N, C, d1, d2)).float()
4192+
targets = torch.rand((N, C, d1, d2)).float()
4193+
verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
4194+
verify_model(torch.nn.L1Loss(reduction="sum").eval(), input_data=[predictions, targets])
4195+
verify_model(torch.nn.L1Loss(reduction="none").eval(), input_data=[predictions, targets])
4196+
4197+
4198+
def test_forward_mse_loss():
4199+
torch.set_grad_enabled(False)
4200+
N, C = 10, 3
4201+
predictions = torch.rand((N, C)).float()
4202+
targets = torch.rand((N, C)).float()
4203+
verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
4204+
verify_model(torch.nn.MSELoss(reduction="sum").eval(), input_data=[predictions, targets])
4205+
verify_model(torch.nn.MSELoss(reduction="none").eval(), input_data=[predictions, targets])
4206+
4207+
# multidimension mse loss
4208+
d1, d2 = 2, 3
4209+
predictions = torch.rand((N, C, d1, d2)).float()
4210+
targets = torch.rand((N, C, d1, d2)).float()
4211+
verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
4212+
verify_model(torch.nn.MSELoss(reduction="sum").eval(), input_data=[predictions, targets])
4213+
verify_model(torch.nn.MSELoss(reduction="none").eval(), input_data=[predictions, targets])
4214+
4215+
41804216
@tvm.testing.uses_gpu
41814217
def test_forward_flip():
41824218
torch.set_grad_enabled(False)

0 commit comments

Comments
 (0)