Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions tripy/nvtripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nvtripy.utils.types import obj_name_or_type_name


# TODO (#230): Support collections of tensors in args/kwargs
@export.public_api(document_under="compiling_code/compile.rst")
def compile(
func: Callable, optimization_level: int = 3, *, args: Sequence[Any] = [], kwargs: Dict[str, Any] = {}
Expand Down Expand Up @@ -157,13 +156,16 @@ def add(a, b):
trace_input_map = {}
input_names = set()
input_infos = {}
trace_inputs = [] # flattened list of trace input tensors in argument order
access_plan_by_name: Dict[str, tuple] = {}

# Set up names for the weights in the module to make the trace easier to read.
if isinstance(func, Module):
for name, weight in func.state_dict().items():
weight.name = name

def process_arg(name, arg):
def process_arg_input_info(name, arg):
"""Process InputInfo or DimensionInputInfo objects and create corresponding tensors."""
if isinstance(arg, InputInfo):
# Make new tensors for tracing.
from nvtripy.common.datatype import floating, integer
Expand All @@ -184,6 +186,7 @@ def process_arg(name, arg):

trace_input_map[name] = tensor
input_names.add(name)
trace_inputs.append(tensor.trace_tensor)

return tensor

Expand All @@ -199,11 +202,49 @@ def process_arg(name, arg):

trace_input_map[name] = tensor
input_names.add(name)
trace_inputs.append(tensor.trace_tensor)

return tensor

return arg

def process_arg_and_flag(top_arg_name, name, arg, steps):
# Handle individual InputInfo or DimensionInputInfo objects
if isinstance(arg, (InputInfo, DimensionInputInfo)):
tensor_or_dim = process_arg_input_info(name, arg)
access_plan_by_name[name] = (top_arg_name, tuple(steps))
return tensor_or_dim, True

# Handle containers of InputInfo objects
if isinstance(arg, dict):
result = {}
has_input = False
for key, value in arg.items():
nested_name = f"{name}.{key}"
processed_child, child_has_input = process_arg_and_flag(
top_arg_name, nested_name, value, (*steps, str(key))
)
result[key] = processed_child
has_input = has_input or child_has_input
return result, has_input
elif isinstance(arg, (list, tuple)):
result_list = []
has_input = False
for idx, value in enumerate(arg):
nested_name = f"{name}[{idx}]"
processed_child, child_has_input = process_arg_and_flag(top_arg_name, nested_name, value, (*steps, idx))
result_list.append(processed_child)
has_input = has_input or child_has_input
return type(arg)(result_list), has_input # preserve sequence type

return arg, False

def process_arg(name, arg):
processed, has_input = process_arg_and_flag(name, name, arg, ())
if has_input:
input_names.add(name)
return processed

compiled_arg_names = []

new_args = []
Expand Down Expand Up @@ -258,8 +299,7 @@ def process_arg(name, arg):
[f"Return value {index} was not a tensor: {repr(trace_out)}"],
)

# Order of trace inputs also needs to match that of the compiled_arg_names
trace_inputs = [trace_input_map[name].trace_tensor for name in compiled_arg_names]
# We collected flattened trace inputs during traversal
trace = Trace(
[tensor.trace_tensor for tensor in trace_outputs],
trace_inputs,
Expand All @@ -281,9 +321,11 @@ def process_arg(name, arg):
assert isinstance(func_out, Tensor) or isinstance(
func_out, Sequence
), "This function is only implemented for Tensors or sequences of Tensors"

return Executable(
executable,
compiled_arg_names,
return_single_tensor_as_sequence=isinstance(func_out, Sequence),
input_infos=input_infos,
access_plan_by_name=access_plan_by_name,
)
56 changes: 50 additions & 6 deletions tripy/nvtripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
arg_names,
return_single_tensor_as_sequence: bool,
input_infos: Dict[str, Union[InputInfo, DimensionInputInfo]],
access_plan_by_name: Dict[str, Tuple[str, Tuple[Union[str, int], ...]]],
):
self._executable = executable

Expand Down Expand Up @@ -78,6 +79,24 @@ def __init__(
Stores metadata, like shapes and data types, for each input to the executable.
"""

# Build accessor map from compile-time access plans
self._accessor_map: Dict[str, callable] = {}
name_to_index = {name: idx for idx, name in enumerate(self._arg_names)}

def make_accessor(arg_index: int, steps: Tuple[Union[str, int], ...]):
def accessor(inputs, idx=arg_index, stps=steps):
v = inputs[idx]
for s in stps:
v = v[s]
return v

return accessor

self._access_plan_by_name = access_plan_by_name
for leaf_name, (arg_name, steps) in self._access_plan_by_name.items():
idx = name_to_index[arg_name]
self._accessor_map[leaf_name] = make_accessor(idx, steps)

def __str__(self) -> str:
params = [
f"{name}: {str_from_type_annotation(param.annotation)}"
Expand Down Expand Up @@ -195,20 +214,41 @@ def add(a, b):
],
)

# Fetch flattened tensors directly via accessors
input_info_names = list(self.input_infos.keys())
flattened_tensors = []
for name in input_info_names:
try:
flattened_tensors.append(self._accessor_map[name](input_tensors))
except Exception:
raise_error(
f"Missing runtime tensor for input `{name}`.",
[
"Ensure your provided collections include tensors for all compiled inputs.",
f"Expected inputs: {input_info_names}",
],
)
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
for tensor, expected_device, arg_name in zip(input_tensors, expected_devices, self._arg_names):

# Validate flattened tensors against input_infos
if len(flattened_tensors) != len(expected_devices):
raise_error(
f"Mismatch between number of flattened tensors ({len(flattened_tensors)}) and expected inputs ({len(expected_devices)})."
)

for tensor, expected_device, info_name in zip(flattened_tensors, expected_devices, self.input_infos.keys()):
producer = tensor.trace_tensor.producer
if not isinstance(producer, Constant):
raise_error(f"Tensor `{arg_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
raise_error(f"Tensor `{info_name}` is not evaluated.", ["Hint: Try calling `.eval()` on the tensor."])
if tensor.device.kind != expected_device:
raise_error(
"Unexpected tensor device.",
[
f"For tensor: `{arg_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
f"For tensor: `{info_name}`, expected to be on device: {expected_device} but got: {tensor.device.kind}.\n",
],
)

input_memrefs = [inp.trace_tensor.producer.data for inp in input_tensors]
input_memrefs = [inp.trace_tensor.producer.data for inp in flattened_tensors]
try:
output_memrefs = self._session.execute_function(
"main", in_args=input_memrefs, stream=self.stream._active_cuda_stream, client=self._runtime_client
Expand All @@ -222,7 +262,7 @@ def add(a, b):
expected_input_dtypes = [
info.dtype if isinstance(info, InputInfo) else int32 for info in self.input_infos.values()
]
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
for tensor, dtype, arg_name in zip(flattened_tensors, expected_input_dtypes, self.input_infos.keys()):
if tensor.dtype != dtype:
raise_error(
f"Unexpected tensor data type.",
Expand All @@ -237,7 +277,9 @@ def add(a, b):
expected_input_shapes = [
info.shape_bounds if isinstance(info, InputInfo) else tuple() for info in self.input_infos.values()
]
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
for tensor, expected_bounds, arg_name in zip(
flattened_tensors, expected_input_shapes, self.input_infos.keys()
):
shape = tensor.shape

if len(shape) != len(expected_bounds.min):
Expand Down Expand Up @@ -346,6 +388,7 @@ def encode_executable(executable):
"executable": base64.b64encode(executable._executable.serialize()).decode(),
"_return_single_tensor_as_sequence": executable._return_single_tensor_as_sequence,
"input_infos": executable.input_infos,
"access_plan_by_name": executable._access_plan_by_name,
}


Expand All @@ -357,4 +400,5 @@ def decode_executable(executable_dict):
executable_dict["arg_names"],
return_single_tensor_as_sequence=executable_dict["_return_single_tensor_as_sequence"],
input_infos=executable_dict["input_infos"],
access_plan_by_name=executable_dict["access_plan_by_name"],
)
1 change: 1 addition & 0 deletions tripy/nvtripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def eval(self) -> "nvtripy.Tensor":
name: InputInfo(list(map(int, inp.trace_tensor.shape)), inp.dtype)
for name, inp in zip(arg_names, inputs)
},
access_plan_by_name={name: (name, tuple()) for name in arg_names},
)
data = executable(*inputs).trace_tensor.producer.data

Expand Down
128 changes: 128 additions & 0 deletions tripy/tests/backend/api/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,131 @@ def test_dimension_input(self):
out = compiled(inp, dim_inp)
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
assert cp.array_equal(cp.from_dlpack(out), expected)

def test_compile_nested_dict_input_info(self):
def func(data_dict):
return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1]

dict_input = {
"a": {
"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
},
"b": {
"list": [
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
],
},
}
compiled_func = tp.compile(func, args=[dict_input])

test_dict = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {
"list": [
(tp.ones((2, 3), dtype=tp.float32) * 2).eval(),
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
]
},
}
result = compiled_func(test_dict)
expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1]
assert tp.equal(result, expected)

def test_compile_nested_sequence_input_info(self):
def func(data_list):
return data_list[0] + data_list[1][0] + data_list[1][1]

list_input = [
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
[
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
tp.ones((2, 3), dtype=tp.float32) * 2,
],
]
compiled_func = tp.compile(func, args=[list_input])

test_list = [
tp.ones((2, 3), dtype=tp.float32).eval(),
(
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
tp.ones((2, 3), dtype=tp.float32) * 2,
),
]
result = compiled_func(test_list)
expected = test_list[0] + test_list[1][0] + test_list[1][1]
assert tp.equal(result, expected)

def test_compile_mixed_containers_and_constants(self):
def func(regular_input, data_dict, data_list, const_in_dict, const):
return (
regular_input
+ data_dict["x"]
+ data_dict["y"]
+ data_list[0]
+ data_list[1]
+ const_in_dict["z"]
+ const
)

regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
dict_input = {
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
"y": tp.zeros((2, 3), dtype=tp.float32),
}
list_input = [tp.ones((2, 3), dtype=tp.float32) * 3, tp.InputInfo(shape=(2, 3), dtype=tp.float32)]
const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5}
const = tp.ones((2, 3), dtype=tp.float32) * 6

compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const])

# Only InputInfo arguments should be in function signature
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
test_list = [None, (tp.ones((2, 3), dtype=tp.float32) * 4).eval()]

result = compiled_func(test_regular, test_dict, test_list)
expected = (
test_regular + test_dict["x"] + dict_input["y"] + test_list[1] + list_input[0] + const_in_dict["z"] + const
)
assert tp.equal(result, expected)

def test_compile_missing_nested_input_fails(self):
def func(data_dict):
return data_dict["a"]["inner"] + data_dict["b"]["list"][1]

dict_input = {
"a": {"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32)},
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.InputInfo(shape=(2, 3), dtype=tp.float32)]},
}

compiled_func = tp.compile(func, args=[dict_input])

# Missing b.list[1]
bad_dict = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {"list": [tp.ones((2, 3), dtype=tp.float32).eval()]},
}
with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_dict\.b\.list\[1\]`."):
compiled_func(bad_dict)

# Wrong shape for b.list[1] should trigger a shape/device validation error
wrong_shape = {
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
"b": {"list": [tp.zeros((2, 3), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32).eval()]},
}
with helper.raises(tp.TripyException, match="Unexpected tensor shape."):
compiled_func(wrong_shape)

def test_compile_container_mismatch_fails(self):
def func(data_list):
return data_list[0] + data_list[1][0]

list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), [tp.InputInfo(shape=(2, 3), dtype=tp.float32)]]

compiled_func = tp.compile(func, args=[list_input])

bad_list = [tp.ones((2, 3), dtype=tp.float32).eval(), {"not": tp.ones((2, 3), dtype=tp.float32).eval()}]

with helper.raises(tp.TripyException, match="Missing runtime tensor for input `data_list\[1\]\[0\]`."):
compiled_func(bad_list)