Skip to content

Commit

Permalink
Proper fix of SD15 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 17, 2024
1 parent dca9007 commit b57a70f
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def repair_config(sd_config):
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True

if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
Expand Down Expand Up @@ -733,10 +733,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = instantiate_from_config(sd_config.model)

sd_model.used_config = checkpoint_config
# ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite
# UnetModel.dtype, it will be the default dtype from config.
# sgm's Unet is not using dtype for casting. The value will be ignored.
sd_model.model.diffusion_model.dtype = devices.dtype_unet

timer.record("create model")

Expand Down

0 comments on commit b57a70f

Please sign in to comment.