Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import math
import sys
from dataclasses import dataclass, replace
from enum import Enum, auto
from typing import Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch._prims_common import make_contiguous_strides_for
from torch.distributed.device_mesh import DeviceMesh
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh

from torchao.utils import torch_version_at_least

aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
if torch.distributed.is_available():
c10d_functional = torch.ops.c10d_functional


def nf4_all_gather_into_tensor(func, *args, **kwargs):
Expand Down Expand Up @@ -63,10 +66,10 @@ def scatter_nf4tensor(func, *args, **kwargs):
return new_attr, update_work


NF4_OPS_TABLE: Dict[Any, Any] = {
torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor,
torch.ops.c10d.scatter_.default: scatter_nf4tensor,
}
NF4_OPS_TABLE: Dict[Any, Any] = {}
if torch.distributed.is_available():
NF4_OPS_TABLE[torch.ops._c10d_functional.all_gather_into_tensor.default] = nf4_all_gather_into_tensor
NF4_OPS_TABLE[torch.ops.c10d.scatter_.default] = scatter_nf4tensor


_INNER_TENSOR_NAMES_FOR_SHARDING = [
Expand Down Expand Up @@ -518,22 +521,24 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
return tensors


@implements(
[
torch.ops._c10d_functional.wait_tensor.default,
]
)
def wait_tensor(func, *args, **kwargs):
nf4tensor = args[0][0]
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
updated_attrs[attr] = func(getattr(nf4tensor, attr))
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
return updatedNF4Tensor
if torch.distributed.is_available():

@implements(
[
torch.ops._c10d_functional.wait_tensor.default,
]
)
def wait_tensor(func, *args, **kwargs):
nf4tensor = args[0][0]
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
updated_attrs[attr] = func(getattr(nf4tensor, attr))
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
return updatedNF4Tensor


# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev and later
if torch_version_at_least("2.11.0.dev"):
if torch_version_at_least("2.11.0.dev") and torch.distributed.is_available():

@implements(
[
Expand Down Expand Up @@ -1199,4 +1204,3 @@ def nf4_constructor(


torch.serialization.add_safe_globals([NF4Tensor])
torch.serialization.add_safe_globals([NF4Tensor])
13 changes: 9 additions & 4 deletions torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

from torchao.float8.float8_training_tensor import Float8TrainingTensor

if torch.distributed.is_available():
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
else:
funcol = None
DTensor = None


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Expand All @@ -18,10 +23,10 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
if isinstance(tensor, Float8TrainingTensor):
return True
elif isinstance(tensor, DTensor):
elif DTensor is not None and isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
elif funcol is not None and isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False
8 changes: 6 additions & 2 deletions torchao/float8/float8_training_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
from typing import Dict, NamedTuple, Optional

import torch
from torch.distributed._tensor import DTensor

from torchao.float8.float8_utils import (
to_fp8_saturated,
)

if torch.distributed.is_available():
from torch.distributed._tensor import DTensor
else:
DTensor = None

aten = torch.ops.aten

#
Expand Down Expand Up @@ -153,7 +157,7 @@ def forward(
tensor_scaled = tensor.to(torch.float32) * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
if DTensor is not None and isinstance(bits_fp8, DTensor):
assert isinstance(scale, DTensor), (
"Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
)
Expand Down
11 changes: 8 additions & 3 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce

from torchao.float8.config import ScalingGranularity

if torch.distributed.is_available():
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce
else:
AsyncCollectiveTensor = None
all_reduce = None

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html

Expand Down Expand Up @@ -71,12 +76,12 @@ def tensor_to_amax(
# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
# happen elsewhere.
if reduce_amax and dist.is_initialized():
if reduce_amax and dist.is_available() and dist.is_initialized():
pg = device_mesh.get_group() if device_mesh is not None else None
# dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)
group = list(range(dist.get_world_size())) if pg is None else pg
amax = all_reduce(amax, "MAX", group)
if isinstance(amax, AsyncCollectiveTensor):
if AsyncCollectiveTensor is not None and isinstance(amax, AsyncCollectiveTensor):
amax = amax.wait()

return amax
Expand Down
9 changes: 6 additions & 3 deletions torchao/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import torch
from torch import Tensor
from torch.distributed._tensor import DTensor
if torch.distributed.is_available():
from torch.distributed._tensor import DTensor
else:
DTensor = None
from torch.optim import Optimizer

from .quant_utils import _fp32_to_bf16_sr
Expand Down Expand Up @@ -69,7 +72,7 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
raise NotImplementedError

def _new_buffer(self, p: Tensor, signed: bool):
local_p = p.to_local() if isinstance(p, DTensor) else p
local_p = p.to_local() if (DTensor is not None and isinstance(p, DTensor)) else p

# follow bitsandbytes, only quantize tensors >= 4096 values
if local_p.numel() >= 4096 and local_p.numel() % self.block_size == 0:
Expand All @@ -81,7 +84,7 @@ def _new_buffer(self, p: Tensor, signed: bool):
# NOTE: local tensor may have different shapes across ranks.
# this happens when the 1st dim is not divisible by WORLD_SIZE.
# thus, we must supply shape (and stride) to DTensor.from_local()
if isinstance(p, DTensor):
if DTensor is not None and isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=out,
device_mesh=p.device_mesh,
Expand Down
7 changes: 5 additions & 2 deletions torchao/optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
from torch.distributed.tensor import DTensor
if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
else:
DTensor = None


# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
Expand Down Expand Up @@ -128,7 +131,7 @@ def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
is_dt = isinstance(_x_f32, DTensor)
is_dt = DTensor is not None and isinstance(_x_f32, DTensor)
x_f32 = _x_f32.to_local() if is_dt else _x_f32

rand_16bit = torch.randint(
Expand Down
32 changes: 18 additions & 14 deletions torchao/optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
)

aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
if torch.distributed.is_available():
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
else:
c10d_functional = None
_c10d_functional = None

# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
# NOTE: power-1 is linear
Expand Down Expand Up @@ -203,18 +207,18 @@ def _(func, types, args, kwargs):


# Build the list of c10d operations to implement
_optim_state_4bit_c10d_ops = [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_4bit_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)
_optim_state_4bit_c10d_ops = [aten.detach.default] # required by torch.distributed.checkpoint.save
if torch.distributed.is_available():
_optim_state_4bit_c10d_ops += [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_4bit_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)


@OptimState4bit.implements(_optim_state_4bit_c10d_ops)
Expand Down
32 changes: 18 additions & 14 deletions torchao/optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
)

aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
if torch.distributed.is_available():
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
else:
c10d_functional = None
_c10d_functional = None

# Lazy initialization to avoid meta device issues during import
from functools import lru_cache
Expand Down Expand Up @@ -175,18 +179,18 @@ def _(func, types, args, kwargs):


# Build the list of c10d operations to implement
_optim_state_8bit_c10d_ops = [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_8bit_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)
_optim_state_8bit_c10d_ops = [aten.detach.default] # required by torch.distributed.checkpoint.save
if torch.distributed.is_available():
_optim_state_8bit_c10d_ops += [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_8bit_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)


@OptimState8bit.implements(_optim_state_8bit_c10d_ops)
Expand Down
32 changes: 18 additions & 14 deletions torchao/optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from torchao.utils import TorchAOBaseTensor, torch_version_at_least

aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
if torch.distributed.is_available():
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
else:
c10d_functional = None
_c10d_functional = None

DTYPE = torch.float8_e4m3fn

Expand Down Expand Up @@ -150,18 +154,18 @@ def _(func, types, args, kwargs):


# Build the list of c10d operations to implement
_optim_state_fp8_c10d_ops = [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_fp8_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)
_optim_state_fp8_c10d_ops = [aten.detach.default] # required by torch.distributed.checkpoint.save
if torch.distributed.is_available():
_optim_state_fp8_c10d_ops += [
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
]
# _wrap_tensor_autograd was added in PyTorch 2.11.0.dev
if torch_version_at_least("2.11.0.dev"):
_optim_state_fp8_c10d_ops.append(_c10d_functional._wrap_tensor_autograd.default)


@OptimStateFp8.implements(_optim_state_fp8_c10d_ops)
Expand Down