diff --git a/train_network.py b/train_network.py index 3a5255160..aad5a7194 100644 --- a/train_network.py +++ b/train_network.py @@ -481,7 +481,8 @@ def save_model_hook(models, weights, output_dir): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") def load_model_hook(models, input_dir):