Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,5 @@

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
print("before \n")
result = pipe(prompt).images[0]
print("after ")
result.save(f"result_{distributed_state.process_index}.png")
54 changes: 50 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import operator
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -217,21 +217,67 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def parse_cat_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Tuple[List[Any], int]:
"""
Process inputs for torch.ops.aten.cat.default.

Handles these valid patterns:
1. args = ((t1, t2, ...), dim)
2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs

Returns:
(input_tensors, dim)
input_tensors: tuple of tensor arguments
dim: integer concatenation dimension (default 0)
"""

if len(args) > 1 and isinstance(args[0], (list, tuple)):
input_tensors = list(args[0])
dim = args_bounds_check(args, 1, 0)

else:
# If single arg is itself a tuple/list, unwrap it
if len(args) == 1 and isinstance(args[0], (list, tuple)):
input_tensors = list(args[0])
else:
input_tensors = list(args)

dim = kwargs.get("dim", 0)

return input_tensors, dim


def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont really understand this condition. So if we have a TRT ITensor that has a 0 in any dimension then we should break the graph? I dont think at validation time any of these ITensors will be available. Since validation is run prior to paritioning

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking for empty PyTorch tensors?

Copy link
Collaborator Author

@apbose apbose Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ideally. The validation would be based on the ITensor shape. Yes should use the meta data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then this won't distinguish between ITensor and torch Tensor case.

# empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
inputs, _ = parse_cat_args(node.args, node.kwargs)
for each_input in inputs:
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.cat.default,
supports_dynamic_shapes=True,
capability_validator=cat_validator,
)
def aten_ops_cat(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
inputs, dim = parse_cat_args(args, kwargs)
return impl.cat.cat(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args_bounds_check(args, 1, 0),
input=inputs,
dim=dim,
)


Expand Down
10 changes: 10 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Sequence, Union

import numpy as np
Expand All @@ -15,6 +16,8 @@
set_layer_name,
)

logger = logging.getLogger(__name__)


def cat(
ctx: ConversionContext,
Expand All @@ -27,6 +30,13 @@ def cat(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
trt_inputs = []
for i, each_input in enumerate(input):
if isinstance(each_input, torch.Tensor) and each_input.numel() == 0:
logger.warning(
f"Warning: empty tensor in cat input {i}, replacing with zeros"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this warning much more specific? Print information like the current node, if you can where in the graph it comes from etc. Because users will not understand what you mean by this. Also where is the replacing with zeros?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if this is caught by the validator then should this be an error? Will conversion fail or can we just ignore it?

Copy link
Collaborator Author

@apbose apbose Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out the error. I was earlier replacing with zeros, but later changed to continue since replacing with zeros is not required. I will change the warning comment.

The difference between this and the validator is that, if the empty tensor is a torch.Tensor, we can handle it in the converter.

Whereas if the empty tensor comes as an ITensor input to the converter, TensorRT complains. (I was trying to implement it earlier via replacing it with zeros, but that still leads to the error [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. To point the difference,

This will pass

def test_cat_with_empty_tensor(self, _, dim):
       # Handle empty tensor in concat
       class Cat(nn.Module):
           def forward(self, x):
               y = torch.empty(0, 2, 3, device="cuda")
               return torch.ops.aten.cat.default((x, y), dim)

       inputs = [
           torch.randn(1, 2, 3, device="cuda"),
       ]
       self.run_test(Cat(), inputs)

This will fail

 def test_cat_with_empty_tensor(self, _, dim):
        # Handle empty tensor in concat
        class Cat(nn.Module):
            def forward(self, x, y):
                return torch.ops.aten.cat.default((x, y), dim)

        inputs = [
            torch.randn(1, 2, 3, device="cuda"),
            y = torch.empty(0, 2, 3, device="cuda")
        ]
        self.run_test(Cat(), inputs)

)
# ITensor with same condition leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
# hence the validator
continue
if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
if cast_dtype:
Expand Down
71 changes: 71 additions & 0 deletions tests/py/dynamo/conversion/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,77 @@ def forward(self, x, y, z):
inputs,
)

@parameterized.expand(
[
("pos", 1),
("neg", -2),
]
)
def test_cat_dim_in_kwargs(self, _, dim):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.ops.aten.cat.default((x, y, z), dim=dim)

inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
self.run_test(
Cat(),
inputs,
)

@parameterized.expand(
[
("pos", 0),
("neg", -3),
]
)
def test_cat_with_scalar_inputs(self, _, dim):
# Ensure scalar tensor wrap works
class Cat(nn.Module):
def forward(self, x, y):
# y is a scalar, x is a tensor
return torch.ops.aten.cat.default((x, y), dim)

x = torch.randn(1, 2, 3, device="cuda")
y = torch.ones_like(x) * 5.0 # simulate scalar broadcast
inputs = [x, y]
self.run_test(Cat(), inputs)

@parameterized.expand(
[
("pos", 0),
("neg", -3),
]
)
def test_cat_with_empty_tensor(self, _, dim):
# Handle empty tensor in concat
class Cat(nn.Module):
def forward(self, x):
y = torch.empty(0, 2, 3, device="cuda")
return torch.ops.aten.cat.default((x, y), dim)

inputs = [
torch.randn(1, 2, 3, device="cuda"),
]
self.run_test(Cat(), inputs)

@parameterized.expand(
[
("pos", 2),
("neg", -1),
]
)
def test_cat_with_different_dtypes(self, _, dim):
# check dtype promotion path in concat
class Cat(nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat.default((x, y), dim)

inputs = [
torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"),
torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"),
]
self.run_test(Cat(), inputs)

@parameterized.expand(
[
("pos", 1),
Expand Down
Loading