Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
import math

def EPE(input_flow, target_flow):
return torch.norm(target_flow-input_flow,p=2,dim=1).mean()
return torch.norm(target_flow-input_flow[:,None],p=2,dim=1).mean()

class L1(nn.Module):
def __init__(self):
super(L1, self).__init__()
def forward(self, output, target):
lossvalue = torch.abs(output - target).mean()
lossvalue = torch.abs(output[:,None] - target).mean()
return lossvalue

class L2(nn.Module):
def __init__(self):
super(L2, self).__init__()
def forward(self, output, target):
lossvalue = torch.norm(output-target,p=2,dim=1).mean()
lossvalue = torch.norm(output[:,None]-target,p=2,dim=1).mean()
return lossvalue

class L1Loss(nn.Module):
Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def train(args, epoch, start_iteration, data_loader, model, optimizer, logger, i
losses = model(data[0], target[0])
losses = [torch.mean(loss_value) for loss_value in losses]
loss_val = losses[0] # Collect first loss for weight update
total_loss += loss_val.data[0]
loss_values = [v.data[0] for v in losses]
total_loss += loss_val.item()
loss_values = [v.item() for v in losses]

# gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather'
loss_labels = list(model.module.loss.loss_labels)
Expand Down Expand Up @@ -366,10 +366,10 @@ def inference(args, epoch, data_loader, model, offset=0):
with torch.no_grad():
losses, output = model(data[0], target[0], inference=True)

losses = [torch.mean(loss_value) for loss_value in losses]
losses = [torch.mean(loss_value) for loss_value in losses]
Copy link

@stefan-sf-wu stefan-sf-wu Jan 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate.

loss_val = losses[0] # Collect first loss for weight update
total_loss += loss_val.data[0]
loss_values = [v.data[0] for v in losses]
total_loss += loss_val.item()
loss_values = [v.item() for v in losses]

# gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather'
loss_labels = list(model.module.loss.loss_labels)
Expand Down