-
Notifications
You must be signed in to change notification settings - Fork 18
#343: Support collections of tensors in args/kwargs for compile #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| else: | ||
| return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this branch reached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When there are constants in the containers that have inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚠️ Performance Alert ⚠️
Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.10.
| Benchmark suite | Current: 87ecff6 | Previous: 48b5d7d | Ratio |
|---|---|---|---|
tests/performance/test_perf.py::test_perf_regression[sdpa-float16] |
2983.0722283127584 iter/sec (stddev: 0.0002213579824486067) |
3302.897668253486 iter/sec (stddev: 0.00007090329654256612) |
1.11 |
This comment was automatically generated by workflow using github-action-benchmark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for collections of tensors (nested dictionaries and lists) in the args and kwargs parameters of the compile function. Previously, only individual InputInfo or DimensionInputInfo objects were supported as runtime inputs, but now these can be organized within container structures.
Key Changes:
- Recursive processing of nested data structures (dicts, lists, tuples) containing
InputInfoobjects during compilation - Flattening of nested input structures at runtime to match the compilation structure
- New validation for input extraction ensuring all compiled inputs are provided at runtime
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tripy/nvtripy/backend/api/compile.py | Added recursive process_arg_and_flag function to handle nested containers during compilation and maintain flattened trace inputs |
| tripy/nvtripy/backend/api/executable.py | Added extract_inputs function to recursively flatten runtime tensor containers and validate against expected compiled inputs |
| tripy/tests/backend/api/test_compile.py | Added three comprehensive test cases covering nested dicts, nested sequences, and mixed container scenarios |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| for name_idx, tensor in enumerate(tensors): | ||
| arg_name = self._arg_names[name_idx] | ||
| extract_recursive(tensor, arg_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is part of the critical path (that's probably why the perf tests are failing). Can we store information at compile time so that we don't attempt to flatten arguments that are not collections?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
I see that some additional docs tests are failing due to the new required argument (which caches the input structure) that I added for creating an executable. Initially I implemented it as an optional argument and then re-created it in the executable if it wasn't provided, but this didn't seem to make much sense to me and I was hoping to have the relevant logic only in one place. I could do this entirely in the |
| raise_error( | ||
| f"Missing runtime tensor for input `{missing.args[0]}`.", | ||
| [ | ||
| "Ensure your provided containers include tensors for all compiled inputs.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "Ensure your provided containers include tensors for all compiled inputs.", | |
| "Ensure your provided collections include tensors for all compiled inputs.", |
| nested_name = f"{name_prefix}.{key}" | ||
| extract_recursive(item, nested_name, allowed_names) | ||
| elif isinstance(value, (list, tuple)): | ||
| for idx, item in enumerate(value): | ||
| nested_name = f"{name_prefix}[{idx}]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should actually know how to access the tensors within each collection at compile-time. I'm wondering if we can just build accessor lambas which will provide a fast way to directly access the right values. When we compile, we could create a mapping of trace input names to functions that will retrieve the necessary argument from the raw inputs - basically, we'd use it like so:
flattened_tensors = []
for name in input_info_names:
flattened_tensors.append(accessor_map[name](input_tensors))At compile time, we'd want to recursively build up this accessor map (probably just by adding an extra return value that's a dictionary of accessor functions). The most efficient way would probably be to build strings like:
"inp['key_1'][5]['key_2'][3]"and then eval them into callables (the alternative would be to return a recursive chain of lambdas, but the string approach avoids recursive calls).
This way we can remove all the name parsing logic and avoid looping over the collection inputs entirely.
No description provided.