Skip to content

Commit a0f0504

Browse files
authored
Add flux example (#2311)
Signed-off-by: Mengni Wang <[email protected]>
1 parent 4151ffe commit a0f0504

File tree

10 files changed

+433
-1
lines changed

10 files changed

+433
-1
lines changed

examples/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
1515
</tr>
1616
</thead>
1717
<tbody>
18+
<tr>
19+
<td>FLUX.1-dev</td>
20+
<td>Text to Image</td>
21+
<td>Quantization (MXFP8+FP8)</td>
22+
<td><a href="./pytorch/diffusion_model/diffusers/flux">link</a></td>
23+
</tr>
1824
<tr>
1925
<td>Llama-4-Scout-17B-16E-Instruct</td>
2026
<td>Multimodal Modeling</td>
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Step-by-Step
2+
3+
This example quantizes and validates the accuracy of Flux.
4+
5+
# Prerequisite
6+
7+
## 1. Environment
8+
9+
```shell
10+
pip install -r requirements.txt
11+
# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/[email protected]` for the latest updates before neural-compressor v3.6 release
12+
pip install neural-compressor-pt==3.6
13+
# Use `pip install git+https://github.com/intel/[email protected]` for the latest updates before auto-round v0.8.0 release
14+
pip install auto-round==0.8.0
15+
```
16+
17+
## 2. Prepare Model
18+
19+
```shell
20+
hf download black-forest-labs/FLUX.1-dev --local-dir FLUX.1-dev
21+
```
22+
23+
## 3. Prepare Dataset
24+
```shell
25+
wget https://github.com/mlcommons/inference/raw/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv
26+
```
27+
28+
# Run
29+
30+
## Quantization
31+
32+
```bash
33+
bash run_quant.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --output_model=mxfp8_model
34+
```
35+
- topology: support flux_fp8 and flux_mxfp8
36+
37+
38+
## Evaluation
39+
40+
```bash
41+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_benchmark.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --quantized_model=mxfp8_model
42+
```
43+
44+
- CUDA_VISIBLE_DEVICES: split the evaluation file into the number of GPUs' subset to speed up the evaluation
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import argparse
2+
import pandas as pd
3+
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument('--split_num', type=int)
6+
parser.add_argument('--limit', default=-1, type=int)
7+
parser.add_argument('--input_file', type=str)
8+
parser.add_argument('--output_file', default="subset", type=str)
9+
args = parser.parse_args()
10+
11+
# load the TSV file
12+
df = pd.read_csv(args.input_file, sep='\t')
13+
14+
if args.limit > 0:
15+
df = df.iloc[0:args.limit]
16+
17+
num = round(len(df) / args.split_num)
18+
for i in range(args.split_num):
19+
start = i * num
20+
end = min((i + 1) * num, len(df))
21+
df_subset = df.iloc[start:end]
22+
df_subset.to_csv(f"{args.output_file}_{i}.tsv", sep='\t', index=False)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
import sys
18+
import argparse
19+
20+
import pandas as pd
21+
import tabulate
22+
import torch
23+
24+
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel
25+
from functools import partial
26+
from neural_compressor.torch.quantization import (
27+
AutoRoundConfig,
28+
convert,
29+
prepare,
30+
)
31+
from auto_round.data_type.mxfp import quant_mx_rceil
32+
from auto_round.data_type.fp8 import quant_fp8_sym
33+
from auto_round.utils import get_block_names, get_module
34+
from auto_round.compressors.diffusion.eval import metric_map
35+
from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader
36+
37+
38+
parser = argparse.ArgumentParser(
39+
description="Flux quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
40+
)
41+
parser.add_argument("--model", "--model_name", "--model_name_or_path", help="model name or path")
42+
parser.add_argument('--scheme', default="MXFP8", type=str, help="quantizaion scheme.")
43+
parser.add_argument("--quantize", action="store_true")
44+
parser.add_argument("--inference", action="store_true")
45+
parser.add_argument("--accuracy", action="store_true")
46+
parser.add_argument("--dataset", type=str, default="coco2014", help="the dataset for quantization training.")
47+
parser.add_argument("--output_dir", "--quantized_model_path", default="./tmp_autoround", type=str, help="the directory to save quantized model")
48+
parser.add_argument("--eval_dataset", default="captions_source.tsv", type=str, help="eval datasets")
49+
parser.add_argument("--output_image_path", default="./tmp_imgs", type=str, help="the directory to save quantized model")
50+
parser.add_argument("--iters", "--iter", default=1000, type=int, help="tuning iters")
51+
parser.add_argument("--limit", default=-1, type=int, help="limit the number of prompts for evaluation")
52+
53+
args = parser.parse_args()
54+
55+
56+
def inference_worker(eval_file, pipe, image_save_dir):
57+
gen_kwargs = {
58+
"guidance_scale": 7.5,
59+
"num_inference_steps": 50,
60+
"generator": None,
61+
}
62+
63+
dataloader, _, _ = get_diffusion_dataloader(eval_file, nsamples=args.limit, bs=1)
64+
for image_ids, prompts in dataloader:
65+
66+
new_ids = []
67+
new_prompts = []
68+
for idx, image_id in enumerate(image_ids):
69+
image_id = image_id.item()
70+
71+
if os.path.exists(os.path.join(image_save_dir, str(image_id) + ".png")):
72+
continue
73+
new_ids.append(image_id)
74+
new_prompts.append(prompts[idx])
75+
76+
if len(new_prompts) == 0:
77+
continue
78+
79+
output = pipe(prompt=new_prompts, **gen_kwargs)
80+
for idx, image_id in enumerate(new_ids):
81+
output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png"))
82+
83+
84+
def tune():
85+
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
86+
model = pipe.transformer
87+
layer_config = {}
88+
kwargs = {}
89+
if args.scheme == "FP8":
90+
for n, m in model.named_modules():
91+
if m.__class__.__name__ == "Linear":
92+
layer_config[n] = {"bits": 8, "data_type": "fp", "group_size": 0}
93+
elif args.scheme == "MXFP8":
94+
kwargs["scheme"] = {
95+
"bits": 8,
96+
"group_size": 32,
97+
"data_type": "mx_fp",
98+
}
99+
100+
qconfig = AutoRoundConfig(
101+
iters=args.iters,
102+
dataset=args.dataset,
103+
layer_config=layer_config,
104+
num_inference_steps=3,
105+
export_format="fake",
106+
nsamples=128,
107+
batch_size=1,
108+
output_dir=args.output_dir,
109+
**kwargs
110+
)
111+
model = prepare(model, qconfig)
112+
model = convert(model, qconfig, pipeline=pipe)
113+
114+
if __name__ == '__main__':
115+
device = "cpu" if torch.cuda.device_count() == 0 else "cuda"
116+
117+
if args.quantize:
118+
print(f"Start to quantize {args.model}.")
119+
tune()
120+
exit(0)
121+
122+
if args.inference:
123+
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
124+
125+
if not os.path.exists(args.output_image_path):
126+
os.makedirs(args.output_image_path)
127+
128+
if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "diffusion_pytorch_model.safetensors.index.json")):
129+
print(f"Loading quantized model from {args.output_dir}")
130+
model = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=torch.bfloat16)
131+
132+
# replace Linear's forward function
133+
if args.scheme == "MXFP8":
134+
def act_qdq_forward(module, x, *args, **kwargs):
135+
qdq_x, _, _ = quant_mx_rceil(x, bits=8, group_size=32, data_type="mx_fp_rceil")
136+
return module.orig_forward(qdq_x, *args, **kwargs)
137+
138+
all_quant_blocks = get_block_names(model)
139+
140+
for block_names in all_quant_blocks:
141+
for block_name in block_names:
142+
block = get_module(model, block_name)
143+
for n, m in block.named_modules():
144+
if m.__class__.__name__ == "Linear":
145+
m.orig_forward = m.forward
146+
m.forward = partial(act_qdq_forward, m)
147+
148+
if args.scheme == "FP8":
149+
def act_qdq_forward(module, x, *args, **kwargs):
150+
qdq_x, _, _ = quant_fp8_sym(x, group_size=0)
151+
return module.orig_forward(qdq_x, *args, **kwargs)
152+
153+
for n, m in model.named_modules():
154+
if m.__class__.__name__ == "Linear":
155+
m.orig_forward = m.forward
156+
m.forward = partial(act_qdq_forward, m)
157+
158+
pipe.transformer = model
159+
160+
else:
161+
print("Don't supply quantized_model_path or quantized model doesn't exist, evaluate BF16 accuracy.")
162+
163+
inference_worker(args.eval_dataset, pipe.to(device), args.output_image_path)
164+
165+
if args.accuracy:
166+
df = pd.read_csv(args.eval_dataset, sep="\t")
167+
prompt_list = []
168+
image_list = []
169+
for index, row in df.iterrows():
170+
assert "id" in row and "caption" in row
171+
caption_id = row["id"]
172+
caption_text = row["caption"]
173+
if os.path.exists(os.path.join(args.output_image_path, str(caption_id) + ".png")):
174+
prompt_list.append(caption_text)
175+
image_list.append(os.path.join(args.output_image_path, str(caption_id) + ".png"))
176+
177+
result = {}
178+
metrics = ["clip", "clip-iqa", "imagereward"]
179+
for metric in metrics:
180+
result.update(metric_map[metric](prompt_list, image_list, device))
181+
182+
print(tabulate.tabulate(result.items(), tablefmt="grid"))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
diffusers==0.35.1
2+
pandas==2.2.2
3+
clip==0.2.0
4+
image-reward==1.5
5+
torchmetrics==1.8.2
6+
transformers==4.55.0
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/bin/bash
2+
set -x
3+
4+
function main {
5+
6+
init_params "$@"
7+
run_benchmark
8+
9+
}
10+
11+
# init params
12+
function init_params {
13+
for var in "$@"
14+
do
15+
case $var in
16+
--topology=*)
17+
topology=$(echo $var |cut -f2 -d=)
18+
;;
19+
--dataset_location=*)
20+
dataset_location=$(echo $var |cut -f2 -d=)
21+
;;
22+
--input_model=*)
23+
input_model=$(echo $var |cut -f2 -d=)
24+
;;
25+
--quantized_model=*)
26+
tuned_checkpoint=$(echo $var |cut -f2 -d=)
27+
;;
28+
--limit=*)
29+
limit=$(echo $var |cut -f2 -d=)
30+
;;
31+
--output_image_path=*)
32+
output_image_path=$(echo $var |cut -f2 -d=)
33+
;;
34+
*)
35+
echo "Error: No such parameter: ${var}"
36+
exit 1
37+
;;
38+
esac
39+
done
40+
41+
}
42+
43+
44+
# run_benchmark
45+
function run_benchmark {
46+
dataset_location=${dataset_location:="captions_source.tsv"}
47+
limit=${limit:=-1}
48+
output_image_path=${output_image_path:="./tmp_imgs"}
49+
50+
if [ "${topology}" = "flux_fp8" ]; then
51+
extra_cmd="--scheme FP8 --inference"
52+
elif [ "${topology}" = "flux_mxfp8" ]; then
53+
extra_cmd="--scheme MXFP8 --inference"
54+
fi
55+
56+
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
57+
gpu_list="${CUDA_VISIBLE_DEVICES:-}"
58+
IFS=',' read -ra gpu_ids <<< "$gpu_list"
59+
visible_gpus=${#gpu_ids[@]}
60+
echo "visible_gpus: ${visible_gpus}"
61+
62+
python dataset_split.py --split_num ${visible_gpus} --input_file ${dataset_location} --limit ${limit}
63+
64+
for ((i=0; i<visible_gpus; i++)); do
65+
export CUDA_VISIBLE_DEVICES=${i}
66+
67+
python3 main.py \
68+
--model ${input_model} \
69+
--quantized_model_path ${tuned_checkpoint} \
70+
--output_image_path ${output_image_path} \
71+
--eval_dataset "subset_$i.tsv" \
72+
${extra_cmd} &
73+
program_pid+=($!)
74+
echo "Start (PID: ${program_pid[-1]}, GPU: ${i})"
75+
done
76+
wait "${program_pid[@]}"
77+
else
78+
python3 main.py \
79+
--model ${input_model} \
80+
--quantized_model_path ${tuned_checkpoint} \
81+
--output_image_path ${output_image_path} \
82+
--eval_dataset ${dataset_location} \
83+
--limit ${limit} \
84+
${extra_cmd}
85+
fi
86+
87+
echo "Start calculating final score..."
88+
89+
python3 main.py --output_image_path ${output_image_path} --accuracy
90+
}
91+
92+
main "$@"

0 commit comments

Comments
 (0)