Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations to PAG and t2i-zero #43

Merged
merged 3 commits into from
May 18, 2024
Merged

Optimizations to PAG and t2i-zero #43

merged 3 commits into from
May 18, 2024

Conversation

drhead
Copy link
Contributor

@drhead drhead commented May 17, 2024

I've made a few optimizations to PAG and t2i-zero.

For PAG, you might not notice the changes too much speed-wise, it's already unavoidably very slow. One tensor was being created on CPU then moved to device. This was a problem when using some proposed performance optimizations for A1111 particularly (AUTOMATIC1111/stable-diffusion-webui#15821). Now these changes won't break PAG.

For t2i-zero, the changes are more substantial. There were a lot of forced device syncs happening. Summarizing the changes:

  • 🗞️💥 A number of scalar values were being initialized as CUDA tensors then used as control flow. That means aten::item is called on them, which forces a device sync. None of them needed to be CUDA tensors to do anything else they were doing so I set them to all be CPU tensors, which removed the blocks.
  • 🗞️💥 When token_indices was None or [], this was creating a list of all tokens using torch.tensor(list(range(1, token_count.item()))), which causes a device sync for every single token in the sequence (!). torch.arange(1, token_count, device=output.device) is much more suitable, and creates the tensor directly on-device.
  • 🗞️💥 This one is Torchvision's fault. GaussianBlur creates the kernels every time it is run on CPU and then moves them to the GPU, which -- you guessed it -- causes a forced device sync. I submitted a patch to the torchvision repo. You'll still have that forced device sync there in the mean time. I changed it so that the module is not initialized multiple times every inference step as part of troubleshooting this -- unfortunately the kernels are created every call regardless, so this doesn't fix the problem directly and probably isn't even necessary, but I would say keep it because it is probably best that we don't initialize modules inside of a forward pass anyways. I did suggest caching the kernels, so it might help in the future.

Once the associated torchvision patch and the a1111 optimizations go through, you can expect t2i-zero to run with almost no significant overhead compared to having it disabled.

@v0xie
Copy link
Owner

v0xie commented May 18, 2024

This is great, thanks so much.

I am getting an error when I run with CTNMS enabled, perhaps related to making ema_factor cpu?

    Traceback (most recent call last):
      File "F:\stablediffusion\stable-diffusion-webui\modules\call_queue.py", line 57, in f
        res = list(func(*args, **kwargs))
      File "F:\stablediffusion\stable-diffusion-webui\modules\call_queue.py", line 36, in f
        res = func(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\txt2img.py", line 109, in txt2img
        processed = processing.process_images(p)
      File "F:\stablediffusion\stable-diffusion-webui\modules\processing.py", line 839, in process_images
        res = process_images_inner(p)
      File "F:\stablediffusion\stable-diffusion-webui\extensions\sd-webui-controlnet\scripts\batch_hijack.py", line 59, in processing_process_images_hijack
        return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\processing.py", line 975, in process_images_inner
        samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
      File "F:\stablediffusion\stable-diffusion-webui\modules\processing.py", line 1322, in sample
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 218, in sample
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_samplers_common.py", line 272, in launch_sampling
        return func()
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 218, in <lambda>
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
        return func(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\sampling.py", line 626, in sample_dpmpp_2m_sde
        denoised = model(x, sigmas[i] * s_in, **extra_args)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_samplers_cfg_denoiser.py", line 237, in forward
        x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 112, in forward
        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\external.py", line 138, in get_eps
        return self.inner_model.apply_model(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_models_xl.py", line 43, in apply_model
        return self.model(x, t, cond)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_hijack_utils.py", line 22, in <lambda>
        setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_hijack_utils.py", line 34, in __call__
        return self.__sub_func(self.__orig_func, *args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_hijack_unet.py", line 48, in apply_model
        result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\diffusionmodules\wrappers.py", line 28, in forward
        return self.diffusion_model(
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\modules\sd_unet.py", line 91, in UNetModel_forward
        return original_forward(self, x, timesteps, context, *args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\diffusionmodules\openaimodel.py", line 993, in forward
        h = module(h, emb, context)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\diffusionmodules\openaimodel.py", line 100, in forward
        x = layer(x, context)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\attention.py", line 627, in forward
        x = block(x, context=context[i])
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
        return forward_call(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\attention.py", line 459, in forward
        return checkpoint(
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\diffusionmodules\util.py", line 167, in checkpoint
        return func(*inputs)
      File "F:\stablediffusion\stable-diffusion-webui\repositories\generative-models\sgm\modules\attention.py", line 478, in _forward
        self.attn2(
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "F:\stablediffusion\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1579, in _call_impl
        hook_result = hook(self, args, kwargs, result)
      File "F:\stablediffusion\stable-diffusion-webui\extensions\sd-webui-incantations\scripts\t2i_zero.py", line 518, in cross_token_non_maximum_suppression
        ema = ema_factor * ema + (1 - ema_factor) * suppressed_attention_map
    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@drhead
Copy link
Contributor Author

drhead commented May 18, 2024

I am getting an error when I run with CTNMS enabled, perhaps related to making ema_factor cpu?

Yeah... those are probably better off as python scalars. I don't have time to test it now but what I just committed should be correct. I can look over it again tomorrow if needed.

@v0xie
Copy link
Owner

v0xie commented May 18, 2024

Works great. Thanks again!

@v0xie v0xie merged commit 1dd3b2b into v0xie:master May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants