diff --git a/sdxl_train.py b/sdxl_train.py
index 3b28575ed..c7eea2224 100644
--- a/sdxl_train.py
+++ b/sdxl_train.py
@@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
 
     # calculate number of trainable parameters
     n_params = 0
-    for params in params_to_optimize:
-        for p in params["params"]:
+    for group in params_to_optimize:
+        for p in group["params"]:
             n_params += p.numel()
 
     accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
@@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
 
     # 学習に必要なクラスを準備する
     accelerator.print("prepare optimizer, data loader etc.")
-    _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
+
+    if args.fused_optimizer_groups:
+        # calculate total number of parameters
+        n_total_params = sum(len(params["params"]) for params in params_to_optimize)
+        params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
+
+        # split params into groups
+        grouped_params = []
+        param_group = []
+        param_group_lr = -1
+        for group in params_to_optimize:
+            lr = group["lr"]
+            for p in group["params"]:
+                if lr != param_group_lr:
+                    if param_group:
+                        grouped_params.append({"params": param_group, "lr": param_group_lr})
+                        param_group = []
+                    param_group_lr = lr
+                param_group.append(p)
+                if len(param_group) == params_per_group:
+                    grouped_params.append({"params": param_group, "lr": param_group_lr})
+                    param_group = []
+                    param_group_lr = -1
+        if param_group:
+            grouped_params.append({"params": param_group, "lr": param_group_lr})
+
+        # prepare optimizers for each group
+        optimizers = []
+        for group in grouped_params:
+            _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
+            optimizers.append(optimizer)
+        optimizer = optimizers[0]  # avoid error in the following code
+
+        print(len(grouped_params))
+        logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
+
+    else:
+        _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
 
     # dataloaderを準備する
     # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
@@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
     train_dataset_group.set_max_train_steps(args.max_train_steps)
 
     # lr schedulerを用意する
-    lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+    if args.fused_optimizer_groups:
+        lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
+        lr_scheduler = lr_schedulers[0]  # avoid error in the following code
+    else:
+        lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
 
     # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
     if args.full_fp16:
@@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
 
     if args.fused_backward_pass:
         import library.adafactor_fused
+
         library.adafactor_fused.patch_adafactor_fused(optimizer)
         for param_group in optimizer.param_groups:
             for parameter in param_group["params"]:
                 if parameter.requires_grad:
+
                     def __grad_hook(tensor: torch.Tensor, param_group=param_group):
                         if accelerator.sync_gradients and args.max_grad_norm != 0.0:
                             accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
@@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
 
                     parameter.register_post_accumulate_grad_hook(__grad_hook)
 
+    elif args.fused_optimizer_groups:
+        for i in range(1, len(optimizers)):
+            optimizers[i] = accelerator.prepare(optimizers[i])
+            lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
+
+        global optimizer_hooked_count
+        global num_parameters_per_group
+        global parameter_optimizer_map
+        optimizer_hooked_count = {}
+        num_parameters_per_group = [0] * len(optimizers)
+        parameter_optimizer_map = {}
+        for opt_idx, optimizer in enumerate(optimizers):
+            for param_group in optimizer.param_groups:
+                for parameter in param_group["params"]:
+                    if parameter.requires_grad:
+
+                        def optimizer_hook(parameter: torch.Tensor):
+                            if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+                                accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
+
+                            i = parameter_optimizer_map[parameter]
+                            optimizer_hooked_count[i] += 1
+                            if optimizer_hooked_count[i] == num_parameters_per_group[i]:
+                                optimizers[i].step()
+                                optimizers[i].zero_grad()
+
+                        parameter.register_post_accumulate_grad_hook(optimizer_hook)
+                        parameter_optimizer_map[parameter] = opt_idx
+                        num_parameters_per_group[opt_idx] += 1
+
     # TextEncoderの出力をキャッシュするときにはCPUへ移動する
     if args.cache_text_encoder_outputs:
         # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
 
         for step, batch in enumerate(train_dataloader):
             current_step.value = global_step
+
+            if args.fused_optimizer_groups:
+                optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
+
             with accelerator.accumulate(*training_models):
                 if "latents" in batch and batch["latents"] is not None:
                     latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
@@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
 
                 # Sample noise, sample a random timestep for each image, and add noise to the latents,
                 # with noise offset and/or multires noise if specified
-                noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
+                noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
+                    args, noise_scheduler, latents
+                )
 
                 noisy_latents = noisy_latents.to(weight_dtype)  # TODO check why noisy_latents is not weight_dtype
 
@@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
                     or args.masked_loss
                 ):
                     # do not mean over batch dimension for snr weight or scale v-pred loss
-                    loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
+                    loss = train_util.conditional_loss(
+                        noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+                    )
                     if args.masked_loss:
                         loss = apply_masked_loss(loss, batch)
                     loss = loss.mean([1, 2, 3])
@@ -630,11 +711,13 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
 
                     loss = loss.mean()  # mean over batch dimension
                 else:
-                    loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
+                    loss = train_util.conditional_loss(
+                        noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+                    )
 
                 accelerator.backward(loss)
 
-                if not args.fused_backward_pass:
+                if not (args.fused_backward_pass or args.fused_optimizer_groups):
                     if accelerator.sync_gradients and args.max_grad_norm != 0.0:
                         params_to_clip = []
                         for m in training_models:
@@ -642,9 +725,14 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
                         accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
 
                     optimizer.step()
+                elif args.fused_optimizer_groups:
+                    for i in range(1, len(optimizers)):
+                        lr_schedulers[i].step()
 
                 lr_scheduler.step()
-                optimizer.zero_grad(set_to_none=True)
+
+                if not (args.fused_backward_pass or args.fused_optimizer_groups):
+                    optimizer.zero_grad(set_to_none=True)
 
             # Checks if the accelerator has performed an optimization step behind the scenes
             if accelerator.sync_gradients:
@@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
 
     accelerator.end_training()
 
-    if args.save_state or args.save_state_on_train_end:        
+    if args.save_state or args.save_state_on_train_end:
         train_util.save_state_on_train_end(args, accelerator)
 
     del accelerator  # この後メモリを使うのでこれは消す
@@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser:
         help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
         + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
     )
+    parser.add_argument(
+        "--fused_optimizer_groups",
+        type=int,
+        default=None,
+        help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
+    )
     return parser