-
Notifications
You must be signed in to change notification settings - Fork 14
/
ray_utils.py
316 lines (260 loc) · 12.7 KB
/
ray_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import numpy as np
import torch
import collections
Rays = collections.namedtuple('Rays', ('origins', 'directions', 'viewdirs', 'radii', 'lossmult', 'near', 'far'))
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*map(fn, tup))
def sorted_piecewise_constant_pdf(bins, weights, num_samples, randomized):
# Pad each weight vector (only if necessary) to bring its sum to `eps`. This
# avoids NaNs when the input is zeros or small, but has no effect otherwise.
eps = 1e-5
weight_sum = torch.sum(weights, dim=-1, keepdim=True)
padding = torch.maximum(torch.zeros_like(weight_sum), eps - weight_sum)
weights += padding / weights.shape[-1]
weight_sum += padding
# Compute the PDF and CDF for each weight vector, while ensuring that the CDF
# starts with exactly 0 and ends with exactly 1.
pdf = weights / weight_sum
cdf = torch.cumsum(pdf[..., :-1], dim=-1)
cdf = torch.minimum(torch.ones_like(cdf), cdf)
cdf = torch.cat([torch.zeros(list(cdf.shape[:-1]) + [1], device=cdf.device),
cdf,
torch.ones(list(cdf.shape[:-1]) + [1], device=cdf.device)],
dim=-1)
# Draw uniform samples.
if randomized:
s = 1 / num_samples
u = (torch.arange(num_samples, device=cdf.device) * s)[None, ...]
u = u + u + torch.empty(list(cdf.shape[:-1]) + [num_samples], device=cdf.device).uniform_(to=(s - torch.finfo(torch.float32).eps))
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
u = torch.minimum(u, torch.full_like(u, 1. - torch.finfo(torch.float32).eps, device=u.device))
else:
# Match the behavior of jax.random.uniform() by spanning [0, 1-eps].
u = torch.linspace(0., 1. - torch.finfo(torch.float32).eps, num_samples, device=cdf.device)
u = torch.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
# Identify the location in `cdf` that corresponds to a random sample.
# The final `True` index in `mask` will be the start of the sampled interval.
mask = u[..., None, :] >= cdf[..., :, None]
def find_interval(x):
# Grab the value where `mask` switches from True to False, and vice versa.
# This approach takes advantage of the fact that `x` is sorted.
x0, _ = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2)
x1, _ = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2)
return x0, x1
bins_g0, bins_g1 = find_interval(bins)
cdf_g0, cdf_g1 = find_interval(cdf)
t = torch.clip(torch.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
samples = bins_g0 + t * (bins_g1 - bins_g0)
return samples
def convert_to_ndc(origins, directions, focal, w, h, near=1.):
"""Convert a set of rays to NDC coordinates."""
# Shift ray origins to near plane
t = -(near + origins[..., 2]) / (directions[..., 2] + 1e-15)
origins = origins + t[..., None] * directions
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
# Projection
o0 = -((2 * focal) / w) * (ox / (oz + 1e-15))
o1 = -((2 * focal) / h) * (oy / (oz+ 1e-15) )
o2 = 1 + 2 * near / (oz+ 1e-15)
d0 = -((2 * focal) / w) * (dx / (dz+ 1e-15) - ox / (oz+ 1e-15))
d1 = -((2 * focal) / h) * (dy / (dz+ 1e-15) - oy / (oz+ 1e-15))
d2 = -2 * near / (oz+ 1e-15)
origins = np.stack([o0, o1, o2], -1)
directions = np.stack([d0, d1, d2], -1)
return origins, directions
def lift_gaussian(d, t_mean, t_var, r_var, diag):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True) + 1e-10
if diag:
d_outer_diag = d ** 2
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
"""
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return lift_gaussian(d, t_mean, t_var, r_var, diag)
def cylinder_to_gaussian(d, t0, t1, radius, diag):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
t_mean = (t0 + t1) / 2
r_var = radius ** 2 / 4
t_var = (t1 - t0) ** 2 / 12
return lift_gaussian(d, t_mean, t_var, r_var, diag)
def cast_rays(t_vals, origins, directions, radii, ray_shape, diag=True):
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
Args:
t_vals: float array, the "fencepost" distances along the ray.
origins: float array, the ray origin coordinates.
directions: float array, the ray direction vectors.
radii: float array, the radii (base radii for cones) of the rays.
diag: boolean, whether or not the covariance matrices should be diagonal.
Returns:
a tuple of arrays of means and covariances.
"""
t0 = t_vals[..., :-1]
t1 = t_vals[..., 1:]
if ray_shape == 'cone':
gaussian_fn = conical_frustum_to_gaussian
elif ray_shape == 'cylinder':
gaussian_fn = cylinder_to_gaussian
else:
assert False
means, covs = gaussian_fn(directions, t0, t1, radii, diag)
means = means + origins[..., None, :]
return means, covs
def sample_along_rays(origins, directions, radii, num_samples, near, far, randomized, lindisp, ray_shape):
"""Stratified sampling along the rays.
Args:
origins: torch.tensor(float32), [batch_size, 3], ray origins.
directions: torch.tensor(float32), [batch_size, 3], ray directions.
radii: torch.tensor(float32), [batch_size, 3], ray radii.
num_samples: int.
near: torch.tensor, [batch_size, 1], near clip.
far: torch.tensor, [batch_size, 1], far clip.
randomized: bool, use randomized stratified sampling.
lindisp: bool, sampling linearly in disparity rather than depth.
Returns:
t_vals: torch.tensor, [batch_size, num_samples], sampled z values.
means: torch.tensor, [batch_size, num_samples, 3], sampled means.
covs: torch.tensor, [batch_size, num_samples, 3, 3], sampled covariances.
"""
batch_size = origins.shape[0]
t_vals = torch.linspace(0., 1., num_samples + 1, device=origins.device)
if lindisp:
t_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
else:
t_vals = near * (1. - t_vals) + far * t_vals
if randomized:
mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
upper = torch.cat([mids, t_vals[..., -1:]], -1)
lower = torch.cat([t_vals[..., :1], mids], -1)
t_rand = torch.rand(batch_size, num_samples + 1, device=origins.device)
t_vals = lower + (upper - lower) * t_rand
else:
# Broadcast t_vals to make the returned shape consistent.
t_vals = torch.broadcast_to(t_vals, [batch_size, num_samples + 1])
means, covs = cast_rays(t_vals, origins, directions, radii, ray_shape)
return t_vals, (means, covs)
def resample_along_rays(origins, directions, radii, t_vals, weights, randomized, stop_grad, resample_padding, ray_shape):
"""Resampling.
Args:
origins: torch.tensor(float32), [batch_size, 3], ray origins.
directions: torch.tensor(float32), [batch_size, 3], ray directions.
radii: torch.tensor(float32), [batch_size, 3], ray radii.
t_vals: torch.tensor(float32), [batch_size, num_samples+1].
weights: torch.tensor(float32), weights for t_vals
randomized: bool, use randomized samples.
stop_grad: bool, whether or not to backprop through sampling.
resample_padding: float, added to the weights before normalizing.
Returns:
t_vals: torch.tensor(float32), [batch_size, num_samples+1].
points: torch.tensor(float32), [batch_size, num_samples, 3].
"""
if stop_grad:
with torch.no_grad():
weights_pad = torch.cat([weights[..., :1], weights, weights[..., -1:]], dim=-1)
weights_max = torch.maximum(weights_pad[..., :-1], weights_pad[..., 1:])
weights_blur = 0.5 * (weights_max[..., :-1] + weights_max[..., 1:])
# Add in a constant (the sampling function will renormalize the PDF).
weights = weights_blur + resample_padding
new_t_vals = sorted_piecewise_constant_pdf(
t_vals,
weights,
t_vals.shape[-1],
randomized,
)
else:
weights_pad = torch.cat([weights[..., :1], weights, weights[..., -1:]], dim=-1)
weights_max = torch.maximum(weights_pad[..., :-1], weights_pad[..., 1:])
weights_blur = 0.5 * (weights_max[..., :-1] + weights_max[..., 1:])
# Add in a constant (the sampling function will renormalize the PDF).
weights = weights_blur + resample_padding
new_t_vals = sorted_piecewise_constant_pdf(
t_vals,
weights,
t_vals.shape[-1],
randomized,
)
means, covs = cast_rays(new_t_vals, origins, directions, radii, ray_shape)
return new_t_vals, (means, covs)
def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd):
"""Volumetric Rendering Function.
Args:
rgb: torch.tensor(float32), color, [batch_size, num_samples, 3]
density: torch.tensor(float32), density, [batch_size, num_samples, 1].
t_vals: torch.tensor(float32), [batch_size, num_samples].
dirs: torch.tensor(float32), [batch_size, 3].
white_bkgd: bool.
Returns:
comp_rgb: torch.tensor(float32), [batch_size, 3].
disp: torch.tensor(float32), [batch_size].
acc: torch.tensor(float32), [batch_size].
weights: torch.tensor(float32), [batch_size, num_samples]
"""
t_mids = 0.5 * (t_vals[..., :-1] + t_vals[..., 1:])
t_dists = t_vals[..., 1:] - t_vals[..., :-1]
delta = t_dists * torch.linalg.norm(dirs[..., None, :], dim=-1)
# Note that we're quietly turning density from [..., 0] to [...].
density_delta = density[..., 0] * delta
alpha = 1 - torch.exp(-density_delta)
trans = torch.exp(-torch.cat([
torch.zeros_like(density_delta[..., :1]),
torch.cumsum(density_delta[..., :-1], dim=-1)
], dim=-1))
weights = alpha * trans
comp_rgb = (weights[..., None] * rgb).sum(dim=-2)
acc = weights.sum(dim=-1)
distance = (weights * t_mids).sum(dim=-1) / acc
distance = torch.clamp(torch.nan_to_num(distance), t_vals[:, 0], t_vals[:, -1])
if white_bkgd:
comp_rgb = comp_rgb + (1. - acc[..., None])
return comp_rgb, distance, acc, weights, alpha