Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def extract_compiled_index(s):
numbers = [int(part) for part in parts if part.isdigit()]
return numbers[0]

# Check all the compilations are as expected
# Check all the compilations are as expected. The dump files include the
# captured graph for the forward function of the nn.Module.
compiled_fns = sorted(glob.glob(
os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
Comment on lines 69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment explaining why __compiled_fn*Forward_graph*.py is now being checked instead of __compiled_fn*Captured*.py. This will help future developers understand the change and its context. It would be useful to add a brief explanation of what the Forward_graph represents.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks gemini, added a comment

key=lambda s: extract_compiled_index(s))

for i, compiled_fn in enumerate(compiled_fns):
Expand Down