Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tripy/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ which specifies doc metadata for each API (e.g. location).
- Docstring must include *at least* **one [code example](#code-examples)**.

- If the function accepts `tp.Tensor`s, must indicate **data type constraints**
with the [`wrappers.interface`](../nvtripy/utils/wrappers.py) decorator.
with the [`wrappers.interface`](../nvtripy/frontend/wrappers.py) decorator.

**Example:**

Expand Down
2 changes: 1 addition & 1 deletion tripy/docs/post0_developer_guides/00-architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ and various operations, e.g. {class}`nvtripy.resize`.
:::{admonition} Info
Most operations are decorated with:
1. [`@export.public_api`](source:/nvtripy/export.py): Enables documentation, type checking, and overloading.
2. [`@wrappers.interface`](source:/nvtripy/utils/wrappers.py): Enforces (and generates tests for) data type constraints.
2. [`@wrappers.interface`](source:/nvtripy/frontend/wrappers.py): Enforces (and generates tests for) data type constraints.
:::

Operations are **lazily evaluated**.
Expand Down
2 changes: 1 addition & 1 deletion tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ from typing import Tuple

from nvtripy import export
from nvtripy.trace.ops.topn import TopN
from nvtripy.utils import wrappers
from nvtripy.frontend import wrappers
from nvtripy.frontend.ops import utils as op_utils

@export.public_api(document_under="operations/functions")
Expand Down
9 changes: 8 additions & 1 deletion tripy/nvtripy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@
module=sys.modules[__name__],
symbol="enable_dtype_checking",
)(True)
"""Whether to enable data type checking in API functions."""
"""[DEPRECATED - use enable_input_validation] Whether to enable data type checking in API functions."""

enable_input_validation: bool = export.public_api(
document_under="config.rst",
module=sys.modules[__name__],
symbol="enable_input_validation",
)(True)
"""Whether to enable input validation in API functions."""

extra_error_information: List[str] = export.public_api(
document_under="config.rst",
Expand Down
19 changes: 19 additions & 0 deletions tripy/nvtripy/frontend/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from nvtripy.frontend.constraints.base import Constraints
from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher
from nvtripy.frontend.constraints.logic import And, Equal, Logic, NotEqual, NotOneOf, OneOf, Or
130 changes: 130 additions & 0 deletions tripy/nvtripy/frontend/constraints/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
The constraints system has two purposes:

1. Imposing input requirements.
2. Describing output guarantees.

Constraints are specified by composing one or more `Constraints` subclasses:

```py
constraint = And(
Equal(GetDataType(GetInput("input0")), GetInput("dtype")),
Equal(GetDataType(GetInput("input1")), GetInput("dtype")),
)
```

We also override several bitwise operators and properties to provide a convenient shorthand.
For example, the above can be written as:

```py
constraint = (GetInput("input0").dtype == GetInput("dtype")) & (GetInput("input1").dtype == GetInput("dtype"))
```

The constraints class also provides a pattern matcher.
For example, we may want to find all constraints that check the data type of an input (`None` is a wildcard).

```py
matches = constraint.find(Equal(GetDataType(GetInput), None))
```
"""

from abc import ABC
from typing import List


class Constraints(ABC):
"""
Base class for the entire constraints system.
"""

def get_children(self) -> List["Constraints"]:
children = []
for attr_value in vars(self).values():
if isinstance(attr_value, Constraints):
children.append(attr_value)
elif isinstance(attr_value, (list, tuple)):
children.extend(v for v in attr_value if isinstance(v, Constraints))
return children

def find(self, pattern: "Constraints") -> List["Constraints"]:
"""
Find all constraints in the tree that match the given pattern.

Performs a depth-first search through the constraint tree to find all
constraints that structurally match the given pattern, using the current
constraint as the root node.

Args:
pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)).
Use None as a wildcard to match anything.

Returns:
A list of all matching constraints found in the tree.

Example:
pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument
matches = constraint_tree.find(pattern)
"""

def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool:
# None is a wildcard that matches anything
if pattern is None:
return True

if isinstance(pattern, type):
return isinstance(constraint, pattern)

if type(pattern) != type(constraint):
return False

# Need to index into sequences rather than comparing directly since there may be patterns in the sequence.
if isinstance(pattern, (list, tuple)) and isinstance(constraint, (list, tuple)):
if len(pattern) != len(constraint):
return False
return all(matches_pattern(p_val, c_val) for p_val, c_val in zip(pattern, constraint))

if not isinstance(pattern, Constraints):
return pattern == constraint

# Compare attributes
pattern_attrs = vars(pattern)
constraint_attrs = vars(constraint)

for key, pattern_value in pattern_attrs.items():
if key not in constraint_attrs:
return False

constraint_value = constraint_attrs[key]

if not matches_pattern(pattern_value, constraint_value):
return False

return True

matches = []

if matches_pattern(pattern, self):
matches.append(self)

# Recursively search children
for child in self.get_children():
matches.extend(child.find(pattern))

return matches
110 changes: 110 additions & 0 deletions tripy/nvtripy/frontend/constraints/fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import abstractmethod
from typing import Any, List, Optional, Sequence, Tuple

from nvtripy.common.datatype import dtype as tp_dtype
from nvtripy.common.exception import raise_error
from nvtripy.frontend.constraints.base import Constraints


class Fetcher(Constraints):
"""
Fetches a value based on the function parameters or return value.
"""

@abstractmethod
def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: ...

def __eq__(self, other: "Fetcher") -> "Equal":
from nvtripy.frontend.constraints.logic import Equal

return Equal(self, other)

def __ne__(self, other: "Fetcher") -> "NotEqual":
from nvtripy.frontend.constraints.logic import NotEqual

return NotEqual(self, other)


class ValueFetcher(Fetcher):
@property
def dtype(self) -> "GetDataType":
return GetDataType(self)


class GetInput(ValueFetcher):
def __init__(self, name: str):
self.name = name

def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any:
for name, value in args:
if name == self.name:
return value
assert False, f"Input '{self.name}' not found in arguments."

def __str__(self):
return self.name


class GetReturn(ValueFetcher):
def __init__(self, index: int):
self.index = index

def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any:
assert returns is not None, "No return values available."
return returns[self.index]

def __str__(self):
return f"return[{self.index}]"


class GetDataType(Fetcher):
def __init__(self, value_fetcher: ValueFetcher):
self.value_fetcher = value_fetcher

def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any:
from nvtripy.frontend.tensor import Tensor

def get_arg_dtype(arg: Any) -> tp_dtype:
if isinstance(arg, Sequence):
arg_dtypes = [get_arg_dtype(elem) for elem in arg]

if len(set(arg_dtypes)) != 1:
raise_error(
f"Could not determine data type of {self.value_fetcher}",
[
f"Mismatched data types in sequence argument.\n",
f"For parameter: '{self.value_fetcher}', all arguments must have the same data type, but got: "
f"{arg_dtypes}",
],
)
arg_dtype = arg_dtypes[0]
elif isinstance(arg, Tensor):
arg_dtype = arg.dtype
else:
raise_error(
f"Could not determine data type of {self.value_fetcher}",
[f"Expected a tensor or data type argument for {self.value_fetcher}, but got: {arg}"],
)
return arg_dtype

tensor = self.value_fetcher(args, returns)
return get_arg_dtype(tensor)

def __str__(self):
return f"{self.value_fetcher}.dtype"
Loading