Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ def extract_compiled_index(s):

# Check all the compilations are as expected
compiled_fns = sorted(glob.glob(
os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")),
Comment on lines 69 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment explaining why __compiled_fn*Forward_graph*.py is now being checked instead of __compiled_fn*Captured*.py. This will help future developers understand the change and its context. It would be useful to add a brief explanation of what the Forward_graph represents.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks gemini, added a comment

key=lambda s: extract_compiled_index(s))

for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn))

breakpoint()
# The first compilation should not have any kv_caches
with open(compiled_fns[0]) as f:
content = f.read()
Expand Down