Skip to content

Commit

Permalink
Faster stable diffusion (#1364)
Browse files Browse the repository at this point in the history
* stable_diffusion example: don't ignore the arguments

* simplify code

* Lower default iterations to just 10
  • Loading branch information
emilk authored Feb 21, 2023
1 parent 5bdf912 commit 7281184
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions examples/python/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,6 @@
IMAGE_NAMES: Final = list(IMAGE_NAME_TO_URL.keys())


def run_stable_diffusion(
image_path: str, prompt: str, n_prompt: str, strength: float, guidance_scale: float, num_inference_steps: int
) -> None:

pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", local_files_only=False, cache_dir=CACHE_DIR.absolute()
)

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
pipe = pipe.to("mps")
elif torch.cuda.is_available():
pipe = pipe.to("cuda")
else:
pipe = pipe.to("cpu")

pipe.enable_attention_slicing()

image = Image.open(image_path)

pipe(
prompt=prompt,
strength=strength,
guidance_scale=11,
negative_prompt=n_prompt,
num_inference_steps=70,
image=image,
)


def get_downloaded_path(dataset_dir: Path, image_name: str) -> str:
image_url = IMAGE_NAME_TO_URL[image_name]
image_file_name = image_url.split("/")[-1]
Expand Down Expand Up @@ -118,7 +89,7 @@ def main() -> None:
parser.add_argument(
"--guidance_scale",
type=float,
default=8,
default=11,
help="""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Expand All @@ -130,7 +101,7 @@ def main() -> None:
parser.add_argument(
"--num_inference_steps",
type=int,
default=50,
default=10,
help="""
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
Expand All @@ -146,13 +117,28 @@ def main() -> None:
if not image_path:
image_path = get_downloaded_path(args.dataset_dir, args.image)

run_stable_diffusion(
image_path=image_path,
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", local_files_only=False, cache_dir=CACHE_DIR.absolute()
)

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
pipe = pipe.to("mps")
elif torch.cuda.is_available():
pipe = pipe.to("cuda")
else:
pipe = pipe.to("cpu")

pipe.enable_attention_slicing()

image = Image.open(image_path)

pipe(
prompt=args.prompt,
n_prompt=args.n_prompt,
strength=args.strength,
guidance_scale=args.guidance_scale,
negative_prompt=args.n_prompt,
num_inference_steps=args.num_inference_steps,
image=image,
)

rr.script_teardown(args)
Expand Down

0 comments on commit 7281184

Please sign in to comment.