diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index 2b70beea66..1c5868bc97 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -73,6 +73,28 @@ jobs: python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti cat ./output_aoti + echo "***********************************************" + echo "******* Emb: 4bit channel-wise quantized ******" + echo "***********************************************" + python generate.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --compile --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + + echo "***********************************************" + echo "******** Emb: 4bit group-wise quantized *******" + echo "***********************************************" + python generate.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager + python generate.py --compile --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled + cat ./output_compiled + python export.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so + python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti + cat ./output_aoti + echo "******************************************" echo "******* INT8 channel-wise quantized ******" echo "******************************************" diff --git a/cli.py b/cli.py index 06d1f64667..b14f0a9444 100644 --- a/cli.py +++ b/cli.py @@ -87,7 +87,7 @@ def cli_args(): parser.add_argument( "--num-samples", type=int, - default=5, + default=1, help="Number of samples.") parser.add_argument( "--max-new-tokens", diff --git a/quantize.py b/quantize.py index 1bb48ff39f..80c525b9a9 100644 --- a/quantize.py +++ b/quantize.py @@ -492,7 +492,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, bitwidth: int = 8, group_size: Optional[int] = None + module, bitwidth: int = 8, group_size: Optional[int] = None, packed = False ): for name, child in module.named_children(): # print(f"name: {name}") @@ -506,22 +506,29 @@ def replace_embedding_weight_only_grouped_int8_per_channel( vocab_size=child.weight.shape[0], embedding_dim=child.weight.shape[1], group_size=group_size, + packed=packed, ), ) else: replace_embedding_weight_only_grouped_int8_per_channel( - child, bitwidth, group_size + child, bitwidth, group_size, packed ) class EmbeddingOnlyInt8QuantHandler(QuantHandler): - def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None): + def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None, packed = False): + if isinstance(packed, str): + packed = (packed == "True") self.mod = mod self.group_size = group_size self.bitwidth = bitwidth + self.packed = packed + if (bitwidth != 4) and packed: + raise RuntimeError("pack only works with bitsize 4") + @torch.no_grad() - def create_quantized_state_dict(self) -> Dict: + def create_quantized_state_dict(self, packed=False) -> Dict: cur_state_dict = self.mod.state_dict() if self.bitwidth == 4: @@ -553,7 +560,22 @@ def create_quantized_state_dict(self) -> Dict: self.group_size, scales_dtype=mod.weight.dtype, ) - + + if packed: + if weight.shape[-1] %2 != 0: + raise RUntimeError("automatic padding not implemented yet") + + weight_range_shifted = weight.add(8).view(torch.uint8) + weight_view = weight_range_shifted.view( + weight.shape[0], + weight.shape[1] //2, + 2 + ) + weight_even = weight_view[:,:,0] * 16 # left shift 4 + weight_odd = weight_view[:,:,1] + weight_packed = weight_even + weight_odd + weight = weight_packed + # Update state dict cur_state_dict[f"{fqn}.weight"] = weight # squeeze makes groupsize=rowsize unidimensional @@ -563,12 +585,12 @@ def create_quantized_state_dict(self) -> Dict: def convert_for_runtime(self) -> nn.Module: replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.bitwidth, self.group_size + self.mod, self.bitwidth, self.group_size, self.packed ) return self.mod def quantized_model(self) -> nn.Module: - model_updated_state_dict = self.create_quantized_state_dict() + model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() self.mod.load_state_dict(model_updated_state_dict) return self.mod @@ -582,15 +604,22 @@ def __init__( group_size: Optional[int] = None, device=None, dtype=torch.half, + packed=False, ) -> None: super().__init__() if group_size is None or group_size == 0: group_size = embedding_dim self.group_size = group_size self.dtype = dtype - self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8) - ) + self.packed = packed + if not packed: + self.register_buffer( + "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8) + ) + else: # packed + self.register_buffer( + "weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8) + ) groups_per_row = (embedding_dim + group_size - 1) // group_size if groups_per_row > 1: self.register_buffer( @@ -612,7 +641,15 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: # result_weights = self.weight.index_select(0, indices.view(-1)) # result_scales = self.scales.index_select(0, indices.view(-1)) - weight = self.weight + if self.packed: + weight_even = self.weight.div(16, rounding_mode='trunc') + weight_odd = self.weight.remainder(16) + weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) + weight = weight_unpacked.view(self.weight.shape[0], -1) + weight = weight.view(torch.int8).add(-8) + else: + weight = self.weight + scales = self.scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight)