- 
                Notifications
    You must be signed in to change notification settings 
- Fork 368
addressing cat empty tensor case.Fixes gpt2 data distributed example #3866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
bee0f1d    to
    88659a1      
    Compare
  
    6e386c1    to
    e6fc22b      
    Compare
  
    | for i, each_input in enumerate(input): | ||
| if isinstance(each_input, torch.Tensor) and each_input.numel() == 0: | ||
| logger.warning( | ||
| f"Warning: empty tensor in cat input {i}, replacing with zeros" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make this warning much more specific? Print information like the current node, if you can where in the graph it comes from etc. Because users will not understand what you mean by this. Also where is the replacing with zeros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if this is caught by the validator then should this be an error? Will conversion fail or can we just ignore it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the error. I was earlier replacing with zeros, but later changed to continue since replacing with zeros is not required. I will change the warning comment.
The difference between this and the validator is that, if the empty tensor is a torch.Tensor, we can handle it in the converter.
Whereas if the empty tensor comes as an ITensor input to the converter, TensorRT complains. (I was trying to implement it earlier via replacing it with zeros, but that still leads to the error [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. To point the difference,
This will pass
def test_cat_with_empty_tensor(self, _, dim):
       # Handle empty tensor in concat
       class Cat(nn.Module):
           def forward(self, x):
               y = torch.empty(0, 2, 3, device="cuda")
               return torch.ops.aten.cat.default((x, y), dim)
       inputs = [
           torch.randn(1, 2, 3, device="cuda"),
       ]
       self.run_test(Cat(), inputs)
This will fail
 def test_cat_with_empty_tensor(self, _, dim):
        # Handle empty tensor in concat
        class Cat(nn.Module):
            def forward(self, x, y):
                return torch.ops.aten.cat.default((x, y), dim)
        inputs = [
            torch.randn(1, 2, 3, device="cuda"),
            y = torch.empty(0, 2, 3, device="cuda")
        ]
        self.run_test(Cat(), inputs)
| return input_tensors, dim | ||
|  | ||
|  | ||
| def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont really understand this condition. So if we have a TRT ITensor that has a 0 in any dimension then we should break the graph? I dont think at validation time any of these ITensors will be available. Since validation is run prior to paritioning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be checking for empty PyTorch tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes ideally. The validation would be based on the ITensor shape. Yes should use the meta data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then this won't distinguish between ITensor and torch Tensor case.
Fixes #3865