|  | 
|  | 1 | +# Copyright (c) 2025 Intel Corporation | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#    http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import json | 
|  | 16 | +import os | 
|  | 17 | +import sys | 
|  | 18 | +import argparse | 
|  | 19 | + | 
|  | 20 | +import pandas as pd | 
|  | 21 | +import tabulate | 
|  | 22 | +import torch | 
|  | 23 | + | 
|  | 24 | +from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel | 
|  | 25 | +from functools import partial | 
|  | 26 | +from neural_compressor.torch.quantization import ( | 
|  | 27 | +    AutoRoundConfig, | 
|  | 28 | +    convert, | 
|  | 29 | +    prepare, | 
|  | 30 | +) | 
|  | 31 | +from auto_round.data_type.mxfp import quant_mx_rceil | 
|  | 32 | +from auto_round.data_type.fp8 import quant_fp8_sym | 
|  | 33 | +from auto_round.utils import get_block_names, get_module | 
|  | 34 | +from auto_round.compressors.diffusion.eval import metric_map | 
|  | 35 | +from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader | 
|  | 36 | + | 
|  | 37 | + | 
|  | 38 | +parser = argparse.ArgumentParser( | 
|  | 39 | +    description="Flux quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter | 
|  | 40 | +) | 
|  | 41 | +parser.add_argument("--model", "--model_name", "--model_name_or_path", help="model name or path") | 
|  | 42 | +parser.add_argument('--scheme', default="MXFP8", type=str, help="quantizaion scheme.") | 
|  | 43 | +parser.add_argument("--quantize", action="store_true") | 
|  | 44 | +parser.add_argument("--inference", action="store_true") | 
|  | 45 | +parser.add_argument("--accuracy", action="store_true") | 
|  | 46 | +parser.add_argument("--dataset", type=str, default="coco2014", help="the dataset for quantization training.") | 
|  | 47 | +parser.add_argument("--output_dir", "--quantized_model_path", default="./tmp_autoround", type=str, help="the directory to save quantized model") | 
|  | 48 | +parser.add_argument("--eval_dataset", default="captions_source.tsv", type=str, help="eval datasets") | 
|  | 49 | +parser.add_argument("--output_image_path", default="./tmp_imgs", type=str, help="the directory to save quantized model") | 
|  | 50 | +parser.add_argument("--iters", "--iter", default=1000, type=int, help="tuning iters") | 
|  | 51 | +parser.add_argument("--limit", default=-1, type=int, help="limit the number of prompts for evaluation") | 
|  | 52 | + | 
|  | 53 | +args = parser.parse_args() | 
|  | 54 | + | 
|  | 55 | + | 
|  | 56 | +def inference_worker(eval_file, pipe, image_save_dir): | 
|  | 57 | +    gen_kwargs = { | 
|  | 58 | +        "guidance_scale": 7.5, | 
|  | 59 | +        "num_inference_steps": 50, | 
|  | 60 | +        "generator": None, | 
|  | 61 | +    } | 
|  | 62 | +  | 
|  | 63 | +    dataloader, _, _ = get_diffusion_dataloader(eval_file, nsamples=args.limit, bs=1) | 
|  | 64 | +    for image_ids, prompts in dataloader: | 
|  | 65 | + | 
|  | 66 | +        new_ids = [] | 
|  | 67 | +        new_prompts = [] | 
|  | 68 | +        for idx, image_id in enumerate(image_ids): | 
|  | 69 | +            image_id = image_id.item() | 
|  | 70 | + | 
|  | 71 | +            if os.path.exists(os.path.join(image_save_dir, str(image_id) + ".png")): | 
|  | 72 | +                continue | 
|  | 73 | +            new_ids.append(image_id) | 
|  | 74 | +            new_prompts.append(prompts[idx]) | 
|  | 75 | + | 
|  | 76 | +        if len(new_prompts) == 0: | 
|  | 77 | +            continue | 
|  | 78 | + | 
|  | 79 | +        output = pipe(prompt=new_prompts, **gen_kwargs) | 
|  | 80 | +        for idx, image_id in enumerate(new_ids): | 
|  | 81 | +            output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png")) | 
|  | 82 | + | 
|  | 83 | + | 
|  | 84 | +def tune(): | 
|  | 85 | +    pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16) | 
|  | 86 | +    model = pipe.transformer | 
|  | 87 | +    layer_config = {} | 
|  | 88 | +    kwargs = {} | 
|  | 89 | +    if args.scheme == "FP8": | 
|  | 90 | +        for n, m in model.named_modules(): | 
|  | 91 | +            if m.__class__.__name__ == "Linear": | 
|  | 92 | +                layer_config[n] = {"bits": 8, "data_type": "fp", "group_size": 0} | 
|  | 93 | +    elif args.scheme == "MXFP8": | 
|  | 94 | +        kwargs["scheme"] = { | 
|  | 95 | +            "bits": 8, | 
|  | 96 | +            "group_size": 32, | 
|  | 97 | +            "data_type": "mx_fp", | 
|  | 98 | +        } | 
|  | 99 | + | 
|  | 100 | +    qconfig = AutoRoundConfig( | 
|  | 101 | +        iters=args.iters, | 
|  | 102 | +        dataset=args.dataset, | 
|  | 103 | +        layer_config=layer_config, | 
|  | 104 | +        num_inference_steps=3, | 
|  | 105 | +        export_format="fake", | 
|  | 106 | +        nsamples=128, | 
|  | 107 | +        batch_size=1, | 
|  | 108 | +        output_dir=args.output_dir, | 
|  | 109 | +        **kwargs | 
|  | 110 | +    ) | 
|  | 111 | +    model = prepare(model, qconfig) | 
|  | 112 | +    model = convert(model, qconfig, pipeline=pipe) | 
|  | 113 | + | 
|  | 114 | +if __name__ == '__main__': | 
|  | 115 | +    device = "cpu" if torch.cuda.device_count() == 0 else "cuda" | 
|  | 116 | + | 
|  | 117 | +    if args.quantize: | 
|  | 118 | +        print(f"Start to quantize {args.model}.") | 
|  | 119 | +        tune() | 
|  | 120 | +        exit(0) | 
|  | 121 | + | 
|  | 122 | +    if args.inference: | 
|  | 123 | +        pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16) | 
|  | 124 | + | 
|  | 125 | +        if not os.path.exists(args.output_image_path): | 
|  | 126 | +            os.makedirs(args.output_image_path) | 
|  | 127 | + | 
|  | 128 | +        if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "diffusion_pytorch_model.safetensors.index.json")): | 
|  | 129 | +            print(f"Loading quantized model from {args.output_dir}") | 
|  | 130 | +            model = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=torch.bfloat16) | 
|  | 131 | + | 
|  | 132 | +            # replace Linear's forward function | 
|  | 133 | +            if args.scheme == "MXFP8": | 
|  | 134 | +                def act_qdq_forward(module, x, *args, **kwargs): | 
|  | 135 | +                    qdq_x, _, _ = quant_mx_rceil(x, bits=8, group_size=32, data_type="mx_fp_rceil") | 
|  | 136 | +                    return module.orig_forward(qdq_x, *args, **kwargs) | 
|  | 137 | + | 
|  | 138 | +                all_quant_blocks = get_block_names(model) | 
|  | 139 | + | 
|  | 140 | +                for block_names in all_quant_blocks: | 
|  | 141 | +                    for block_name in block_names: | 
|  | 142 | +                        block = get_module(model, block_name) | 
|  | 143 | +                        for n, m in block.named_modules(): | 
|  | 144 | +                            if m.__class__.__name__ == "Linear": | 
|  | 145 | +                                m.orig_forward = m.forward | 
|  | 146 | +                                m.forward = partial(act_qdq_forward, m) | 
|  | 147 | + | 
|  | 148 | +            if args.scheme == "FP8": | 
|  | 149 | +                def act_qdq_forward(module, x, *args, **kwargs): | 
|  | 150 | +                    qdq_x, _, _ = quant_fp8_sym(x, group_size=0) | 
|  | 151 | +                    return module.orig_forward(qdq_x, *args, **kwargs) | 
|  | 152 | + | 
|  | 153 | +                for n, m in model.named_modules(): | 
|  | 154 | +                    if m.__class__.__name__ == "Linear": | 
|  | 155 | +                        m.orig_forward = m.forward | 
|  | 156 | +                        m.forward = partial(act_qdq_forward, m) | 
|  | 157 | + | 
|  | 158 | +            pipe.transformer = model | 
|  | 159 | + | 
|  | 160 | +        else: | 
|  | 161 | +            print("Don't supply quantized_model_path or quantized model doesn't exist, evaluate BF16 accuracy.") | 
|  | 162 | + | 
|  | 163 | +        inference_worker(args.eval_dataset, pipe.to(device), args.output_image_path) | 
|  | 164 | + | 
|  | 165 | +    if args.accuracy: | 
|  | 166 | +        df = pd.read_csv(args.eval_dataset, sep="\t") | 
|  | 167 | +        prompt_list = [] | 
|  | 168 | +        image_list = [] | 
|  | 169 | +        for index, row in df.iterrows(): | 
|  | 170 | +            assert "id" in row and "caption" in row | 
|  | 171 | +            caption_id = row["id"] | 
|  | 172 | +            caption_text = row["caption"] | 
|  | 173 | +            if os.path.exists(os.path.join(args.output_image_path, str(caption_id) + ".png")): | 
|  | 174 | +                prompt_list.append(caption_text) | 
|  | 175 | +                image_list.append(os.path.join(args.output_image_path, str(caption_id) + ".png")) | 
|  | 176 | + | 
|  | 177 | +        result = {} | 
|  | 178 | +        metrics = ["clip", "clip-iqa", "imagereward"] | 
|  | 179 | +        for metric in metrics: | 
|  | 180 | +            result.update(metric_map[metric](prompt_list, image_list, device)) | 
|  | 181 | + | 
|  | 182 | +        print(tabulate.tabulate(result.items(), tablefmt="grid")) | 
0 commit comments