Skip to content

Commit 9403b0f

Browse files
committed
changing function name
1 parent 3fcf398 commit 9403b0f

File tree

3 files changed

+4
-89
lines changed

3 files changed

+4
-89
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717

1818

19-
def unify_trt_tensors(
19+
def unify_and_concat_trt_tensors(
2020
ctx: ConversionContext,
2121
target: Target,
2222
name: str,
@@ -115,7 +115,7 @@ def cat(
115115
trt_promoted_type = None
116116

117117
dim = get_positive_dim(dim, len(trt_inputs[0].shape))
118-
return unify_trt_tensors(
118+
return unify_and_concat_trt_tensors(
119119
ctx,
120120
target,
121121
name,

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Sequence, Tuple, Union
3+
from typing import List, Optional, Tuple
44

55
import numpy as np
66
import tensorrt as trt
@@ -123,88 +123,3 @@ def get_shape_with_dynamic_shape(
123123
select_layer = ctx.net.add_select(condition_val, input_shape, scale_res)
124124
set_layer_name(select_layer, target, f"{name}_select")
125125
return select_layer.get_output(0)
126-
127-
128-
def to_trt_shape_tensor(
129-
ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor]
130-
) -> TRTTensor:
131-
"""
132-
Convert a mixed shape list (ints + ITensors) into a single ITensor.
133-
134-
Args:
135-
ctx (ConversionContext): TensorRT ConversionContext object.
136-
target (Target): Target of fx node.
137-
name (str): base name for layer naming.
138-
shape_list (list[int | ITensor]): list containing static ints and/or ITensors.
139-
140-
Returns:
141-
ITensor if shape_list contains any ITensors, else plain Python list of ints.
142-
"""
143-
trt_tensors = []
144-
145-
for i, s in enumerate(shape_list):
146-
if isinstance(s, (int, torch.Tensor)):
147-
const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32))
148-
set_layer_name(const, target, f"{name}_dim{i}_const")
149-
trt_tensors.append(const.get_output(0))
150-
else:
151-
trt_tensors.append(s)
152-
153-
if any(not isinstance(s, int) for s in shape_list):
154-
# Concatenate everything into a single ITensor if there are any ITensors/Tensors
155-
concat_layer = ctx.net.add_concatenation(trt_tensors)
156-
concat_layer.axis = 0
157-
set_layer_name(concat_layer, target, f"{name}_shape_concat")
158-
return concat_layer.get_output(0)
159-
160-
# If no ITensor found, return plain list of ints
161-
return shape_list
162-
163-
164-
def collect_and_concat_trt_inputs(
165-
ctx: ConversionContext,
166-
target: Target,
167-
name: str,
168-
inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]],
169-
concat_axis: int = 0,
170-
allow_static_return: bool = False,
171-
) -> Union[TRTTensor, List[int]]:
172-
"""
173-
Normalize a sequence of values into TRT ITensors and concatenate them.
174-
If `allow_static_return=True` and all inputs are ints, return a Python
175-
list of ints instead of creating any TRT layers.
176-
"""
177-
trt_tensors = []
178-
has_dynamic = False
179-
180-
for i, x in enumerate(inputs):
181-
if isinstance(x, TRTTensor):
182-
trt_tensors.append(x)
183-
has_dynamic = True
184-
185-
elif isinstance(x, (int, np.integer)):
186-
# keep raw for now, convert only if dynamic found
187-
trt_tensors.append(int(x))
188-
189-
else:
190-
# torch/np tensor -> TRT tensor
191-
t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}")
192-
trt_tensors.append(t)
193-
has_dynamic = True
194-
195-
# fully static shape case
196-
if not has_dynamic and allow_static_return:
197-
return [int(v) for v in trt_tensors]
198-
199-
# promote remaining ints to TRT constants
200-
for i, v in enumerate(trt_tensors):
201-
if isinstance(v, int):
202-
const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32))
203-
set_layer_name(const, target, f"{name}_static_dim{i}_const")
204-
trt_tensors[i] = const.get_output(0)
205-
206-
# concatenate
207-
concat = ctx.net.add_concatenation(trt_tensors)
208-
concat.axis = concat_axis
209-
set_layer_name(concat, target, f"{name}_concat")
210-
return concat.get_output(0)

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
set_layer_name,
1111
)
1212
from torch_tensorrt.dynamo.conversion.impl.cat import (
13-
unify_trt_tensors as unify_trt_shape_tensors,
13+
unify_and_concat_trt_tensors as unify_trt_shape_tensors,
1414
)
1515
from torch_tensorrt.dynamo.conversion.impl.shape import (
1616
get_shape_with_dynamic_shape,

0 commit comments

Comments
 (0)