Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable float8 CI on sm89 #587

Merged
merged 29 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Run Float8 Tests

on:
push:
branches:
- main
pull_request:
branches:
- main

concurrency:
group: float8_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

jobs:
test:
strategy:
fail-fast: false
matrix:
include:
- name: SM-89
runs-on: amz2023.linux.g6.4xlarge.experimental.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
timeout: 60
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
script: |
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
pip install .
pytest test/float8 --verbose -s
14 changes: 7 additions & 7 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
jainapurva marked this conversation as resolved.
Show resolved Hide resolved

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
Expand Down Expand Up @@ -223,7 +223,7 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_linear(
config,
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_autocast_outputs(
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
emulate = (
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_repr(self):

class TestScaledMM:
@unittest.skipIf(
not is_H100,
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_H100, "CUDA not available")
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_different_configs_error(self):
a @ b

@unittest.skipIf(
not is_H100,
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
16 changes: 8 additions & 8 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch._dynamo.testing import CompileCounterWithBackend

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _test_compile_base(
backend: str,
Expand Down Expand Up @@ -77,7 +77,7 @@ def _test_compile_base(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_eager_only(
Expand All @@ -104,7 +104,7 @@ def test_eager_only(


@pytest.mark.parametrize("fullgraph", [True])
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_aot_eager(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_inductor(
fullgraph,
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_float8_with_graph_break_in_the_middle(self):
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
torch.testing.assert_close(y_eager, y_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""

Expand All @@ -231,7 +231,7 @@ def to_float(x):
)
torch.testing.assert_close(y2_eager, y2_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
cnts = CompileCounterWithBackend("inductor")
Expand All @@ -258,7 +258,7 @@ def test_float8_graph_output(self):
)


@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_sync_amax_func():
torch._dynamo.reset()
cnts = CompileCounterWithBackend("inductor")
Expand Down Expand Up @@ -296,7 +296,7 @@ def __exit__(self, *args):
sys.stderr = self.sys_stderr


@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_sync_amax_func_cuda_graph_success():
torch._dynamo.reset()
with capture_stderr() as stderr:
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode")
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
elif torch.cuda.get_device_capability() < (8, 9):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
f"CUDA capability {torch.cuda.get_device_capability()} < (8.9), running in emulation mode"
)
emulate = True

Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
TransformerBlock,
)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
if not is_H100:
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
if not is_cuda_8_9:
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

class TestFloat8Common:
Expand Down
18 changes: 9 additions & 9 deletions test/float8/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

class FeedForward(nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -65,8 +65,8 @@ def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_dynamic_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -100,8 +100,8 @@ def test_dynamic_fp8_mlp(self, compile_backend, dtype):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_static_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -139,8 +139,8 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_weight_only_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -189,8 +189,8 @@ def train(self, model: nn.Module, dtype: torch.dtype):

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_fp8_save_and_load(self, dtype: torch.dtype):
# Initialize FP8 model
Expand Down
5 changes: 2 additions & 3 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
)
from torchao.float8.float8_utils import compute_error, IS_ROCM

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

torch.manual_seed(0)

Expand Down Expand Up @@ -89,7 +88,7 @@ class TestFloat8NumericsIntegrationTest:
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC],
)
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
@pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine")
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw(
self,
Expand Down
Loading