-
Notifications
You must be signed in to change notification settings - Fork 0
/
augment.py
113 lines (99 loc) · 4.66 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import torch
# TODO:
# parameterize stuff
class Augment:
def __init__(self, args):
self.args = args
self.la = []
print('augmentations start')
if args.augment_occlude_spatial_probability > 0:
print('occlude_spatial_probability', args.augment_occlude_spatial_probability,
'occlude_spatial_scale', args.augment_occlude_spatial_scale)
self.la.append(self.occlude_spatial)
if args.augment_occlude_temporal_probability > 0:
print('occlude_temporal_probability', args.augment_occlude_temporal_probability,
'occlude_temporal_scale', args.augment_occlude_temporal_scale)
self.la.append(self.occlude_temporal)
if args.augment_swap_spatial_k > 1:
print('swap_spatial_k', args.augment_swap_spatial_k)
self.la.append(self.swap_spatial)
if args.augment_swap_temporal_k > 1:
print('swap_temporal_k', args.augment_swap_temporal_k)
self.la.append(self.swap_temporal)
if args.augment_scramble_spatial_probability > 0:
print('scramble_spatial_probability', args.augment_scramble_spatial_probability)
self.la.append(self.scramble_spatial)
if args.augment_scramble_temporal_probability > 0:
print('scramble_temporal_probability', args.augment_scramble_temporal_probability)
self.la.append(self.scramble_temporal)
if args.augment_uniform_noise_scale > 0:
print('uniform_noise_scale', args.augment_uniform_noise_scale)
self.la.append(self.uniform_noise)
print('augmentations end')
return
def augment(self, x):
with torch.no_grad():
for ia in self.la:
x = ia(x)
return x
def occlude_spatial(self, x): # bdnm
b, d, n, m = x.size()
mask = torch.rand(b, 1, n, 1)
mask = mask < self.args.augment_occlude_spatial_probability
mask = mask.expand(-1, d, -1, m).clone()
mask[:, 1:, :, :] = False
tensor_random = torch.rand_like(mask.float())[mask].to(x.device)
x[mask] *= tensor_random * self.args.augment_occlude_spatial_scale
return x
def occlude_temporal(self, x): # bdnm
b, d, n, m = x.size()
mask = torch.rand(b, 1, 1, m)
mask = mask < self.args.augment_occlude_temporal_probability
mask = mask.expand(-1, d, n, -1).clone()
mask[:, 1:, :, :] = False
tensor_random = torch.rand_like(mask.float())[mask].to(x.device)
x[mask] *= tensor_random * self.args.augment_occlude_temporal_scale
return x
def swap_spatial(self, x): # bdnm
k = self.args.augment_swap_spatial_k
for ib in range(x.size(0)):
i1 = np.random.choice(np.arange(x.size(2)), k, replace=False).astype(int)
i2 = np.random.permutation(i1)
x[ib, :, i1, :] = x[ib, :, i2, :]
return x
def swap_temporal(self, x):
k = self.args.augment_swap_temporal_k
for ib in range(x.size(0)):
i1 = np.random.choice(np.arange(x.size(3)), k, replace=False).astype(int)
i2 = np.random.permutation(i1)
x[ib, :, :, i1] = x[ib, :, :, i2]
return x
def scramble_spatial(self, x):
for ib in range(x.size(0)):
for inn in range(x.size(2)):
if torch.rand(1) < self.args.augment_scramble_spatial_probability:
x[ib, :, inn, :] = x[ib, :, inn, np.random.permutation(np.arange(x.size(3)))]
return x
def scramble_temporal(self, x):
for ib in range(x.size(0)):
for il in range(x.size(3)):
if torch.rand(1) < self.args.augment_scramble_spatial_probability:
x[ib, :, :, il] = x[ib, :, np.random.permutation(np.arange(x.size(2))), il]
return x
def uniform_noise(self, x): # bdnm
rand = (torch.rand_like(x) - 0.5) * self.args.augment_uniform_noise_scale
rand[:, 1:, :, :] = 0
return x + rand
def datapoint_zero_mean(self, x): # bdnm
mu = x[:, 0:1, :, :].mean(dim=[1, 2, 3])
x[:, 0:1, :, :] = x[:, 0:1, :, :] - mu.view(-1, 1, 1, 1)
return x
def temporal_zero_mean(self, x): # bdnm
mu = x[:, 0:1, :, :].mean(dim=[1, 2])
x[:, 0:1, :, :] = x[:, 0:1, :, :] - mu.view(mu.size(0), 1, 1, mu.size(1))
return x
def spatial_zero_mean(self, x): # bdnm
mu = x[:, 0:1, :, :].mean(dim=[1, 3])
x[:, 0:1, :, :] = x[:, 0:1, :, :] - mu.view(mu.size(0), 1, mu.size(1), 1)
return x