Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3 shebang and whitespace cleanup #85

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
cd alt_cuda_corr && python3 setup.py install --user && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
17 changes: 7 additions & 10 deletions core/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid

try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass


class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
Expand All @@ -20,7 +14,7 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):

batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)

self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
Expand Down Expand Up @@ -53,15 +47,18 @@ def __call__(self, coords):
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)

corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())


class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
import alt_cuda_corr
self.alt_corr_fwd = alt_cuda_corr.forward

self.num_levels = num_levels
self.radius = radius

Expand All @@ -83,7 +80,7 @@ def __call__(self, coords):
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr, = self.alt_corr_fwd(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))

corr = torch.stack(corr_list, dim=1)
Expand Down
18 changes: 9 additions & 9 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, args):
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3

else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
Expand All @@ -46,12 +46,12 @@ def __init__(self, args):

# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)

else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)

Expand Down Expand Up @@ -83,7 +83,7 @@ def upsample_flow(self, flow, mask):
return up_flow.reshape(N, 2, 8*H, 8*W)


def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
def forward(self, image1, image2, iters=torch.tensor(12), flow_init=torch.tensor([]), upsample=torch.tensor(True), test_mode=torch.tensor(False)):
""" Estimate optical flow between pair of frames """

image1 = 2 * (image1 / 255.0) - 1.0
Expand All @@ -97,8 +97,8 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_

# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1, fmap2 = self.fnet([image1, image2])

fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
Expand All @@ -115,7 +115,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_

coords0, coords1 = self.initialize_flow(image1)

if flow_init is not None:
if flow_init is not None and flow_init.numel()>0:
coords1 = coords1 + flow_init

flow_predictions = []
Expand All @@ -135,10 +135,10 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)

flow_predictions.append(flow_up)

if test_mode:
return coords1 - coords0, flow_up

return flow_predictions
8 changes: 5 additions & 3 deletions demo.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#! /usr/bin/env python3

import sys
sys.path.append('core')

Expand Down Expand Up @@ -26,7 +28,7 @@ def load_image(imfile):
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()

# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
Expand All @@ -50,7 +52,7 @@ def demo(args):
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))

images = sorted(images)
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
Expand All @@ -59,7 +61,7 @@ def demo(args):
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
flow_low, flow_up = model(image1, image2, iters=torch.tensor(20), test_mode=torch.tensor(True))
viz(image1, flow_up)


Expand Down
8 changes: 5 additions & 3 deletions evaluate.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#! /usr/bin/env python3

import sys
sys.path.append('core')

Expand All @@ -24,13 +26,13 @@ def create_sintel_submission(model, iters=32, warm_start=False, output_path='sin
model.eval()
for dstype in ['clean', 'final']:
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)

flow_prev, sequence_prev = None, None
for test_id in range(len(test_dataset)):
image1, image2, (sequence, frame) = test_dataset[test_id]
if sequence != sequence_prev:
flow_prev = None

padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())

Expand All @@ -39,7 +41,7 @@ def create_sintel_submission(model, iters=32, warm_start=False, output_path='sin

if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()

output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))

Expand Down
20 changes: 11 additions & 9 deletions train.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#! /usr/bin/env python3

from __future__ import print_function, division
import sys
sys.path.append('core')
Expand Down Expand Up @@ -47,7 +49,7 @@ def update(self):
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """

n_predictions = len(flow_preds)
n_predictions = len(flow_preds)
flow_loss = 0.0

# exlude invalid pixels and extremely large diplacements
Expand Down Expand Up @@ -84,7 +86,7 @@ def fetch_optimizer(args, model):
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')

return optimizer, scheduler


class Logger:
def __init__(self, model, scheduler):
Expand All @@ -98,7 +100,7 @@ def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)

# print the training status
print(training_str + metrics_str)

Expand Down Expand Up @@ -169,13 +171,13 @@ def train(args):
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)

flow_predictions = model(image1, image2, iters=args.iters)
flow_predictions = model(image1, image2, iters=args.iters)

loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

scaler.step(optimizer)
scheduler.step()
scaler.update()
Expand All @@ -196,11 +198,11 @@ def train(args):
results.update(evaluate.validate_kitti(model.module))

logger.write_dict(results)

model.train()
if args.stage != 'chairs':
model.module.freeze_bn()

total_steps += 1

if total_steps > args.num_steps:
Expand All @@ -217,7 +219,7 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
Expand Down