diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 633970f9fe..733681fe09 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -1324,9 +1324,12 @@ def make_embedding(self, embedding): self.make_reshape( weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim] ) + input_names = [weight_reshape_output, "input_ids", "lm_head.MatMul.weight_scale"]; + if not self.quant_attrs["int4"]["is_symmetric"]: + input_names.append("lm_head.MatMul.weight_zp") self.make_node( "GatherBlockQuantized", - inputs=[weight_reshape_output, "input_ids", "lm_head.MatMul.weight_scale", "lm_head.MatMul.weight_zp"], + inputs=input_names, outputs=[gather_output], name=gather_name, domain="com.microsoft",