Describe the bug
I guess the implementation of logit-normal weighting is wrong
|
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) |
.
Consider that timestep is sampled uniformly first, weighting function should be a function that depends on pre-sampled timestep, however current version is just a random sample of logit-normal distribution.
Reproduction
no
Logs
No response
System Info
no
Who can help?
@DN6 @yiyixuxu @sayakpaul @DN6 cc.