diff --git a/tripy/nvtripy/backend/api/compile.py b/tripy/nvtripy/backend/api/compile.py index 9ba38ecd4..558e279d8 100644 --- a/tripy/nvtripy/backend/api/compile.py +++ b/tripy/nvtripy/backend/api/compile.py @@ -28,7 +28,6 @@ from nvtripy.utils.types import obj_name_or_type_name -# TODO (#230): Support collections of tensors in args/kwargs @export.public_api(document_under="compiling_code/compile.rst") def compile( func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {} @@ -157,13 +156,15 @@ def add(a, b): trace_input_map = {} input_names = set() input_infos = {} + trace_inputs = [] # flattened list of trace input tensors in argument order # Set up names for the weights in the module to make the trace easier to read. if isinstance(func, Module): for name, weight in func.state_dict().items(): weight.name = name - def process_arg(name, arg): + def process_arg_input_info(name, arg): + """Process InputInfo or DimensionInputInfo objects and create corresponding tensors.""" if isinstance(arg, InputInfo): # Make new tensors for tracing. from nvtripy.common.datatype import floating, integer @@ -184,6 +185,7 @@ def process_arg(name, arg): trace_input_map[name] = tensor input_names.add(name) + trace_inputs.append(tensor.trace_tensor) return tensor @@ -199,11 +201,45 @@ def process_arg(name, arg): trace_input_map[name] = tensor input_names.add(name) + trace_inputs.append(tensor.trace_tensor) return tensor return arg + def process_arg_and_flag(name, arg): + # Handle individual InputInfo or DimensionInputInfo objects + if isinstance(arg, (InputInfo, DimensionInputInfo)): + return process_arg_input_info(name, arg), True + + # Handle containers of InputInfo objects + if isinstance(arg, dict): + result = {} + has_input = False + for key, value in arg.items(): + nested_name = f"{name}.{key}" + processed_child, child_has_input = process_arg_and_flag(nested_name, value) + result[key] = processed_child + has_input = has_input or child_has_input + return result, has_input + elif isinstance(arg, (list, tuple)): + result_list = [] + has_input = False + for idx, value in enumerate(arg): + nested_name = f"{name}[{idx}]" + processed_child, child_has_input = process_arg_and_flag(nested_name, value) + result_list.append(processed_child) + has_input = has_input or child_has_input + return type(arg)(result_list), has_input # preserve sequence type + + return arg, False + + def process_arg(name, arg): + processed, has_input = process_arg_and_flag(name, arg) + if has_input: + input_names.add(name) + return processed + compiled_arg_names = [] new_args = [] @@ -258,8 +294,7 @@ def process_arg(name, arg): [f"Return value {index} was not a tensor: {repr(trace_out)}"], ) - # Order of trace inputs also needs to match that of the compiled_arg_names - trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names] + # We collected flattened trace inputs during traversal trace = Trace( [tensor.trace_tensor for tensor in trace_outputs], trace_inputs, @@ -281,9 +316,22 @@ def process_arg(name, arg): assert isinstance(func_out, Tensor) or isinstance( func_out, Sequence ), "This function is only implemented for Tensors or sequences of Tensors" + + # Group leaf input names by top-level argument for efficient runtime extraction + leaf_names_by_arg = {} + leaf_names = list(input_infos.keys()) + for arg_name in compiled_arg_names: + matching = [ + leaf + for leaf in leaf_names + if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[") + ] + leaf_names_by_arg[arg_name] = matching + return Executable( executable, compiled_arg_names, return_single_tensor_as_sequence=isinstance(func_out, Sequence), input_infos=input_infos, + leaf_names_by_arg=leaf_names_by_arg, ) diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 57b3a78c7..1ca480f74 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -46,6 +46,7 @@ def __init__( arg_names, return_single_tensor_as_sequence: bool, input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]], + leaf_names_by_arg: Dict[str, Sequence[str]], ): self._executable = executable @@ -78,6 +79,8 @@ def __init__( Stores metadata, like shapes and data types, for each input to the executable. """ + self._leaf_names_by_arg = leaf_names_by_arg + def __str__(self) -> str: params = [ f"{name}: {str_from_type_annotation(param.annotation)}" @@ -195,20 +198,67 @@ def add(a, b): ], ) + # Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion + input_info_names = list(self.input_infos.keys()) + name_to_tensor: Dict[str, Tensor] = {} + + def extract_recursive(value, name_prefix, allowed_names): + if name_prefix in allowed_names: + name_to_tensor[name_prefix] = value + return + if isinstance(value, dict): + for key, item in value.items(): + nested_name = f"{name_prefix}.{key}" + extract_recursive(item, nested_name, allowed_names) + elif isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + nested_name = f"{name_prefix}[{idx}]" + extract_recursive(item, nested_name, allowed_names) + else: + return + + for name_idx, tensor in enumerate(input_tensors): + arg_name = self._arg_names[name_idx] + # Fast path: direct leaf input + if arg_name in self.input_infos: + name_to_tensor[arg_name] = tensor + continue + # If this arg has no compiled leaves beneath it, skip any recursion + allowed = self._leaf_names_by_arg.get(arg_name) + if not allowed: + continue + extract_recursive(tensor, arg_name, set(allowed)) + try: + flattened_tensors = [name_to_tensor[name] for name in input_info_names] + except KeyError as missing: + raise_error( + f"Missing runtime tensor for input `{missing.args[0]}`.", + [ + "Ensure your provided containers include tensors for all compiled inputs.", + f"Expected inputs: {input_info_names}", + ], + ) expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()] - for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names): + + # Validate flattened tensors against input_infos + if len(flattened_tensors) != len(expected_devices): + raise_error( + f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})." + ) + + for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()): producer = tensor.trace_tensor.producer if not isinstance(producer, Constant): - raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."]) + raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."]) if tensor.device.kind != expected_device: raise_error( "Unexpected tensor device.", [ - f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n", + f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n", ], ) - input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors] + input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors] try: output_memrefs = self._session.execute_function( "main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client @@ -222,7 +272,7 @@ def add(a, b): expected_input_dtypes = [ info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values() ] - for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names): + for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()): if tensor.dtype != dtype: raise_error( f"Unexpected tensor data type.", @@ -237,7 +287,9 @@ def add(a, b): expected_input_shapes = [ info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values() ] - for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names): + for tensor, expected_bounds, arg_name in zip( + flattened_tensors, expected_input_shapes, self.input_infos.keys() + ): shape = tensor.shape if len(shape) != len(expected_bounds.min): @@ -346,6 +398,7 @@ def encode_executable(executable): "executable": base64.b64encode(executable._executable.serialize()).decode(), "_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence, "input_infos": executable.input_infos, + "leaf_names_by_arg": executable._leaf_names_by_arg, } @@ -357,4 +410,5 @@ def decode_executable(executable_dict): executable_dict["arg_names"], return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"], input_infos=executable_dict["input_infos"], + leaf_names_by_arg=executable_dict.get("leaf_names_by_arg"), ) diff --git a/tripy/nvtripy/frontend/tensor.py b/tripy/nvtripy/frontend/tensor.py index 91146a3d6..0268acf55 100644 --- a/tripy/nvtripy/frontend/tensor.py +++ b/tripy/nvtripy/frontend/tensor.py @@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor": name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype) for name, inp in zip(arg_names, inputs) }, + leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input ) data = executable(*inputs).trace_tensor.producer.data diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index b25ba4b92..6cc935628 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -241,3 +241,131 @@ def test_dimension_input(self): out = compiled(inp, dim_inp) expected = (inp_cp + inp_cp).reshape((-1, reshape_dim)) assert cp.array_equal(cp.from_dlpack(out), expected) + + def test_compile_nested_dict_input_info(self): + def func(data_dict): + return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1] + + dict_input = { + "a": { + "inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + }, + "b": { + "list": [ + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + ], + }, + } + compiled_func = tp.compile(func, args=[dict_input]) + + test_dict = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": { + "list": [ + (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), + (tp.ones((2, 3), dtype=tp.float32) * 3).eval(), + ] + }, + } + result = compiled_func(test_dict) + expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1] + assert tp.equal(result, expected) + + def test_compile_nested_sequence_input_info(self): + def func(data_list): + return data_list[0] + data_list[1][0] + data_list[1][1] + + list_input = [ + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + [ + tp.InputInfo(shape=(2, 3), dtype=tp.float32), + tp.ones((2, 3), dtype=tp.float32) * 2, + ], + ] + compiled_func = tp.compile(func, args=[list_input]) + + test_list = [ + tp.ones((2, 3), dtype=tp.float32).eval(), + ( + (tp.ones((2, 3), dtype=tp.float32) * 3).eval(), + tp.ones((2, 3), dtype=tp.float32) * 2, + ), + ] + result = compiled_func(test_list) + expected = test_list[0] + test_list[1][0] + test_list[1][1] + assert tp.equal(result, expected) + + def test_compile_mixed_containers_and_constants(self): + def func(regular_input, data_dict, data_list, const_in_dict, const): + return ( + regular_input + + data_dict["x"] + + data_dict["y"] + + data_list[0] + + data_list[1] + + const_in_dict["z"] + + const + ) + + regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32) + dict_input = { + "x": tp.InputInfo(shape=(2, 3), dtype=tp.float32), + "y": tp.zeros((2, 3), dtype=tp.float32), + } + list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)] + const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5} + const = tp.ones((2, 3), dtype=tp.float32) * 6 + + compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const]) + + # Only InputInfo arguments should be in function signature + test_regular = tp.ones((2, 3), dtype=tp.float32).eval() + test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()} + test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()] + + result = compiled_func(test_regular, test_dict, test_list) + expected = ( + test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const + ) + assert tp.equal(result, expected) + + def test_compile_missing_nested_input_fails(self): + def func(data_dict): + return data_dict["a"]["inner"] + data_dict["b"]["list"][1] + + dict_input = { + "a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)}, + "b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]}, + } + + compiled_func = tp.compile(func, args=[dict_input]) + + # Missing b.list[1] + bad_dict = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]}, + } + with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."): + compiled_func(bad_dict) + + # Wrong shape for b.list[1] should trigger a shape/device validation error + wrong_shape = { + "a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()}, + "b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]}, + } + with helper.raises(tp.TripyException, match="Unexpected tensor shape."): + compiled_func(wrong_shape) + + def test_compile_container_mismatch_fails(self): + def func(data_list): + return data_list[0] + data_list[1][0] + + list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]] + + compiled_func = tp.compile(func, args=[list_input]) + + bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}] + + with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."): + compiled_func(bad_list)