Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions tripy/nvtripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nvtripy.utils.types import obj_name_or_type_name


# TODO (#230): Support collections of tensors in args/kwargs
@export.public_api(document_under="compiling_code/compile.rst")
def compile(
func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {}
Expand Down Expand Up @@ -163,7 +162,8 @@ def add(a, b):
for name, weight in func.state_dict().items():
weight.name = name

def process_arg(name, arg):
def process_arg_input_info(name, arg):
"""Process InputInfo or DimensionInputInfo objects and create corresponding tensors."""
if isinstance(arg, InputInfo):
# Make new tensors for tracing.
from nvtripy.common.datatype import floating, integer
Expand Down Expand Up @@ -204,6 +204,31 @@ def process_arg(name, arg):

return arg

def process_arg(name, arg):
# Handle individual InputInfo or DimensionInputInfo objects
if isinstance(arg, (InputInfo, DimensionInputInfo)):
return process_arg_input_info(name, arg)

# Handle containers of InputInfo objects
if isinstance(arg, dict):
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()):
input_names.add(name)
result = {}
for key, value in arg.items():
nested_name = f"{name}.{key}"
result[key] = process_arg(nested_name, value)
return result
elif isinstance(arg, (list, tuple)):
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg):
input_names.add(name)
result = []
for idx, value in enumerate(arg):
nested_name = f"{name}[{idx}]"
result.append(process_arg(nested_name, value))
return type(arg)(result)

return arg

compiled_arg_names = []

new_args = []
Expand Down Expand Up @@ -259,7 +284,24 @@ def process_arg(name, arg):
)

# Order of trace inputs also needs to match that of the compiled_arg_names
trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names]
# For containers, we need to collect all individual trace tensors
def collect_trace_tensors(name):
"""Collect trace tensors for a name, flattening containers."""
if name in trace_input_map:
# Regular InputInfo or DimensionInputInfo
return [trace_input_map[name].trace_tensor]
else:
# Collect all nested trace tensors inside the container
nested_tensors = []
for nested_name in sorted(trace_input_map.keys()):
if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["):
nested_tensors.append(trace_input_map[nested_name].trace_tensor)
return nested_tensors

# Flatten all trace tensors from containers and individual inputs
trace_inputs = []
for name in compiled_arg_names:
trace_inputs.extend(collect_trace_tensors(name))
trace = Trace(
[tensor.trace_tensor for tensor in trace_outputs],
trace_inputs,
Expand Down
49 changes: 45 additions & 4 deletions tripy/nvtripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,61 @@ def add(a, b):
],
)

# Recursively extract inputs from containers to get individual tensors for validation and execution
def extract_inputs(tensors, input_info_names):
def extract_recursive(tensor, name_prefix):
if isinstance(tensor, dict):
result = []
for key in sorted(tensor.keys()):
nested_name = f"{name_prefix}.{key}"
if nested_name in input_info_names:
result.append(tensor[key])
else:
result.extend(extract_recursive(tensor[key], nested_name))
return result
elif isinstance(tensor, (list, tuple)):
result = []
for idx, value in enumerate(tensor):
nested_name = f"{name_prefix}[{idx}]"
if nested_name in input_info_names:
result.append(value)
else:
result.extend(extract_recursive(value, nested_name))
return result
else: # Regular tensor
if name_prefix in input_info_names:
return [tensor]
else:
return []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this branch reached?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When there are constants in the containers that have inputs


flattened = []
for name_idx, tensor in enumerate(tensors):
arg_name = self._arg_names[name_idx]
flattened.extend(extract_recursive(tensor, arg_name))
return flattened

flattened_tensors = extract_inputs(input_tensors, set(self.input_infos.keys()))
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):

# Validate flattened tensors against input_infos
if len(flattened_tensors) != len(expected_devices):
raise_error(
f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})."
)

for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()):
producer = tensor.trace_tensor.producer
if not isinstance(producer, Constant):
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
if tensor.device.kind != expected_device:
raise_error(
"Unexpected tensor device.",
[
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
],
)

input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors]
try:
output_memrefs = self._session.execute_function(
"main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client
Expand Down
68 changes: 68 additions & 0 deletions tripy/tests/backend/api/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,71 @@ def test_dimension_input(self):
out = compiled(inp, dim_inp)
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
assert cp.array_equal(cp.from_dlpack(out), expected)

def test_compile_dict_input_info(self):
"""Test compilation with dictionary of InputInfo objects."""

def func(data_dict):
return data_dict["a"] + data_dict["b"]

dict_input = {
"a": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
"b": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
}
compiled_func = tp.compile(func, args=[dict_input])

test_dict = {"a": tp.ones((2, 3), dtype=tp.float32).eval(), "b": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
result = compiled_func(test_dict)
expected = test_dict["a"] + test_dict["b"]
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))

def test_compile_nested_list_input_info(self):
"""Test compilation with nested list containers."""

def func(data_list):
return data_list[0] + data_list[1][0] + data_list[1][1]

list_input = [
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
[ # Nested list
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
tp.ones((2, 3), dtype=tp.float32) * 2, # Constant in nested list
],
]
compiled_func = tp.compile(func, args=[list_input])

test_list = [
tp.ones((2, 3), dtype=tp.float32).eval(),
[ # Nested list in test data
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
tp.ones((2, 3), dtype=tp.float32) * 2, # Should match baked constant
],
]
result = compiled_func(test_list)
expected = test_list[0] + test_list[1][0] + test_list[1][1]
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))

def test_compile_mixed_containers_and_constants(self):
"""Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant."""

def func(regular_input, data_dict, data_list, constant_value):
return regular_input + data_dict["x"] + data_dict["y"] + data_list[0] + data_list[1] + constant_value

regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
dict_input = {
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
"y": tp.zeros((2, 3), dtype=tp.float32), # Constant in dict
}
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3]
constant_value = tp.ones((2, 3), dtype=tp.float32) * 5

compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, constant_value])

# Only InputInfo arguments should be in function signature
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)}
test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3]

result = compiled_func(test_regular, test_dict, test_list)
expected = test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + constant_value
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))