diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py index f3e1586e714..69a049ad955 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py @@ -66,6 +66,14 @@ except: HAVE_TE_MXFP8TENSOR = False +# Detect the Blockwise FP8 tensor class +try: + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor + + HAVE_TE_BLOCKWISE_FP8TENSOR = True +except: + HAVE_TE_BLOCKWISE_FP8TENSOR = False + # Detect the "cast_master_weights_to_fp8" function of Transformer Engine try: from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_fp8 @@ -151,6 +159,11 @@ def is_float8tensor(tensor: torch.Tensor) -> bool: return HAVE_TE and isinstance(tensor, FP8_TENSOR_CLASS) +def is_blockwise_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Blockwise FP8 tensor.""" + return HAVE_TE_BLOCKWISE_FP8TENSOR and isinstance(tensor, Float8BlockwiseQTensor) + + def fp8_need_transpose_data(tensor: torch.Tensor) -> bool: """Check if a FP8 tensor needs transpose data.""" return HAVE_TE_MXFP8TENSOR and isinstance(tensor, MXFP8Tensor) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index f117da9d188..16b7d9dd50c 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -41,6 +41,7 @@ fp8_need_transpose_data_for_meta_device_init, fp8_quantize, fp8_set_raw_data, + is_blockwise_float8tensor, is_float8tensor, is_te_min_version, ) @@ -2220,10 +2221,11 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): # Nothing else needs to be done, because the main weights # do not require autograd operations, only possibly sharding. p_local = to_local_if_dtensor(p) - if is_float8tensor(p_local): - mbuf.set_item(item_id, fp8_dequantize(p_local)) - else: - mbuf.set_item(item_id, p_local) + assert not is_float8tensor(p_local), ( + self.param_to_name[p], + "fp8 param should use get_high_precision_init_val method.", + ) + mbuf.set_item(item_id, p_local) if wbuf and wbuf.is_data_distributed: # Free the memory backing the temporarily-allocated bucket associated @@ -2607,6 +2609,50 @@ def copy_main_weights_to_model_weights(self): expert_param_quantize_kwargs = copy.deepcopy(dense_param_quantize_kwargs) data_parallel_group = None expert_data_parallel_group = None + clear_quantize_kwargs = lambda kwargs: [d.clear() for d in kwargs.values()] + + def _fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs): + if len(dense_param_quantize_kwargs["model_params"]) > 0: + # If we have FP8 parameters, we need to quantize them. + fp8_quantize(data_parallel_group=data_parallel_group, **dense_param_quantize_kwargs) + + if len(expert_param_quantize_kwargs["model_params"]) > 0: + # If we have FP8 expert parameters, we need to quantize them. + fp8_quantize( + data_parallel_group=expert_data_parallel_group, **expert_param_quantize_kwargs + ) + + clear_quantize_kwargs(dense_param_quantize_kwargs) + clear_quantize_kwargs(expert_param_quantize_kwargs) + + # Special handling of blockwise FP8 + BATCH_QUANT_MEMORY_LIMIT_BYTES = 5 * 1024**3 # 5 GB + blockwise_fp8_weight_buffers = [] + blockwise_fp8_param_buffers = [] + + def _batch_quantize_blockwise_fp8_params( + dense_param_quantize_kwargs, expert_param_quantize_kwargs, blockwise_fp8_param_buffers + ): + if len(blockwise_fp8_param_buffers) == 0: + return + + # Copy original param shards into their blockwise FP8 working buffers + for bufs in blockwise_fp8_param_buffers: + bufs["bucket_param"].copy_(bufs["param"]) + + # Apply FP8 quantization to blockwise FP8 parameters + _fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs) + + # Copy quantized params back from working buffers to original param tensors + for bufs in blockwise_fp8_param_buffers: + bufs["param"].copy_(bufs["bucket_param"]) + blockwise_fp8_param_buffers.clear() + + # Free bucket storage for blockwise FP8 weight buffers + for wbuf in blockwise_fp8_weight_buffers: + wbuf.free_bucket_storage() + blockwise_fp8_weight_buffers.clear() + for pg in self.parameter_groups: mbuf = pg.main_weight_buffer wbuf = pg.model_weight_buffer @@ -2626,6 +2672,7 @@ def copy_main_weights_to_model_weights(self): shard_offsets_in_fp8 = quantize_func_kwargs["start_offsets"] shard_model_params = quantize_func_kwargs["fsdp_shard_model_params"] + has_blockwise_fp8_param = False for param in pg.params: item_id = mbuf.param_idx[param] if wbuf: @@ -2648,6 +2695,34 @@ def copy_main_weights_to_model_weights(self): model_param = to_local_if_dtensor(param) main_weight = mbuf.get_item(item_id) + if is_blockwise_float8tensor(param): + fp8_params.append(param) + if model_param.numel() == 0: + shard_fp32_from_fp8.append(None) + shard_offsets_in_fp8.append(None) + shard_model_params.append([None, None]) + else: + shard_fp32_from_fp8.append(main_weight) + shard_offsets_in_fp8.append(wbuf.locate_item_in_global_item(item_id)[0]) + bucket = wbuf.fetch_bucket() + b_model_param = wbuf.get_item_from_bucket(bucket, item_id)[ + slice(*wbuf.locate_item_in_global_item(item_id)) + ] + assert ( + transpose_param is None + ), "Blockwise FP8 does not support transpose param." + shard_model_params.append([b_model_param, None]) + assert b_model_param.numel() == model_param.numel(), ( + f"Blockwise FP8 bucket param numel {b_model_param.numel()} does" + f" not match model param numel {model_param.numel()}" + f" name: {self.param_to_name[param]}" + ) + blockwise_fp8_param_buffers.append( + {"bucket_param": b_model_param, "param": model_param} + ) + has_blockwise_fp8_param = True + continue + if is_float8tensor(param): fp8_params.append(param) if model_param.numel() == 0: @@ -2663,15 +2738,22 @@ def copy_main_weights_to_model_weights(self): if model_param.numel() > 0: model_param.data.copy_(main_weight.view(model_param.shape)) - if len(dense_param_quantize_kwargs["model_params"]) > 0: - # If we have FP8 parameters, we need to quantize them. - dense_param_quantize_kwargs["data_parallel_group"] = data_parallel_group - fp8_quantize(**dense_param_quantize_kwargs) + if has_blockwise_fp8_param: + blockwise_fp8_weight_buffers.append(wbuf) + if ( + sum([wbuf.bucket_index.size for wbuf in blockwise_fp8_weight_buffers]) + > BATCH_QUANT_MEMORY_LIMIT_BYTES + ): + _batch_quantize_blockwise_fp8_params( + dense_param_quantize_kwargs, + expert_param_quantize_kwargs, + blockwise_fp8_param_buffers, + ) - if len(expert_param_quantize_kwargs["model_params"]) > 0: - # If we have FP8 expert parameters, we need to quantize them. - expert_param_quantize_kwargs["data_parallel_group"] = expert_data_parallel_group - fp8_quantize(**expert_param_quantize_kwargs) + _batch_quantize_blockwise_fp8_params( + dense_param_quantize_kwargs, expert_param_quantize_kwargs, blockwise_fp8_param_buffers + ) + _fp8_quantize_params(dense_param_quantize_kwargs, expert_param_quantize_kwargs) @torch.no_grad() def copy_model_weights_to_main_weights(self):