From 7b54e09164533601468abfcb2daa8ff34317b980 Mon Sep 17 00:00:00 2001 From: Jimmy <39@🇺🇸.com> Date: Wed, 2 Oct 2024 20:09:04 -0400 Subject: [PATCH 01/14] Add the ability to use a Beta schedule to select Flux timesteps --- helpers/configuration/cmd_args.py | 24 ++++++++++++++++++++++++ helpers/training/trainer.py | 16 +++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 4be1ffce..b09704e7 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -148,6 +148,30 @@ def get_argument_parser(): " which has improved results in short experiments. Thanks to @mhirki for the contribution." ), ) + parser.add_argument( + "--flux_use_beta_schedule", + action="store_true", + help=( + "Whether or not to use a beta schedule with Flux instead of sigmoid. The default values of alpha" + " and beta approximate a sigmoid." + ), + ) + parser.add_argument( + "--flux_beta_schedule_alpha", + type=float, + default=2.0, + help=( + "The alpha value of the flux beta schedule. Default is 2.0" + ), + ) + parser.add_argument( + "--flux_beta_schedule_beta", + type=float, + default=2.0, + help=( + "The beta value of the flux beta schedule. Default is 2.0" + ), + ) parser.add_argument( "--flux_schedule_shift", type=float, diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 2ef7e143..78d405f0 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -81,6 +81,7 @@ from accelerate import Accelerator from accelerate.utils import set_seed from configure import model_classes +from torch.distributions import Beta try: from lycoris import LycorisNetwork @@ -1916,7 +1917,7 @@ def train(self): ) training_logger.debug(f"Working on batch size: {bsz}") if self.config.flow_matching: - if not self.config.flux_fast_schedule: + if not self.config.flux_fast_schedule and not self.config.flux_use_beta_schedule: # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF # also used by: https://github.com/XLabs-AI/x-flux/tree/main # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f @@ -1924,6 +1925,19 @@ def train(self): self.config.flow_matching_sigmoid_scale * torch.randn((bsz,), device=self.accelerator.device) ) + sigmas = apply_flux_schedule_shift( + self.config, self.noise_scheduler, sigmas, noise + ) + elif self.config.flux_use_beta_schedule: + alpha = self.config.flux_beta_schedule_alpha + beta = self.config.flux_beta_schedule_beta + + # Create a Beta distribution instance + beta_dist = Beta(alpha, beta) + + # Sample from the Beta distribution + sigmas = beta_dist.sample((bsz,)).to(device=self.accelerator.device) + sigmas = apply_flux_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) From d2e47058d00ab0c2cf2bd6933b95ef03e3b2a9ac Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 4 Oct 2024 09:05:13 -0600 Subject: [PATCH 02/14] bnb optims when possible to load them --- helpers/training/optimizer_param.py | 216 ++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index 7e3b8bc4..9c9c05a5 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -44,6 +44,17 @@ "Could not load optimi library. Please install `torch-optimi` for better memory efficiency." ) +is_bitsandbytes_available = False +try: + import bitsandbytes + + is_bitsandbytes_available = True +except: + if torch.cuda.is_available(): + logger.warning( + "Could not load bitsandbytes library. BnB-specific optimisers and other functionality will be unavailable." + ) + optimizer_choices = { "adamw_bf16": { "precision": "bf16", @@ -255,6 +266,211 @@ "class": SOAP, }, } + +if is_bitsandbytes_available: + optimizer_choices.update( + { + "bnb-adagrad": { + "precision": "any", + "default_settings": { + "lr": 0.01, + "lr_decay": 0, + "weight_decay": 0, + "initial_accumulator_value": 0, + "eps": 1e-10, + "optim_bits": 32, + "args": None, + "min_8bit_size": 4096, + "percentile_clipping": 100, + "block_wise": True, + }, + "class": bitsandbytes.optim.Adagrad, + }, + "bnb-adagrad8bit": { + "precision": "any", + "default_settings": { + "lr": 0.01, + "lr_decay": 0, + "weight_decay": 0, + "initial_accumulator_value": 0, + "eps": 1e-10, + "optim_bits": 32, + "args": None, + "min_8bit_size": 4096, + "percentile_clipping": 100, + "block_wise": True, + }, + "class": bitsandbytes.optim.Adagrad8bit, + }, + "bnb-adam": { + "precision": "any", + "default_settings": { + "lr": 0.001, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + "min_8bit_size": 4096, + "percentile_clipping": 100, + "block_wise": True, + "is_paged": False, + }, + "class": bitsandbytes.optim.Adam, + }, + "bnb-adam8bit": { + "precision": "any", + "default_settings": { + "lr": 0.001, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + "min_8bit_size": 4096, + "percentile_clipping": 100, + "block_wise": True, + "is_paged": False, + }, + "class": bitsandbytes.optim.Adam8bit, + }, + "bnb-adamw": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999), + "weight_decay": 1e-2, + "eps": 1e-6, + }, + "class": bitsandbytes.optim.AdamW, + }, + "bnb-adamw8bit": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999), + "weight_decay": 1e-2, + "eps": 1e-6, + }, + "class": bitsandbytes.optim.AdamW8bit, + }, + "bnb-adamw-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999), + "weight_decay": 1e-2, + "eps": 1e-6, + }, + "class": bitsandbytes.optim.PagedAdamW, + }, + "bnb-adamw8bit-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999), + "weight_decay": 1e-2, + "eps": 1e-6, + }, + "class": bitsandbytes.optim.PagedAdamW8bit, + }, + "bnb-ademamix": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999, 0.9999), + "alpha": 5.0, + "t_alpha": None, + "t_beta3": None, + "eps": 1e-08, + "weight_decay": 0.01, + "optim_bits": 32, + "min_8bit_size": 4096, + "is_paged": False, + }, + "class": bitsandbytes.optim.AdEMAMix, + }, + "bnb-ademamix8bit": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999, 0.9999), + "alpha": 5.0, + "t_alpha": None, + "t_beta3": None, + "eps": 1e-08, + "weight_decay": 0.01, + "optim_bits": 32, + "min_8bit_size": 4096, + "is_paged": False, + }, + "class": bitsandbytes.optim.AdEMAMix8bit, + }, + "bnb-ademamix-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999, 0.9999), + "alpha": 5.0, + "t_alpha": None, + "t_beta3": None, + "eps": 1e-08, + "weight_decay": 0.01, + "optim_bits": 32, + "min_8bit_size": 4096, + }, + "class": bitsandbytes.optim.PagedAdEMAMix, + }, + "bnb-ademamix8bit-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.999, 0.9999), + "alpha": 5.0, + "t_alpha": None, + "t_beta3": None, + "eps": 1e-08, + "weight_decay": 0.01, + "optim_bits": 32, + "min_8bit_size": 4096, + }, + "class": bitsandbytes.optim.PagedAdEMAMix8bit, + }, + "bnb-lion": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.99), + "weight_decay": 0.0, + "optim_bits": 32, + "min_8bit_size": 4096, + "is_paged": False, + }, + "class": bitsandbytes.optim.Lion, + }, + "bnb-lion8bit": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.99), + "weight_decay": 0.0, + "optim_bits": 32, + "min_8bit_size": 4096, + "is_paged": False, + }, + "class": bitsandbytes.optim.Lion8bit, + }, + "bnb-lion-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.99), + "weight_decay": 0.0, + "optim_bits": 32, + "min_8bit_size": 4096, + }, + "class": bitsandbytes.optim.PagedLion, + }, + "bnb-lion8bit-paged": { + "precision": "any", + "default_settings": { + "betas": (0.9, 0.99), + "weight_decay": 0.0, + "optim_bits": 32, + "min_8bit_size": 4096, + }, + "class": bitsandbytes.optim.PagedLion8bit, + }, + } + ) + args_to_optimizer_mapping = { "use_adafactor_optimizer": "adafactor", "use_prodigy_optimizer": "prodigy", From 3f86c8da0435e31881a2a26c9ddf1885a767f4ad Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 4 Oct 2024 09:16:18 -0600 Subject: [PATCH 03/14] update bnb for ademamix --- install/rocm/poetry.lock | 21 ++++++++++++++++++++- install/rocm/pyproject.toml | 1 + poetry.lock | 18 +++++++++++------- pyproject.toml | 2 +- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/install/rocm/poetry.lock b/install/rocm/poetry.lock index 9f7bfcde..4c1cc2a4 100644 --- a/install/rocm/poetry.lock +++ b/install/rocm/poetry.lock @@ -227,6 +227,25 @@ tests = ["attrs[tests-no-zope]", "zope-interface"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +[[package]] +name = "bitsandbytes" +version = "0.44.1" +description = "k-bit optimizers and matrix multiplication routines." +optional = false +python-versions = "*" +files = [ + {file = "bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2f24c6cbf11fc8c5d69b3dcecee9f7011451ec59d6ac833e873c9f105259668"}, + {file = "bitsandbytes-0.44.1-py3-none-win_amd64.whl", hash = "sha256:8e68e12aa25d2cf9a1730ad72890a5d1a19daa23f459a6a4679331f353d58cb4"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["lion-pytorch", "scipy"] + [[package]] name = "boto3" version = "1.35.24" @@ -4113,4 +4132,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "90e439290743857506f87e10db930629bd13a0c642cf493ae01ac6a28fc09351" +content-hash = "7b3e351bb5248016430e209937848d963dac32d950f691463d74b0fd82687d84" diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index 0165d8e7..96cf16e2 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -43,6 +43,7 @@ optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"} lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"} torch-optimi = "^0.2.1" fastapi = {extras = ["standard"], version = "^0.115.0"} +bitsandbytes = "^0.44.1" [build-system] diff --git a/poetry.lock b/poetry.lock index 207f5e15..de2cc6d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -244,13 +244,13 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "bitsandbytes" -version = "0.43.3" +version = "0.44.1" description = "k-bit optimizers and matrix multiplication routines." optional = false python-versions = "*" files = [ - {file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"}, - {file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"}, + {file = "bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2f24c6cbf11fc8c5d69b3dcecee9f7011451ec59d6ac833e873c9f105259668"}, + {file = "bitsandbytes-0.44.1-py3-none-win_amd64.whl", hash = "sha256:8e68e12aa25d2cf9a1730ad72890a5d1a19daa23f459a6a4679331f353d58cb4"}, ] [package.dependencies] @@ -259,7 +259,7 @@ torch = "*" [package.extras] benchmark = ["matplotlib", "pandas"] -test = ["scipy"] +test = ["lion-pytorch", "scipy"] [[package]] name = "boto3" @@ -2128,9 +2128,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, ] [[package]] @@ -2223,8 +2223,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4032,9 +4032,13 @@ optional = false python-versions = ">=3.8" files = [ {file = "torchvision-0.20.0.dev20240929+cu124-cp310-cp310-linux_x86_64.whl", hash = "sha256:f525a61b532baf70b9f798b46c951143ee2c103c529b83365a2995fb5a4d3aa6"}, + {file = "torchvision-0.20.0.dev20240929+cu124-cp310-cp310-win_amd64.whl", hash = "sha256:f4a4fce7e0f98938682e2faae181d31553e5c05aa7892fc05fd2fb06673949f2"}, {file = "torchvision-0.20.0.dev20240929+cu124-cp311-cp311-linux_x86_64.whl", hash = "sha256:255f7f5142b22430fd0c50ac53659dec98826e69576c449a5483d39446bbc471"}, + {file = "torchvision-0.20.0.dev20240929+cu124-cp311-cp311-win_amd64.whl", hash = "sha256:3646651e57a25c4156f9cbc431ef112f88f29fac198167b3c3bbb21cfea43467"}, {file = "torchvision-0.20.0.dev20240929+cu124-cp312-cp312-linux_x86_64.whl", hash = "sha256:ef9c0c3201b8f383e9cdbdb9908d61b9a06e066e707015ffd4e03c69d46a660d"}, + {file = "torchvision-0.20.0.dev20240929+cu124-cp312-cp312-win_amd64.whl", hash = "sha256:31a1551aaf080e82205cd3a21583ab5a9f9a859f8fb2b45265ba741599ef97ef"}, {file = "torchvision-0.20.0.dev20240929+cu124-cp39-cp39-linux_x86_64.whl", hash = "sha256:f3e2e23952c4e2f472e4301d4b2c7554d76ce45dece4d2cc33a34f99be39e75a"}, + {file = "torchvision-0.20.0.dev20240929+cu124-cp39-cp39-win_amd64.whl", hash = "sha256:d892a3f2c3689806c72d9561c8be0495913c0422ca7739bb7a30eec1343e41e5"}, ] [package.dependencies] @@ -5010,4 +5014,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "7163fb1af3244066c1c66e19eee8a4ab7d9a24fe3089b928026ec353b35b5d81" +content-hash = "5b1dfcb226809610ee373bbb76b32f59906e883bdfc696fb835750fcf6da9909" diff --git a/pyproject.toml b/pyproject.toml index 1f707c05..5ec456a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ torchvision = { version = ">0.19", source = "pytorch-nightly" } diffusers = "^0.30.3" transformers = "^4.45.1" datasets = "^3.0.1" -bitsandbytes = { version = "^0.43.3" } +bitsandbytes = "^0.44.1" wandb = "^0.18.2" requests = "^2.32.3" pillow = "^10.4.0" From 4c6f2244c4d5992229318eb4d852684ace1a5bae Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 4 Oct 2024 09:21:10 -0600 Subject: [PATCH 04/14] update bnb params --- helpers/training/optimizer_param.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index 9c9c05a5..a106b86d 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -273,62 +273,48 @@ "bnb-adagrad": { "precision": "any", "default_settings": { - "lr": 0.01, "lr_decay": 0, "weight_decay": 0, "initial_accumulator_value": 0, "eps": 1e-10, - "optim_bits": 32, - "args": None, "min_8bit_size": 4096, "percentile_clipping": 100, - "block_wise": True, }, "class": bitsandbytes.optim.Adagrad, }, "bnb-adagrad8bit": { "precision": "any", "default_settings": { - "lr": 0.01, "lr_decay": 0, "weight_decay": 0, "initial_accumulator_value": 0, "eps": 1e-10, - "optim_bits": 32, - "args": None, "min_8bit_size": 4096, "percentile_clipping": 100, - "block_wise": True, }, "class": bitsandbytes.optim.Adagrad8bit, }, "bnb-adam": { "precision": "any", "default_settings": { - "lr": 0.001, "betas": (0.9, 0.999), "eps": 1e-08, "weight_decay": 0, "amsgrad": False, "min_8bit_size": 4096, "percentile_clipping": 100, - "block_wise": True, - "is_paged": False, }, "class": bitsandbytes.optim.Adam, }, "bnb-adam8bit": { "precision": "any", "default_settings": { - "lr": 0.001, "betas": (0.9, 0.999), "eps": 1e-08, "weight_decay": 0, "amsgrad": False, "min_8bit_size": 4096, "percentile_clipping": 100, - "block_wise": True, - "is_paged": False, }, "class": bitsandbytes.optim.Adam8bit, }, @@ -377,9 +363,7 @@ "t_beta3": None, "eps": 1e-08, "weight_decay": 0.01, - "optim_bits": 32, "min_8bit_size": 4096, - "is_paged": False, }, "class": bitsandbytes.optim.AdEMAMix, }, @@ -392,9 +376,7 @@ "t_beta3": None, "eps": 1e-08, "weight_decay": 0.01, - "optim_bits": 32, "min_8bit_size": 4096, - "is_paged": False, }, "class": bitsandbytes.optim.AdEMAMix8bit, }, @@ -407,7 +389,6 @@ "t_beta3": None, "eps": 1e-08, "weight_decay": 0.01, - "optim_bits": 32, "min_8bit_size": 4096, }, "class": bitsandbytes.optim.PagedAdEMAMix, @@ -421,7 +402,6 @@ "t_beta3": None, "eps": 1e-08, "weight_decay": 0.01, - "optim_bits": 32, "min_8bit_size": 4096, }, "class": bitsandbytes.optim.PagedAdEMAMix8bit, @@ -431,9 +411,7 @@ "default_settings": { "betas": (0.9, 0.99), "weight_decay": 0.0, - "optim_bits": 32, "min_8bit_size": 4096, - "is_paged": False, }, "class": bitsandbytes.optim.Lion, }, @@ -442,9 +420,7 @@ "default_settings": { "betas": (0.9, 0.99), "weight_decay": 0.0, - "optim_bits": 32, "min_8bit_size": 4096, - "is_paged": False, }, "class": bitsandbytes.optim.Lion8bit, }, @@ -453,7 +429,6 @@ "default_settings": { "betas": (0.9, 0.99), "weight_decay": 0.0, - "optim_bits": 32, "min_8bit_size": 4096, }, "class": bitsandbytes.optim.PagedLion, @@ -463,7 +438,6 @@ "default_settings": { "betas": (0.9, 0.99), "weight_decay": 0.0, - "optim_bits": 32, "min_8bit_size": 4096, }, "class": bitsandbytes.optim.PagedLion8bit, From 75e64004738b9e8fc82838e8c5b841f7282016a2 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 4 Oct 2024 16:42:43 -0600 Subject: [PATCH 05/14] nf4: initial commit --- helpers/data_backend/aws.py | 7 +++---- helpers/data_backend/local.py | 5 ++--- helpers/models/flux/__init__.py | 1 - helpers/training/diffusion_model.py | 8 ++++++++ helpers/training/trainer.py | 10 +++++++--- install/apple/poetry.lock | 20 ++++++++++++-------- install/apple/pyproject.toml | 2 +- install/rocm/poetry.lock | 20 ++++++++++++-------- install/rocm/pyproject.toml | 2 +- poetry.lock | 16 ++++++++++------ pyproject.toml | 2 +- 11 files changed, 57 insertions(+), 36 deletions(-) diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index 0185df7b..1e3d5dc2 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -45,7 +45,7 @@ class S3DataBackend(BaseDataBackend): # Storing the list_files output in a local dict. - _list_cache = {} + _list_cache: dict = {} def __init__( self, @@ -301,9 +301,8 @@ def torch_load(self, s3_key): try: stored_tensor = self._decompress_torch(stored_tensor) except Exception as e: - logger.error( - f"Failed to decompress torch file, falling back to passthrough: {e}" - ) + pass + if hasattr(stored_tensor, "seek"): stored_tensor.seek(0) diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index 18d0036d..2495f9bb 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -194,9 +194,8 @@ def torch_load(self, filename): try: stored_tensor = self._decompress_torch(stored_tensor) except Exception as e: - logger.error( - f"Failed to decompress torch file, falling back to passthrough: {e}" - ) + pass + if hasattr(stored_tensor, "seek"): stored_tensor.seek(0) try: diff --git a/helpers/models/flux/__init__.py b/helpers/models/flux/__init__.py index 271c738e..00dbdfab 100644 --- a/helpers/models/flux/__init__.py +++ b/helpers/models/flux/__init__.py @@ -113,7 +113,6 @@ def prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels, ) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index cec6406f..20e725e8 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -21,6 +21,14 @@ def load_diffusion_model(args, weight_dtype): unet = None transformer = None + if "bnb-nf4" == args.base_model_precision: + pretrained_load_args["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + if args.model_family == "sd3": # Stable Diffusion 3 uses a Diffusion transformer. logger.info("Loading Stable Diffusion 3 diffusion transformer..") diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 78d405f0..2b3d9a4d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1917,7 +1917,10 @@ def train(self): ) training_logger.debug(f"Working on batch size: {bsz}") if self.config.flow_matching: - if not self.config.flux_fast_schedule and not self.config.flux_use_beta_schedule: + if ( + not self.config.flux_fast_schedule + and not self.config.flux_use_beta_schedule + ): # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF # also used by: https://github.com/XLabs-AI/x-flux/tree/main # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f @@ -1936,7 +1939,9 @@ def train(self): beta_dist = Beta(alpha, beta) # Sample from the Beta distribution - sigmas = beta_dist.sample((bsz,)).to(device=self.accelerator.device) + sigmas = beta_dist.sample((bsz,)).to( + device=self.accelerator.device + ) sigmas = apply_flux_schedule_shift( self.config, self.noise_scheduler, sigmas, noise @@ -2204,7 +2209,6 @@ def train(self): ) text_ids = torch.zeros( - packed_noisy_latents.shape[0], batch["prompt_embeds"].shape[1], 3, ).to( diff --git a/install/apple/poetry.lock b/install/apple/poetry.lock index b3b9a16b..eaf95d6d 100644 --- a/install/apple/poetry.lock +++ b/install/apple/poetry.lock @@ -535,19 +535,17 @@ triton = ["triton (==2.1.0)"] [[package]] name = "diffusers" -version = "0.30.3" +version = "0.31.0.dev0" description = "State-of-the-art diffusion in PyTorch and JAX." optional = false python-versions = ">=3.8.0" -files = [ - {file = "diffusers-0.30.3-py3-none-any.whl", hash = "sha256:1b70209e4d2c61223b96a7e13bc4d70869c8b0b68f54a35ce3a67fcf813edeee"}, - {file = "diffusers-0.30.3.tar.gz", hash = "sha256:67c5eb25d5b50bf0742624ef43fe0f6d1e1604f64aad3e8558469cbe89ecf72f"}, -] +files = [] +develop = false [package.dependencies] filelock = "*" huggingface-hub = ">=0.23.2" -importlib-metadata = "*" +importlib_metadata = "*" numpy = "*" Pillow = "*" regex = "!=2019.12.17" @@ -555,7 +553,7 @@ requests = "*" safetensors = ">=0.3.1" [package.extras] -dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] +dev = ["GitPython (<3.1.19)", "Jinja2", "Jinja2", "accelerate (>=0.31.0)", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] docs = ["hf-doc-builder (>=0.3.0)"] flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] @@ -563,6 +561,12 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.31.0)", "torch (>=1.4)"] training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[package.source] +type = "git" +url = "https://github.com/huggingface/diffusers" +reference = "quantization-config" +resolved_reference = "9b9a6107a76239700d4fa7c365ac8b69a1792637" + [[package]] name = "dill" version = "0.3.7" @@ -4166,4 +4170,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "2f53e4d801bfcdb09650ba58aaa4edf50171f73e4aea46f83dc1503353734aba" +content-hash = "b8863ab1dc1441031333e849293f2b99bb3942c196fb51ba0b90d6b663c2e613" diff --git a/install/apple/pyproject.toml b/install/apple/pyproject.toml index eaa61e2e..cda8df64 100644 --- a/install/apple/pyproject.toml +++ b/install/apple/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.10,<3.13" torch = "^2.4.1" torchvision = "^0.19.0" -diffusers = "^0.30.3" +diffusers = {git = "https://github.com/huggingface/diffusers", rev = "quantization-config"} transformers = "^4.44.2" datasets = "^3.0.0" wandb = "^0.18.1" diff --git a/install/rocm/poetry.lock b/install/rocm/poetry.lock index 4c1cc2a4..16c05dce 100644 --- a/install/rocm/poetry.lock +++ b/install/rocm/poetry.lock @@ -539,19 +539,17 @@ triton = ["triton (==2.1.0)"] [[package]] name = "diffusers" -version = "0.30.3" +version = "0.31.0.dev0" description = "State-of-the-art diffusion in PyTorch and JAX." optional = false python-versions = ">=3.8.0" -files = [ - {file = "diffusers-0.30.3-py3-none-any.whl", hash = "sha256:1b70209e4d2c61223b96a7e13bc4d70869c8b0b68f54a35ce3a67fcf813edeee"}, - {file = "diffusers-0.30.3.tar.gz", hash = "sha256:67c5eb25d5b50bf0742624ef43fe0f6d1e1604f64aad3e8558469cbe89ecf72f"}, -] +files = [] +develop = false [package.dependencies] filelock = "*" huggingface-hub = ">=0.23.2" -importlib-metadata = "*" +importlib_metadata = "*" numpy = "*" Pillow = "*" regex = "!=2019.12.17" @@ -559,7 +557,7 @@ requests = "*" safetensors = ">=0.3.1" [package.extras] -dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] +dev = ["GitPython (<3.1.19)", "Jinja2", "Jinja2", "accelerate (>=0.31.0)", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] docs = ["hf-doc-builder (>=0.3.0)"] flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] @@ -567,6 +565,12 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.31.0)", "torch (>=1.4)"] training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[package.source] +type = "git" +url = "https://github.com/huggingface/diffusers" +reference = "quantization-config" +resolved_reference = "9b9a6107a76239700d4fa7c365ac8b69a1792637" + [[package]] name = "dill" version = "0.3.8" @@ -4132,4 +4136,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "7b3e351bb5248016430e209937848d963dac32d950f691463d74b0fd82687d84" +content-hash = "ae72422703750e4026b790c25fe8a505b79667bdc99a53d409c8ec2298ec2cc5" diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index 96cf16e2..52a33fd6 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -22,7 +22,7 @@ colorama = "^0.4.6" compel = "^2" datasets = "^3.0.0" deepspeed = "^0.15.1" -diffusers = "^0.30.3" +diffusers = {git = "https://github.com/huggingface/diffusers", rev = "quantization-config"} iterutils = "^0.1.6" numpy = "1.26" open-clip-torch = "^2.26.1" diff --git a/poetry.lock b/poetry.lock index de2cc6d4..70c5eeba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -663,14 +663,12 @@ triton = ["triton (==2.1.0)"] [[package]] name = "diffusers" -version = "0.30.3" +version = "0.31.0.dev0" description = "State-of-the-art diffusion in PyTorch and JAX." optional = false python-versions = ">=3.8.0" -files = [ - {file = "diffusers-0.30.3-py3-none-any.whl", hash = "sha256:1b70209e4d2c61223b96a7e13bc4d70869c8b0b68f54a35ce3a67fcf813edeee"}, - {file = "diffusers-0.30.3.tar.gz", hash = "sha256:67c5eb25d5b50bf0742624ef43fe0f6d1e1604f64aad3e8558469cbe89ecf72f"}, -] +files = [] +develop = false [package.dependencies] filelock = "*" @@ -691,6 +689,12 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.31.0)", "torch (>=1.4)"] training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[package.source] +type = "git" +url = "https://github.com/huggingface/diffusers" +reference = "quantization-config" +resolved_reference = "9b9a6107a76239700d4fa7c365ac8b69a1792637" + [[package]] name = "dill" version = "0.3.8" @@ -5014,4 +5018,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "5b1dfcb226809610ee373bbb76b32f59906e883bdfc696fb835750fcf6da9909" +content-hash = "eea1cd458cf26ac71571071e4fae1ed343ea6654f48f0dcae85ca266dd81b48b" diff --git a/pyproject.toml b/pyproject.toml index 5ec456a8..b6f3f042 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.10,<3.12" torch = { version = ">2.5.0", source = "pytorch-nightly" } torchvision = { version = ">0.19", source = "pytorch-nightly" } -diffusers = "^0.30.3" +diffusers = {git = "https://github.com/huggingface/diffusers", rev = "quantization-config"} transformers = "^4.45.1" datasets = "^3.0.1" bitsandbytes = "^0.44.1" From 63d386e3ca3cb6a1d4b085b862ec55ded6442b89 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:45:33 +0100 Subject: [PATCH 06/14] diffusers git compatibility --- helpers/models/flux/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/helpers/models/flux/__init__.py b/helpers/models/flux/__init__.py index 00dbdfab..2b4c16e9 100644 --- a/helpers/models/flux/__init__.py +++ b/helpers/models/flux/__init__.py @@ -113,8 +113,9 @@ def prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels, ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids.to(device=device, dtype=dtype)[0] From 092c387474962aba7c579dda141daca32502b1ed Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:45:43 +0100 Subject: [PATCH 07/14] add nf4 notes --- README.md | 7 ++++--- documentation/quickstart/FLUX.md | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a4ce5120..4a6aef63 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP - Most models are trainable on a 24G GPU, or even down to 16G at lower base resolutions. - LoRA/LyCORIS training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM - DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly. -- Quantised LoRA training, using low-precision base model or text encoder weights to reduce VRAM consumption while still allowing DreamBooth. +- Quantised NF4/INT8/FP8 LoRA training, using low-precision base model to reduce VRAM consumption. - Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA. - Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3) - For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite) @@ -105,7 +105,7 @@ RunwayML's SD 1.5 and StabilityAI's SD 2.x are both trainable under the `legacy` ### NVIDIA -Pretty much anything 3090 and up is a safe bet. YMMV. +Pretty much anything 3080 and up is a safe bet. YMMV. ### AMD @@ -124,7 +124,8 @@ LoRA and full-rank tuning are tested to work on an M3 Max with 128G memory, taki - A100-80G (Full tune with DeepSpeed) - A100-40G (LoRA, LoKr) - 3090 24G (LoRA, LoKr) -- 4060 Ti, 3080 16G (int8, LoRA, LoKr) +- 4060 Ti 16G, 4070 Ti 16G, 3080 16G (int8, LoRA, LoKr) +- 4070 Super 12G, 3080 10G, 3060 12GB (nf4, LoRA, LoKr) Flux prefers being trained with multiple large GPUs but a single 16G card should be able to do it with quantisation of the transformer and text encoders. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 4e4f61c9..12d41a09 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -402,7 +402,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing - It allows you to push higher batch sizes and possibly obtain a better result - Behaves the same as full-precision training - fp32 won't make your model any better than bf16+int8. - **int8** has hardware acceleration and `torch.compile()` support on newer NVIDIA hardware (3090 or better) -- **nf4** does not seem to benefit training as much as it benefits inference +- **nf4-bnb** brings VRAM requirements down to 9GB, fitting on a 10G card (with bfloat16 support) - When loading the LoRA in ComfyUI later, you **must** use the same base model precision as you trained your LoRA on. - **int4** is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels From 9bb68f903618abd3aac6b8fea6dc4943d0c5d349 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:46:34 +0100 Subject: [PATCH 08/14] add nf4 to list of quantisation levels --- helpers/training/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/helpers/training/__init__.py b/helpers/training/__init__.py index d5386028..0127f470 100644 --- a/helpers/training/__init__.py +++ b/helpers/training/__init__.py @@ -1,5 +1,6 @@ quantised_precision_levels = [ "no_change", + "nf4-bnb", # "fp4-bnb", # "fp8-bnb", "fp8-quanto", From db31a9c6dffd16549595162f4b9579af74307f5d Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:46:55 +0100 Subject: [PATCH 09/14] nf4 loading for base model --- helpers/training/diffusion_model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 20e725e8..ad707599 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -17,16 +17,19 @@ def load_diffusion_model(args, weight_dtype): pretrained_load_args = { "revision": args.revision, "variant": args.variant, + "torch_dtype": weight_dtype, } unet = None transformer = None - if "bnb-nf4" == args.base_model_precision: + if "nf4-bnb" == args.base_model_precision: + import torch + from diffusers import BitsAndBytesConfig pretrained_load_args["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_compute_dtype=weight_dtype, ) if args.model_family == "sd3": @@ -53,7 +56,6 @@ def load_diffusion_model(args, weight_dtype): args.pretrained_transformer_model_name_or_path or args.pretrained_model_name_or_path, subfolder=determine_subfolder(args.pretrained_transformer_subfolder), - torch_dtype=weight_dtype, **pretrained_load_args, ) elif args.model_family.lower() == "flux" and args.flux_attention_masked_training: @@ -64,7 +66,6 @@ def load_diffusion_model(args, weight_dtype): transformer = FluxTransformer2DModelWithMasking.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", - torch_dtype=weight_dtype, **pretrained_load_args, ) elif args.model_family == "pixart_sigma": @@ -74,7 +75,6 @@ def load_diffusion_model(args, weight_dtype): args.pretrained_transformer_model_name_or_path or args.pretrained_model_name_or_path, subfolder=determine_subfolder(args.pretrained_transformer_subfolder), - torch_dtype=weight_dtype, **pretrained_load_args, ) elif args.model_family == "smoldit": @@ -108,7 +108,6 @@ def load_diffusion_model(args, weight_dtype): args.pretrained_unet_model_name_or_path or args.pretrained_model_name_or_path, subfolder=determine_subfolder(args.pretrained_unet_subfolder), - torch_dtype=weight_dtype, **pretrained_load_args, ) From 5543065a07716988aad578b84a570eb3176ea3aa Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:47:18 +0100 Subject: [PATCH 10/14] lycoris: fix casting weights to transformer dtype --- helpers/training/save_hooks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index a847faf6..66330e77 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -432,13 +432,14 @@ def _load_lycoris(self, models, input_dir): if len(state.keys()) > 0: logging.error(f"LyCORIS failed to load: {state}") raise RuntimeError("Loading of LyCORIS model failed") + weight_dtype = StateTracker.get_weight_dtype() if self.transformer is not None: self.accelerator._lycoris_wrapped_network.to( - device=self.accelerator.device, dtype=self.transformer.dtype + device=self.accelerator.device, dtype=weight_dtype ) elif self.unet is not None: self.accelerator._lycoris_wrapped_network.to( - device=self.accelerator.device, dtype=self.unet.dtype + device=self.accelerator.device, dtype=weight_dtype ) else: raise ValueError("No model found to load LyCORIS weights into.") From 1f8d9575f09950688416ab5193d2e95a21b9b7ab Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:47:50 +0100 Subject: [PATCH 11/14] bnb nf4 changes for trainer class --- helpers/training/trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 2b3d9a4d..ba6017cc 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -733,9 +733,21 @@ def init_precision(self): self.config.base_weight_dtype = self.config.weight_dtype self.config.is_quanto = False self.config.is_torchao = False + self.config.is_bnb = False + if "quanto" in self.config.base_model_precision: + self.config.is_quanto = True + elif "torchao" in self.config.base_model_precision: + self.config.is_torchao = True + elif "bnb" in self.config.base_model_precision: + self.config.is_bnb = True quantization_device = ( "cpu" if self.config.quantize_via == "cpu" else self.accelerator.device ) + + if 'bnb' in self.config.base_model_precision: + # can't cast or move bitsandbytes models + return + if not self.config.disable_accelerator and self.config.is_quantized: if self.config.base_model_default_dtype == "fp32": self.config.base_weight_dtype = torch.float32 @@ -755,10 +767,6 @@ def init_precision(self): self.transformer.to( quantization_device, dtype=self.config.base_weight_dtype ) - if "quanto" in self.config.base_model_precision: - self.config.is_quanto = True - elif "torchao" in self.config.base_model_precision: - self.config.is_torchao = True if self.config.is_quanto: from helpers.training.quantisation import quantise_model From 3cf9d38eb8c07df384c92aed695cab16ee5f32b9 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 00:52:03 +0100 Subject: [PATCH 12/14] use adamw4bit instead of adam4bit --- helpers/training/optimizer_param.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/training/optimizer_param.py b/helpers/training/optimizer_param.py index a106b86d..bd364658 100644 --- a/helpers/training/optimizer_param.py +++ b/helpers/training/optimizer_param.py @@ -21,7 +21,7 @@ try: from torchao.prototype.low_bit_optim import ( AdamW8bit as AOAdamW8Bit, - Adam4bit as AOAdamW4Bit, + AdamW4bit as AOAdamW4Bit, AdamFp8 as AOAdamFp8, AdamWFp8 as AOAdamWFp8, CPUOffloadOptimizer as AOCPUOffloadOptimizer, From e0deb56a5b5f9f6a5a5b56b5af2a2ed84ae6bf28 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 01:16:09 +0100 Subject: [PATCH 13/14] nf4: fix pipeline --- helpers/models/flux/pipeline.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index 7c74ed9f..e40e7703 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -25,7 +25,7 @@ ) from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import SD3LoraLoaderMixin +from diffusers.loaders import FluxLoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler @@ -147,7 +147,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. @@ -361,7 +361,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -395,12 +395,12 @@ def encode_prompt( ) if self.text_encoder is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) @@ -824,16 +824,16 @@ def __call__( noise_pred = self.transformer( hidden_states=latents.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), encoder_hidden_states=prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), txt_ids=text_ids, img_ids=latent_image_ids, @@ -846,16 +846,16 @@ def __call__( if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep: noise_pred_uncond = self.transformer( hidden_states=latents.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=negative_pooled_prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), encoder_hidden_states=negative_prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4 ), txt_ids=negative_text_ids.to(device=self.transformer.device), img_ids=latent_image_ids.to(device=self.transformer.device), From 7e2d97c1fafa8174300e21e8ecd1bba0939e09db Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 5 Oct 2024 01:31:49 +0100 Subject: [PATCH 14/14] flux: nf4 pipeline fixes for validations, guidance for low vram config in quickstart --- documentation/quickstart/FLUX.md | 17 +++++++++++++++++ helpers/models/flux/pipeline.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 12d41a09..52df5ad9 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -370,6 +370,23 @@ Inferencing the CFG-distilled LoRA is as easy as using a lower guidance_scale ar ## Notes & troubleshooting tips +### Lowest VRAM config + +Currently, the lowest VRAM utilisation (9090M) can be attained with: + +- OS: Ubuntu Linux 24 +- GPU: A single NVIDIA CUDA device (10G, 12G) +- System memory: 50G of system memory approximately +- Base model precision: `bnb-nf4` +- Optimiser: Lion 8Bit Paged, `bnb-lion8bit-paged` +- Resolution: 512px + - 1024px requires >= 12G VRAM +- Batch size: 1, zero gradient accumulation steps +- DeepSpeed: disabled / unconfigured +- PyTorch: 2.6 Nightly (Sept 29th build) + +Speed was approximately 1.4 iterations per second on a 4090. + ### Classifier-free guidance #### Problem diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index e40e7703..885bde7a 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -794,9 +794,9 @@ def __call__( self._num_timesteps = len(timesteps) latents = latents.to(self.transformer.device) - latent_image_ids = latent_image_ids.to(self.transformer.device) + latent_image_ids = latent_image_ids.to(self.transformer.device)[0] timesteps = timesteps.to(self.transformer.device) - text_ids = text_ids.to(self.transformer.device) + text_ids = text_ids.to(self.transformer.device)[0] # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: