Skip to content

Commit

Permalink
Add support for quantized and packed 4 bit embedding (pytorch#188)
Browse files Browse the repository at this point in the history
* add 4 bit packed embedding

* add unit test

* typo

* fix json syntax

* stack dimension is last

* handle json dict conversion

* add packed parameter to module rewrite

* typo

* typo

* typo

* typo

* obscure div error due to float result in lieu of int without rounding_mode
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 8be8516 commit 6e9d0c5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ jobs:
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "***********************************************"
echo "******* Emb: 4bit channel-wise quantized ******"
echo "***********************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 0, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "***********************************************"
echo "******** Emb: 4bit group-wise quantized *******"
echo "***********************************************"
python generate.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"embedding" : {"bitwidth": 4, "group_size": 8, "packed": "True"}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******* INT8 channel-wise quantized ******"
echo "******************************************"
Expand Down
2 changes: 1 addition & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def cli_args():
parser.add_argument(
"--num-samples",
type=int,
default=5,
default=1,
help="Number of samples.")
parser.add_argument(
"--max-new-tokens",
Expand Down
59 changes: 48 additions & 11 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


def replace_embedding_weight_only_grouped_int8_per_channel(
module, bitwidth: int = 8, group_size: Optional[int] = None
module, bitwidth: int = 8, group_size: Optional[int] = None, packed = False
):
for name, child in module.named_children():
# print(f"name: {name}")
Expand All @@ -506,22 +506,29 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
packed=packed,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, bitwidth, group_size
child, bitwidth, group_size, packed
)


class EmbeddingOnlyInt8QuantHandler(QuantHandler):
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None, packed = False):
if isinstance(packed, str):
packed = (packed == "True")
self.mod = mod
self.group_size = group_size
self.bitwidth = bitwidth
self.packed = packed
if (bitwidth != 4) and packed:
raise RuntimeError("pack only works with bitsize 4")


@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
def create_quantized_state_dict(self, packed=False) -> Dict:
cur_state_dict = self.mod.state_dict()

if self.bitwidth == 4:
Expand Down Expand Up @@ -553,7 +560,22 @@ def create_quantized_state_dict(self) -> Dict:
self.group_size,
scales_dtype=mod.weight.dtype,
)


if packed:
if weight.shape[-1] %2 != 0:
raise RUntimeError("automatic padding not implemented yet")

weight_range_shifted = weight.add(8).view(torch.uint8)
weight_view = weight_range_shifted.view(
weight.shape[0],
weight.shape[1] //2,
2
)
weight_even = weight_view[:,:,0] * 16 # left shift 4
weight_odd = weight_view[:,:,1]
weight_packed = weight_even + weight_odd
weight = weight_packed

# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
Expand All @@ -563,12 +585,12 @@ def create_quantized_state_dict(self) -> Dict:

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.bitwidth, self.group_size
self.mod, self.bitwidth, self.group_size, self.packed
)
return self.mod

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
Expand All @@ -582,15 +604,22 @@ def __init__(
group_size: Optional[int] = None,
device=None,
dtype=torch.half,
packed=False,
) -> None:
super().__init__()
if group_size is None or group_size == 0:
group_size = embedding_dim
self.group_size = group_size
self.dtype = dtype
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
)
self.packed = packed
if not packed:
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
)
else: # packed
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8)
)
groups_per_row = (embedding_dim + group_size - 1) // group_size
if groups_per_row > 1:
self.register_buffer(
Expand All @@ -612,7 +641,15 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

weight = self.weight
if self.packed:
weight_even = self.weight.div(16, rounding_mode='trunc')
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.view(torch.int8).add(-8)
else:
weight = self.weight

scales = self.scales.view(weight.shape[0], -1)

result_weights = F.embedding(indices, weight)
Expand Down

0 comments on commit 6e9d0c5

Please sign in to comment.