Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
</tr>
</thead>
<tbody>
<tr>
<td>FLUX.1-dev</td>
<td>Text to Image</td>
<td>Quantization (MXFP8+FP8)</td>
<td><a href="./pytorch/diffusion_model/diffusers/flux">link</a></td>
</tr>
<tr>
<td>Llama-4-Scout-17B-16E-Instruct</td>
<td>Multimodal Modeling</td>
Expand Down
44 changes: 44 additions & 0 deletions examples/pytorch/diffusion_model/diffusers/flux/README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Step-by-Step

This example quantizes and validates the accuracy of Flux.

# Prerequisite

## 1. Environment

```shell
pip install -r requirements.txt
# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/[email protected]` for the latest updates before neural-compressor v3.6 release
pip install neural-compressor-pt==3.6
# Use `pip install git+https://github.com/intel/[email protected]` for the latest updates before auto-round v0.8.0 release
pip install auto-round==0.8.0
```

## 2. Prepare Model

```shell
hf download black-forest-labs/FLUX.1-dev --local-dir FLUX.1-dev
```

## 3. Prepare Dataset
```shell
wget https://github.com/mlcommons/inference/raw/refs/heads/master/text_to_image/coco2014/captions/captions_source.tsv
```

# Run

## Quantization

```bash
bash run_quant.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --output_model=mxfp8_model
```
- topology: support flux_fp8 and flux_mxfp8


## Evaluation

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_benchmark.sh --topology=flux_mxfp8 --input_model=FLUX.1-dev --quantized_model=mxfp8_model
```

- CUDA_VISIBLE_DEVICES: split the evaluation file into the number of GPUs' subset to speed up the evaluation
22 changes: 22 additions & 0 deletions examples/pytorch/diffusion_model/diffusers/flux/dataset_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import argparse
import pandas as pd

parser = argparse.ArgumentParser()
parser.add_argument('--split_num', type=int)
parser.add_argument('--limit', default=-1, type=int)
parser.add_argument('--input_file', type=str)
parser.add_argument('--output_file', default="subset", type=str)
args = parser.parse_args()

# load the TSV file
df = pd.read_csv(args.input_file, sep='\t')

if args.limit > 0:
df = df.iloc[0:args.limit]

num = round(len(df) / args.split_num)
for i in range(args.split_num):
start = i * num
end = min((i + 1) * num, len(df))
df_subset = df.iloc[start:end]
df_subset.to_csv(f"{args.output_file}_{i}.tsv", sep='\t', index=False)
182 changes: 182 additions & 0 deletions examples/pytorch/diffusion_model/diffusers/flux/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import sys
import argparse

import pandas as pd
import tabulate
import torch

from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel
from functools import partial
from neural_compressor.torch.quantization import (
AutoRoundConfig,
convert,
prepare,
)
from auto_round.data_type.mxfp import quant_mx_rceil
from auto_round.data_type.fp8 import quant_fp8_sym
from auto_round.utils import get_block_names, get_module
from auto_round.compressors.diffusion.eval import metric_map
from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader


parser = argparse.ArgumentParser(
description="Flux quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--model", "--model_name", "--model_name_or_path", help="model name or path")
parser.add_argument('--scheme', default="MXFP8", type=str, help="quantizaion scheme.")
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--inference", action="store_true")
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--dataset", type=str, default="coco2014", help="the dataset for quantization training.")
parser.add_argument("--output_dir", "--quantized_model_path", default="./tmp_autoround", type=str, help="the directory to save quantized model")
parser.add_argument("--eval_dataset", default="captions_source.tsv", type=str, help="eval datasets")
parser.add_argument("--output_image_path", default="./tmp_imgs", type=str, help="the directory to save quantized model")
parser.add_argument("--iters", "--iter", default=1000, type=int, help="tuning iters")
parser.add_argument("--limit", default=-1, type=int, help="limit the number of prompts for evaluation")

args = parser.parse_args()


def inference_worker(eval_file, pipe, image_save_dir):
gen_kwargs = {
"guidance_scale": 7.5,
"num_inference_steps": 50,
"generator": None,
}

dataloader, _, _ = get_diffusion_dataloader(eval_file, nsamples=args.limit, bs=1)
for image_ids, prompts in dataloader:

new_ids = []
new_prompts = []
for idx, image_id in enumerate(image_ids):
image_id = image_id.item()

if os.path.exists(os.path.join(image_save_dir, str(image_id) + ".png")):
continue
new_ids.append(image_id)
new_prompts.append(prompts[idx])

if len(new_prompts) == 0:
continue

output = pipe(prompt=new_prompts, **gen_kwargs)
for idx, image_id in enumerate(new_ids):
output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png"))


def tune():
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
model = pipe.transformer
layer_config = {}
kwargs = {}
if args.scheme == "FP8":
for n, m in model.named_modules():
if m.__class__.__name__ == "Linear":
layer_config[n] = {"bits": 8, "data_type": "fp", "group_size": 0}
elif args.scheme == "MXFP8":
kwargs["scheme"] = {
"bits": 8,
"group_size": 32,
"data_type": "mx_fp",
}

qconfig = AutoRoundConfig(
iters=args.iters,
dataset=args.dataset,
layer_config=layer_config,
num_inference_steps=3,
export_format="fake",
nsamples=128,
batch_size=1,
output_dir=args.output_dir,
**kwargs
)
model = prepare(model, qconfig)
model = convert(model, qconfig, pipeline=pipe)

if __name__ == '__main__':
device = "cpu" if torch.cuda.device_count() == 0 else "cuda"

if args.quantize:
print(f"Start to quantize {args.model}.")
tune()
exit(0)

if args.inference:
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)

if not os.path.exists(args.output_image_path):
os.makedirs(args.output_image_path)

if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "diffusion_pytorch_model.safetensors.index.json")):
print(f"Loading quantized model from {args.output_dir}")
model = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=torch.bfloat16)

# replace Linear's forward function
if args.scheme == "MXFP8":
def act_qdq_forward(module, x, *args, **kwargs):
qdq_x, _, _ = quant_mx_rceil(x, bits=8, group_size=32, data_type="mx_fp_rceil")
return module.orig_forward(qdq_x, *args, **kwargs)

all_quant_blocks = get_block_names(model)

for block_names in all_quant_blocks:
for block_name in block_names:
block = get_module(model, block_name)
for n, m in block.named_modules():
if m.__class__.__name__ == "Linear":
m.orig_forward = m.forward
m.forward = partial(act_qdq_forward, m)

if args.scheme == "FP8":
def act_qdq_forward(module, x, *args, **kwargs):
qdq_x, _, _ = quant_fp8_sym(x, group_size=0)
return module.orig_forward(qdq_x, *args, **kwargs)

for n, m in model.named_modules():
if m.__class__.__name__ == "Linear":
m.orig_forward = m.forward
m.forward = partial(act_qdq_forward, m)

pipe.transformer = model

else:
print("Don't supply quantized_model_path or quantized model doesn't exist, evaluate BF16 accuracy.")

inference_worker(args.eval_dataset, pipe.to(device), args.output_image_path)

if args.accuracy:
df = pd.read_csv(args.eval_dataset, sep="\t")
prompt_list = []
image_list = []
for index, row in df.iterrows():
assert "id" in row and "caption" in row
caption_id = row["id"]
caption_text = row["caption"]
if os.path.exists(os.path.join(args.output_image_path, str(caption_id) + ".png")):
prompt_list.append(caption_text)
image_list.append(os.path.join(args.output_image_path, str(caption_id) + ".png"))

result = {}
metrics = ["clip", "clip-iqa", "imagereward"]
for metric in metrics:
result.update(metric_map[metric](prompt_list, image_list, device))

print(tabulate.tabulate(result.items(), tablefmt="grid"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
diffusers==0.35.1
pandas==2.2.2
clip==0.2.0
image-reward==1.5
torchmetrics==1.8.2
transformers==4.55.0
92 changes: 92 additions & 0 deletions examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
for var in "$@"
do
case $var in
--topology=*)
topology=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--quantized_model=*)
tuned_checkpoint=$(echo $var |cut -f2 -d=)
;;
--limit=*)
limit=$(echo $var |cut -f2 -d=)
;;
--output_image_path=*)
output_image_path=$(echo $var |cut -f2 -d=)
;;
*)
echo "Error: No such parameter: ${var}"
exit 1
;;
esac
done

}


# run_benchmark
function run_benchmark {
dataset_location=${dataset_location:="captions_source.tsv"}
limit=${limit:=-1}
output_image_path=${output_image_path:="./tmp_imgs"}

if [ "${topology}" = "flux_fp8" ]; then
extra_cmd="--scheme FP8 --inference"
elif [ "${topology}" = "flux_mxfp8" ]; then
extra_cmd="--scheme MXFP8 --inference"
fi

if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
gpu_list="${CUDA_VISIBLE_DEVICES:-}"
IFS=',' read -ra gpu_ids <<< "$gpu_list"
visible_gpus=${#gpu_ids[@]}
echo "visible_gpus: ${visible_gpus}"

python dataset_split.py --split_num ${visible_gpus} --input_file ${dataset_location} --limit ${limit}

for ((i=0; i<visible_gpus; i++)); do
export CUDA_VISIBLE_DEVICES=${i}

python3 main.py \
--model ${input_model} \
--quantized_model_path ${tuned_checkpoint} \
--output_image_path ${output_image_path} \
--eval_dataset "subset_$i.tsv" \
${extra_cmd} &
program_pid+=($!)
echo "Start (PID: ${program_pid[-1]}, GPU: ${i})"
done
wait "${program_pid[@]}"
else
python3 main.py \
--model ${input_model} \
--quantized_model_path ${tuned_checkpoint} \
--output_image_path ${output_image_path} \
--eval_dataset ${dataset_location} \
--limit ${limit} \
${extra_cmd}
fi

echo "Start calculating final score..."

python3 main.py --output_image_path ${output_image_path} --accuracy
}

main "$@"
Loading