Skip to content

Commit

Permalink
keep up-to-date with recent pytorch (#64)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 4, 2024
1 parent a65f043 commit 6ca9d8c
Show file tree
Hide file tree
Showing 464 changed files with 22,734 additions and 3,829 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"] # Python 3.11 bug waits for fix https://github.com/thuml/depyf/actions/runs/7004325219/job/19051829613 .
python-version: ["3.9", "3.10", "3.11"] # Python 3.11 bug waits for fix https://github.com/thuml/depyf/actions/runs/7004325219/job/19051829613 .

steps:
- uses: actions/checkout@v3
Expand Down
29 changes: 20 additions & 9 deletions depyf/explain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,27 @@ def __init__(self, original_code, module, cache):

cpp_guard = False

# starting from https://github.com/pytorch/pytorch/pull/138896 ,
# pytorch uses `guard_manager` instead of `check_fn` to store the
# guards
attr_name = "guard_manager" if hasattr(cache, "guard_manager") else "check_fn"

guard_manager = getattr(cache, attr_name)

try:
from torch._dynamo.guards import GuardManager
cpp_guard = isinstance(cache.check_fn, GuardManager)
klass = getattr(torch._dynamo.guards, "GuardManagerWrapper", None) or \
getattr(torch._C._dynamo.guards, "GuardManager", None)
assert klass is not None
cpp_guard = isinstance(guard_manager, klass)
except Exception:
pass

if not cpp_guard:
guard = cache.check_fn.code_parts
freevar_names = cache.check_fn.__code__.co_freevars
freevar_values = [x.cell_contents for x in cache.check_fn.__closure__]
# for old version of pytorch,
# `guard_manager` is a plain python function
guard = guard_manager.code_parts
freevar_names = guard_manager.__code__.co_freevars
freevar_values = [x.cell_contents for x in guard_manager.__closure__]
else:
# keep the logic synced with
# https://github.com/pytorch/pytorch/blob/7b6b10417d8616ebd7a42b06528c5c2b2fded55a/torch/_dynamo/guards.py#L262
Expand All @@ -118,14 +129,14 @@ def visit(root, ans):
for child in root.get_child_managers():
visit(child, ans)
guard = []
root = cache.check_fn.root
root = guard_manager.root
visit(root, guard)
if cache.check_fn.closure_vars is None:
if guard_manager.closure_vars is None:
freevar_names = tuple()
freevar_values = []
else:
freevar_names = tuple(cache.check_fn.closure_vars.keys())
freevar_values = list(cache.check_fn.closure_vars.values())
freevar_names = tuple(guard_manager.closure_vars.keys())
freevar_values = list(guard_manager.closure_vars.values())

self.guard = guard
self.freevars = {name: value for name, value in zip(freevar_names, freevar_values)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@



def forward(self, primals_1: "f32[10]", div: "f32[10]", tangents_1: "f32[10]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
def forward(self, primals_1: "f32[10]", tangents_1: "f32[10]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
neg: "f32[10]" = torch.ops.aten.neg.default(tangents_1)
abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div_2: "f32[10]" = torch.ops.aten.div.Tensor(div, add); div = None
div_1: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
div_2: "f32[10]" = torch.ops.aten.div.Tensor(div_1, add); div_1 = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(neg, div_2); neg = div_2 = None
div_3: "f32[10]" = torch.ops.aten.div.Tensor(tangents_1, add); tangents_1 = add = None
sgn: "f32[10]" = torch.ops.aten.sgn.default(primals_1); primals_1 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(mul, sgn); mul = sgn = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
add_1: "f32[10]" = torch.ops.aten.add.Tensor(div_3, mul_1); div_3 = mul_1 = None
return [add_1, None]
return (add_1, None)

Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ def forward(self, L_a_: "f32[10]", L_b_: "f32[10]"):
l_a_ = L_a_
l_b_ = L_b_

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.abs(l_a_)
add: "f32[10]" = abs_1 + 1; abs_1 = None
x: "f32[10]" = l_a_ / add; l_a_ = add = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
sum_1: "f32[]" = l_b_.sum(); l_b_ = None
lt: "b8[]" = sum_1 < 0; sum_1 = None
return (x, lt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@


def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add); add = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(primals_2); primals_2 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
return [div, lt, primals_1, div]
return (div, lt, primals_1)

Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ def forward(self, primals, tangents):
primals_1: "f32[10]"; primals_2: "f32[10]"; tangents_1: "f32[10]";

primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(primals_2); primals_2 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
neg: "f32[10]" = torch.ops.aten.neg.default(tangents_1)
div_1: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
div_2: "f32[10]" = torch.ops.aten.div.Tensor(div_1, add); div_1 = None
Expand All @@ -24,7 +24,7 @@ def forward(self, primals, tangents):
sgn: "f32[10]" = torch.ops.aten.sgn.default(primals_1); primals_1 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(mul, sgn); mul = sgn = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
add_1: "f32[10]" = torch.ops.aten.add.Tensor(div_3, mul_1); div_3 = mul_1 = None
return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec)

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations



def forward(self, L_a_: "f32[10]", L_b_: "f32[10]"):
l_a_ = L_a_
l_b_ = L_b_

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.abs(l_a_)
add: "f32[10]" = abs_1 + 1; abs_1 = None
x: "f32[10]" = l_a_ / add; l_a_ = add = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
sum_1: "f32[]" = l_b_.sum(); l_b_ = None
lt: "b8[]" = sum_1 < 0; sum_1 = None
return (x, lt)

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations



def forward(self, primals, tangents):
primals_1: "f32[10]"; primals_2: "f32[10]"; tangents_1: "f32[10]";

primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:5 in toy_function, code: if b.sum() < 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(primals_2); primals_2 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
neg: "f32[10]" = torch.ops.aten.neg.default(tangents_1)
div_1: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
div_2: "f32[10]" = torch.ops.aten.div.Tensor(div_1, add); div_1 = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(neg, div_2); neg = div_2 = None
div_3: "f32[10]" = torch.ops.aten.div.Tensor(tangents_1, add); tangents_1 = add = None
sgn: "f32[10]" = torch.ops.aten.sgn.default(primals_1); primals_1 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(mul, sgn); mul = sgn = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:4 in toy_function, code: x = a / (torch.abs(a) + 1)
add_1: "f32[10]" = torch.ops.aten.add.Tensor(div_3, mul_1); div_3 = mul_1 = None
return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec)

Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@



def forward(self, primals_1: "f32[8]", primals_2: "f32[8]", tangents_1: "f32[8]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[s0]", primals_3: "f32[s0]", tangents_1: "f32[s0]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_4: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, primals_3); primals_3 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_1, -1); primals_1 = None
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_2, -1); primals_2 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_3: "f32[8]" = torch.ops.aten.mul.Tensor(tangents_1, mul); tangents_1 = mul = None
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_5: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, mul); tangents_1 = mul = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul_4: "f32[8]" = torch.ops.aten.mul.Tensor(mul_2, -1); mul_2 = None
return [mul_4, mul_3]
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul_6: "f32[s0]" = torch.ops.aten.mul.Tensor(mul_4, -1); mul_4 = None
return (None, mul_6, mul_5)

Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@



def forward(self, L_b_: "f32[8]", L_x_: "f32[8]"):
def forward(self, s0: "Sym(s0)", L_b_: "f32[s0]", L_x_: "f32[s0]"):
l_b_ = L_b_
l_x_ = L_x_

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
b: "f32[8]" = l_b_ * -1; l_b_ = None
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
b: "f32[s0]" = l_b_ * -1; l_b_ = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_1: "f32[8]" = l_x_ * b; l_x_ = b = None
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_1: "f32[s0]" = l_x_ * b; l_x_ = b = None
return (mul_1,)

Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@



def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_1, -1)
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[s0]", primals_3: "f32[s0]"):
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_2, -1)

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(primals_2, mul); mul = None
return [mul_1, primals_1, primals_2]
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_2: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_3, mul); mul = None
return (mul_2, primals_2, primals_3, primals_1)

Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@


def forward(self, primals, tangents):
primals_1: "f32[8]"; primals_2: "f32[8]"; tangents_1: "f32[8]";
primals_1: "Sym(s0)"; primals_2: "f32[s0]"; primals_3: "f32[s0]"; tangents_1: "f32[s0]";

primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_1, -1); primals_1 = None
primals_1, primals_2, primals_3, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_2, -1); primals_2 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(primals_2, mul)
mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); primals_2 = None
mul_3: "f32[8]" = torch.ops.aten.mul.Tensor(tangents_1, mul); tangents_1 = mul = None
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_2: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_3, mul)
mul_4: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, primals_3); primals_3 = None
mul_5: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, mul); tangents_1 = mul = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul_4: "f32[8]" = torch.ops.aten.mul.Tensor(mul_2, -1); mul_2 = None
return pytree.tree_unflatten([mul_1, mul_4, mul_3], self._out_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul_6: "f32[s0]" = torch.ops.aten.mul.Tensor(mul_4, -1); mul_4 = None
return pytree.tree_unflatten([mul_2, None, mul_6, mul_5], self._out_spec)

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations



def forward(self, s0: "Sym(s0)", L_b_: "f32[s0]", s1: "Sym(s0)", L_x_: "f32[s0]"):
l_b_ = L_b_
l_x_ = L_x_

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
b: "f32[s0]" = l_b_ * -1; l_b_ = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_1: "f32[s0]" = l_x_ * b; l_x_ = b = None
return (mul_1,)

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations



def forward(self, primals, tangents):
primals_1: "Sym(s0)"; primals_2: "f32[s0]"; primals_3: "f32[s0]"; tangents_1: "f32[s0]";

primals_1, primals_2, primals_3, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_2, -1); primals_2 = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:7 in torch_dynamo_resume_in_toy_function_at_5, code: return x * b
mul_2: "f32[s0]" = torch.ops.aten.mul.Tensor(primals_3, mul)
mul_4: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, primals_3); primals_3 = None
mul_5: "f32[s0]" = torch.ops.aten.mul.Tensor(tangents_1, mul); tangents_1 = mul = None

# File: /Users/youkaichao/data/DeepLearning/depyf/tests/test_pytorch/test_pytorch.py:6 in torch_dynamo_resume_in_toy_function_at_5, code: b = b * -1
mul_6: "f32[s0]" = torch.ops.aten.mul.Tensor(mul_4, -1); mul_4 = None
return pytree.tree_unflatten([mul_2, None, mul_6, mul_5], self._out_spec)

Loading

0 comments on commit 6ca9d8c

Please sign in to comment.