Skip to content

Commit 2c461b4

Browse files
author
Guang Yang
committed
rebase on gemma3 ci and log pte file size
1 parent bf5605b commit 2c461b4

File tree

5 files changed

+15
-11
lines changed

5 files changed

+15
-11
lines changed

optimum/commands/export/executorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def parse_args_executorch(parser):
5858
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
5959
)
6060
required_group.add_argument(
61-
"-q",
62-
"--quantize",
61+
"-qmode",
62+
"--quantization_mode",
6363
required=False,
6464
choices=["8da4w"],
6565
help="Quantization recipe to use. Defaults to None.",
@@ -79,8 +79,8 @@ def run(self):
7979
kwargs = {}
8080
if self.args.use_custom_sdpa:
8181
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
82-
if self.args.quantize:
83-
kwargs["quantize"] = self.args.quantize
82+
if self.args.quantization_mode:
83+
kwargs["quantization_mode"] = self.args.quantization_mode
8484

8585
main_export(
8686
model_name_or_path=self.args.model,

optimum/executorch/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def _from_pretrained(
180180
local_files_only=local_files_only,
181181
)
182182
model = _load_for_executorch(model_cache_path)
183-
logging.info(f"Loaded model from {model_cache_path}")
183+
logging.info(
184+
f"Loaded model from {model_cache_path} ({os.path.getsize(model_cache_path) / (1024 * 1024):.2f} MB)"
185+
)
184186

185187
return {default_file_name.removesuffix(_PTE_SUFFIX): model}
186188

optimum/exporters/executorch/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from .recipe_registry import discover_recipes, recipe_registry
2727

2828

29-
logger = logging.getLogger(__name__)
30-
3129
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
3230

3331

@@ -82,6 +80,8 @@ def export_to_executorch(
8280
full_path = os.path.join(f"{output_dir}", f"{name}.pte")
8381
with open(full_path, "wb") as f:
8482
prog.write_to_file(f)
85-
logger.info(f"Saved exported program to {full_path}")
83+
logging.info(
84+
f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)"
85+
)
8686

8787
return executorch_progs

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5757
cache_implementation = kwargs.get("cache_implementation", "static")
5858
max_length = kwargs.get("max_length", 2048)
5959
config = kwargs.get("config", None)
60-
quantization_recipe = kwargs.get("quantize", None)
60+
quantization_mode = kwargs.get("quantization_mode", None)
6161

6262
eager_model = AutoModelForCausalLM.from_pretrained(
6363
model_name_or_path,
@@ -77,7 +77,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7777
),
7878
)
7979

80-
if quantization_recipe == "8da4w":
80+
if quantization_mode == "8da4w":
8181
if parse(torchao.__version__) < parse("0.11.0.dev0"):
8282
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
8383

tests/models/test_modeling_gemma3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ def test_gemma3_text_generation_with_custom_sdpa_float16(self):
177177
reason="Only available on torchao >= 0.11.0.dev0",
178178
)
179179
def test_gemma3_text_generation_with_custom_sdpa_8da4w(self):
180-
model_id = "google/gemma-3-1b-it"
180+
# TODO: Until https://github.com/huggingface/optimum/issues/2127 is fixed, have to use non-gated model on CI
181+
# model_id = "google/gemma-3-1b-it"
182+
model_id = "unsloth/gemma-3-1b-it"
181183
prompt = "Write a poem about a machine learning."
182184
tokenizer = AutoTokenizer.from_pretrained(model_id)
183185
kwargs = {"quantize": "8da4w"}

0 commit comments

Comments
 (0)