Skip to content

Commit ed281c7

Browse files
authored
Merge pull request #65 from prs-eth/pipeline_variable
Add pipeline variables
2 parents dfc2e11 + c91a70a commit ed281c7

File tree

6 files changed

+262
-137
lines changed

6 files changed

+262
-137
lines changed

README.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,10 @@ Activate the environment again after restarting the terminal session.
125125

126126
### 🚀 Run inference with LCM (faster)
127127

128-
The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 to 4:
128+
The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 (default) to 4. Run with default LCM setting:
129129

130130
```bash
131131
python run.py \
132-
--denoise_steps 4 \
133-
--ensemble_size 5 \
134132
--input_rgb_dir input/in-the-wild_example \
135133
--output_dir output/in-the-wild_example_lcm
136134
```
@@ -156,11 +154,11 @@ The default settings are optimized for the best result. However, the behavior of
156154

157155
- Trade-offs between the **accuracy** and **speed** (for both options, larger values result in better accuracy at the cost of slower inference.)
158156
- `--ensemble_size`: Number of inference passes in the ensemble. For LCM `ensemble_size` is more important than `denoise_steps`. Default: ~~10~~ 5 (for LCM).
159-
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. Default: ~~10~~ 4 (for LCM).
157+
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. When unassigned (`None`), will read default setting from model config. Default: ~~10 4 (for LCM)~~ `None`.
160158
161159
- By default, the inference script resizes input images to the *processing resolution*, and then resizes the prediction back to the original resolution. This gives the best quality, as Stable Diffusion, from which Marigold is derived, performs best at 768x768 resolution.
162160
163-
- `--processing_res`: the processing resolution; set 0 to process the input resolution directly. Default: 768.
161+
- `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`.
164162
- `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False.
165163
- `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`.
166164

infer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Last modified: 2024-04-15
1+
# Last modified: 2024-05-24
22
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -213,6 +213,9 @@ def check_directory(directory):
213213
logging.debug("run without xformers")
214214

215215
pipe = pipe.to(device)
216+
logging.info(
217+
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
218+
)
216219

217220
# -------------------- Inference and saving --------------------
218221
with torch.no_grad():

marigold/marigold_pipeline.py

+79-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2+
# Last modified: 2024-05-24
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.
@@ -19,7 +20,7 @@
1920

2021

2122
import logging
22-
from typing import Dict, Union
23+
from typing import Dict, Optional, Union
2324

2425
import numpy as np
2526
import torch
@@ -33,13 +34,13 @@
3334
from diffusers.utils import BaseOutput
3435
from PIL import Image
3536
from torch.utils.data import DataLoader, TensorDataset
36-
from torchvision.transforms.functional import resize, pil_to_tensor
3737
from torchvision.transforms import InterpolationMode
38+
from torchvision.transforms.functional import pil_to_tensor, resize
3839
from tqdm.auto import tqdm
3940
from transformers import CLIPTextModel, CLIPTokenizer
4041

4142
from .util.batchsize import find_batch_size
42-
from .util.ensemble import ensemble_depths
43+
from .util.ensemble import ensemble_depth
4344
from .util.image_util import (
4445
chw2hwc,
4546
colorize_depth_maps,
@@ -85,6 +86,25 @@ class MarigoldPipeline(DiffusionPipeline):
8586
Text-encoder, for empty text embedding.
8687
tokenizer (`CLIPTokenizer`):
8788
CLIP tokenizer.
89+
scale_invariant (`bool`, *optional*):
90+
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
91+
the model config. When used together with the `shift_invariant=True` flag, the model is also called
92+
"affine-invariant". NB: overriding this value is not supported.
93+
shift_invariant (`bool`, *optional*):
94+
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
95+
the model config. When used together with the `scale_invariant=True` flag, the model is also called
96+
"affine-invariant". NB: overriding this value is not supported.
97+
default_denoising_steps (`int`, *optional*):
98+
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
99+
quality with the given model. This value must be set in the model config. When the pipeline is called
100+
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
101+
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
102+
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
103+
default_processing_resolution (`int`, *optional*):
104+
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
105+
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
106+
default value is used. This is required to ensure reasonable results with various model flavors trained
107+
with varying optimal processing resolution values.
88108
"""
89109

90110
rgb_latent_scale_factor = 0.18215
@@ -97,26 +117,40 @@ def __init__(
97117
scheduler: Union[DDIMScheduler, LCMScheduler],
98118
text_encoder: CLIPTextModel,
99119
tokenizer: CLIPTokenizer,
120+
scale_invariant: Optional[bool] = True,
121+
shift_invariant: Optional[bool] = True,
122+
default_denoising_steps: Optional[int] = None,
123+
default_processing_resolution: Optional[int] = None,
100124
):
101125
super().__init__()
102-
103126
self.register_modules(
104127
unet=unet,
105128
vae=vae,
106129
scheduler=scheduler,
107130
text_encoder=text_encoder,
108131
tokenizer=tokenizer,
109132
)
133+
self.register_to_config(
134+
scale_invariant=scale_invariant,
135+
shift_invariant=shift_invariant,
136+
default_denoising_steps=default_denoising_steps,
137+
default_processing_resolution=default_processing_resolution,
138+
)
139+
140+
self.scale_invariant = scale_invariant
141+
self.shift_invariant = shift_invariant
142+
self.default_denoising_steps = default_denoising_steps
143+
self.default_processing_resolution = default_processing_resolution
110144

111145
self.empty_text_embed = None
112146

113147
@torch.no_grad()
114148
def __call__(
115149
self,
116150
input_image: Union[Image.Image, torch.Tensor],
117-
denoising_steps: int = 10,
118-
ensemble_size: int = 10,
119-
processing_res: int = 768,
151+
denoising_steps: Optional[int] = None,
152+
ensemble_size: int = 5,
153+
processing_res: Optional[int] = None,
120154
match_input_res: bool = True,
121155
resample_method: str = "bilinear",
122156
batch_size: int = 0,
@@ -131,18 +165,21 @@ def __call__(
131165
Args:
132166
input_image (`Image`):
133167
Input RGB (or gray-scale) image.
134-
processing_res (`int`, *optional*, defaults to `768`):
135-
Maximum resolution of processing.
136-
If set to 0: will not resize at all.
168+
denoising_steps (`int`, *optional*, defaults to `None`):
169+
Number of denoising diffusion steps during inference. The default value `None` results in automatic
170+
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
171+
for Marigold-LCM models.
172+
ensemble_size (`int`, *optional*, defaults to `10`):
173+
Number of predictions to be ensembled.
174+
processing_res (`int`, *optional*, defaults to `None`):
175+
Effective processing resolution. When set to `0`, processes at the original image resolution. This
176+
produces crisper predictions, but may also lead to the overall loss of global context. The default
177+
value `None` resolves to the optimal value from the model config.
137178
match_input_res (`bool`, *optional*, defaults to `True`):
138179
Resize depth prediction to match input resolution.
139180
Only valid if `processing_res` > 0.
140181
resample_method: (`str`, *optional*, defaults to `bilinear`):
141182
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
142-
denoising_steps (`int`, *optional*, defaults to `10`):
143-
Number of diffusion denoising steps (DDIM) during inference.
144-
ensemble_size (`int`, *optional*, defaults to `10`):
145-
Number of predictions to be ensembled.
146183
batch_size (`int`, *optional*, defaults to `0`):
147184
Inference batch size, no bigger than `num_ensemble`.
148185
If set to 0, the script will automatically decide the proper batch size.
@@ -152,6 +189,10 @@ def __call__(
152189
Display a progress bar of diffusion denoising.
153190
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
154191
Colormap used to colorize the depth map.
192+
scale_invariant (`str`, *optional*, defaults to `True`):
193+
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
194+
shift_invariant (`str`, *optional*, defaults to `True`):
195+
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
155196
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
156197
Arguments for detailed ensembling settings.
157198
Returns:
@@ -161,6 +202,12 @@ def __call__(
161202
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
162203
coming from ensembling. None if `ensemble_size = 1`
163204
"""
205+
# Model-specific optimal default values leading to fast and reasonable results.
206+
if denoising_steps is None:
207+
denoising_steps = self.default_denoising_steps
208+
if processing_res is None:
209+
processing_res = self.default_processing_resolution
210+
164211
assert processing_res >= 0
165212
assert ensemble_size >= 1
166213

@@ -175,14 +222,15 @@ def __call__(
175222
input_image = input_image.convert("RGB")
176223
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
177224
rgb = pil_to_tensor(input_image)
225+
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
178226
elif isinstance(input_image, torch.Tensor):
179-
rgb = input_image.squeeze()
227+
rgb = input_image
180228
else:
181229
raise TypeError(f"Unknown input type: {type(input_image) = }")
182230
input_size = rgb.shape
183231
assert (
184-
3 == rgb.dim() and 3 == input_size[0]
185-
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
232+
4 == rgb.dim() and 3 == input_size[-3]
233+
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
186234

187235
# Resize image
188236
if processing_res > 0:
@@ -199,7 +247,7 @@ def __call__(
199247

200248
# ----------------- Predicting depth -----------------
201249
# Batch repeated input image
202-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
250+
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
203251
single_rgb_dataset = TensorDataset(duplicated_rgb)
204252
if batch_size > 0:
205253
_bs = batch_size
@@ -231,35 +279,36 @@ def __call__(
231279
generator=generator,
232280
)
233281
depth_pred_ls.append(depth_pred_raw.detach())
234-
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
282+
depth_preds = torch.concat(depth_pred_ls, dim=0)
235283
torch.cuda.empty_cache() # clear vram cache for ensembling
236284

237285
# ----------------- Test-time ensembling -----------------
238286
if ensemble_size > 1:
239-
depth_pred, pred_uncert = ensemble_depths(
240-
depth_preds, **(ensemble_kwargs or {})
287+
depth_pred, pred_uncert = ensemble_depth(
288+
depth_preds,
289+
scale_invariant=self.scale_invariant,
290+
shift_invariant=self.shift_invariant,
291+
max_res=50,
292+
**(ensemble_kwargs or {}),
241293
)
242294
else:
243295
depth_pred = depth_preds
244296
pred_uncert = None
245297

246-
# ----------------- Post processing -----------------
247-
# Scale prediction to [0, 1]
248-
min_d = torch.min(depth_pred)
249-
max_d = torch.max(depth_pred)
250-
depth_pred = (depth_pred - min_d) / (max_d - min_d)
251-
252298
# Resize back to original resolution
253299
if match_input_res:
254300
depth_pred = resize(
255-
depth_pred.unsqueeze(0),
256-
input_size[1:],
301+
depth_pred,
302+
input_size[-2:],
257303
interpolation=resample_method,
258304
antialias=True,
259-
).squeeze()
305+
)
260306

261307
# Convert to numpy
308+
depth_pred = depth_pred.squeeze()
262309
depth_pred = depth_pred.cpu().numpy()
310+
if pred_uncert is not None:
311+
pred_uncert = pred_uncert.squeeze().cpu().numpy()
263312

264313
# Clip output range
265314
depth_pred = depth_pred.clip(0, 1)

0 commit comments

Comments
 (0)