-
Notifications
You must be signed in to change notification settings - Fork 928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Val loss improvement #1903
base: sd3
Are you sure you want to change the base?
Val loss improvement #1903
Conversation
kohya-ss
commented
Jan 27, 2025
•
edited
Loading
edited
- train/eval state for the network and the optimizer.
- stable timesteps
- stable noise
- support block swap
I love the approach to holding the rng_state aside, setting the validation state using the validation seed, and then restoring the rng_state afterwards. It's much more elegant than tracking the state separately and has no overhead. |
I would also add that once this is put in place, there won't be a need for a moving average to track the validation loss. Using consistent timesteps and noise will make it almost entirely stable, so displaying the mean of the validation loss amounts for each validation run should be all that's needed. Since the validation set is subject to change if the core dataset changes, I've found tracking the validation loss relative to the initial loss is also helpful to make progress across different training runs comparable. |
This looks great! What are you using for formatting the code? I've been manually formatting but might be easier to align the formatting if I use the same formatting tool. |
That makes sense. Currently, there is a problem viewing logs in TensorBoard, but I would like to at least get the mean of the validation loss to be displayed correctly.
For formatting, I use black with the |
It seems that correction for timestep sampling works better (I previously used debiased 1/√SNR, which is similar in meaning). Additionally, I have some thoughts on the args. |
https://github.com/[spacepxl/demystifying-sd-finetuning](https://github.com/spacepxl/demystifying-sd-finetuning) |
This makes some sense.
Although it means giving multiple meanings to a single setting value, it is worth considering. |
@gesen2egee you would need a different fit equation for each new model, and it's not really relevant when you make validation fully deterministic. I've tried applying it to training loss and it was extremely harmful. You can also visualize the raw training loss by plotting it like so: That was done by storing all loss and timestep values, and coloring them by training step. Not sure if there's a way to do that natively in tensorboard/wandb, I did this with matplotlib and just logged it as an image. |
In get_timesteps maybe if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.ones(b_size, device="cpu") * min_timestep I know this isn't completed but I tried it anyways. |
To add to this, here's a rough example of how loss per timestep looks like on flux at 1024x1024: |
I just realized that if there are any repeats in the dataset, validation error will not be calculated properly as the images used for validation can still exist in the training data. Guidance should likely be provided that repeats should not be used when validation error is being calculated, otherwise it will not be a useful metric for identifying overfitting. Since the validation set is now going through a dataloader, it might also be easier to set up a separate directory of validation images to use. |
Repeats are done after the split. I also made sure regularization images do not go into the validation dataset when splitting. |
Nice! That's good to hear! |
I am not able to get a stable loss from this PR: In this, I calculate the average loss on each validation run and graph them out. As you can see, the line is quite volatile, and has been on all of the attempts I've tried so far. It could be user error. Here are the arguments I'm using: --validation_seed 43 --validation_split .05 --validate_every_n_steps 100 To calculate the error per cycle:
I also tried printing the individual losses and based on their magnitudes, it looks like the same images are being run each cycle. The timesteps are being set manually, so that's not the variance, which means it must be the noise which is varying from one iteration to the next. I'll do some additional testing tomorrow to identify if it is the noise and if so, why it would be varying in consecutive runs even though the same seed is used. |
In SDXL training, the random numbers are generated using the device, so this may be the cause. sd-scripts/library/train_util.py Line 5945 in 45ec02b
Also, torch.manual_seed initializes the random seed for all devices, but rng_state seems to only work for CPU. https://pytorch.org/docs/stable/generated/torch.manual_seed.html#torch.manual_seed Although it would be a breaking change, it would be better to unify the way random numbers are generated. |
I probably should have specified, but I am doing Flux LoRAs in this case, not SDXL. Also, if manual_seed is setting it for all devices, then it should be providing the desired consistency, so it's odd that it's not. I'll do some digging to confirm whether what I'm experiencing is even due to the noise, or whether it's something else. There is always the possibility of caching the noise and using it in repeat runs, which doesn't interfere with any random state. Based on my initial testing, this didn't seem to have a significant overhead, but isn't quite as pretty as just setting the seed. |
After some digging, the noise is absolutely consistent over time, but the noisy_model_input is not. I'm trying to understand why it is changing from one iteration to the next. |
That's the culprit. It wasn't the noise changing, it was the latents changing! Digging deeper:
random.randint(0, range) It's due to the random cropping coming from random and not torch, so setting the manual_seed doesn't fix the cropping in place. I'm going to turn it off and do a longer run to confirm. If you want to test it yourself, you'll need the fix from 8dfb42f to allow flux to run with validation and without latents cached. Turning it off leads to a perfectly smooth and beautiful loss curve: |
…or consistent validation
Fixed timestep generation in SD/SDXL. |
Fix Python random seed is not set. I think it's probably ready to be merged, so please let me know if you have any suggestions. |
For the main progress_bar if you do https://tqdm.github.io/docs/tqdm/#unpause Otherwise has been working for me and block swap is working as well. |
Not at my computer today to test, but were the fixes from #1900 implemented to fix the issue with running validation when latents are not cached and to fix bar positioning? |