Skip to content

[SD3] the training script of SD3 dreambooth has wrong logit-normal weighting #8534

@Luciennnnnnn

Description

@Luciennnnnnn

Describe the bug

I guess the implementation of logit-normal weighting is wrong

u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
.

Consider that timestep is sampled uniformly first, weighting function should be a function that depends on pre-sampled timestep, however current version is just a random sample of logit-normal distribution.

Reproduction

no

Logs

No response

System Info

no

Who can help?

@DN6 @yiyixuxu @sayakpaul @DN6 cc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions