Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tripy/nvtripy/frontend/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from nvtripy.frontend.constraints.base import Constraints
from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher
from nvtripy.frontend.constraints.logic import And, Equal, Logic, Not, NotEqual, OneOf
130 changes: 130 additions & 0 deletions tripy/nvtripy/frontend/constraints/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
The constraints system has two purposes:

1. Imposing input requirements.
2. Describing output guarantees.

Constraints are specified by composing one or more `Constraints` subclasses:

```py
constraint = And(
Equal(GetDataType(GetInput("input0")), GetInput("dtype")),
Equal(GetDataType(GetInput("input1")), GetInput("dtype")),
)
```

We also override several bitwise operators and properties to provide a convenient shorthand.
For example, the above can be written as:

```py
constraint = (GetInput("input0").dtype == GetInput("dtype")) & (GetInput("input1").dtype == GetInput("dtype"))
```

The constraints class also provides a pattern matcher.
For example, we may want to find all constraints that check the data type of an input (`None` is a wildcard).

```py
matches = constraint.find(Equal(GetDataType(GetInput), None))
```
"""

from abc import ABC
from typing import List


class Constraints(ABC):
"""
Base class for the entire constraints system.
"""

def get_children(self) -> List["Constraints"]:
children = []
for attr_value in vars(self).values():
if isinstance(attr_value, Constraints):
children.append(attr_value)
elif isinstance(attr_value, (list, tuple)):
children.extend(v for v in attr_value if isinstance(v, Constraints))
return children

def find(self, pattern: "Constraints") -> List["Constraints"]:
"""
Find all constraints in the tree that match the given pattern.

Performs a depth-first search through the constraint tree to find all
constraints that structurally match the given pattern, using the current
constraint as the root node.

Args:
pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)).
Use None as a wildcard to match anything.

Returns:
A list of all matching constraints found in the tree.

Example:
pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument
matches = constraint_tree.find(pattern)
"""

def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool:
# None is a wildcard that matches anything
if pattern is None:
return True

if isinstance(pattern, type):
return isinstance(constraint, pattern)

if type(pattern) != type(constraint):
return False

# Need to index into sequences rather than comparing directly since there may be patterns in the sequence.
if isinstance(pattern, (list, tuple)) and isinstance(constraint, (list, tuple)):
if len(pattern) != len(constraint):
return False
return all(matches_pattern(p_val, c_val) for p_val, c_val in zip(pattern, constraint))

if not isinstance(pattern, Constraints):
return pattern == constraint

# Compare attributes
pattern_attrs = vars(pattern)
constraint_attrs = vars(constraint)

for key, pattern_value in pattern_attrs.items():
if key not in constraint_attrs:
return False

constraint_value = constraint_attrs[key]

if not matches_pattern(pattern_value, constraint_value):
return False

return True

matches = []

if matches_pattern(pattern, self):
matches.append(self)

# Recursively search children
for child in self.get_children():
matches.extend(child.find(pattern))

return matches
111 changes: 111 additions & 0 deletions tripy/nvtripy/frontend/constraints/fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import abstractmethod
from typing import Any, List, Sequence, Tuple

from nvtripy.common.datatype import dtype as tp_dtype
from nvtripy.common.exception import raise_error
from nvtripy.frontend.constraints.base import Constraints


class Fetcher(Constraints):
"""
Fetches a value based on the function parameters or return value.
"""

@abstractmethod
def __call__(self, args: List[Tuple[str, Any]]) -> Any: ...

def __eq__(self, other: "Fetcher") -> "Equal":
from nvtripy.frontend.constraints.logic import Equal

return Equal(self, other)

def __ne__(self, other: "Fetcher") -> "NotEqual":
from nvtripy.frontend.constraints.logic import NotEqual

return NotEqual(self, other)


class ValueFetcher(Fetcher):
@property
def dtype(self) -> "GetDataType":
return GetDataType(self)


class GetInput(ValueFetcher):
def __init__(self, name: str):
self.name = name

def __call__(self, args: List[Tuple[str, Any]]) -> Any:
for name, value in args:
if name == self.name:
return value
assert False, f"Input '{self.name}' not found in arguments."

def __str__(self):
return self.name


class GetReturn(ValueFetcher):
def __init__(self, index: int):
self.index = index

def __call__(self, args: List[Tuple[str, Any]]) -> Any:
raise NotImplementedError(
"GetReturn is only used to describe output guarantees and must not be called for input validation purposes."
)

def __str__(self):
return f"return[{self.index}]"


class GetDataType(Fetcher):
def __init__(self, value_fetcher: ValueFetcher):
self.value_fetcher = value_fetcher

def __call__(self, args: List[Tuple[str, Any]]) -> Any:
from nvtripy.frontend.tensor import Tensor

def get_arg_dtype(arg: Any) -> tp_dtype:
if isinstance(arg, Sequence):
arg_dtypes = [get_arg_dtype(elem) for elem in arg]

if len(set(arg_dtypes)) != 1:
raise_error(
f"Could not determine data type of {self.value_fetcher}",
[
f"Mismatched data types in sequence argument.\n",
f"For parameter: '{self.value_fetcher}', all arguments must have the same data type, but got: "
f"{arg_dtypes}",
],
)
arg_dtype = arg_dtypes[0]
elif isinstance(arg, Tensor):
arg_dtype = arg.dtype
else:
raise_error(
f"Could not determine data type of {self.value_fetcher}",
[f"Expected a tensor or data type argument for {self.value_fetcher}, but got: {arg}"],
)
return arg_dtype

tensor = self.value_fetcher(args)
return get_arg_dtype(tensor)

def __str__(self):
return f"{self.value_fetcher}.dtype"
125 changes: 125 additions & 0 deletions tripy/nvtripy/frontend/constraints/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import abstractmethod
from typing import Any, List, Sequence, Tuple

from nvtripy.frontend.constraints.base import Constraints
from nvtripy.frontend.constraints.fetcher import Fetcher
from nvtripy.utils.result import Result


class Logic(Constraints):
"""
Represents logical operations on constraints.
"""

@abstractmethod
def __call__(self, args: List[Tuple[str, Any]]) -> Result: ...

def __and__(self, other: "Logic") -> "Logic":
if isinstance(self, And):
return And(*self.constraints, other)
elif isinstance(other, And):
return And(self, *other.constraints)
return And(self, other)

def __invert__(self) -> "Logic":
if isinstance(self, Equal):
return NotEqual(self.fetcher1, self.fetcher2)
return Not(self)


class OneOf(Logic):
def __init__(self, fetcher: Fetcher, options: Sequence[Any]):
self.fetcher = fetcher
self.options = options

def __call__(self, args: List[Tuple[str, Any]]) -> Result:
value = self.fetcher(args)
if value in self.options:
return Result.ok()

return Result.err([f"Expected {self.fetcher} to be one of {self.options}, but got {value}."])

def __str__(self):
return f"{self.fetcher} is one of {self.options}"


class Equal(Logic):
def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher):
self.fetcher1 = fetcher1
self.fetcher2 = fetcher2

def __call__(self, args: List[Tuple[str, Any]]) -> Result:
value1 = self.fetcher1(args)
value2 = self.fetcher2(args)
if value1 == value2:
return Result.ok()

return Result.err([f"Expected {self.fetcher1} to be equal to {self.fetcher2}, but got {value1} and {value2}."])

def __str__(self):
return f"{self.fetcher1} == {self.fetcher2}"


class NotEqual(Logic):
def __init__(self, fetcher1: Fetcher, fetcher2: Fetcher):
self.fetcher1 = fetcher1
self.fetcher2 = fetcher2

def __call__(self, args: List[Tuple[str, Any]]) -> Result:
value1 = self.fetcher1(args)
value2 = self.fetcher2(args)
if value1 != value2:
return Result.ok()

return Result.err([f"Expected {self.fetcher1} to be not equal to {self.fetcher2}, but both were {value1}."])

def __str__(self):
return f"{self.fetcher1} != {self.fetcher2}"


class And(Logic):
def __init__(self, *constraints: Logic):
self.constraints = constraints

def __call__(self, args: List[Tuple[str, Any]]) -> Result:
errors = []
for constraint in self.constraints:
result = constraint(args)
if not result:
errors.extend(result.error_details)
if errors:
return Result.err(errors)
return Result.ok()

def __str__(self):
return " and ".join(str(constraint) for constraint in self.constraints)


class Not(Logic):
def __init__(self, constraint: Logic):
self.constraint = constraint

def __call__(self, args: List[Tuple[str, Any]]) -> Result:
result = self.constraint(args)
if result:
return Result.err([f"Expected NOT {self.constraint}, but it was satisfied."])
return Result.ok()

def __str__(self):
return f"NOT ({self.constraint})"
Loading