Skip to content

Commit

Permalink
fix fp8 weight quant need contiguous tensor (#632)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Dec 2, 2024
1 parent bd0712e commit 9845569
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lightllm/common/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(self):
def quantize(self, weight: torch.Tensor):
if self.is_moe:
return self.quantize_moe(weight)
qweight, weight_scale = ops.scaled_fp8_quant(weight.cuda(), scale=None, use_per_token_if_dynamic=True)
qweight, weight_scale = ops.scaled_fp8_quant(
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=True
)
return qweight.transpose(0, 1), weight_scale

def quantize_moe(self, weight):
Expand All @@ -71,7 +73,9 @@ def quantize_moe(self, weight):
weight_scales = []
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda()
for i in range(num_experts):
qweight, weight_scale = ops.scaled_fp8_quant(weight[i].cuda(), scale=None, use_per_token_if_dynamic=False)
qweight, weight_scale = ops.scaled_fp8_quant(
weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
)
qweights[i] = qweight
weight_scales.append(weight_scale)
weight_scale = torch.cat(weight_scales, dim=0).reshape(-1)
Expand Down

0 comments on commit 9845569

Please sign in to comment.