diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 449824a33b9..a20c006677c 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -539,7 +539,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: bitwidth = int(bitwidth) transforms.append( lambda model: EmbeddingQuantHandler( - model, bitwidth=bitwidth, group_size=group_size + model, + bitwidth=bitwidth, + group_size=group_size, + packed=(bitwidth == 4), ).quantized_model() ) diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index e198ff383e9..ec543560b86 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta( quantized_decomposed_lib.define( "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", + "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", )