|  | 
| 1 | 1 | from __future__ import annotations | 
| 2 | 2 | 
 | 
| 3 |  | -from typing import List, Optional, Sequence, Tuple, Union | 
|  | 3 | +from typing import List, Optional, Tuple | 
| 4 | 4 | 
 | 
| 5 | 5 | import numpy as np | 
| 6 | 6 | import tensorrt as trt | 
| @@ -123,88 +123,3 @@ def get_shape_with_dynamic_shape( | 
| 123 | 123 |     select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) | 
| 124 | 124 |     set_layer_name(select_layer, target, f"{name}_select") | 
| 125 | 125 |     return select_layer.get_output(0) | 
| 126 |  | - | 
| 127 |  | - | 
| 128 |  | -def to_trt_shape_tensor( | 
| 129 |  | -    ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] | 
| 130 |  | -) -> TRTTensor: | 
| 131 |  | -    """ | 
| 132 |  | -    Convert a mixed shape list (ints + ITensors) into a single ITensor. | 
| 133 |  | -
 | 
| 134 |  | -    Args: | 
| 135 |  | -        ctx (ConversionContext): TensorRT ConversionContext object. | 
| 136 |  | -        target (Target): Target of fx node. | 
| 137 |  | -        name (str): base name for layer naming. | 
| 138 |  | -        shape_list (list[int | ITensor]): list containing static ints and/or ITensors. | 
| 139 |  | -
 | 
| 140 |  | -    Returns: | 
| 141 |  | -        ITensor if shape_list contains any ITensors, else plain Python list of ints. | 
| 142 |  | -    """ | 
| 143 |  | -    trt_tensors = [] | 
| 144 |  | - | 
| 145 |  | -    for i, s in enumerate(shape_list): | 
| 146 |  | -        if isinstance(s, (int, torch.Tensor)): | 
| 147 |  | -            const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) | 
| 148 |  | -            set_layer_name(const, target, f"{name}_dim{i}_const") | 
| 149 |  | -            trt_tensors.append(const.get_output(0)) | 
| 150 |  | -        else: | 
| 151 |  | -            trt_tensors.append(s) | 
| 152 |  | - | 
| 153 |  | -    if any(not isinstance(s, int) for s in shape_list): | 
| 154 |  | -        # Concatenate everything into a single ITensor if there are any ITensors/Tensors | 
| 155 |  | -        concat_layer = ctx.net.add_concatenation(trt_tensors) | 
| 156 |  | -        concat_layer.axis = 0 | 
| 157 |  | -        set_layer_name(concat_layer, target, f"{name}_shape_concat") | 
| 158 |  | -        return concat_layer.get_output(0) | 
| 159 |  | - | 
| 160 |  | -    # If no ITensor found, return plain list of ints | 
| 161 |  | -    return shape_list | 
| 162 |  | - | 
| 163 |  | - | 
| 164 |  | -def collect_and_concat_trt_inputs( | 
| 165 |  | -    ctx: ConversionContext, | 
| 166 |  | -    target: Target, | 
| 167 |  | -    name: str, | 
| 168 |  | -    inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]], | 
| 169 |  | -    concat_axis: int = 0, | 
| 170 |  | -    allow_static_return: bool = False, | 
| 171 |  | -) -> Union[TRTTensor, List[int]]: | 
| 172 |  | -    """ | 
| 173 |  | -    Normalize a sequence of values into TRT ITensors and concatenate them. | 
| 174 |  | -    If `allow_static_return=True` and all inputs are ints, return a Python | 
| 175 |  | -    list of ints instead of creating any TRT layers. | 
| 176 |  | -    """ | 
| 177 |  | -    trt_tensors = [] | 
| 178 |  | -    has_dynamic = False | 
| 179 |  | - | 
| 180 |  | -    for i, x in enumerate(inputs): | 
| 181 |  | -        if isinstance(x, TRTTensor): | 
| 182 |  | -            trt_tensors.append(x) | 
| 183 |  | -            has_dynamic = True | 
| 184 |  | - | 
| 185 |  | -        elif isinstance(x, (int, np.integer)): | 
| 186 |  | -            # keep raw for now, convert only if dynamic found | 
| 187 |  | -            trt_tensors.append(int(x)) | 
| 188 |  | - | 
| 189 |  | -        else: | 
| 190 |  | -            # torch/np tensor -> TRT tensor | 
| 191 |  | -            t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}") | 
| 192 |  | -            trt_tensors.append(t) | 
| 193 |  | -            has_dynamic = True | 
| 194 |  | - | 
| 195 |  | -    # fully static shape case | 
| 196 |  | -    if not has_dynamic and allow_static_return: | 
| 197 |  | -        return [int(v) for v in trt_tensors] | 
| 198 |  | - | 
| 199 |  | -    # promote remaining ints to TRT constants | 
| 200 |  | -    for i, v in enumerate(trt_tensors): | 
| 201 |  | -        if isinstance(v, int): | 
| 202 |  | -            const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32)) | 
| 203 |  | -            set_layer_name(const, target, f"{name}_static_dim{i}_const") | 
| 204 |  | -            trt_tensors[i] = const.get_output(0) | 
| 205 |  | - | 
| 206 |  | -    # concatenate | 
| 207 |  | -    concat = ctx.net.add_concatenation(trt_tensors) | 
| 208 |  | -    concat.axis = concat_axis | 
| 209 |  | -    set_layer_name(concat, target, f"{name}_concat") | 
| 210 |  | -    return concat.get_output(0) | 
0 commit comments