Skip to content

Commit 31dd3b8

Browse files
committed
Merge branch 'main' into multiweight
# Conflicts: # torchvision/prototype/models/detection/ssdlite.py
2 parents 6d96ed5 + 289fce2 commit 31dd3b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+769
-248
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ supported Python versions.
2121
+--------------------------+--------------------------+---------------------------------+
2222
| ``torch`` | ``torchvision`` | ``python`` |
2323
+==========================+==========================+=================================+
24-
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.9`` |
24+
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.10`` |
25+
+--------------------------+--------------------------+---------------------------------+
26+
| ``1.11.0`` | ``0.12.0`` | ``>=3.7``, ``<=3.10`` |
2527
+--------------------------+--------------------------+---------------------------------+
2628
| ``1.10.2`` | ``0.11.3`` | ``>=3.6``, ``<=3.9`` |
2729
+--------------------------+--------------------------+---------------------------------+

references/classification/transforms.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module):
2121

2222
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
2323
super().__init__()
24-
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
25-
assert alpha > 0, "Alpha param can't be zero."
24+
25+
if num_classes < 1:
26+
raise ValueError(
27+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28+
)
29+
30+
if alpha <= 0:
31+
raise ValueError("Alpha param can't be zero.")
2632

2733
self.num_classes = num_classes
2834
self.p = p
@@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module):
99105

100106
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
101107
super().__init__()
102-
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
103-
assert alpha > 0, "Alpha param can't be zero."
108+
if num_classes < 1:
109+
raise ValueError("Please provide a valid positive value for the num_classes.")
110+
if alpha <= 0:
111+
raise ValueError("Alpha param can't be zero.")
104112

105113
self.num_classes = num_classes
106114
self.p = p

references/detection/coco_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
class CocoEvaluator:
1414
def __init__(self, coco_gt, iou_types):
15-
assert isinstance(iou_types, (list, tuple))
15+
if not isinstance(iou_types, (list, tuple)):
16+
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
1617
coco_gt = copy.deepcopy(coco_gt)
1718
self.coco_gt = coco_gt
1819

references/detection/coco_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def _has_valid_annotation(anno):
126126
return True
127127
return False
128128

129-
assert isinstance(dataset, torchvision.datasets.CocoDetection)
129+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
130+
raise TypeError(
131+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
132+
)
130133
ids = []
131134
for ds_idx, img_id in enumerate(dataset.ids):
132135
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)

references/optical_flow/train.py

Lines changed: 79 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
6060

6161

6262
@torch.no_grad()
63-
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
63+
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
6464
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
6565
6666
We process as many samples as possible with ddp, and process the rest on a single worker.
6767
"""
6868
batch_size = batch_size or args.batch_size
69+
device = torch.device(args.device)
6970

7071
model.eval()
7172

72-
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
73+
if args.distributed:
74+
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
75+
else:
76+
sampler = torch.utils.data.SequentialSampler(val_dataset)
77+
7378
val_loader = torch.utils.data.DataLoader(
7479
val_dataset,
7580
sampler=sampler,
@@ -88,7 +93,7 @@ def inner_loop(blob):
8893
image1, image2, flow_gt = blob[:3]
8994
valid_flow_mask = None if len(blob) == 3 else blob[-1]
9095

91-
image1, image2 = image1.cuda(), image2.cuda()
96+
image1, image2 = image1.to(device), image2.to(device)
9297

9398
padder = utils.InputPadder(image1.shape, mode=padder_mode)
9499
image1, image2 = padder.pad(image1, image2)
@@ -115,21 +120,22 @@ def inner_loop(blob):
115120
inner_loop(blob)
116121
num_processed_samples += blob[0].shape[0] # batch size
117122

118-
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
119-
print(
120-
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
121-
"Going to process the remaining samples individually, if any."
122-
)
123+
if args.distributed:
124+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
125+
print(
126+
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
127+
"Going to process the remaining samples individually, if any."
128+
)
129+
if args.rank == 0: # we only need to process the rest on a single worker
130+
for i in range(num_processed_samples, len(val_dataset)):
131+
inner_loop(val_dataset[i])
123132

124-
if args.rank == 0: # we only need to process the rest on a single worker
125-
for i in range(num_processed_samples, len(val_dataset)):
126-
inner_loop(val_dataset[i])
133+
logger.synchronize_between_processes()
127134

128-
logger.synchronize_between_processes()
129135
print(header, logger)
130136

131137

132-
def validate(model, args):
138+
def evaluate(model, args):
133139
val_datasets = args.val_dataset or []
134140

135141
if args.prototype:
@@ -145,21 +151,21 @@ def validate(model, args):
145151
if name == "kitti":
146152
# Kitti has different image sizes so we need to individually pad them, we can't batch.
147153
# see comment in InputPadder
148-
if args.batch_size != 1 and args.rank == 0:
154+
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
149155
warnings.warn(
150156
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
151157
)
152158

153159
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
154-
_validate(
160+
_evaluate(
155161
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
156162
)
157163
elif name == "sintel":
158164
for pass_name in ("clean", "final"):
159165
val_dataset = Sintel(
160166
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
161167
)
162-
_validate(
168+
_evaluate(
163169
model,
164170
args,
165171
val_dataset,
@@ -172,11 +178,12 @@ def validate(model, args):
172178

173179

174180
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
181+
device = torch.device(args.device)
175182
for data_blob in logger.log_every(train_loader):
176183

177184
optimizer.zero_grad()
178185

179-
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
186+
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
180187
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
181188

182189
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
@@ -200,36 +207,68 @@ def main(args):
200207
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
201208
utils.setup_ddp(args)
202209

210+
if args.distributed and args.device == "cpu":
211+
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
212+
device = torch.device(args.device)
213+
203214
if args.prototype:
204215
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
205216
else:
206217
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
207218

208-
model = model.to(args.local_rank)
209-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
219+
if args.distributed:
220+
model = model.to(args.local_rank)
221+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
222+
model_without_ddp = model.module
223+
else:
224+
model.to(device)
225+
model_without_ddp = model
210226

211227
if args.resume is not None:
212-
d = torch.load(args.resume, map_location="cpu")
213-
model.load_state_dict(d, strict=True)
228+
checkpoint = torch.load(args.resume, map_location="cpu")
229+
model_without_ddp.load_state_dict(checkpoint["model"])
214230

215231
if args.train_dataset is None:
216232
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
217233
torch.backends.cudnn.benchmark = False
218234
torch.backends.cudnn.deterministic = True
219-
validate(model, args)
235+
evaluate(model, args)
220236
return
221237

222238
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
223239

240+
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
241+
242+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
243+
244+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
245+
optimizer=optimizer,
246+
max_lr=args.lr,
247+
epochs=args.epochs,
248+
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
249+
pct_start=0.05,
250+
cycle_momentum=False,
251+
anneal_strategy="linear",
252+
)
253+
254+
if args.resume is not None:
255+
optimizer.load_state_dict(checkpoint["optimizer"])
256+
scheduler.load_state_dict(checkpoint["scheduler"])
257+
args.start_epoch = checkpoint["epoch"] + 1
258+
else:
259+
args.start_epoch = 0
260+
224261
torch.backends.cudnn.benchmark = True
225262

226263
model.train()
227264
if args.freeze_batch_norm:
228265
utils.freeze_batch_norm(model.module)
229266

230-
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
267+
if args.distributed:
268+
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
269+
else:
270+
sampler = torch.utils.data.RandomSampler(train_dataset)
231271

232-
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
233272
train_loader = torch.utils.data.DataLoader(
234273
train_dataset,
235274
sampler=sampler,
@@ -238,25 +277,15 @@ def main(args):
238277
num_workers=args.num_workers,
239278
)
240279

241-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
242-
243-
scheduler = torch.optim.lr_scheduler.OneCycleLR(
244-
optimizer=optimizer,
245-
max_lr=args.lr,
246-
epochs=args.epochs,
247-
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
248-
pct_start=0.05,
249-
cycle_momentum=False,
250-
anneal_strategy="linear",
251-
)
252-
253280
logger = utils.MetricLogger()
254281

255282
done = False
256-
for current_epoch in range(args.epochs):
283+
for current_epoch in range(args.start_epoch, args.epochs):
257284
print(f"EPOCH {current_epoch}")
285+
if args.distributed:
286+
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
287+
sampler.set_epoch(current_epoch)
258288

259-
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
260289
train_one_epoch(
261290
model=model,
262291
optimizer=optimizer,
@@ -269,13 +298,19 @@ def main(args):
269298
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
270299
print(f"Epoch {current_epoch} done. ", logger)
271300

272-
if args.rank == 0:
273-
# TODO: Also save the optimizer and scheduler
274-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
275-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
301+
if not args.distributed or args.rank == 0:
302+
checkpoint = {
303+
"model": model_without_ddp.state_dict(),
304+
"optimizer": optimizer.state_dict(),
305+
"scheduler": scheduler.state_dict(),
306+
"epoch": current_epoch,
307+
"args": args,
308+
}
309+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
310+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
276311

277312
if current_epoch % args.val_freq == 0 or done:
278-
validate(model, args)
313+
evaluate(model, args)
279314
model.train()
280315
if args.freeze_batch_norm:
281316
utils.freeze_batch_norm(model.module)
@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
349384
action="store_true",
350385
)
351386
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
387+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
352388

353389
return parser
354390

references/optical_flow/transforms.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
77
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
88
def forward(self, img1, img2, flow, valid_flow_mask):
99

10-
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
11-
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)
10+
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
11+
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
12+
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
13+
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")
1214

13-
assert img1.shape == img2.shape
15+
if img1.shape != img2.shape:
16+
raise ValueError("img1 and img2 should have the same shape.")
1417
h, w = img1.shape[-2:]
15-
if flow is not None:
16-
assert flow.shape == (2, h, w)
18+
if flow is not None and flow.shape != (2, h, w):
19+
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
1720
if valid_flow_mask is not None:
18-
assert valid_flow_mask.shape == (h, w)
19-
assert valid_flow_mask.dtype == torch.bool
21+
if valid_flow_mask.shape != (h, w):
22+
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
23+
if valid_flow_mask.dtype != torch.bool:
24+
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")
2025

2126
return img1, img2, flow, valid_flow_mask
2227

@@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
109114
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
110115
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
111116
self.max_erase = max_erase
112-
assert self.max_erase > 0
117+
if self.max_erase <= 0:
118+
raise ValueError("max_raise should be greater than 0")
113119

114120
def forward(self, img1, img2, flow, valid_flow_mask):
115121
if torch.rand(1) > self.p:

references/optical_flow/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def update(self, **kwargs):
7171
for k, v in kwargs.items():
7272
if isinstance(v, torch.Tensor):
7373
v = v.item()
74-
assert isinstance(v, (float, int))
74+
if not isinstance(v, (float, int)):
75+
raise TypeError(
76+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
77+
)
7578
self.meters[k].update(v)
7679

7780
def __getattr__(self, attr):
@@ -256,7 +259,12 @@ def setup_ddp(args):
256259
# if we're here, the script was called by run_with_submitit.py
257260
args.local_rank = args.gpu
258261
else:
259-
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.")
262+
print("Not using distributed mode!")
263+
args.distributed = False
264+
args.world_size = 1
265+
return
266+
267+
args.distributed = True
260268

261269
_redefine_print(is_main=(args.rank == 0))
262270

references/segmentation/coco_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def _has_valid_annotation(anno):
6868
# if more than 1k pixels occupied in the image
6969
return sum(obj["area"] for obj in anno) > 1000
7070

71-
assert isinstance(dataset, torchvision.datasets.CocoDetection)
71+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
72+
raise TypeError(
73+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74+
)
75+
7276
ids = []
7377
for ds_idx, img_id in enumerate(dataset.ids):
7478
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)

references/segmentation/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def update(self, **kwargs):
118118
for k, v in kwargs.items():
119119
if isinstance(v, torch.Tensor):
120120
v = v.item()
121-
assert isinstance(v, (float, int))
121+
if not isinstance(v, (float, int)):
122+
raise TypeError(
123+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
124+
)
122125
self.meters[k].update(v)
123126

124127
def __getattr__(self, attr):

0 commit comments

Comments
 (0)