Skip to content

Commit

Permalink
Float8 tensor parallel for aqt_dynamic_act_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Oct 15, 2024
1 parent e7b33bc commit 21711e0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 8 deletions.
17 changes: 16 additions & 1 deletion test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torchao.quantization import int8_weight_only, float8_weight_only
from torchao.quantization import int8_weight_only, float8_weight_only, float8_dynamic_activation_float8_weight
from torchao.quantization.observer import PerRow, PerTensor

class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(int8_weight_only)
Expand All @@ -13,5 +14,19 @@ class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp")

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8dqTensorAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqTensorAffineQuantizedTensorParallel, "fp8dqt_tp")

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8dqRowAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
copy_tests(TorchAOTensorParallelTestCase, TestFloat8dqRowAffineQuantizedTensorParallel, "fp8dqr_tp")

if __name__ == "__main__":
run_tests()
31 changes: 26 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
#TODO: scale replecation should be dependent on block size
if self.scale.ndim == 1:
print("slice for dim 0, scale is 1")
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
else:
print("slice for dim 0, scale != 1")
return return_and_correct_aliasing(
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
)
elif dim == 1:
print("slice for dim 1")
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
)
elif dim == 1:
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout)
else:
raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported")
else:
Expand Down Expand Up @@ -1621,7 +1631,7 @@ def _linear_fp8_act_fp8_weight_impl(
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
scaled_mm_config = weight_tensor._layout.mm_config
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
# out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
w_tensor_impl = weight_tensor.tensor_impl
Expand All @@ -1641,9 +1651,15 @@ def _linear_fp8_act_fp8_weight_impl(
w_scale = w_scale.unsqueeze(-1).T
input_scale = preprocess_scale(input_scale, input_tensor.shape)

out_shape = get_out_shape(inpt_data.shape, w_data.shape)
# Preprocess data
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)


print(f"out_shape: {out_shape}")
print(f"input_tensor: {input_tensor.shape}, weight_tensor: {weight_tensor.shape}")
print(f"inpt_data: {inpt_data.shape}, w_data: {w_data.shape}")

# Perform the computation
return addmm_float8_unwrapped_inference(
inpt_data,
Expand Down Expand Up @@ -1858,12 +1874,17 @@ def _(func, types, args, kwargs):
end = self.shape[dim]
shape = list(self.shape)
shape[dim] = end - start
print(f"Shape: {self.shape} -> {shape}")
print(f"Block size: {self.block_size} -> {self.block_size}")
print(f"end: {end}, start: {start}")
block_size = self.block_size
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
# with slice, some shape dimension might be smaller than block_size dimension, so
# we need to make sure there is no overflow
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
print(f"slice (Outer tensor shape): {self.shape} -> {new.shape}")
print(f"slice (Inner data shape): {self.tensor_impl.float8_data.shape} -> {new.tensor_impl.float8_data.shape}")
return return_and_correct_aliasing(func, args, kwargs, new)

# this is needed for DTensor.from_local() and for flattening tensor
Expand Down
22 changes: 22 additions & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _(func, types, args, kwargs):
return func(bias, aqt, original_weight_tensor)
else:
# aten.mm.default
print('Args: ', args[0].shape, args[1].shape, type(args[0]), type(args[1]))
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
Expand Down Expand Up @@ -165,6 +166,27 @@ def _(func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)

@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
print('Input quant func: ', args[0].input_quant_func)
x = return_and_correct_aliasing(
func, args, kwargs, LinearActivationQuantizedTensor(
func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
)
print(f'Linear act Post slice: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
return x

# this is needed for DTensor.from_local() and for flattening tensor
@implements(aten.view.default)
def _(func, types, args, kwargs):
print('Linear view args:', args[1:])
print('Device: ', args[0].original_weight_tensor.device)
x= return_and_correct_aliasing(
func, args, kwargs, LinearActivationQuantizedTensor(func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func)
)
print(f'Linear act Post view: {x.original_weight_tensor.shape} {x.original_weight_tensor.tensor_impl.float8_data.shape}')
return x

to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float

if TORCH_VERSION_AT_LEAST_2_5:
Expand Down
11 changes: 9 additions & 2 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
print('colwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)
return m

@staticmethod
Expand All @@ -264,11 +265,15 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
print(f'dtensor shape: {dtensor.shape}')
print(f'Other dtensor values: {local_shard.original_weight_tensor.tensor_impl.float8_data.shape}, {mesh}, {[Shard(1)]}')
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
print('rowwise shard Shapeof m.linear.weight : ', m.linear.weight.shape)

return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
Expand Down Expand Up @@ -302,11 +307,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))

print('Run before y')
# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
print('Run before y_q')
y_q = dn_quant(up_quant(example_input))
print('Executed y_q')

mesh = self.build_device_mesh()
mesh.device_type = "cuda"
Expand Down

0 comments on commit 21711e0

Please sign in to comment.