diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index d3f8c82baa..c343299242 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -9,57 +9,73 @@ import argparse import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, +) import torch import torch.distributed as dist +from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn, optim from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext +LOCAL_RANK = None -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x - - -def save_custom_attrs(module): - custom_attrs = {} - for name, param in module.named_parameters(): - attrs = vars(param) - custom_attrs[name] = {k: v for k, v in attrs.items()} - return custom_attrs - - -def restore_custom_attrs(module, custom_attrs): - for name, param in module.named_parameters(): - if name in custom_attrs: - for attr_name, attr_value in custom_attrs[name].items(): - setattr(param, attr_name, attr_value) +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") - parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") - parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") - parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") - parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads") + parser.add_argument("--head-dim", type=int, default=64, help="Attention head size") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input") + parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") + parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) + parser.add_argument( + "--recipe", + type=str, + default="mx_fp8_block_scaling", + help="Quantizer type.", + choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"], + ) + parser.add_argument( + "--layer-type", + type=str, + default="TransformerLayer", + choices=[ + "Linear", + "LayerNormLinear", + "LayerNormMLP", + "MultiheadAttention", + "TransformerLayer", + ], + help="Transformer Engine layer type", + ) + parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model") parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" ) + parser.add_argument( + "--device", + type=str, + default="meta", + help="Device to run the model on.", + choices=["cuda", "meta"], + ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") # Adding hsdp_dim as a list argument, comma-separated parser.add_argument( @@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None): return args -sub_modules_to_wrap = [te.Linear] +## Methods to help initialize the TE model in an FSDP2 setting +## with required configurations based on command line args +def get_te_layer_from_string(layer_name): + te_layer_types = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + te_layer_names = [layer.__name__ for layer in te_layer_types] + te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types)) + if layer_name.lower() not in te_layer_map.keys(): + raise argparse.ArgumentTypeError( + f'"{layer_name}" is not a valid Transformer Engine layer, ' + f"please choose layer from {te_layer_names}." + ) + return te_layer_map[layer_name.lower()] + + +def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): + if recipe == "delayed_scaling": + return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + elif recipe == "current_scaling": + return Float8CurrentScaling(fp8_format=fp8_format) + elif recipe == "mx_fp8_block_scaling": + return MXFP8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unknown quantizer type: {recipe}") + + +def init_te_model(config): + hidden_size = config.num_heads * config.head_dim + args = [hidden_size, hidden_size] + inp_shape = [config.seq_length, config.batch_size, hidden_size] + out_shape = [config.seq_length, config.batch_size, hidden_size] + if config.params_dtype == "float16": + params_dtype = torch.float16 + elif config.params_dtype == "bfloat16": + params_dtype = torch.bfloat16 + else: + params_dtype = torch.float32 + kwargs = { + "params_dtype": params_dtype, + } + kwargs["device"] = config.device + + layer_type = get_te_layer_from_string(config.layer_type) + # We are creating model in a way so that we can test both reshard_after_forward=True/False cases. + # more details below. + if layer_type in [te.MultiheadAttention, te.TransformerLayer]: + # For this case, we are creating a model that resemebles production use-cases + # wherein there are mltiple TransformerLayers in the model. And we would need + # to shard each transformer layer. Since each transformer layer is not a root module, + # FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model. + args[1] *= 4 # FFN hidden size + args.append(config.num_heads) + kwargs["fuse_qkv_params"] = True + if layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)]) + elif layer_type == te.LayerNormLinear: + # For this case, we are creating a model with just one LayerNormLinear layer + # so that the model itself is a root module, and FSDP2's fully_shard assigns + # reshard_after_forward=True for the parameters of these model. + args[1] *= 3 # QKV projection + out_shape[-1] *= 3 + model = layer_type(*args, **kwargs) + else: + model = layer_type(*args, **kwargs) + + return model, inp_shape, out_shape + + +def get_device_mesh(world_size, sharding_dims): + dist_print(f"sharding-dims:{sharding_dims}") + device_ids = list(range(world_size)) + if sharding_dims is None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 1: + assert sharding_dims[0] == world_size + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 2: # HSDP + assert sharding_dims[0] * sharding_dims[1] == world_size + mesh = init_device_mesh( + "cuda", + (sharding_dims[0], sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + return mesh + + +def shard_model_with_fsdp2(model, mesh): + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + return model + + +#### Methods to save the custom attributes of QuantizedTensors before sharding +#### them with FSDP2, and restore them after sharding. +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + # Ignore FP8 metadata attributes. Otherwise we will save duplicate copies + # for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save. + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +@torch.no_grad() +def test_fp8_fsdp2_allgather(model): + # Do manual allgather in fp32 and match against fp8 allgather done + # with fsdp2 + # FP32 manual weight allgather + fp32_allgathered_params = {} + for name, param in model.named_parameters(): + assert isinstance(param, DTensor) + local_tensor = param._local_tensor + device_mesh = param.device_mesh + dist_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + # Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch + # for local_tensor will go down the dequantization route. + gathered_tensor = [ + torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group)) + ] + dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group) + full_tensor = torch.cat(gathered_tensor, dim=0) + fp32_allgathered_params[name] = full_tensor + # FP8 allgather using FSDP2 + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "unshard"): + module.unshard() + # Make sure allgathered parameters match exactly + for name, param in model.named_parameters(): + assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) + # Revert model to original sharded state + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "reshard"): + module.reshard() def _train(args): + global LOCAL_RANK assert "TORCHELASTIC_RUN_ID" in os.environ WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -103,74 +279,69 @@ def _train(args): # FP8 Configuration fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - - # Create build context manager - if args.fp8_init: - from transformer_engine.pytorch import quantized_model_init + fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) - build_model_context = quantized_model_init() + build_model_context_args = {} + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext else: - build_model_context = nullcontext() + from transformer_engine.pytorch import fp8_model_init - # Build the model with the specified context - with build_model_context: - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + build_model_context_args["recipe"] = fp8_recipe - # Move the model to the correct device - model.to(device) + dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB") + # Create the model on the meta/cuda device as per args + with build_model_context(**build_model_context_args): + model, inp_shape, out_shape = init_te_model(args) + dist_print( + f"Memory after model init on device {args.device}:" + f" {torch.cuda.memory_allocated(device)/1e6} MB" + ) - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") # Creating a DeviceMesh for fully_shard world_size = int(WORLD_SIZE) - device_ids = list(range(world_size)) - if LOCAL_RANK == 0: - print(f"sharding-dims:{args.sharding_dims}") # Setup the sharding mesh for FSDP/HSDP - if args.sharding_dims == None: # FSDP - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 1: - assert args.sharding_dims[0] == device_ids[-1] + 1 - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 2: # HSDP - assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 - mesh = init_device_mesh( - "cuda", - (args.sharding_dims[0], args.sharding_dims[1]), - mesh_dim_names=("replicate", "shard"), - ) - else: - assert False - - # Apply FSDP/HSDP + mesh = get_device_mesh(world_size, args.sharding_dims) custom_attrs = save_custom_attrs(model) - for sub_module in model.modules(): - if any( - isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap - ): - fully_shard(sub_module, mesh=mesh) - fully_shard(model, mesh=mesh) + model = shard_model_with_fsdp2(model, mesh) restore_custom_attrs(model, custom_attrs) + # model now has DTensors as its parameters + + if args.device == "meta": + # After FSDP2 has been applied, materialize and initialize the sharded parameters + # TE base.py's reset_parameters() handles DTensors with FP8 initialization + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + dist_print(f" Sharded parameters materialized and initialized on cuda device.") + + dist_print( + f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" + ) optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(args.batch_size, args.input_size).to(device) + input_data = torch.randn(inp_shape).to(device) with te.autocast(enabled=True, recipe=fp8_recipe): output = model(input_data) - target = torch.randn(args.batch_size, args.output_size).to(device) + target = torch.randn(out_shape).to(device) loss = F.mse_loss(output, target) loss.backward() optimizer.step() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + dist_print(f"Iteration {iteration} completed with loss {loss.item()}") + + # Some of the FSDP states are lazy initialized during FSDP forward pass + # so testing fp8 allgather at the end of the training loop. + if args.fp8_init: + test_fp8_fsdp2_allgather(model) dist.destroy_process_group() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Done...") return 0 diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 8fe4e8bc7c..91d6fc6ed1 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -12,22 +12,26 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) - +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims): +def _run_test(fp_init, sharding_dims, recipe, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] if fp_init: test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: test_cmd += ["--sharding-dims", str(sharding_dims[0])] elif len(sharding_dims) == 2: test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False + test_cmd += ["--recipe", recipe] + test_cmd += ["--layer-type", layer_type] + result = subprocess.run(test_cmd, env=os.environ, check=True) @@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims): @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -def test_distributed(fp8_init, sharding_dims): +@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) +@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) +def test_distributed(fp8_init, sharding_dims, recipe, layer_type): # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") - if fp8_init and not fp8_available: + if recipe == "mx_fp8_block_scaling" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + elif not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims) + _run_test(fp8_init, sharding_dims, recipe, layer_type) def test_dummy() -> None: diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 8c14d5ab7f..2bfcc2e06f 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1885,6 +1885,43 @@ def allreduce( return inp, handle +def _get_module_fsdp_state(module): + """ + If module is an FSDP module, return its _FSDPState. + Otherwise, return the _FSDPState of the closest parent FSDP module + in the module hierarchy the module belongs to. + """ + + if hasattr(module, "_get_fsdp_state"): + # this will return correct fsdp state if module itself is an fsdp module + fsdp_state = module._get_fsdp_state() + elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None: + # See if we have cached the parent fsdp state of the module + fsdp_state = module._te_cached_parent_fsdp_state + else: + from torch.distributed._composable_state import _module_state_mapping + + # Otherwise get the fsdp state of lca of module in the module hierarchy + min_nodes_in_parent = float("inf") + closest_parent_fsdp_mod = None + for fsdp_mod in _module_state_mapping.keys(): + all_submodules = list(fsdp_mod.modules()) + for submodule in all_submodules: + if submodule is module: + if min_nodes_in_parent > len(all_submodules): + closest_parent_fsdp_mod = fsdp_mod + min_nodes_in_parent = len(all_submodules) + if closest_parent_fsdp_mod is None: + raise RuntimeError( + "Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules." + ) + fsdp_state = closest_parent_fsdp_mod._get_fsdp_state() + # Cache the parent fsdp state of the module to avoid recomputing + # the closest parent fsdp module. + module._te_cached_parent_fsdp_state = fsdp_state + return fsdp_state + + def _fsdp_scatter_tensors( fsdp_group: dist_group_type, *tensors: torch.Tensor, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9b6ca9d9cd..041cc396cf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -1244,7 +1245,12 @@ def register_parameter(self, name, param, **kwargs): metedata used in deferred initialization. """ super().register_parameter(name, param) - self.param_init_meta[name] = _ParameterInitMeta(**kwargs) + # Initialize param_init_meta exactly once during the init. FSDP2 can call + # register parameter again to change parameters to DTensors. And it calls + # it without custom fp8 specific kwargs that we need. And so we dont want + # to reset/loose our fp8 init attributes. + if hasattr(self, "param_init_meta") and name not in self.param_init_meta: + self.param_init_meta[name] = _ParameterInitMeta(**kwargs) def reset_parameters(self, defer_init: Optional[bool] = False) -> None: """ @@ -1256,10 +1262,22 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: return for name, param in self.named_parameters(recurse=False): + # Check if parameter is a DTensor (FSDP2) or regular tensor + is_dtensor = isinstance(param, DTensor) + dtensor_param = param if is_dtensor else None + # Need to update/quantize local tensor in case of DTensor + param = param._local_tensor if is_dtensor else param # Ensure parameter is on a real device if param.device == torch.device("meta"): param = torch.empty_like(param, device="cuda") - + if is_dtensor: + dtensor_param = DTensor.from_local( + param, + device_mesh=dtensor_param.device_mesh, + placements=dtensor_param.placements, + shape=dtensor_param.size(), + stride=dtensor_param.stride(), + ) # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker @@ -1288,7 +1306,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - + if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + device_mesh = dtensor_param.device_mesh + amax_reduction_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + quantizer.amax_reduction_group = amax_reduction_group + quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) @@ -1296,7 +1322,11 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - param = torch.nn.Parameter(param) + if is_dtensor: + dtensor_param._local_tensor = param + dtensor_param = torch.nn.Parameter(dtensor_param) + else: + param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed if high_precision_init_val is not None: @@ -1324,8 +1354,12 @@ def clear(self): param._high_precision_init_val = high_precision_init_val param.get_high_precision_init_val = MethodType(get, param) param.clear_high_precision_init_val = MethodType(clear, param) + # Update the parameter based on its type - setattr(self, name, param) + if not is_dtensor: + setattr(self, name, param) + else: + setattr(self, name, dtensor_param) @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4d6b2f23b9..59dc2b2997 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -108,9 +108,15 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - if weight_quantizers[0] is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizers[0] is not None and not isinstance( + weights[0], QuantizedTensorStorage + ): for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(weights[0], QuantizedTensorStorage): + # If weights are already quantized, no need to set quantizer states + weight_quantizers = [weight._quantizer for weight in weights] if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -205,10 +211,6 @@ def forward( inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms - if inp.requires_grad: - for weight in weights_fp8: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) if cpu_offloading: ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") @@ -354,13 +356,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - - for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorStorage): - weight.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 933c7cde53..d0ef0e1c9a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -276,12 +276,15 @@ def forward( # Prepare weight tensor # ------------------------------------------------------ weightmat = weight - quantized_weight = False + is_weight_param_quantized = False if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorStorage) + is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) # Configure quantizer - if weight_quantizer is not None: + # If weight is already quantized, no need to set quantizer states + if is_weight_param_quantized: + weight_quantizer = weight._quantizer + elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight @@ -413,10 +416,6 @@ def forward( ): ln_out.update_usage(rowwise_usage=False) - # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorStorage): - weightmat.update_usage(columnwise_usage=True) - if cpu_offloading: mark_activation_offload(inputmat, mu, rsigma, ln_out) @@ -429,7 +428,7 @@ def forward( fsdp_group, mu, rsigma, - weightmat if quantized_weight else None, + weightmat if is_weight_param_quantized else None, ln_out if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -459,7 +458,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.requires_dgrad = inp_requires_grad ctx.requires_wgrad = weight.requires_grad - ctx.quantized_weight = quantized_weight + ctx.is_weight_param_quantized = is_weight_param_quantized if fuse_wgrad_accumulation and weight.requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates @@ -563,7 +562,7 @@ def backward( ctx.fsdp_shapes, mu, rsigma, - weight if ctx.fp8 and ctx.quantized_weight else None, + weight if ctx.fp8 and ctx.is_weight_param_quantized else None, ln_out, ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 889f545c1e..a358ae7ddf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -351,8 +351,17 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + # No need to set the quantizer states if weights are already quantized + if isinstance(fc1_weight, QuantizedTensorStorage): + fc1_weight_quantizer = fc1_weight._quantizer + elif fc1_weight_quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + + if isinstance(fc2_weight, QuantizedTensorStorage): + fc2_weight_quantizer = fc2_weight._quantizer + elif fc2_weight_quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -538,13 +547,6 @@ def forward( # Cache state for backward pass if is_grad_enabled: - - # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(fc1_weight_final, QuantizedTensorStorage): - fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorStorage): - fc2_weight_final.update_usage(columnwise_usage=True) - if cpu_offloading: mark_activation_offload( inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ccb84e6642..0e2310a5a2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -240,7 +240,8 @@ def forward( weightmat = weight if fp8 or debug: # Configure quantizer - if weight_quantizer is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: columnwise_usage = ( @@ -248,7 +249,9 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - + elif isinstance(weight, QuantizedTensor): + # If weight is already quantized, no need to set quantizer states + weight_quantizer = weight._quantizer # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -389,11 +392,6 @@ def forward( if backward_needs_input: saved_inputmat = inputmat - # Weight with column-wise usage is needed for dgrad GEMM. - if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorStorage): - weightmat.update_usage(columnwise_usage=True) - if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 15f5b6bd5e..7d49e3964f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -433,6 +433,10 @@ def maybe_update_inplace(arg, new_arg, schema_arg): and schema_arg.alias_info.is_write ): arg.quantize_(new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + # Recursively handle update for lists of tensors + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) # In-place op: dequantize, perform op, and quantize if func._schema.is_mutable: @@ -489,20 +493,16 @@ def make_like( shape: Optional[Iterable[int]] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, - data: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. + data. This function is intended to create view of tensors. """ - if shape is None: - shape = data.shape if data is not None else tensor.shape + shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() - if data is not None: - kwargs["data"] = data return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index de112bb3fd..24e814bf69 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,10 +4,10 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Union +from typing import Any, Optional, Tuple, Iterable, Union import warnings - import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -299,14 +299,12 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) - # Construct FP8 tensor return Float8Tensor( shape=shape, @@ -534,9 +532,36 @@ def remove_caches(self) -> None: self._transpose = None @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + data_transpose: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. - # View op + """ + if shape is None and data is not None: + shape = data.shape + new_tensor = super().make_like( + tensor, shape=shape, dtype=dtype, requires_grad=requires_grad + ) + if data is not None: + new_tensor._data = data + if data_transpose is not None: + new_tensor._transpose = data_transpose + new_tensor._transpose_invalid = False + return new_tensor + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == aten.view.default: tensor = args[0] data = tensor._data @@ -555,6 +580,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): or out_transpose_shape[1:] != out_shape[:-1] ): out_transpose = None + else: + view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -587,11 +615,37 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [ - Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) - for split_tensor in func_out + t_func_out = [None] * len(func_out) + # Compute corresponding split of the transpose cache if available + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + ndim = data.dim() + # Figure out the original split dim + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + # Dimension along which transpose needs to be split + t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1 + t_func_out = transpose.__torch_dispatch__( + func, + types, + [transpose, args[1], t_dim], + kwargs, + ) + outs = [ + Float8Tensor.make_like( + tensor, + data=split_tensor, + data_transpose=split_transpose_tensor, + shape=split_tensor.shape, + ) + for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) ] + return outs + if func == aten.new_zeros.default: + # create fresh new tensor with zeros. tensor = args[0] data = tensor._data func_out = data.__torch_dispatch__( @@ -600,17 +654,63 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + t_shape = [size[-1]] + list(size[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_shape] + list(args[2:]), + kwargs, + ) + # deep copy the scale inverse tensor and quantizer as well. + scale_inv = tensor._scale_inv.detach().clone() + quantizer = tensor._quantizer.copy() + out_tensor = Float8Tensor( + data=func_out, + shape=func_out.shape, + dtype=tensor.dtype, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=scale_inv, + data_transpose=func_transposed_out, + quantizer=quantizer, + ) + return out_tensor + if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data + # Apply as_strided to the primary uint8 data func_out = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + stride = args[2] + if "storage_offset" in kwargs: + storage_offset = kwargs["storage_offset"] + else: + storage_offset = args[3] if len(args) > 3 else 0 + # Shape and strided needed for transpose matrix + t_size = [size[-1]] + list(size[:-1]) + t_stride = [stride[-1]] + list(stride[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_size, t_stride, storage_offset] + list(args[4:]), + kwargs, + ) + return Float8Tensor.make_like( + tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + ) + if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: @@ -632,9 +732,108 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) else: pass - return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()) + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this FP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered.(In this case uint8 data tensor) + metadata: Tuple[Any]: Metadata needed for reconstructing the + Float8Tensor after all-gather. + """ + # pylint: disable=unused-argument + # Importing here to avoid circular imports + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None: + # When sharded weight is updated after reduce scattering the gradients in FSDP2, + # we need to do amax reduction across the mesh to make sure all weight shards are + # updated with same scale inverse. Setting the state below in the quantizer will make + # sure that updated Quantized weight tensor have same scale inverse across all shards. + self._quantizer.amax_reduction_group = mesh.get_group() + self._quantizer.with_amax_reduction = True + # Allgathered weights might only need one of data or transpose based on + # L40/Hopper based on forward or backward pass in fsdp state. + quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass. Otherwise, weights allgathered in forward + # are used in backward and pre/post allgather methods wont be called again in backward pass. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # In case of hopper/L40, only one of data/transpose is needed + # based on forward or backward pass. + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = (self._data,) + metadata = (self._scale_inv, self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Float8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor. + param_dtype (torch.dtype): high precision dtype of the Float8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + + Returns: + Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors + used by the Float8Tensor that was being computed after allgather. + """ + + (data,) = all_gather_outputs + (fp8_scale_inv, fp8_dtype, quantizer) = metadata + orig_shape = data.size() + # Quantizer has only columnwise usage set for backward pass + # In Blackwell+ architectures, transpose is not needed at all, + # even if columnwise usage is set. and is going to be handled + # internally in the update_usage method. + if out is not None: + out._data = data + out.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + else: + fp8_args = { + "shape": orig_shape, + "dtype": param_dtype, + "fp8_scale_inv": fp8_scale_inv, + "fp8_dtype": fp8_dtype, + "quantizer": quantizer, + "requires_grad": False, + "data": data, + } + out = Float8Tensor(**fp8_args) + out.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -752,6 +951,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + view_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -796,6 +998,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.reshape(*reshape_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5ef5708fdb..458b87320a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,16 +6,17 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any +import warnings import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple - from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc @@ -298,7 +299,6 @@ def contiguous( memory_format: torch.memory_format = torch.contiguous_format, ) -> MXFP8Tensor: """Returns tensor with data in provided memory format - Returns `self` if data is already in correct memory format. """ @@ -314,7 +314,6 @@ def contiguous( @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # View op if func == aten.view.default: tensor = args[0] @@ -338,9 +337,320 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, ) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): + if src._rowwise_data is not None and dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data.detach()) + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) + if src._columnwise_data is not None and dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data.detach()) + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) + return dst + + # FSDP2 related functions. + if func == aten.split.Tensor: + # This is called if entire model is initialized on CUDA device and + # then splitted. Finally the shard needed by the process is used + # and other splitted shards are discarded. + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + tensor = args[0] + split_size = args[1] + dim0_size = tensor.size(0) + dimlast_size = math.prod(tensor.shape[1:]) + if ( + dim0_size % split_size != 0 + or dim_to_split != 0 + or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 + or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + # Handle splitting by dequantizing and splitting the hp tensor + return super().__torch_dispatch__(func, types, args, kwargs) + + out_data = [] + for data in [tensor._rowwise_data, tensor._columnwise_data]: + func_out = ( + data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if data is not None + else None + ) + out_data.append(func_out) + + scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv] + split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE] + # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 + padding_multiples = [128, 4] + for scale_inv, scale_split_size, pad_multiple in zip( + scale_invs, split_sizes_for_scale, padding_multiples + ): + scale_inv_out = ( + scale_inv.__torch_dispatch__( + func, + types, + [scale_inv, scale_split_size] + list(args[2:]), + kwargs, + ) + if scale_inv is not None + else None + ) + # Pad scale_inv_out to be a multiple of pad_multiple + if scale_inv_out is not None: + current_shape = scale_inv_out.shape + pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple + if pad_dim0 > 0: + scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0)) + + out_data.append(scale_inv_out) + return [ + MXFP8Tensor( + shape=splitted_tensor_data[0].size(), + dtype=tensor.dtype, + rowwise_data=splitted_tensor_data[0], + rowwise_scale_inv=splitted_tensor_data[2], + columnwise_data=splitted_tensor_data[1], + columnwise_scale_inv=splitted_tensor_data[3], + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + for splitted_tensor_data in zip(*out_data) + ] + if func == torch.ops.aten.as_strided.default: + # Applied on unsharded param in FSDP2. In our case, this should be a no-op + # This is needed for the case where some MXFP8 shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision. + # If weight doesnt need padding, this is just a no-op. + shape = args[1] + strides = args[2] + tensor = args[0] + if ( + len(shape) != 2 + or len(strides) != 2 + or strides[1] != 1 + or shape[0] != tensor.shape[0] + or shape[1] != tensor.shape[1] + ): + return super().__torch_dispatch__(func, types, args, kwargs) + + return MXFP8Tensor.make_like(tensor) + + if func == aten.slice.Tensor: + # FSDP2 needed function. + # We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision instead. + # If sharded weight doesnt have padding, this is just a no-op. + dim = args[1] + start = args[2] + length = args[3] + tensor = args[0] + if ( + dim != 0 + or length != tensor.shape[0] + or start != 0 + or length % MXFP8_BLOCK_SCALING_SIZE != 0 + or start % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + return MXFP8Tensor.make_like(tensor) + + if func == aten.new_zeros.default: + rowwise_data = None + columnwise_data = None + rowwise_scale_inv = None + columnwise_scale_inv = None + tensor = args[0] + shape = args[1] + first_dim = math.prod(shape[:-1]) + last_dim = shape[-1] + if ( + first_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE] + columnwise_scale_inv_shape = [ + first_dim // MXFP8_BLOCK_SCALING_SIZE, + last_dim, + ] + if tensor._rowwise_data is not None: + rowwise_data = tensor._rowwise_data.__torch_dispatch__( + func, + types, + [tensor._rowwise_data] + list(args[1:]), + kwargs, + ) + rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + if tensor._columnwise_data is not None: + columnwise_data = tensor._columnwise_data.__torch_dispatch__( + func, + types, + [tensor._columnwise_data] + list(args[1:]), + kwargs, + ) + columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + return MXFP8Tensor( + shape=args[1], + dtype=tensor.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=tensor._quantizer.copy(), + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()). + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this MXFP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered. + metadata: Tuple[Any]: Metadata needed for reconstructing the + MXFP8Tensor after all-gather. + """ + # pylint: disable=unused-argument + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + quantizer = self._quantizer.copy() + # Remove padding from scale inverses before allgather + # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] + rowwise_scale_inv = self._rowwise_scale_inv + columnwise_scale_inv = self._columnwise_scale_inv + shape = self.shape + if rowwise_scale_inv is not None: + # Remove padding from rowwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) + if rowwise_scale_inv.size(0) != flattened_in_shape0: + rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] + + if columnwise_scale_inv is not None: + # Remove padding from columnwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE + if columnwise_scale_inv.size(0) != flattened_in_shape0: + columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] + + sharded_tensors = (self._rowwise_data, rowwise_scale_inv) + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass. If weights are not resharded after forward pass, + # weights allgathered in forward are used in backward and pre/post allgather methods wont be called. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # Allgather only the necessary tensors based on forward/backward pass + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = ( + (self._columnwise_data, columnwise_scale_inv) + if is_backward_pass + else sharded_tensors + ) + else: + if quantizer.columnwise_usage: + # If weights are not resharded after forward, then both + # rowwise and columnwise data/scale_inv need to be allgathered. + sharded_tensors += (self._columnwise_data, columnwise_scale_inv) + metadata = (self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[MXFP8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor. + param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + Returns: + Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors + used by the MXFP8Tensor that was being computed after allgather. + """ + fp8_dtype, quantizer = metadata + rowwise_data, rowwise_scale_inv = ( + all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None) + ) + columnwise_data, columnwise_scale_inv = ( + all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None) + ) + + # Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise + if rowwise_scale_inv is not None: + # Pad rowwise_scale_inv to be a multiple of [128, 4] + current_shape = rowwise_scale_inv.shape + pad_dim0 = (128 - current_shape[0] % 128) % 128 + if pad_dim0 > 0: + rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0)) + + if columnwise_scale_inv is not None: + # Pad columnwise_scale_inv to be a multiple of [4, 128] + current_shape = columnwise_scale_inv.shape + pad_dim0 = (4 - current_shape[0] % 4) % 4 + if pad_dim0 > 0: + columnwise_scale_inv = torch.nn.functional.pad( + columnwise_scale_inv, (0, 0, 0, pad_dim0) + ) + + if out is not None: + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = columnwise_data + out._columnwise_scale_inv = columnwise_scale_inv + out._quantizer = quantizer + else: + out = MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=fp8_dtype, + dtype=param_dtype, + shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + quantizer=quantizer, + ) + + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -478,10 +788,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != ctx.shape[-1]: - raise RuntimeError( - "MXFP8Tensor does not support reshaping inner dimension " + warnings.warn( + "MXFP8Tensor does not support reshaping inner dimension. " f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + "If you are using this for FSDP2 without compiled_autograd_enabled," + "then ignore this warning. Since this view is not going to be used anywhere. ", + stacklevel=2, ) + return tensor.dequantize().view(*shape) # Construct new tensor if shape is provided new_rowwise_data = None