Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
091d81c
Support GLM-Image model quantizaiton
lvliang-intel Mar 8, 2026
b367bf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2026
e73a31d
fix test script
lvliang-intel Mar 8, 2026
c1d819b
Merge branch 'lvl/support_glm_image' of https://github.com/intel/auto…
lvliang-intel Mar 8, 2026
510b6c4
support hybrid mode
lvliang-intel Mar 10, 2026
102f8ae
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Mar 10, 2026
1691ac6
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Mar 17, 2026
27cddba
fix hybrid mode
lvliang-intel Mar 17, 2026
e80472f
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Mar 17, 2026
0ec767a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2026
8f9f607
Merge branch 'main' into lvl/support_glm_image
lvliang-intel Mar 19, 2026
dcf3b52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
0e4dafc
fix issue
lvliang-intel Mar 19, 2026
19f8484
fix comments
lvliang-intel Mar 20, 2026
344825f
Merge branch 'main' of https://github.com/intel/auto-round into lvl/s…
lvliang-intel Mar 20, 2026
57c58d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
1367356
Merge branch 'main' into lvl/support_glm_image
chensuyue Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LLMCompressor,
MLLMCompressor,
)
from auto_round.compressors.diffusion.hybrid import HybridCompressor, is_hybrid_diffusion_model
from auto_round.logger import deprecated, logger
from auto_round.schemes import QuantizationScheme
from auto_round.utils import is_diffusion_model, is_mllm_model
Expand Down Expand Up @@ -162,7 +163,19 @@ def __new__(

model_cls = []

if (extra_config and not extra_config.mllm_config.is_default()) or is_mllm_model(model, platform=platform):
has_multimodal_assets = kwargs.get("processor") is not None or kwargs.get("image_processor") is not None

if is_hybrid_diffusion_model(model):
logger.info("using Hybrid AR+Diffusion mode for hybrid model.")
model_cls.append(HybridCompressor)
if extra_config:
extra_config.mllm_config = None
extra_config.diffusion_config = None
elif (
(extra_config and not extra_config.mllm_config.is_default())
or has_multimodal_assets
or is_mllm_model(model, platform=platform)
):
logger.info("using MLLM mode for multimodal model.")
model_cls.append(MLLMCompressor)
if extra_config:
Expand Down
1 change: 1 addition & 0 deletions auto_round/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from auto_round.compressors.base import LLMCompressor
from auto_round.compressors.mllm.compressor import MLLMCompressor
from auto_round.compressors.diffusion.compressor import DiffusionCompressor
from auto_round.compressors.diffusion.hybrid import HybridCompressor
from auto_round.compressors.config import (
DiffusionExtraConfig,
ExtraConfig,
Expand Down
9 changes: 7 additions & 2 deletions auto_round/compressors/diffusion/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
extract_block_names_to_str,
find_matching_blocks,
get_block_names,
merge_block_output_keys,
wrap_block_forward_positional_to_kwargs,
)

pipeline_utils = LazyImport("diffusers.pipelines.pipeline_utils")
Expand Down Expand Up @@ -168,6 +170,9 @@ def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, dict]:
q_inputs = {k: q_inputs.pop(k, None) for k in input_id_str}
return inputs, q_inputs

def _get_block_forward_func(self, name):
return wrap_block_forward_positional_to_kwargs(super()._get_block_forward_func(name))

def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[dict, dict]:
input_id_str = [key for key in inputs.keys() if "hidden_state" in key]
input_ids = {k: inputs.pop(k, None) for k in input_id_str}
Expand Down Expand Up @@ -201,7 +206,7 @@ def _get_current_q_output(
)
if isinstance(current_input_ids, dict):
hidden_states = current_input_ids.pop("hidden_states")
current_input_others.update(current_input_ids)
merge_block_output_keys(block, current_input_others, current_input_ids)
current_input_ids = hidden_states
output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device, idx)
return output_q.to(cache_device)
Expand Down Expand Up @@ -247,7 +252,7 @@ def _get_block_outputs(
)
if isinstance(tmp_input_ids, dict):
hidden_states = tmp_input_ids.pop("hidden_states")
tmp_input_others.update(tmp_input_ids)
merge_block_output_keys(block, tmp_input_others, tmp_input_ids)
tmp_input_ids = hidden_states

tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device, None)
Expand Down
Loading
Loading