Skip to content

Commit

Permalink
realizing jax is only for linux
Browse files Browse the repository at this point in the history
  • Loading branch information
codesavory committed May 7, 2023
1 parent 066d085 commit e85ef3c
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,30 +1,108 @@

import carb
import torch
import asyncio
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionImg2ImgPipeline

async def generateImage(progress_widget, image_prompt: str, image_widget):
carb.log_info("Stable Diffusion Stage")
async def generateTextToImage(progress_widget, outputImage_widget, image_prompt: str):
carb.log_info("Stable Diffusion Stage: Text to Image")

if (len(image_prompt) != 0):
run_loop = asyncio.get_event_loop()
progress_widget.show_bar(True)
task = run_loop.create_task(progress_widget.play_anim_forever())

print("creating image with prompt: "+image_prompt)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)
model_id = "runwayml/stable-diffusion-v1-5"
#model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to("cuda")
#prompt = "a photograph of an astronaut riding a horse"
prompt = image_prompt
image = pipe(prompt).images[0]
image = pipe(image_prompt).images[0]
# you can save the image with
image_url = "D:\\CG_Source\\Omniverse\\Extensions\\3DAvatarExtensionPOC\\stable3D\\"+prompt.replace(" ", "")+".png"
image_url = "D:\\CG_Source\\Omniverse\\Extensions\\3DAvatarExtensionPOC\\stable3D\\"+image_prompt.replace(" ", "")+".png"
image.save(image_url)
print("image created")

task.cancel()
await asyncio.sleep(1)
# todo bug fix: reload image if same prompt
image_widget.source_url = image_url
progress_widget.show_bar(False)
outputImage_widget.source_url = image_url
progress_widget.show_bar(False)

async def generateImageToImage(progress_widget, outputImage_widget, image_prompt: str, inputImageUrl):
carb.log_info("Stable Diffusion Stage: Image to Image")

if (len(image_prompt) != 0):
run_loop = asyncio.get_event_loop()
progress_widget.show_bar(True)
task = run_loop.create_task(progress_widget.play_anim_forever())

print("creating image with prompt+image: "+image_prompt)
model_id = "runwayml/stable-diffusion-v1-5"
#model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to("cuda")

init_image = Image.open(inputImageUrl).convert("RGB")
init_image = init_image.resize((768, 512))
image = pipe(prompt=image_prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]

image_url = "D:\\CG_Source\\Omniverse\\Extensions\\3DAvatarExtensionPOC\\stable3D\\"+image_prompt.replace(" ", "")+".png"
image.save(image_url)
print("image created")

task.cancel()
await asyncio.sleep(1)
# todo bug fix: reload image if same prompt
outputImage_widget.source_url = image_url
progress_widget.show_bar(False)

def create_key(seed=0):
return jax.random.PRNGKey(seed)

def generateImageToImage2(progress_widget, outputImage_widget, image_prompt: str, inputImageUrl):
rng = create_key(0)

init_img = Image.open(inputImageUrl).convert("RGB")
init_img = init_img.resize((768, 512))

prompts = image_prompt

pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="flax",
dtype=jnp.bfloat16,
)

num_samples = jax.device_count()
rng = jax.random.split(rng, jax.device_count())
prompt_ids, processed_image = pipeline.prepare_inputs(
prompt=[prompts] * num_samples, image=[init_img] * num_samples
)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)

output = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
strength=0.75,
num_inference_steps=50,
jit=True,
height=512,
width=768,
).images

output_image = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))[0]
image_url = "D:\\CG_Source\\Omniverse\\Extensions\\3DAvatarExtensionPOC\\stable3D\\"+image_prompt.replace(" ", "")+".png"
output_image.save(image_url)
print("image created")
36 changes: 34 additions & 2 deletions exts/suriya.avatar.generator/suriya/avatar/generator/window.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import omni.ext
import omni.ui as ui
import asyncio
from .stablediffusion import generateImage
from .stablediffusion import generateTextToImage, generateImageToImage
from .widgets import ProgressBar

class AvatarWindow(ui.Window):
Expand All @@ -14,7 +14,34 @@ def __init__(self, title: str, **kwargs) -> None:

def _generate(self):
run_loop = asyncio.get_event_loop()
run_loop.create_task(generateImage(self.progress, self.prompt.as_string, self.stableImage))
run_loop.create_task(generateTextToImage(self.progress, self.stableImage, self.prompt.as_string))

def _generateI(self):
run_loop = asyncio.get_event_loop()
run_loop.create_task(generateImageToImage2(self.progress, self.stableImage, self.prompt.as_string, self.inputImage.source_url))

#-------------------------Drag and Drop Functions----------------------------
def drop_accept(self, url):
print("drop accept")
# accepts drop of specific extension only
return True

def drop(self, widget, event):
print("drop")
# called when dropping the data
widget.source_url = event.mime_data

def drop_area(self):
print("drop area")
# a drop area that shows image when drop
stack = ui.ZStack()
with stack:
ui.Rectangle()
#ui.Label(f"Accepts {ext}")
self.inputImage = ui.Image()

self.inputImage.set_accept_drop_fn(lambda d: self.drop_accept(d))
self.inputImage.set_drop_fn(lambda a, w=self.inputImage: self.drop(w, a))

def _build_fn(self):
with self.frame:
Expand All @@ -23,8 +50,13 @@ def _build_fn(self):
with ui.HStack(height=0):
ui.StringField(model=self.prompt)

with ui.HStack():
self.drop_area()
#self.inputImage

with ui.HStack(height=0):
ui.Button("Generate", clicked_fn=self._generate)
ui.Button("Generate+I", clicked_fn=self._generateI)
self.progress = ProgressBar()
with ui.HStack():
self.stableImage = ui.Image()
Binary file added image archives/imageofapersonsmokingapot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/Flintstones_mustache.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/base.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed skatingonice.png
Binary file not shown.

0 comments on commit e85ef3c

Please sign in to comment.