Skip to content

Commit

Permalink
avoid compiling compute_loss (#511)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #511

# Context
As per discussion in D48361308, looks like we can remove this since it's only used for torch dynamo but it's not officially supported

# This diff
Separating the change from the above since we also need to clean up some existing UTs which don't make sense anymore. Going to rebase the above change on top of this

Reviewed By: JKSenthil

Differential Revision: D48581289

fbshipit-source-id: 0fc08554edb2ff726ab5d6249a9b1ce045c5e7d5
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 23, 2023
1 parent 77e380f commit 43555dd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 100 deletions.
97 changes: 2 additions & 95 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_compile_state_dict(self) -> None:
"""
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
self.assertIsNone(my_module._compiled_call_impl)
my_module_state_dict = my_module.state_dict()
auto_unit = DummyAutoUnit(
module=my_module,
Expand All @@ -283,94 +284,7 @@ def test_compile_state_dict(self) -> None:
self.assertTrue(
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_compile_eager(self) -> None:
"""
e2e torch compile test
"""

my_module = torch.nn.Linear(2, 2)

input_dim = 2
dataset_len = 16
batch_size = 2
max_epochs = 1

auto_unit = DummyAutoUnit(
module=my_module,
torch_compile_params=TorchCompileParams(backend="eager"),
)

train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
self.assertFalse(auto_unit._compile_used)
train(auto_unit, train_dl, max_epochs=max_epochs)
self.assertTrue(auto_unit._compile_used)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_compile_train(self) -> None:
"""
e2e torch compile on train
"""

my_module = torch.nn.Linear(2, 2)

input_dim = 2
dataset_len = 16
batch_size = 2
max_epochs = 1

auto_unit = DummyAutoUnit(
module=my_module,
torch_compile_params=TorchCompileParams(backend="inductor"),
)

train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)

self.assertFalse(auto_unit._compile_used)
train(auto_unit, train_dl, max_epochs=max_epochs)
self.assertTrue(auto_unit._compile_used)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_compile_eval(self) -> None:
"""
e2e torch compile on eval
"""

my_module = torch.nn.Linear(2, 2)

input_dim = 2
dataset_len = 16
batch_size = 2

auto_unit = DummyAutoUnit(
module=my_module,
torch_compile_params=TorchCompileParams(backend="inductor"),
)

input_dim = 2
dataset_len = 8
batch_size = 2

eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
self.assertFalse(auto_unit._compile_used)
evaluate(auto_unit, eval_dl)
self.assertTrue(auto_unit._compile_used)
self.assertIsNotNone(auto_unit.module._compiled_call_impl)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
Expand Down Expand Up @@ -983,16 +897,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:

# pyre-fixme[11]: Annotation `Batch` is not defined as a type.
class DummyAutoUnit(AutoUnit[Batch]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compile_used = False

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
if COMPILE_AVAIL:
self._compile_used = torch._dynamo.is_compiling()
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand Down
5 changes: 0 additions & 5 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,6 @@ def __init__(
)

if torch_compile_params:
# pyre-ignore
self.compute_loss = torch.compile(
self.compute_loss,
**asdict(torch_compile_params),
)
try:
# use in-place compile to avoid altering the state_dict keys
module.compile(**asdict(torch_compile_params))
Expand Down

0 comments on commit 43555dd

Please sign in to comment.