Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render cameras with render times for dynamic nerfs #1199

Merged
merged 5 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nerfstudio/cameras/camera_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def get_path_from_json(camera_path: Dict[str, Any]) -> Cameras:
fxs.append(focal_length)
fys.append(focal_length)

# Iff ALL cameras in the path have a "time" value, construct Cameras with times
if all("render_time" in camera for camera in camera_path["camera_path"]):
times = torch.tensor([camera["render_time"] for camera in camera_path["camera_path"]])
else:
times = None

camera_to_worlds = torch.stack(c2ws, dim=0)
fx = torch.tensor(fxs)
fy = torch.tensor(fys)
Expand All @@ -162,4 +168,5 @@ def get_path_from_json(camera_path: Dict[str, Any]) -> Cameras:
cy=image_height / 2,
camera_to_worlds=camera_to_worlds,
camera_type=camera_type,
times=times,
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,104 @@ function set_camera_position(camera, matrix) {
mat.decompose(camera.position, camera.quaternion, camera.scale);
}


function RenderTimeSelector(props) {
const disabled = props.disabled;
const isGlobal = props.isGlobal;
const camera = props.camera;
const dispatch = props.dispatch;
const globalRenderTime = props.globalRenderTime;
const setGlobalRenderTime = props.setGlobalRenderTime;
const applyAll = props.applyAll;
const changeMain = props.changeMain;
const setAllCameraRenderTime = props.setAllCameraRenderTime;

const getRenderTimeLabel = () => {
if (!isGlobal) {
return camera.renderTime;
}
camera.renderTime = globalRenderTime
return globalRenderTime;
};

const [UIRenderTime, setUIRenderTime] = React.useState(
isGlobal ? globalRenderTime : getRenderTimeLabel(),
);

const [valid, setValid] = React.useState(true);

useEffect(
() => setUIRenderTime(getRenderTimeLabel()),
[camera, globalRenderTime],
);

const setRndrTime = (val) => {
if (!isGlobal) {
camera.renderTime = val;
} else {
camera.renderTime = val;
setGlobalRenderTime(val);
}

if (applyAll) {
setAllCameraRenderTime(val);
}

if (changeMain) {
dispatch({
type: 'write',
path: 'renderingState/render_time',
data: camera.renderTime,
});
}
};

const handleValidation = (e) => {
const valueFloat = parseFloat(e.target.value);
let valueStr = String(valueFloat);
if (e.target.value >= 0 && e.target.value <= 1){
setValid(true);
if (valueFloat === 1.0) {
valueStr = '1.0';
}
if (valueFloat === 0.0) {
valueStr = '0.0';
}
setUIRenderTime(parseFloat(valueStr));
setRndrTime(parseFloat(valueStr));
} else {
setValid(false);
}
};

return (
<TextField
label='Render Time'
InputLabelProps={{
style: { color: '#8E8E8E' },
}}
inputProps={{
inputMode: 'numeric',
}}
onChange={(e) => setUIRenderTime(e.target.value)}
onBlur={(e) => handleValidation(e)}
disabled={disabled}
sx={{
input: {
'-webkit-text-fill-color': `${
disabled ? '#24B6FF' : '#EBEBEB'
} !important`,
color: `${disabled ? '#24B6FF' : '#EBEBEB'} !important`,
},
}}
value={UIRenderTime}
error={!valid}
helperText={!valid ? 'RenderTime should be between 0.0 and 1.0' : ''}
variant="standard"
/>
);
}

function FovSelector(props) {
const fovLabel = props.fovLabel;
const setFovLabel = props.setFovLabel;
Expand Down Expand Up @@ -279,6 +377,7 @@ function CameraList(props) {
set_camera_position(camera_render, first_camera.matrix);
camera_render_helper.set_visibility(true);
camera_render.fov = first_camera.fov;
camera_render.renderTime = first_camera.renderTime;
}
set_slider_value(slider_min);
};
Expand Down Expand Up @@ -392,6 +491,7 @@ function CameraList(props) {
e.stopPropagation();
set_camera_position(camera_main, camera.matrix);
camera_main.fov = camera.fov;
camera_main.renderTime = camera.renderTime;
set_slider_value(camera.properties.get('TIME'));
}}
>
Expand All @@ -414,7 +514,16 @@ function CameraList(props) {
changeMain={false}
/>
)}
{!isAnimated('FOV') && (
{isAnimated('RenderTime') && (
<RenderTimeSelector
camera={camera}
dispatch={dispatch}
disabled={!isAnimated('RenderTime')}
isGlobal={false}
changeMain={false}
/>
)}
{!isAnimated('FOV') && !isAnimated('RenderTime') && (
<p style={{ fontSize: 'smaller', color: '#999999' }}>
Animated camera properties will show up here!
</p>
Expand Down Expand Up @@ -452,6 +561,7 @@ export default function CameraPanel(props) {
);
const websocket = useContext(WebSocketContext).socket;
const DEFAULT_FOV = 50;
const DEFAULT_RENDER_TIME = 0.0;

// react state
const [cameras, setCameras] = React.useState([]);
Expand All @@ -466,6 +576,7 @@ export default function CameraPanel(props) {
const [render_modal_open, setRenderModalOpen] = React.useState(false);
const [animate, setAnimate] = React.useState(new Set());
const [globalFov, setGlobalFov] = React.useState(DEFAULT_FOV);
const [globalRenderTime, setGlobalRenderTime] = React.useState(DEFAULT_RENDER_TIME);

// leva store
const cameraPropsStore = useCreateStore();
Expand All @@ -491,6 +602,17 @@ export default function CameraPanel(props) {
);
const camera_type = useSelector((state) => state.renderingState.camera_type);

const [display_render_time, set_display_render_time] = React.useState(false);

const receive_temporal_dist = e => {
const msg = msgpack.decode(new Uint8Array(e.data));
if (msg.path === "/model/has_temporal_distortion") {
set_display_render_time(msg.data === "true");
websocket.removeEventListener("message", receive_temporal_dist);
}
}
websocket.addEventListener('message', receive_temporal_dist);

const setRenderHeight = (value) => {
dispatch({
type: 'write',
Expand Down Expand Up @@ -522,6 +644,14 @@ export default function CameraPanel(props) {
});
};

const setRenderTime = (value) => {
dispatch({
type: 'write',
path: 'renderingState/render_time',
data: parseFloat(value),
});
};

// ui state
const [fovLabel, setFovLabel] = React.useState(FOV_LABELS.FOV);

Expand All @@ -538,6 +668,7 @@ export default function CameraPanel(props) {
if (new_camera_list.length >= 1) {
set_camera_position(camera_render, new_camera_list[0].matrix);
setFieldOfView(new_camera_list[0].fov);
setRenderTime(new_camera_list[0].renderTime)
set_slider_value(slider_min);
}
};
Expand All @@ -546,6 +677,7 @@ export default function CameraPanel(props) {
const camera_main_copy = camera_main.clone();
camera_main_copy.aspect = 1.0;
camera_main_copy.fov = globalFov;
camera_main_copy.renderTime = globalRenderTime;
const new_camera_properties = new Map();
camera_main_copy.properties = new_camera_properties;
new_camera_properties.set('FOV', globalFov);
Expand Down Expand Up @@ -828,11 +960,21 @@ export default function CameraPanel(props) {

const mat = get_transform_matrix(position, lookat, up);

camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
});
if (display_render_time) {
const renderTime = curve_object.curve_render_times.getPoint(pt).z;
camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
render_time: Math.max(Math.min(renderTime, 1.0), 0.0), // clamp time values to [0, 1]
});
} else {
camera_path.push({
camera_to_world: mat.transpose().elements, // convert from col-major to row-major matrix
fov,
aspect: camera_render.aspect,
});
}
}

const keyframes = [];
Expand Down Expand Up @@ -985,6 +1127,12 @@ export default function CameraPanel(props) {
}
};

const setAllCameraRenderTime = (val) => {
for (let i = 0; i < cameras.length; i += 1) {
cameras[i].renderTime = val;
}
};

return (
<div className="CameraPanel">
<div>
Expand Down Expand Up @@ -1047,6 +1195,51 @@ export default function CameraPanel(props) {
/>
</LevaStoreProvider>
</div>
{display_render_time && (
<div className="CameraList-row-animation-properties">
<Tooltip title="Animate Render Time for Each Camera">
<Button
value="animateRenderTime"
selected={isAnimated('RenderTime')}
onClick={() => {
toggleAnimate('RenderTime');
}}
style={{
maxWidth: '20px',
maxHeight: '20px',
minWidth: '20px',
minHeight: '20px',
position: 'relative',
top: '22px',
}}
sx={{
mt: 1,
}}
>
<Animation
style={{
color: isAnimated('RenderTime') ? '#24B6FF' : '#EBEBEB',
maxWidth: '20px',
maxHeight: '20px',
minWidth: '20px',
minHeight: '20px',
}}
/>
</Button>
</Tooltip>
<RenderTimeSelector
disabled={false}
isGlobal
camera={camera_main}
dispatch={dispatch}
globalRenderTime={globalRenderTime}
setGlobalRenderTime={setGlobalRenderTime}
applyAll={!isAnimated('RenderTime')}
setAllCameraRenderTime={setAllCameraRenderTime}
changeMain
/>
</div>
)}
{camera_type !== 'equirectangular' && (
<div className="CameraList-row-animation-properties">
<Tooltip title="Animate FOV for Each Camera">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export function get_curve_object_from_cameras(
const lookats = [];
const ups = [];
const fovs = [];
const render_times = [];

for (let i = 0; i < cameras.length; i += 1) {
const camera = cameras[i];
Expand All @@ -45,23 +46,27 @@ export function get_curve_object_from_cameras(
lookats.push(lookat);
// Reuse catmullromcurve3 for 1d values. TODO fix this
fovs.push(new THREE.Vector3(0, 0, camera.fov));
render_times.push(new THREE.Vector3(0, 0, camera.renderTime));
}

let curve_positions = null;
let curve_lookats = null;
let curve_ups = null;
let curve_fovs = null;
let curve_render_times = null;

curve_positions = get_catmull_rom_curve(positions, is_cycle, smoothness_value);
curve_lookats = get_catmull_rom_curve(lookats, is_cycle, smoothness_value);
curve_ups = get_catmull_rom_curve(ups, is_cycle, smoothness_value);
curve_fovs = get_catmull_rom_curve(fovs, is_cycle, smoothness_value / 10);
curve_render_times = get_catmull_rom_curve(render_times, is_cycle, smoothness_value);

const curve_object = {
curve_positions,
curve_lookats,
curve_ups,
curve_fovs,
curve_render_times,
};
return curve_object;
}
Expand Down