Skip to content
Open
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
d64ee35
fix for float8 tensor fsdp2 training
vthumbe1503 Oct 7, 2025
e64f5bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2025
980330c
zeros_like should return fp32 for fsdp2 to work
vthumbe1503 Oct 7, 2025
737852c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 7, 2025
2a3ca77
minor cleanup
vthumbe1503 Oct 7, 2025
65f50af
fix unsharded weights not releasing memory
vthumbe1503 Oct 8, 2025
1360381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
2ad5624
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 8, 2025
09c6848
implement using fsdp preallgather and postallgather functions
vthumbe1503 Oct 13, 2025
322853c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 13, 2025
485cab3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
f5bedbb
FSDP2 works on Hopper/L40
vthumbe1503 Oct 13, 2025
4a72968
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
fddfd21
minor comment
vthumbe1503 Oct 13, 2025
099a92b
fix merge conflict
vthumbe1503 Oct 13, 2025
0b81d32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
ae2bde5
some fixes for fp8 + handwavy changes for mxfp8
vthumbe1503 Oct 17, 2025
c4c4347
only transpose saved for backward pass allgather in case of L40/Hoppe…
vthumbe1503 Oct 17, 2025
eec56da
missed minor change to hopper use-case
vthumbe1503 Oct 17, 2025
79475fe
communicate only required data in mxfp8, fix for updating weight usag…
vthumbe1503 Oct 20, 2025
f881535
changes for meta Dtensors for weights and better all gather data hand…
vthumbe1503 Oct 22, 2025
9d87fb7
better solution to figure out forward pass in FSDP2
vthumbe1503 Oct 22, 2025
40d4dfd
adress review comments
vthumbe1503 Oct 22, 2025
71d9f5d
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Oct 23, 2025
20228c4
resolve merge conflict
vthumbe1503 Oct 24, 2025
67dc86e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
0a6e23c
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 25, 2025
631f01f
everything functioning except hack for transformerlayer
vthumbe1503 Oct 28, 2025
fb49cfe
fix merge conflict
vthumbe1503 Oct 28, 2025
9516213
fix merge conflict
vthumbe1503 Oct 28, 2025
c9dc10c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
8c8ef93
revert change of commit id for cudnnt-frontend
vthumbe1503 Oct 28, 2025
04d7416
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 28, 2025
c97e204
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
dbbbaab
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 28, 2025
0ac3bfa
unnecessary change
vthumbe1503 Oct 28, 2025
2c3b958
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 28, 2025
36a6d01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
44d8c72
minor issues with linting, add some comments
vthumbe1503 Oct 28, 2025
bca48a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a50928d
minor stuff
vthumbe1503 Oct 29, 2025
b5bd42e
revert space removal
vthumbe1503 Oct 29, 2025
2d75310
fix the fsdp state collection issue, and minor review comments addres…
vthumbe1503 Oct 31, 2025
eb6964d
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 31, 2025
c16a844
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
5db9a33
revert change for dgrad redundant computation
vthumbe1503 Oct 31, 2025
e207845
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 31, 2025
32365ce
bug: get fsdp param group's training state instead of root training s…
vthumbe1503 Nov 1, 2025
8c7f375
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
63b517f
address review comments
vthumbe1503 Nov 1, 2025
26d617c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 1, 2025
6f7f037
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 1, 2025
6034f5d
address review comments
vthumbe1503 Nov 1, 2025
cd454e0
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 1, 2025
7269ab3
address coderabbit review comments
vthumbe1503 Nov 1, 2025
044d098
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
ee3c33e
address review comments
vthumbe1503 Nov 1, 2025
7d1726e
xMerge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/Transformer…
vthumbe1503 Nov 1, 2025
bf3546a
address review comments
vthumbe1503 Nov 1, 2025
fb58afe
adress review comments; fix fp8 allgather test to do after fsdp lazy …
vthumbe1503 Nov 2, 2025
7ca1961
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2025
9894d95
address review comments
vthumbe1503 Nov 3, 2025
c6b5b42
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 3, 2025
b661461
remove detach
vthumbe1503 Nov 3, 2025
c99a92b
do what makes sense
vthumbe1503 Nov 3, 2025
0d291d9
Update transformer_engine/pytorch/tensor/float8_tensor.py
vthumbe1503 Nov 3, 2025
5588e96
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
3f4957d
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
3a2b2c1
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
61f4fc8
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
94a8126
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
62ad0aa
address review comments
vthumbe1503 Nov 3, 2025
29db9af
address review comments
vthumbe1503 Nov 3, 2025
16c77c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
023a4f4
address review comments
vthumbe1503 Nov 3, 2025
651ebd8
adress review comments
vthumbe1503 Nov 3, 2025
15793fa
fix merge conflixts
vthumbe1503 Nov 3, 2025
c6e0a08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
4aeb6ca
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 3, 2025
8e09624
address review comments
vthumbe1503 Nov 3, 2025
207c2eb
have better dtype for fsdp_post_all_gather arguments
vthumbe1503 Nov 3, 2025
948ab91
minor comment
vthumbe1503 Nov 3, 2025
1f1de13
improve comment
vthumbe1503 Nov 3, 2025
acec1b5
fix the error in CI
vthumbe1503 Nov 4, 2025
ff20145
minor comment add
vthumbe1503 Nov 4, 2025
2593a8a
accidentally removed view function
vthumbe1503 Nov 4, 2025
4cde1c3
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 4, 2025
892024e
fix minor bug for h100
vthumbe1503 Nov 4, 2025
205be41
minor addition
vthumbe1503 Nov 5, 2025
0300a5f
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 5, 2025
adbb31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
fa71d84
implement padding removal/addition for allgather
vthumbe1503 Nov 5, 2025
cc41893
fix merge conflict
vthumbe1503 Nov 5, 2025
5a40f2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
2d0b2d5
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
0e849ce
address review comments
vthumbe1503 Nov 6, 2025
019ae7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2025
11f72dd
address review comments
vthumbe1503 Nov 6, 2025
d6a8dbb
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 6, 2025
92f932a
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
a72eedc
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
c69bd00
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 246 additions & 75 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,73 @@
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import (
Format,
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
)

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext

LOCAL_RANK = None

class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x


def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def dist_print(msg):
if LOCAL_RANK == 0:
print(msg)


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input")
parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--recipe",
type=str,
default="mx_fp8_block_scaling",
help="Quantizer type.",
choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"],
)
parser.add_argument(
"--layer-type",
type=str,
default="TransformerLayer",
choices=[
"Linear",
"LayerNormLinear",
"LayerNormMLP",
"MultiheadAttention",
"TransformerLayer",
],
help="Transformer Engine layer type",
)
parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument(
"--device",
type=str,
default="meta",
help="Device to run the model on.",
choices=["cuda", "meta"],
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
Expand All @@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None):
return args


sub_modules_to_wrap = [te.Linear]
## Methods to help initialize the TE model in an FSDP2 setting
## with required configurations based on command line args
def get_te_layer_from_string(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]


def get_recipe_from_string(recipe, fp8_format=Format.HYBRID):
if recipe == "delayed_scaling":
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
elif recipe == "current_scaling":
return Float8CurrentScaling(fp8_format=fp8_format)
elif recipe == "mx_fp8_block_scaling":
return MXFP8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unknown quantizer type: {recipe}")


def init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
if config.params_dtype == "float16":
params_dtype = torch.float16
elif config.params_dtype == "bfloat16":
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
kwargs = {
"params_dtype": params_dtype,
}
kwargs["device"] = config.device

layer_type = get_te_layer_from_string(config.layer_type)
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
# more details below.
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# For this case, we are creating a model that resemebles production use-cases
# wherein there are mltiple TransformerLayers in the model. And we would need
# to shard each transformer layer. Since each transformer layer is not a root module,
# FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model.
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
elif layer_type == te.LayerNormLinear:
# For this case, we are creating a model with just one LayerNormLinear layer
# so that the model itself is a root module, and FSDP2's fully_shard assigns
# reshard_after_forward=True for the parameters of these model.
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
model = layer_type(*args, **kwargs)
else:
model = layer_type(*args, **kwargs)

return model, inp_shape, out_shape


def get_device_mesh(world_size, sharding_dims):
dist_print(f"sharding-dims:{sharding_dims}")
device_ids = list(range(world_size))
if sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 1:
assert sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(sharding_dims) == 2: # HSDP
assert sharding_dims[0] * sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(sharding_dims[0], sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
return mesh


def shard_model_with_fsdp2(model, mesh):
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
return model


#### Methods to save the custom attributes of QuantizedTensors before sharding
#### them with FSDP2, and restore them after sharding.
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)


@torch.no_grad()
def test_fp8_fsdp2_allgather(model):
# Do manual allgather in fp32 and match against fp8 allgather done
# with fsdp2
# FP32 manual weight allgather
fp32_allgathered_params = {}
for name, param in model.named_parameters():
assert isinstance(param, DTensor)
local_tensor = param._local_tensor
device_mesh = param.device_mesh
dist_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
# Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch
# for local_tensor will go down the dequantization route.
gathered_tensor = [
torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group))
]
dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group)
full_tensor = torch.cat(gathered_tensor, dim=0)
fp32_allgathered_params[name] = full_tensor
# FP8 allgather using FSDP2
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "unshard"):
module.unshard()
# Make sure allgathered parameters match exactly
for name, param in model.named_parameters():
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
# Revert model to original sharded state
for module in model.modules():
# Not all modules are wrapped/sharded with FSDP2.
if hasattr(module, "reshard"):
module.reshard()


def _train(args):
global LOCAL_RANK
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -103,74 +279,69 @@ def _train(args):

# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

# Create build context manager
if args.fp8_init:
from transformer_engine.pytorch import quantized_model_init
fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)

build_model_context = quantized_model_init()
build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
else:
build_model_context = nullcontext()
from transformer_engine.pytorch import fp8_model_init

# Build the model with the specified context
with build_model_context:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
build_model_context_args["recipe"] = fp8_recipe

# Move the model to the correct device
model.to(device)
dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB")
# Create the model on the meta/cuda device as per args
with build_model_context(**build_model_context_args):
model, inp_shape, out_shape = init_te_model(args)
dist_print(
f"Memory after model init on device {args.device}:"
f" {torch.cuda.memory_allocated(device)/1e6} MB"
)

if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False

# Apply FSDP/HSDP
mesh = get_device_mesh(world_size, args.sharding_dims)
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
model = shard_model_with_fsdp2(model, mesh)
restore_custom_attrs(model, custom_attrs)
# model now has DTensors as its parameters

if args.device == "meta":
# After FSDP2 has been applied, materialize and initialize the sharded parameters
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
dist_print(f" Sharded parameters materialized and initialized on cuda device.")

dist_print(
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
input_data = torch.randn(inp_shape).to(device)
with te.autocast(enabled=True, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
target = torch.randn(out_shape).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
dist_print(f"Iteration {iteration} completed with loss {loss.item()}")

# Some of the FSDP states are lazy initialized during FSDP forward pass
# so testing fp8 allgather at the end of the training loop.
if args.fp8_init:
test_fp8_fsdp2_allgather(model)

dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0


Expand Down
Loading
Loading