Skip to content

Commit

Permalink
flux autoencoder unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinpelletier committed Dec 2, 2024
1 parent 0e6611e commit e7504b6
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
5 changes: 5 additions & 0 deletions tests/torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
81 changes: 81 additions & 0 deletions tests/torchtune/models/flux/test_flux_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.flux._autoencoder import FluxAutoencoder
from torchtune.training.seed import set_seed

BSZ = 32
CH_IN = 3
RESOLUTION = 16
CH_MULTS = [1, 2]
CH_Z = 4
RES_Z = RESOLUTION // len(CH_MULTS)


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestFluxAutoencoder:
@pytest.fixture
def model(self):
model = FluxAutoencoder(
resolution=RESOLUTION,
ch_in=CH_IN,
ch_out=3,
ch_base=32,
ch_mults=CH_MULTS,
ch_z=CH_Z,
n_layers_per_resample_block=2,
scale_factor=1.0,
shift_factor=0.0,
)

for param in model.parameters():
param.data.uniform_(0, 0.1)

return model

@pytest.fixture
def img(self):
return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION)

@pytest.fixture
def z(self):
return torch.randn(BSZ, CH_Z, RES_Z, RES_Z)

def test_forward(self, model, img):
actual = model(img)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4286, 0.4276, 0.4054])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, img):
y = model(img)
loss = y.mean()
loss.backward()

def test_encode(self, model, img):
actual = model.encode(img)
assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_decode(self, model, z):
actual = model.decode(z)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4246, 0.4241, 0.4014])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
8 changes: 3 additions & 5 deletions torchtune/models/flux/_convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _convert_key(key: str) -> str:
i += 1
for layer_idx, layer_name in enumerate(["block_1", "attn_1", "block_2"]):
if layer == layer_name:
new_parts.append(str(layer_idx)) # layer name -> idx in the sequence
new_parts.append(str(layer_idx))
if layer_name.startswith("attn"):
_convert_attn_layer(new_parts, parts, i)
else:
Expand All @@ -90,9 +90,8 @@ def _convert_key(key: str) -> str:
elif section == "down":
new_parts.append(parts[i]) # add the down block idx
i += 1
# resnet layers are preceded by "block"
if parts[i] == "block":
new_parts.append("layers") # "block" -> "layers"
new_parts.append("layers")
i += 1
new_parts.append(parts[i]) # add the resnet layer idx
i += 1
Expand All @@ -109,9 +108,8 @@ def _convert_key(key: str) -> str:
# so we need to convert [0, 1, 2, 3] -> [3, 2, 1, 0]
new_parts.append(str(3 - int(parts[i])))
i += 1
# resnet layers are preceded by "block"
if parts[i] == "block":
new_parts.append("layers") # "block" -> "layers"
new_parts.append("layers")
i += 1
new_parts.append(parts[i]) # add the resnet layer idx
i += 1
Expand Down

0 comments on commit e7504b6

Please sign in to comment.