Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new community pipeline for 'Adaptive Mask Inpainting', introduced in [ECCV2024] ComA #9228

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions adaptive_mask_inpainting_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
message = """

Example Demo of Adaptive Mask Inpainting

Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models
Kim et al.
ECCV-2024 (Oral)


Please prepare the environment via

```
conda create --name ami python=3.9
conda activate ami

conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
pip install easydict
pip install diffusers==0.20.2 accelerate safetensors transformers
pip install setuptools==59.5.0
pip install opencv-python
```


Put the code inside the root of diffusers library (i.e., as '/home/username/diffusers/adaptive_mask_inpainting_example.py') and run the python code.




"""
print(message)


import numpy as np
import torch
from easydict import EasyDict
from PIL import Image


from diffusers import DDIMScheduler
from diffusers import DiffusionPipeline
from diffusers.utils import load_image


from examples.community.adaptive_mask_inpainting import(
download_file,
AdaptiveMaskInpaintPipeline,
PointRendPredictor,
MaskDilateScheduler,
ProvokeScheduler,
)




if __name__ == "__main__":
"""
Download Necessary Files
Copy link
Member

@asomoza asomoza Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all this example part should go inside the community pipelines README. This file can't go in the root of the project and in general, we don't use a separate example file for each community pipeline

"""
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
output_file = "model_final_edd263.pkl",
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml"
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
output_file = "input_img.png"
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
output_file = "input_mask.png"
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
output_file = "Base-PointRend-RCNN-FPN.yaml"
)
download_file(
url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
output_file = "Base-RCNN-FPN.yaml",
)

"""
Prepare Adaptive Mask Inpainting Pipeline
"""
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_steps = 50

# Scheduler
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can simplify this by just changing the scheduler, no need to add the same config manually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm kinda new to contributing to huggingface. What do you mean by changing the scheduler? Do you mean changing the config?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, no problem, normally we can just change the scheduler with any other scheduler just using it like this:

from diffusers import DDIMScheduler
....
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

which is cleaner and simpler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if it has a different config (which doesn't seem the case here) you can just add the changed param as an additional kwarg

)
scheduler.set_timesteps(num_inference_steps=num_steps)

## load models as pipelines
pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
scheduler=scheduler,
torch_dtype=torch.float16,
requires_safety_checker=False
).to(device)

## disable safety checker
enable_safety_checker = False
if not enable_safety_checker:
pipeline.safety_checker = None

# declare segmentation model used for mask adaptation
use_visualizer = False
assert not use_visualizer, \
"""
If you plan to 'use_visualizer', USE WITH CAUTION.
It creates a directory of images and masks, which is used for merging into a video.
The procedure involves deleting the directory of images, which means that
if you set the directory wrong you can have other important files blown away.
"""

adaptive_mask_model = PointRendPredictor(
pointrend_thres=0.2,
device="cuda" if torch.cuda.is_available() else "cpu",
use_visualizer=use_visualizer,
config_pth="pointrend_rcnn_R_50_FPN_3x_coco.yaml",
weights_pth="model_final_edd263.pkl",
)
pipeline.register_adaptive_mask_model(adaptive_mask_model)
Copy link
Member

@asomoza asomoza Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably remove this from here and add it as code inside the pipeline to make it easier for the user. But it's your decision, if you do, you should probably download the files inside the pipeline too then.


step_num = int(num_steps * 0.1)
final_step_num = num_steps - step_num * 7
# adaptive mask settings
adaptive_mask_settings = EasyDict(
dict(
dilate_scheduler=MaskDilateScheduler(
max_dilate_num=20,
num_inference_steps=num_steps,
schedule=[20] * step_num + [10] * step_num + [5] * step_num + [4] * step_num + [3] * step_num + [2] * step_num + [1] * step_num + [0] * final_step_num
),
dilate_kernel=np.ones((3, 3), dtype=np.uint8),
provoke_scheduler=ProvokeScheduler(
num_inference_steps=num_steps,
schedule=list(range(2, 10 + 1, 2)) + list(range(12, 40 + 1, 2)) + [45],
is_zero_indexing=False,
),
)
)
pipeline.register_adaptive_mask_settings(adaptive_mask_settings)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough feedback! I've applied them, and will PR again.

(Please check minor questions above)


"""
Run Adaptive Mask Inpainting
"""
default_mask_image = Image.open("./input_mask.png").convert("L")
init_image = Image.open("./input_img.png").convert("RGB")


seed = 46
generator = torch.Generator(device=device)
generator.manual_seed(seed)

image = pipeline(
prompt="a man sitting on a couch",
negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
image=init_image,
default_mask_image=default_mask_image,
guidance_scale=11.0,
strength=0.98,
use_adaptive_mask=True,
generator=generator,
enforce_full_mask_ratio=0.0,
visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T EVER CHANGE THIS!!!
human_detection_thres=0.015,
).images[0]


image.save(f'final_img.png')
Loading