Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@
except:
HAVE_TE_MXFP8TENSOR = False

# Detect the Blockwise FP8 tensor class
try:
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor

HAVE_TE_BLOCKWISE_FP8TENSOR = True
except:
HAVE_TE_BLOCKWISE_FP8TENSOR = False

# Detect the "cast_master_weights_to_fp8" function of Transformer Engine
try:
from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_fp8
Expand Down Expand Up @@ -151,6 +159,11 @@ def is_float8tensor(tensor: torch.Tensor) -> bool:
return HAVE_TE and isinstance(tensor, FP8_TENSOR_CLASS)


def is_blockwise_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Blockwise FP8 tensor."""
return HAVE_TE_BLOCKWISE_FP8TENSOR and isinstance(tensor, Float8BlockwiseQTensor)


def fp8_need_transpose_data(tensor: torch.Tensor) -> bool:
"""Check if a FP8 tensor needs transpose data."""
return HAVE_TE_MXFP8TENSOR and isinstance(tensor, MXFP8Tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
fp8_need_transpose_data_for_meta_device_init,
fp8_quantize,
fp8_set_raw_data,
is_blockwise_float8tensor,
is_float8tensor,
is_te_min_version,
)
Expand Down Expand Up @@ -2220,10 +2221,11 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
# Nothing else needs to be done, because the main weights
# do not require autograd operations, only possibly sharding.
p_local = to_local_if_dtensor(p)
if is_float8tensor(p_local):
mbuf.set_item(item_id, fp8_dequantize(p_local))
else:
mbuf.set_item(item_id, p_local)
assert not is_float8tensor(p_local), (
self.param_to_name[p],
"fp8 param should use get_high_precision_init_val method.",
)
mbuf.set_item(item_id, p_local)

if wbuf and wbuf.is_data_distributed:
# Free the memory backing the temporarily-allocated bucket associated
Expand Down Expand Up @@ -2607,6 +2609,50 @@ def copy_main_weights_to_model_weights(self):
expert_param_quantize_kwargs = copy.deepcopy(dense_param_quantize_kwargs)
data_parallel_group = None
expert_data_parallel_group = None
clear_quantize_kwargs = lambda kwargs: [d.clear() for d in kwargs.values()]

def _fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs):
if len(dense_param_quantize_kwargs["model_params"]) > 0:
# If we have FP8 parameters, we need to quantize them.
fp8_quantize(data_parallel_group=data_parallel_group, **dense_param_quantize_kwargs)

if len(expert_param_quantize_kwargs["model_params"]) > 0:
# If we have FP8 expert parameters, we need to quantize them.
fp8_quantize(
data_parallel_group=expert_data_parallel_group, **expert_param_quantize_kwargs
)

clear_quantize_kwargs(dense_param_quantize_kwargs)
clear_quantize_kwargs(expert_param_quantize_kwargs)

# Special handling of blockwise FP8
BATCH_QUANT_MEMORY_LIMIT_BYTES = 5 * 1024**3 # 5 GB
blockwise_fp8_weight_buffers = []
blockwise_fp8_param_buffers = []

def _batch_quantize_blockwise_fp8_params(
dense_param_quantize_kwargs, expert_param_quantize_kwargs, blockwise_fp8_param_buffers
):
if len(blockwise_fp8_param_buffers) == 0:
return

# Copy original param shards into their blockwise FP8 working buffers
for bufs in blockwise_fp8_param_buffers:
bufs["bucket_param"].copy_(bufs["param"])

# Apply FP8 quantization to blockwise FP8 parameters
_fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs)

# Copy quantized params back from working buffers to original param tensors
for bufs in blockwise_fp8_param_buffers:
bufs["param"].copy_(bufs["bucket_param"])
blockwise_fp8_param_buffers.clear()

# Free bucket storage for blockwise FP8 weight buffers
for wbuf in blockwise_fp8_weight_buffers:
wbuf.free_bucket_storage()
blockwise_fp8_weight_buffers.clear()

for pg in self.parameter_groups:
mbuf = pg.main_weight_buffer
wbuf = pg.model_weight_buffer
Expand All @@ -2626,6 +2672,7 @@ def copy_main_weights_to_model_weights(self):
shard_offsets_in_fp8 = quantize_func_kwargs["start_offsets"]
shard_model_params = quantize_func_kwargs["fsdp_shard_model_params"]

has_blockwise_fp8_param = False
for param in pg.params:
item_id = mbuf.param_idx[param]
if wbuf:
Expand All @@ -2648,6 +2695,34 @@ def copy_main_weights_to_model_weights(self):
model_param = to_local_if_dtensor(param)
main_weight = mbuf.get_item(item_id)

if is_blockwise_float8tensor(param):
fp8_params.append(param)
if model_param.numel() == 0:
shard_fp32_from_fp8.append(None)
shard_offsets_in_fp8.append(None)
shard_model_params.append([None, None])
else:
shard_fp32_from_fp8.append(main_weight)
shard_offsets_in_fp8.append(wbuf.locate_item_in_global_item(item_id)[0])
bucket = wbuf.fetch_bucket()
b_model_param = wbuf.get_item_from_bucket(bucket, item_id)[
slice(*wbuf.locate_item_in_global_item(item_id))
]
assert (
transpose_param is None
), "Blockwise FP8 does not support transpose param."
shard_model_params.append([b_model_param, None])
assert b_model_param.numel() == model_param.numel(), (
f"Blockwise FP8 bucket param numel {b_model_param.numel()} does"
f" not match model param numel {model_param.numel()}"
f" name: {self.param_to_name[param]}"
)
blockwise_fp8_param_buffers.append(
{"bucket_param": b_model_param, "param": model_param}
)
has_blockwise_fp8_param = True
continue

if is_float8tensor(param):
fp8_params.append(param)
if model_param.numel() == 0:
Expand All @@ -2663,15 +2738,22 @@ def copy_main_weights_to_model_weights(self):
if model_param.numel() > 0:
model_param.data.copy_(main_weight.view(model_param.shape))

if len(dense_param_quantize_kwargs["model_params"]) > 0:
# If we have FP8 parameters, we need to quantize them.
dense_param_quantize_kwargs["data_parallel_group"] = data_parallel_group
fp8_quantize(**dense_param_quantize_kwargs)
if has_blockwise_fp8_param:
blockwise_fp8_weight_buffers.append(wbuf)
if (
sum([wbuf.bucket_index.size for wbuf in blockwise_fp8_weight_buffers])
> BATCH_QUANT_MEMORY_LIMIT_BYTES
):
_batch_quantize_blockwise_fp8_params(
dense_param_quantize_kwargs,
expert_param_quantize_kwargs,
blockwise_fp8_param_buffers,
)

if len(expert_param_quantize_kwargs["model_params"]) > 0:
# If we have FP8 expert parameters, we need to quantize them.
expert_param_quantize_kwargs["data_parallel_group"] = expert_data_parallel_group
fp8_quantize(**expert_param_quantize_kwargs)
_batch_quantize_blockwise_fp8_params(
dense_param_quantize_kwargs, expert_param_quantize_kwargs, blockwise_fp8_param_buffers
)
_fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs)

@torch.no_grad()
def copy_model_weights_to_main_weights(self):
Expand Down