Skip to content

Commit

Permalink
Sample with DPM-Solver++
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Nov 4, 2022
1 parent 15717b0 commit 21afd8b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 1 addition & 3 deletions sample_clip_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def main():
help='the batch size')
p.add_argument('--checkpoint', type=str, required=True,
help='the checkpoint to use')
p.add_argument('--churn', type=float, default=50.,
help='the amount of noise to add during sampling')
p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500.,
help='the CLIP guidance scale')
p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(),
Expand Down Expand Up @@ -115,7 +113,7 @@ def run():
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
def sample_fn(n):
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0]
x_0 = K.sampling.sample_dpm_2(model_fn, x, sigmas, s_churn=args.churn, disable=not accelerator.is_local_main_process)
x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process)
return x_0
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
if accelerator.is_main_process:
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def demo():
n_per_proc = math.ceil(args.sample_n / accelerator.num_processes)
x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process)
x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=not accelerator.is_main_process)
x_0 = accelerator.gather(x_0)[:args.sample_n]
if accelerator.is_main_process:
grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0)
Expand All @@ -260,7 +260,7 @@ def evaluate():
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
def sample_fn(n):
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True)
x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=True)
return x_0
fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size)
if accelerator.is_main_process:
Expand Down

0 comments on commit 21afd8b

Please sign in to comment.