-
Notifications
You must be signed in to change notification settings - Fork 368
addressing cat empty tensor case.Fixes gpt2 data distributed example #3866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import logging | ||
| from typing import Optional, Sequence, Union | ||
|
|
||
| import numpy as np | ||
|
|
@@ -15,6 +16,8 @@ | |
| set_layer_name, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def cat( | ||
| ctx: ConversionContext, | ||
|
|
@@ -27,6 +30,13 @@ def cat( | |
| ) -> Union[TRTTensor, Sequence[TRTTensor]]: | ||
| trt_inputs = [] | ||
| for i, each_input in enumerate(input): | ||
| if isinstance(each_input, torch.Tensor) and each_input.numel() == 0: | ||
| logger.warning( | ||
| f"Warning: empty tensor in cat input {i}, replacing with zeros" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make this warning much more specific? Print information like the current node, if you can where in the graph it comes from etc. Because users will not understand what you mean by this. Also where is the replacing with zeros? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also if this is caught by the validator then should this be an error? Will conversion fail or can we just ignore it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing out the error. I was earlier replacing with zeros, but later changed to The difference between this and the validator is that, if the empty tensor is a torch.Tensor, we can handle it in the converter. Whereas if the empty tensor comes as an ITensor input to the converter, TensorRT complains. (I was trying to implement it earlier via replacing it with zeros, but that still leads to the error This will pass This will fail |
||
| ) | ||
| # ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. | ||
| # hence the validator | ||
| continue | ||
| if not isinstance(each_input, TRTTensor): | ||
| each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") | ||
| if cast_dtype: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont really understand this condition. So if we have a TRT ITensor that has a 0 in any dimension then we should break the graph? I dont think at validation time any of these ITensors will be available. Since validation is run prior to paritioning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be checking for empty PyTorch tensors?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes ideally. The validation would be based on the ITensor shape. Yes should use the meta data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then this won't distinguish between ITensor and torch Tensor case.