Skip to content

Commit

Permalink
Render cameras with render times for dynamic nerfs (#1199)
Browse files Browse the repository at this point in the history
* Render cameras with render times for dynamic nerfs

* Add --dynamic flag to cmd. Clamp render times.

* Fix for issues with clamping renderTime

* Only add render_time to camera path if dynamic nerf

* change DEFAULT_TIME to DEFAULT_RENDER_TIME

Co-authored-by: Liam Schoneveld <[email protected]>
  • Loading branch information
nlml and Liam Schoneveld authored Jan 4, 2023
1 parent 102c00a commit 31be97d
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 6 deletions.
7 changes: 7 additions & 0 deletions nerfstudio/cameras/camera_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def get_path_from_json(camera_path: Dict[str, Any]) -> Cameras:
fxs.append(focal_length)
fys.append(focal_length)

# Iff ALL cameras in the path have a "time" value, construct Cameras with times
if all("render_time" in camera for camera in camera_path["camera_path"]):
times = torch.tensor([camera["render_time"] for camera in camera_path["camera_path"]])
else:
times = None

camera_to_worlds = torch.stack(c2ws, dim=0)
fx = torch.tensor(fxs)
fy = torch.tensor(fys)
Expand All @@ -162,4 +168,5 @@ def get_path_from_json(camera_path: Dict[str, Any]) -> Cameras:
cy=image_height / 2,
camera_to_worlds=camera_to_worlds,
camera_type=camera_type,
times=times,
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,104 @@ function set_camera_position(camera, matrix) {
mat.decompose(camera.position, camera.quaternion, camera.scale);
}


function RenderTimeSelector(props) {
const disabled = props.disabled;
const isGlobal = props.isGlobal;
const camera = props.camera;
const dispatch = props.dispatch;
const globalRenderTime = props.globalRenderTime;
const setGlobalRenderTime = props.setGlobalRenderTime;
const applyAll = props.applyAll;
const changeMain = props.changeMain;
const setAllCameraRenderTime = props.setAllCameraRenderTime;

const getRenderTimeLabel = () => {
if (!isGlobal) {
return camera.renderTime;
}
camera.renderTime = globalRenderTime
return globalRenderTime;
};

const [UIRenderTime, setUIRenderTime] = React.useState(
isGlobal ? globalRenderTime : getRenderTimeLabel(),
);

const [valid, setValid] = React.useState(true);

useEffect(
() => setUIRenderTime(getRenderTimeLabel()),
[camera, globalRenderTime],
);

const setRndrTime = (val) => {
if (!isGlobal) {
camera.renderTime = val;
} else {
camera.renderTime = val;
setGlobalRenderTime(val);
}

if (applyAll) {
setAllCameraRenderTime(val);
}

if (changeMain) {
dispatch({
type: 'write',
path: 'renderingState/render_time',
data: camera.renderTime,
});
}
};

const handleValidation = (e) => {
const valueFloat = parseFloat(e.target.value);
let valueStr = String(valueFloat);
if (e.target.value >= 0 && e.target.value <= 1){
setValid(true);
if (valueFloat === 1.0) {
valueStr = '1.0';
}
if (valueFloat === 0.0) {
valueStr = '0.0';
}
setUIRenderTime(parseFloat(valueStr));
setRndrTime(parseFloat(valueStr));
} else {
setValid(false);
}
};

return (
<TextField
label='Render Time'
InputLabelProps={{
style: { color: '#8E8E8E' },
}}
inputProps={{
inputMode: 'numeric',
}}
onChange={(e) => setUIRenderTime(e.target.value)}
onBlur={(e) => handleValidation(e)}
disabled={disabled}
sx={{
input: {
'-webkit-text-fill-color': `${
disabled ? '#24B6FF' : '#EBEBEB'
} !important`,
color: `${disabled ? '#24B6FF' : '#EBEBEB'} !important`,
},
}}
value={UIRenderTime}
error={!valid}
helperText={!valid ? 'RenderTime should be between 0.0 and 1.0' : ''}
variant="standard"
/>
);
}

function FovSelector(props) {
const fovLabel = props.fovLabel;
const setFovLabel = props.setFovLabel;
Expand Down Expand Up @@ -279,6 +377,7 @@ function CameraList(props) {
set_camera_position(camera_render, first_camera.matrix);
camera_render_helper.set_visibility(true);
camera_render.fov = first_camera.fov;
camera_render.renderTime = first_camera.renderTime;
}
set_slider_value(slider_min);
};
Expand Down Expand Up @@ -392,6 +491,7 @@ function CameraList(props) {
e.stopPropagation();
set_camera_position(camera_main, camera.matrix);
camera_main.fov = camera.fov;
camera_main.renderTime = camera.renderTime;
set_slider_value(camera.properties.get('TIME'));
}}
>
Expand All @@ -414,7 +514,16 @@ function CameraList(props) {
changeMain={false}
/>
)}
{!isAnimated('FOV') && (
{isAnimated('RenderTime') && (
<RenderTimeSelector
camera={camera}
dispatch={dispatch}
disabled={!isAnimated('RenderTime')}
isGlobal={false}
changeMain={false}
/>
)}
{!isAnimated('FOV') && !isAnimated('RenderTime') && (
<p style={{ fontSize: 'smaller', color: '#999999' }}>
Animated camera properties will show up here!
</p>
Expand Down Expand Up @@ -452,6 +561,7 @@ export default function CameraPanel(props) {
);
const websocket = useContext(WebSocketContext).socket;
const DEFAULT_FOV = 50;
const DEFAULT_RENDER_TIME = 0.0;

// react state
const [cameras, setCameras] = React.useState([]);
Expand All @@ -466,6 +576,7 @@ export default function CameraPanel(props) {
const [render_modal_open, setRenderModalOpen] = React.useState(false);
const [animate, setAnimate] = React.useState(new Set());
const [globalFov, setGlobalFov] = React.useState(DEFAULT_FOV);
const [globalRenderTime, setGlobalRenderTime] = React.useState(DEFAULT_RENDER_TIME);

// leva store
const cameraPropsStore = useCreateStore();
Expand All @@ -491,6 +602,17 @@ export default function CameraPanel(props) {
);
const camera_type = useSelector((state) => state.renderingState.camera_type);

const [display_render_time, set_display_render_time] = React.useState(false);

const receive_temporal_dist = e => {
const msg = msgpack.decode(new Uint8Array(e.data));
if (msg.path === "/model/has_temporal_distortion") {
set_display_render_time(msg.data === "true");
websocket.removeEventListener("message", receive_temporal_dist);
}
}
websocket.addEventListener('message', receive_temporal_dist);

const setRenderHeight = (value) => {
dispatch({
type: 'write',
Expand Down Expand Up @@ -522,6 +644,14 @@ export default function CameraPanel(props) {
});
};

const setRenderTime = (value) => {
dispatch({
type: 'write',
path: 'renderingState/render_time',
data: parseFloat(value),
});
};

// ui state
const [fovLabel, setFovLabel] = React.useState(FOV_LABELS.FOV);

Expand All @@ -538,6 +668,7 @@ export default function CameraPanel(props) {
if (new_camera_list.length >= 1) {
set_camera_position(camera_render, new_camera_list[0].matrix);
setFieldOfView(new_camera_list[0].fov);
setRenderTime(new_camera_list[0].renderTime)
set_slider_value(slider_min);
}
};
Expand All @@ -546,6 +677,7 @@ export default function CameraPanel(props) {
const camera_main_copy = camera_main.clone();
camera_main_copy.aspect = 1.0;
camera_main_copy.fov = globalFov;
camera_main_copy.renderTime = globalRenderTime;
const new_camera_properties = new Map();
camera_main_copy.properties = new_camera_properties;
new_camera_properties.set('FOV', globalFov);
Expand Down Expand Up @@ -828,11 +960,21 @@ export default function CameraPanel(props) {

const mat = get_transform_matrix(position, lookat, up);

camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
});
if (display_render_time) {
const renderTime = curve_object.curve_render_times.getPoint(pt).z;
camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
render_time: Math.max(Math.min(renderTime, 1.0), 0.0), // clamp time values to [0, 1]
});
} else {
camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
});
}
}

const keyframes = [];
Expand Down Expand Up @@ -985,6 +1127,12 @@ export default function CameraPanel(props) {
}
};

const setAllCameraRenderTime = (val) => {
for (let i = 0; i < cameras.length; i += 1) {
cameras[i].renderTime = val;
}
};

return (
<div className="CameraPanel">
<div>
Expand Down Expand Up @@ -1047,6 +1195,51 @@ export default function CameraPanel(props) {
/>
</LevaStoreProvider>
</div>
{display_render_time && (
<div className="CameraList-row-animation-properties">
<Tooltip title="Animate Render Time for Each Camera">
<Button
value="animateRenderTime"
selected={isAnimated('RenderTime')}
onClick={() => {
toggleAnimate('RenderTime');
}}
style={{
maxWidth: '20px',
maxHeight: '20px',
minWidth: '20px',
minHeight: '20px',
position: 'relative',
top: '22px',
}}
sx={{
mt: 1,
}}
>
<Animation
style={{
color: isAnimated('RenderTime') ? '#24B6FF' : '#EBEBEB',
maxWidth: '20px',
maxHeight: '20px',
minWidth: '20px',
minHeight: '20px',
}}
/>
</Button>
</Tooltip>
<RenderTimeSelector
disabled={false}
isGlobal
camera={camera_main}
dispatch={dispatch}
globalRenderTime={globalRenderTime}
setGlobalRenderTime={setGlobalRenderTime}
applyAll={!isAnimated('RenderTime')}
setAllCameraRenderTime={setAllCameraRenderTime}
changeMain
/>
</div>
)}
{camera_type !== 'equirectangular' && (
<div className="CameraList-row-animation-properties">
<Tooltip title="Animate FOV for Each Camera">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export function get_curve_object_from_cameras(
const lookats = [];
const ups = [];
const fovs = [];
const render_times = [];

for (let i = 0; i < cameras.length; i += 1) {
const camera = cameras[i];
Expand All @@ -45,23 +46,27 @@ export function get_curve_object_from_cameras(
lookats.push(lookat);
// Reuse catmullromcurve3 for 1d values. TODO fix this
fovs.push(new THREE.Vector3(0, 0, camera.fov));
render_times.push(new THREE.Vector3(0, 0, camera.renderTime));
}

let curve_positions = null;
let curve_lookats = null;
let curve_ups = null;
let curve_fovs = null;
let curve_render_times = null;

curve_positions = get_catmull_rom_curve(positions, is_cycle, smoothness_value);
curve_lookats = get_catmull_rom_curve(lookats, is_cycle, smoothness_value);
curve_ups = get_catmull_rom_curve(ups, is_cycle, smoothness_value);
curve_fovs = get_catmull_rom_curve(fovs, is_cycle, smoothness_value / 10);
curve_render_times = get_catmull_rom_curve(render_times, is_cycle, smoothness_value);

const curve_object = {
curve_positions,
curve_lookats,
curve_ups,
curve_fovs,
curve_render_times,
};
return curve_object;
}
Expand Down

0 comments on commit 31be97d

Please sign in to comment.