Skip to content

Commit

Permalink
Enable float8 CI on sm89 (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Aug 7, 2024
1 parent f6595ac commit 1cfe69e
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 31 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/float8_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Run Float8 Tests

on:
push:
branches:
- main
pull_request:
branches:
- main

concurrency:
group: float8_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

jobs:
test:
strategy:
fail-fast: false
matrix:
include:
- name: SM-89
runs-on: amz2023.linux.g6.4xlarge.experimental.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
timeout: 60
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
script: |
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
pip install .
pytest test/float8 --verbose -s
14 changes: 7 additions & 7 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
Expand Down Expand Up @@ -223,7 +223,7 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_linear(
config,
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_autocast_outputs(
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
emulate = (
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_repr(self):

class TestScaledMM:
@unittest.skipIf(
not is_H100,
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_H100, "CUDA not available")
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_different_configs_error(self):
a @ b

@unittest.skipIf(
not is_H100,
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
16 changes: 8 additions & 8 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch._dynamo.testing import CompileCounterWithBackend

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _test_compile_base(
backend: str,
Expand Down Expand Up @@ -77,7 +77,7 @@ def _test_compile_base(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_eager_only(
Expand All @@ -104,7 +104,7 @@ def test_eager_only(


@pytest.mark.parametrize("fullgraph", [True])
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_aot_eager(
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC]
)
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_inductor(
fullgraph,
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_float8_with_graph_break_in_the_middle(self):
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
torch.testing.assert_close(y_eager, y_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""

Expand All @@ -231,7 +231,7 @@ def to_float(x):
)
torch.testing.assert_close(y2_eager, y2_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
cnts = CompileCounterWithBackend("inductor")
Expand All @@ -258,7 +258,7 @@ def test_float8_graph_output(self):
)


@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_sync_amax_func():
torch._dynamo.reset()
cnts = CompileCounterWithBackend("inductor")
Expand Down Expand Up @@ -296,7 +296,7 @@ def __exit__(self, *args):
sys.stderr = self.sys_stderr


@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_sync_amax_func_cuda_graph_success():
torch._dynamo.reset()
with capture_stderr() as stderr:
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode")
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
elif torch.cuda.get_device_capability() < (8, 9):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
f"CUDA capability {torch.cuda.get_device_capability()} < (8.9), running in emulation mode"
)
emulate = True

Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
TransformerBlock,
)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
if not is_H100:
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
if not is_cuda_8_9:
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

class TestFloat8Common:
Expand Down
18 changes: 9 additions & 9 deletions test/float8/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

class FeedForward(nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -74,8 +74,8 @@ def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_dynamic_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -109,8 +109,8 @@ def test_dynamic_fp8_mlp(self, compile_backend, dtype):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_static_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -150,8 +150,8 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_weight_only_fp8_mlp(self, compile_backend, dtype):
original_mlp = FeedForward().to("cuda", dtype=dtype)
Expand Down Expand Up @@ -205,8 +205,8 @@ def train(self, model: nn.Module, dtype: torch.dtype):

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA not available or machine does not support SM89",
)
def test_fp8_save_and_load(self, dtype: torch.dtype):
# Initialize FP8 model
Expand Down
5 changes: 2 additions & 3 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
)
from torchao.float8.float8_utils import compute_error, IS_ROCM

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

torch.manual_seed(0)

Expand Down Expand Up @@ -89,7 +88,7 @@ class TestFloat8NumericsIntegrationTest:
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC],
)
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
@pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine")
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw(
self,
Expand Down

0 comments on commit 1cfe69e

Please sign in to comment.