-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Take the interpreter holding references to the arguments into account #1032
Conversation
When rewriting the test, I faced #1043. We previously did not test the backward for split. |
8c6306d
to
ec16a29
Compare
I tried to eliminate the "golden" values for tests, but this seems impossible when nvFuser is involved. For example, on the following test case, nvFuser returns as def bar(a, b): # [4] [2,2]
a_1 = torch.unsqueeze(a, 0) # [1,4]
a_2 = torch.unsqueeze(a_1, 1) # [1,1,4]
a_3 = a_2.expand(2, 3, 4) # [2,3,4]
b_1 = torch.reshape(b, (4,)) # [4]
b_2 = torch.unsqueeze(b_1, 0) # [1,4]
b_3 = torch.unsqueeze(b_2, 1) # [1,1,4]
b_4 = b_3.expand(2, 3, 4) # [2,3,4]
result1 = a_2 + b_3
result2 = b_4 + a_3
return result1, result2 generated tracedef augmented_forward_fn(a, b):
# a: "cuda:0 f32[4]"
# b: "cuda:0 f32[2, 2]"
[t14, t15] = nvFusion0(a, b)
# t0 = prims.broadcast_in_dim(a, [1, 4], [1]) # t0: "cuda:0 f32[1, 4]"
# t1 = prims.broadcast_in_dim(t0, [1, 1, 4], [0, 2]) # t1: "cuda:0 f32[1, 1, 4]"
# t5 = prims.broadcast_in_dim(t1, (2, 3, 4), (0, 1, 2)) # t5: "cuda:0 f32[2, 3, 4]"
# t7 = prims.reshape(b, (4,)) # t7: "cuda:0 f32[4]"
# t8 = prims.broadcast_in_dim(t7, [1, 4], [1]) # t8: "cuda:0 f32[1, 4]"
# t9 = prims.broadcast_in_dim(t8, [1, 1, 4], [0, 2]) # t9: "cuda:0 f32[1, 1, 4]"
# t13 = prims.broadcast_in_dim(t9, (2, 3, 4), (0, 1, 2)) # t13: "cuda:0 f32[2, 3, 4]"
# t14 = prims.add(t1, t9) # t14: "cuda:0 f32[1, 1, 4]"
# t15 = prims.add(t13, t5) # t15: "cuda:0 f32[2, 3, 4]"
return {'output': (t14, t15), 'flat_args': [a, b], 'flat_output': (t14, t15)}, ((), ())
(144, {'unpack_trivial a': 16, 'unpack_trivial b': 16, 'nvFusion0 t14, t15': 112}) Can we skip |
Yes, let's skip nvFuser with a comment explaining why. In the future, we need to update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python has a really interesting behavior with respect to holding references to input variables for the duration of the function call. In order to make get_alloc_memory
more accurately represent Python's memory deallocation we need to take into account the type of containers in which tensors are stored, they could be immutable and then Python would hold a reference to an immutable object and not let us modify its content so tensors will never be freed during the function call, but containers could be mutable and then it's possible to free the tensors if the container itself doesn't hold a reference to tensor anymore. We should take this into consideration in a follow-up.
I'm disappointed that eager doesn't do so by default. 😢 |
As Ivan pointed out, the previous code did not consider def f(x, y, z):
return x * y * z
{'unpack_trivial ': 0, 'unpack_sequence ': 0, 'clear_mutable_collection ': 0, 'unpack_sequence t2': 16, 'unpack_sequence t0, x, y, z': 64, 'mul t8': 16, 'mul t9': 16, 'mul t10': 16, 'mul t11': 16, 'python_del t8': -16}
peak = 144, after = 128 Actual memory usage: peak = 96, after = 64 I slightly changed Now (commit 7c5318c) {'argument t0': 16, 'argument x': 16, 'argument y': 16, 'argument z': 16, 'argument t2': 16, 'mul t8': 16, 'python_del z': -16, 'mul t9': 16, 'python_del t0': -16, 'mul t10': 16, 'python_del y': -16, 'mul t11': 16, 'python_del x': -16, 'python_del t8': -16}
peak = 96, after = 64 |
@t-vi This is ready for review. Thank you! |
Unless I missed it, we would want to add to add the comment. |
Thank you for pointing that out! I added the comment. The current test failure is apparently due to a breaking change in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shino16 @IvanYashchuk @jjsjann123 @crcrpar
Fixes #1029. The tests in
test_examine_memory.py
uses some magic numbers made with an incorrect assumption, and I am yet to update them.