Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bits n Bytes NF4 training #1028

Merged
merged 10 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP
- Most models are trainable on a 24G GPU, or even down to 16G at lower base resolutions.
- LoRA/LyCORIS training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM
- DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly.
- Quantised LoRA training, using low-precision base model or text encoder weights to reduce VRAM consumption while still allowing DreamBooth.
- Quantised NF4/INT8/FP8 LoRA training, using low-precision base model to reduce VRAM consumption.
- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA.
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
Expand Down Expand Up @@ -105,7 +105,7 @@ RunwayML's SD 1.5 and StabilityAI's SD 2.x are both trainable under the `legacy`

### NVIDIA

Pretty much anything 3090 and up is a safe bet. YMMV.
Pretty much anything 3080 and up is a safe bet. YMMV.

### AMD

Expand All @@ -124,7 +124,8 @@ LoRA and full-rank tuning are tested to work on an M3 Max with 128G memory, taki
- A100-80G (Full tune with DeepSpeed)
- A100-40G (LoRA, LoKr)
- 3090 24G (LoRA, LoKr)
- 4060 Ti, 3080 16G (int8, LoRA, LoKr)
- 4060 Ti 16G, 4070 Ti 16G, 3080 16G (int8, LoRA, LoKr)
- 4070 Super 12G, 3080 10G, 3060 12GB (nf4, LoRA, LoKr)

Flux prefers being trained with multiple large GPUs but a single 16G card should be able to do it with quantisation of the transformer and text encoders.

Expand Down
19 changes: 18 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,23 @@ Inferencing the CFG-distilled LoRA is as easy as using a lower guidance_scale ar

## Notes & troubleshooting tips

### Lowest VRAM config

Currently, the lowest VRAM utilisation (9090M) can be attained with:

- OS: Ubuntu Linux 24
- GPU: A single NVIDIA CUDA device (10G, 12G)
- System memory: 50G of system memory approximately
- Base model precision: `bnb-nf4`
- Optimiser: Lion 8Bit Paged, `bnb-lion8bit-paged`
- Resolution: 512px
- 1024px requires >= 12G VRAM
- Batch size: 1, zero gradient accumulation steps
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.6 Nightly (Sept 29th build)

Speed was approximately 1.4 iterations per second on a 4090.

### Classifier-free guidance

#### Problem
Expand Down Expand Up @@ -402,7 +419,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing
- It allows you to push higher batch sizes and possibly obtain a better result
- Behaves the same as full-precision training - fp32 won't make your model any better than bf16+int8.
- **int8** has hardware acceleration and `torch.compile()` support on newer NVIDIA hardware (3090 or better)
- **nf4** does not seem to benefit training as much as it benefits inference
- **nf4-bnb** brings VRAM requirements down to 9GB, fitting on a 10G card (with bfloat16 support)
- When loading the LoRA in ComfyUI later, you **must** use the same base model precision as you trained your LoRA on.
- **int4** is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels

Expand Down
7 changes: 3 additions & 4 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

class S3DataBackend(BaseDataBackend):
# Storing the list_files output in a local dict.
_list_cache = {}
_list_cache: dict = {}

def __init__(
self,
Expand Down Expand Up @@ -301,9 +301,8 @@ def torch_load(self, s3_key):
try:
stored_tensor = self._decompress_torch(stored_tensor)
except Exception as e:
logger.error(
f"Failed to decompress torch file, falling back to passthrough: {e}"
)
pass

if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)

Expand Down
5 changes: 2 additions & 3 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ def torch_load(self, filename):
try:
stored_tensor = self._decompress_torch(stored_tensor)
except Exception as e:
logger.error(
f"Failed to decompress torch file, falling back to passthrough: {e}"
)
pass

if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)
try:
Expand Down
2 changes: 1 addition & 1 deletion helpers/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_id_channels,
)

return latent_image_ids.to(device=device, dtype=dtype)
return latent_image_ids.to(device=device, dtype=dtype)[0]
26 changes: 13 additions & 13 deletions helpers/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import SD3LoraLoaderMixin
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -147,7 +147,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.

Expand Down Expand Up @@ -361,7 +361,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
Expand Down Expand Up @@ -395,12 +395,12 @@ def encode_prompt(
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)

Expand Down Expand Up @@ -794,9 +794,9 @@ def __call__(
self._num_timesteps = len(timesteps)

latents = latents.to(self.transformer.device)
latent_image_ids = latent_image_ids.to(self.transformer.device)
latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
timesteps = timesteps.to(self.transformer.device)
text_ids = text_ids.to(self.transformer.device)
text_ids = text_ids.to(self.transformer.device)[0]

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -824,16 +824,16 @@ def __call__(

noise_pred = self.transformer(
hidden_states=latents.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=text_ids,
img_ids=latent_image_ids,
Expand All @@ -846,16 +846,16 @@ def __call__(
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
noise_pred_uncond = self.transformer(
hidden_states=latents.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=negative_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=negative_text_ids.to(device=self.transformer.device),
img_ids=latent_image_ids.to(device=self.transformer.device),
Expand Down
1 change: 1 addition & 0 deletions helpers/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
quantised_precision_levels = [
"no_change",
"nf4-bnb",
# "fp4-bnb",
# "fp8-bnb",
"fp8-quanto",
Expand Down
15 changes: 11 additions & 4 deletions helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ def load_diffusion_model(args, weight_dtype):
pretrained_load_args = {
"revision": args.revision,
"variant": args.variant,
"torch_dtype": weight_dtype,
}
unet = None
transformer = None

if "nf4-bnb" == args.base_model_precision:
import torch
from diffusers import BitsAndBytesConfig
pretrained_load_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=weight_dtype,
)

if args.model_family == "sd3":
# Stable Diffusion 3 uses a Diffusion transformer.
logger.info("Loading Stable Diffusion 3 diffusion transformer..")
Expand All @@ -45,7 +56,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_transformer_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family.lower() == "flux" and args.flux_attention_masked_training:
Expand All @@ -56,7 +66,6 @@ def load_diffusion_model(args, weight_dtype):
transformer = FluxTransformer2DModelWithMasking.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family == "pixart_sigma":
Expand All @@ -66,7 +75,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_transformer_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family == "smoldit":
Expand Down Expand Up @@ -100,7 +108,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_unet_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_unet_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)

Expand Down
2 changes: 1 addition & 1 deletion helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
try:
from torchao.prototype.low_bit_optim import (
AdamW8bit as AOAdamW8Bit,
Adam4bit as AOAdamW4Bit,
AdamW4bit as AOAdamW4Bit,
AdamFp8 as AOAdamFp8,
AdamWFp8 as AOAdamWFp8,
CPUOffloadOptimizer as AOCPUOffloadOptimizer,
Expand Down
5 changes: 3 additions & 2 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,14 @@ def _load_lycoris(self, models, input_dir):
if len(state.keys()) > 0:
logging.error(f"LyCORIS failed to load: {state}")
raise RuntimeError("Loading of LyCORIS model failed")
weight_dtype = StateTracker.get_weight_dtype()
if self.transformer is not None:
self.accelerator._lycoris_wrapped_network.to(
device=self.accelerator.device, dtype=self.transformer.dtype
device=self.accelerator.device, dtype=weight_dtype
)
elif self.unet is not None:
self.accelerator._lycoris_wrapped_network.to(
device=self.accelerator.device, dtype=self.unet.dtype
device=self.accelerator.device, dtype=weight_dtype
)
else:
raise ValueError("No model found to load LyCORIS weights into.")
Expand Down
26 changes: 19 additions & 7 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,21 @@ def init_precision(self):
self.config.base_weight_dtype = self.config.weight_dtype
self.config.is_quanto = False
self.config.is_torchao = False
self.config.is_bnb = False
if "quanto" in self.config.base_model_precision:
self.config.is_quanto = True
elif "torchao" in self.config.base_model_precision:
self.config.is_torchao = True
elif "bnb" in self.config.base_model_precision:
self.config.is_bnb = True
quantization_device = (
"cpu" if self.config.quantize_via == "cpu" else self.accelerator.device
)

if 'bnb' in self.config.base_model_precision:
# can't cast or move bitsandbytes models
return

if not self.config.disable_accelerator and self.config.is_quantized:
if self.config.base_model_default_dtype == "fp32":
self.config.base_weight_dtype = torch.float32
Expand All @@ -755,10 +767,6 @@ def init_precision(self):
self.transformer.to(
quantization_device, dtype=self.config.base_weight_dtype
)
if "quanto" in self.config.base_model_precision:
self.config.is_quanto = True
elif "torchao" in self.config.base_model_precision:
self.config.is_torchao = True

if self.config.is_quanto:
from helpers.training.quantisation import quantise_model
Expand Down Expand Up @@ -1917,7 +1925,10 @@ def train(self):
)
training_logger.debug(f"Working on batch size: {bsz}")
if self.config.flow_matching:
if not self.config.flux_fast_schedule and not self.config.flux_use_beta_schedule:
if (
not self.config.flux_fast_schedule
and not self.config.flux_use_beta_schedule
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
# and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f
Expand All @@ -1936,7 +1947,9 @@ def train(self):
beta_dist = Beta(alpha, beta)

# Sample from the Beta distribution
sigmas = beta_dist.sample((bsz,)).to(device=self.accelerator.device)
sigmas = beta_dist.sample((bsz,)).to(
device=self.accelerator.device
)

sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
Expand Down Expand Up @@ -2204,7 +2217,6 @@ def train(self):
)

text_ids = torch.zeros(
packed_noisy_latents.shape[0],
batch["prompt_embeds"].shape[1],
3,
).to(
Expand Down
20 changes: 12 additions & 8 deletions install/apple/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading