Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions tripy/nvtripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nvtripy.utils.types import obj_name_or_type_name


# TODO (#230): Support collections of tensors in args/kwargs
@export.public_api(document_under="compiling_code/compile.rst")
def compile(
func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {}
Expand Down Expand Up @@ -157,13 +156,15 @@ def add(a, b):
trace_input_map = {}
input_names = set()
input_infos = {}
trace_inputs = [] # flattened list of trace input tensors in argument order

# Set up names for the weights in the module to make the trace easier to read.
if isinstance(func, Module):
for name, weight in func.state_dict().items():
weight.name = name

def process_arg(name, arg):
def process_arg_input_info(name, arg):
"""Process InputInfo or DimensionInputInfo objects and create corresponding tensors."""
if isinstance(arg, InputInfo):
# Make new tensors for tracing.
from nvtripy.common.datatype import floating, integer
Expand All @@ -184,6 +185,7 @@ def process_arg(name, arg):

trace_input_map[name] = tensor
input_names.add(name)
trace_inputs.append(tensor.trace_tensor)

return tensor

Expand All @@ -199,11 +201,45 @@ def process_arg(name, arg):

trace_input_map[name] = tensor
input_names.add(name)
trace_inputs.append(tensor.trace_tensor)

return tensor

return arg

def process_arg_and_flag(name, arg):
# Handle individual InputInfo or DimensionInputInfo objects
if isinstance(arg, (InputInfo, DimensionInputInfo)):
return process_arg_input_info(name, arg), True

# Handle containers of InputInfo objects
if isinstance(arg, dict):
result = {}
has_input = False
for key, value in arg.items():
nested_name = f"{name}.{key}"
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
result[key] = processed_child
has_input = has_input or child_has_input
return result, has_input
elif isinstance(arg, (list, tuple)):
result_list = []
has_input = False
for idx, value in enumerate(arg):
nested_name = f"{name}[{idx}]"
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
result_list.append(processed_child)
has_input = has_input or child_has_input
return type(arg)(result_list), has_input # preserve sequence type

return arg, False

def process_arg(name, arg):
processed, has_input = process_arg_and_flag(name, arg)
if has_input:
input_names.add(name)
return processed

compiled_arg_names = []

new_args = []
Expand Down Expand Up @@ -258,8 +294,7 @@ def process_arg(name, arg):
[f"Return value {index} was not a tensor: {repr(trace_out)}"],
)

# Order of trace inputs also needs to match that of the compiled_arg_names
trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names]
# We collected flattened trace inputs during traversal
trace = Trace(
[tensor.trace_tensor for tensor in trace_outputs],
trace_inputs,
Expand All @@ -281,9 +316,22 @@ def process_arg(name, arg):
assert isinstance(func_out, Tensor) or isinstance(
func_out, Sequence
), "This function is only implemented for Tensors or sequences of Tensors"

# Group leaf input names by top-level argument for efficient runtime extraction
leaf_names_by_arg = {}
leaf_names = list(input_infos.keys())
for arg_name in compiled_arg_names:
matching = [
leaf
for leaf in leaf_names
if leaf == arg_name or leaf.startswith(f"{arg_name}.") or leaf.startswith(f"{arg_name}[")
]
leaf_names_by_arg[arg_name] = matching

return Executable(
executable,
compiled_arg_names,
return_single_tensor_as_sequence=isinstance(func_out, Sequence),
input_infos=input_infos,
leaf_names_by_arg=leaf_names_by_arg,
)
66 changes: 60 additions & 6 deletions tripy/nvtripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
arg_names,
return_single_tensor_as_sequence: bool,
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
leaf_names_by_arg: Dict[str, Sequence[str]],
):
self._executable = executable

Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(
Stores metadata, like shapes and data types, for each input to the executable.
"""

self._leaf_names_by_arg = leaf_names_by_arg

def __str__(self) -> str:
params = [
f"{name}: {str_from_type_annotation(param.annotation)}"
Expand Down Expand Up @@ -195,20 +198,67 @@ def add(a, b):
],
)

# Build a name->tensor map using precomputed leaf names to avoid unnecessary recursion
input_info_names = list(self.input_infos.keys())
name_to_tensor: Dict[str, Tensor] = {}

def extract_recursive(value, name_prefix, allowed_names):
if name_prefix in allowed_names:
name_to_tensor[name_prefix] = value
return
if isinstance(value, dict):
for key, item in value.items():
nested_name = f"{name_prefix}.{key}"
extract_recursive(item, nested_name, allowed_names)
elif isinstance(value, (list, tuple)):
for idx, item in enumerate(value):
nested_name = f"{name_prefix}[{idx}]"
Comment on lines +211 to +215
Copy link
Collaborator

@pranavm-nvidia pranavm-nvidia Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should actually know how to access the tensors within each collection at compile-time. I'm wondering if we can just build accessor lambas which will provide a fast way to directly access the right values. When we compile, we could create a mapping of trace input names to functions that will retrieve the necessary argument from the raw inputs - basically, we'd use it like so:

flattened_tensors = []
for name in input_info_names:
    flattened_tensors.append(accessor_map[name](input_tensors))

At compile time, we'd want to recursively build up this accessor map (probably just by adding an extra return value that's a dictionary of accessor functions). The most efficient way would probably be to build strings like:

"inp['key_1'][5]['key_2'][3]"

and then eval them into callables (the alternative would be to return a recursive chain of lambdas, but the string approach avoids recursive calls).

This way we can remove all the name parsing logic and avoid looping over the collection inputs entirely.

extract_recursive(item, nested_name, allowed_names)
else:
return

for name_idx, tensor in enumerate(input_tensors):
arg_name = self._arg_names[name_idx]
# Fast path: direct leaf input
if arg_name in self.input_infos:
name_to_tensor[arg_name] = tensor
continue
# If this arg has no compiled leaves beneath it, skip any recursion
allowed = self._leaf_names_by_arg.get(arg_name)
if not allowed:
continue
extract_recursive(tensor, arg_name, set(allowed))
try:
flattened_tensors = [name_to_tensor[name] for name in input_info_names]
except KeyError as missing:
raise_error(
f"Missing runtime tensor for input `{missing.args[0]}`.",
[
"Ensure your provided containers include tensors for all compiled inputs.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Ensure your provided containers include tensors for all compiled inputs.",
"Ensure your provided collections include tensors for all compiled inputs.",

f"Expected inputs: {input_info_names}",
],
)
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):

# Validate flattened tensors against input_infos
if len(flattened_tensors) != len(expected_devices):
raise_error(
f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})."
)

for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()):
producer = tensor.trace_tensor.producer
if not isinstance(producer, Constant):
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
if tensor.device.kind != expected_device:
raise_error(
"Unexpected tensor device.",
[
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
],
)

input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors]
try:
output_memrefs = self._session.execute_function(
"main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client
Expand All @@ -222,7 +272,7 @@ def add(a, b):
expected_input_dtypes = [
info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values()
]
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()):
if tensor.dtype != dtype:
raise_error(
f"Unexpected tensor data type.",
Expand All @@ -237,7 +287,9 @@ def add(a, b):
expected_input_shapes = [
info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values()
]
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
for tensor, expected_bounds, arg_name in zip(
flattened_tensors, expected_input_shapes, self.input_infos.keys()
):
shape = tensor.shape

if len(shape) != len(expected_bounds.min):
Expand Down Expand Up @@ -346,6 +398,7 @@ def encode_executable(executable):
"executable": base64.b64encode(executable._executable.serialize()).decode(),
"_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence,
"input_infos": executable.input_infos,
"leaf_names_by_arg": executable._leaf_names_by_arg,
}


Expand All @@ -357,4 +410,5 @@ def decode_executable(executable_dict):
executable_dict["arg_names"],
return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"],
input_infos=executable_dict["input_infos"],
leaf_names_by_arg=executable_dict.get("leaf_names_by_arg"),
)
1 change: 1 addition & 0 deletions tripy/nvtripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor":
name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype)
for name, inp in zip(arg_names, inputs)
},
leaf_names_by_arg={name: [name] for name in arg_names}, # every argument is a direct input
)
data = executable(*inputs).trace_tensor.producer.data

Expand Down
128 changes: 128 additions & 0 deletions tripy/tests/backend/api/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,131 @@ def test_dimension_input(self):
out = compiled(inp, dim_inp)
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
assert cp.array_equal(cp.from_dlpack(out), expected)

def test_compile_nested_dict_input_info(self):
def func(data_dict):
return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1]

dict_input = {
"a": {
"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
},
"b": {
"list": [
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
],
},
}
compiled_func = tp.compile(func, args=[dict_input])

test_dict = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {
"list": [
(tp.ones((2, 3), dtype=tp.float32) * 2).eval(),
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
]
},
}
result = compiled_func(test_dict)
expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1]
assert tp.equal(result, expected)

def test_compile_nested_sequence_input_info(self):
def func(data_list):
return data_list[0] + data_list[1][0] + data_list[1][1]

list_input = [
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
[
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
tp.ones((2, 3), dtype=tp.float32) * 2,
],
]
compiled_func = tp.compile(func, args=[list_input])

test_list = [
tp.ones((2, 3), dtype=tp.float32).eval(),
(
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
tp.ones((2, 3), dtype=tp.float32) * 2,
),
]
result = compiled_func(test_list)
expected = test_list[0] + test_list[1][0] + test_list[1][1]
assert tp.equal(result, expected)

def test_compile_mixed_containers_and_constants(self):
def func(regular_input, data_dict, data_list, const_in_dict, const):
return (
regular_input
+ data_dict["x"]
+ data_dict["y"]
+ data_list[0]
+ data_list[1]
+ const_in_dict["z"]
+ const
)

regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
dict_input = {
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
"y": tp.zeros((2, 3), dtype=tp.float32),
}
list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)]
const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5}
const = tp.ones((2, 3), dtype=tp.float32) * 6

compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const])

# Only InputInfo arguments should be in function signature
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()]

result = compiled_func(test_regular, test_dict, test_list)
expected = (
test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const
)
assert tp.equal(result, expected)

def test_compile_missing_nested_input_fails(self):
def func(data_dict):
return data_dict["a"]["inner"] + data_dict["b"]["list"][1]

dict_input = {
"a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)},
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]},
}

compiled_func = tp.compile(func, args=[dict_input])

# Missing b.list[1]
bad_dict = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]},
}
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."):
compiled_func(bad_dict)

# Wrong shape for b.list[1] should trigger a shape/device validation error
wrong_shape = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]},
}
with helper.raises(tp.TripyException, match="Unexpected tensor shape."):
compiled_func(wrong_shape)

def test_compile_container_mismatch_fails(self):
def func(data_list):
return data_list[0] + data_list[1][0]

list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]]

compiled_func = tp.compile(func, args=[list_input])

bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}]

with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."):
compiled_func(bad_list)