Skip to content

Commit

Permalink
fix torch.compile with no cl.
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Nov 4, 2024
1 parent 5e6f24a commit 36d2af6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

debug.py
14 changes: 14 additions & 0 deletions tests/module_pool/simple_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import torch.nn as nn


class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
12 changes: 1 addition & 11 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,7 @@
import torch.nn as nn
import torch.optim as optim


class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
from tests.module_pool.simple_nn import SimpleNN


def test_forward_pass():
Expand Down
36 changes: 36 additions & 0 deletions tests/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.optim as optim

from tests.module_pool.simple_nn import SimpleNN


def test_torch_compile_forward():
model = SimpleNN()
compiled_model = torch.compile(model, backend="aot_eager")

input_data = torch.randn(1, 10)
expected_output = model(input_data)
actual_output = compiled_model(input_data)
assert torch.allclose(expected_output, actual_output), "Output mismatch between compiled and original model"


def test_torch_compile_backward():
model = SimpleNN()
compiled_model = torch.compile(model, backend="aot_eager")

criterion = nn.MSELoss()
optimizer = optim.SGD(compiled_model.parameters(), lr=0.01)

input_data = torch.randn(1, 10)
target = torch.tensor([1.0])

output = compiled_model(input_data)
loss = criterion(output, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

for param in compiled_model.parameters():
assert param.grad is not None, "Gradient not computed for parameter in compiled model"

0 comments on commit 36d2af6

Please sign in to comment.