diff --git a/tripy/docs/README.md b/tripy/docs/README.md index a107706ed..612f0f326 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -45,7 +45,7 @@ which specifies doc metadata for each API (e.g. location). - Docstring must include *at least* **one [code example](#code-examples)**. - If the function accepts `tp.Tensor`s, must indicate **data type constraints** - with the [`wrappers.interface`](../nvtripy/utils/wrappers.py) decorator. + with the [`wrappers.interface`](../nvtripy/frontend/wrappers.py) decorator. **Example:** diff --git a/tripy/docs/post0_developer_guides/00-architecture.md b/tripy/docs/post0_developer_guides/00-architecture.md index 055dd0b3b..880cf5441 100644 --- a/tripy/docs/post0_developer_guides/00-architecture.md +++ b/tripy/docs/post0_developer_guides/00-architecture.md @@ -76,7 +76,7 @@ and various operations, e.g. {class}`nvtripy.resize`. :::{admonition} Info Most operations are decorated with: 1. [`@export.public_api`](source:/nvtripy/export.py): Enables documentation, type checking, and overloading. -2. [`@wrappers.interface`](source:/nvtripy/utils/wrappers.py): Enforces (and generates tests for) data type constraints. +2. [`@wrappers.interface`](source:/nvtripy/frontend/wrappers.py): Enforces (and generates tests for) data type constraints. ::: Operations are **lazily evaluated**. diff --git a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md index bd354d582..8a16af3a6 100644 --- a/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md +++ b/tripy/docs/post0_developer_guides/01-how-to-add-new-ops.md @@ -129,7 +129,7 @@ from typing import Tuple from nvtripy import export from nvtripy.trace.ops.topn import TopN -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.frontend.ops import utils as op_utils @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/config.py b/tripy/nvtripy/config.py index ba9d8d10c..f274aa124 100644 --- a/tripy/nvtripy/config.py +++ b/tripy/nvtripy/config.py @@ -50,7 +50,14 @@ module=sys.modules[__name__], symbol="enable_dtype_checking", )(True) -"""Whether to enable data type checking in API functions.""" +"""[DEPRECATED - use enable_input_validation] Whether to enable data type checking in API functions.""" + +enable_input_validation: bool = export.public_api( + document_under="config.rst", + module=sys.modules[__name__], + symbol="enable_input_validation", +)(True) +"""Whether to enable input validation in API functions.""" extra_error_information: List[str] = export.public_api( document_under="config.rst", diff --git a/tripy/nvtripy/frontend/constraints/__init__.py b/tripy/nvtripy/frontend/constraints/__init__.py new file mode 100644 index 000000000..2889b36bb --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/__init__.py @@ -0,0 +1,19 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.fetcher import Fetcher, GetDataType, GetInput, GetReturn, ValueFetcher +from nvtripy.frontend.constraints.logic import And, Equal, Logic, NotEqual, NotOneOf, OneOf, Or diff --git a/tripy/nvtripy/frontend/constraints/base.py b/tripy/nvtripy/frontend/constraints/base.py new file mode 100644 index 000000000..ffecad43b --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/base.py @@ -0,0 +1,130 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +The constraints system has two purposes: + +1. Imposing input requirements. +2. Describing output guarantees. + +Constraints are specified by composing one or more `Constraints` subclasses: + +```py +constraint = And( + Equal(GetDataType(GetInput("input0")), GetInput("dtype")), + Equal(GetDataType(GetInput("input1")), GetInput("dtype")), + ) +``` + +We also override several bitwise operators and properties to provide a convenient shorthand. +For example, the above can be written as: + +```py +constraint = (GetInput("input0").dtype == GetInput("dtype")) & (GetInput("input1").dtype == GetInput("dtype")) +``` + +The constraints class also provides a pattern matcher. +For example, we may want to find all constraints that check the data type of an input (`None` is a wildcard). + +```py +matches = constraint.find(Equal(GetDataType(GetInput), None)) +``` +""" + +from abc import ABC +from typing import List + + +class Constraints(ABC): + """ + Base class for the entire constraints system. + """ + + def get_children(self) -> List["Constraints"]: + children = [] + for attr_value in vars(self).values(): + if isinstance(attr_value, Constraints): + children.append(attr_value) + elif isinstance(attr_value, (list, tuple)): + children.extend(v for v in attr_value if isinstance(v, Constraints)) + return children + + def find(self, pattern: "Constraints") -> List["Constraints"]: + """ + Find all constraints in the tree that match the given pattern. + + Performs a depth-first search through the constraint tree to find all + constraints that structurally match the given pattern, using the current + constraint as the root node. + + Args: + pattern: The pattern to search for (e.g., Equal(GetDataType, GetDataType)). + Use None as a wildcard to match anything. + + Returns: + A list of all matching constraints found in the tree. + + Example: + pattern = Equal(GetDataType(TensorFetcher), None) # None matches any second argument + matches = constraint_tree.find(pattern) + """ + + def matches_pattern(pattern: Constraints, constraint: Constraints) -> bool: + # None is a wildcard that matches anything + if pattern is None: + return True + + if isinstance(pattern, type): + return isinstance(constraint, pattern) + + if type(pattern) != type(constraint): + return False + + # Need to index into sequences rather than comparing directly since there may be patterns in the sequence. + if isinstance(pattern, (list, tuple)) and isinstance(constraint, (list, tuple)): + if len(pattern) != len(constraint): + return False + return all(matches_pattern(p_val, c_val) for p_val, c_val in zip(pattern, constraint)) + + if not isinstance(pattern, Constraints): + return pattern == constraint + + # Compare attributes + pattern_attrs = vars(pattern) + constraint_attrs = vars(constraint) + + for key, pattern_value in pattern_attrs.items(): + if key not in constraint_attrs: + return False + + constraint_value = constraint_attrs[key] + + if not matches_pattern(pattern_value, constraint_value): + return False + + return True + + matches = [] + + if matches_pattern(pattern, self): + matches.append(self) + + # Recursively search children + for child in self.get_children(): + matches.extend(child.find(pattern)) + + return matches diff --git a/tripy/nvtripy/frontend/constraints/fetcher.py b/tripy/nvtripy/frontend/constraints/fetcher.py new file mode 100644 index 000000000..4987426c8 --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/fetcher.py @@ -0,0 +1,110 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +from typing import Any, List, Optional, Sequence, Tuple + +from nvtripy.common.datatype import dtype as tp_dtype +from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints.base import Constraints + + +class Fetcher(Constraints): + """ + Fetches a value based on the function parameters or return value. + """ + + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: ... + + def __eq__(self, other: "Fetcher") -> "Equal": + from nvtripy.frontend.constraints.logic import Equal + + return Equal(self, other) + + def __ne__(self, other: "Fetcher") -> "NotEqual": + from nvtripy.frontend.constraints.logic import NotEqual + + return NotEqual(self, other) + + +class ValueFetcher(Fetcher): + @property + def dtype(self) -> "GetDataType": + return GetDataType(self) + + +class GetInput(ValueFetcher): + def __init__(self, name: str): + self.name = name + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + for name, value in args: + if name == self.name: + return value + assert False, f"Input '{self.name}' not found in arguments." + + def __str__(self): + return self.name + + +class GetReturn(ValueFetcher): + def __init__(self, index: int): + self.index = index + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + assert returns is not None, "No return values available." + return returns[self.index] + + def __str__(self): + return f"return[{self.index}]" + + +class GetDataType(Fetcher): + def __init__(self, value_fetcher: ValueFetcher): + self.value_fetcher = value_fetcher + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Any: + from nvtripy.frontend.tensor import Tensor + + def get_arg_dtype(arg: Any) -> tp_dtype: + if isinstance(arg, Sequence): + arg_dtypes = [get_arg_dtype(elem) for elem in arg] + + if len(set(arg_dtypes)) != 1: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [ + f"Mismatched data types in sequence argument.\n", + f"For parameter: '{self.value_fetcher}', all arguments must have the same data type, but got: " + f"{arg_dtypes}", + ], + ) + arg_dtype = arg_dtypes[0] + elif isinstance(arg, Tensor): + arg_dtype = arg.dtype + else: + raise_error( + f"Could not determine data type of {self.value_fetcher}", + [f"Expected a tensor or data type argument for {self.value_fetcher}, but got: {arg}"], + ) + return arg_dtype + + tensor = self.value_fetcher(args, returns) + return get_arg_dtype(tensor) + + def __str__(self): + return f"{self.value_fetcher}.dtype" diff --git a/tripy/nvtripy/frontend/constraints/logic.py b/tripy/nvtripy/frontend/constraints/logic.py new file mode 100644 index 000000000..de5e6982d --- /dev/null +++ b/tripy/nvtripy/frontend/constraints/logic.py @@ -0,0 +1,187 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import abstractmethod +from typing import Any, List, Optional, Sequence, Tuple + +from nvtripy.frontend.constraints.base import Constraints +from nvtripy.frontend.constraints.fetcher import Fetcher +from nvtripy.utils.result import Result + + +class Logic(Constraints): + """ + Represents logical operations on constraints. + """ + + # When the constraint is not met, the error details should complete the sentence: "Expected ..." + @abstractmethod + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: ... + + @abstractmethod + def inverse(self) -> "Logic": + """ + Returns the logical inverse of this constraint. + """ + ... + + def __and__(self, other: "Logic") -> "Logic": + if isinstance(self, And): + return And(*self.constraints, other) + elif isinstance(other, And): + return And(self, *other.constraints) + return And(self, other) + + def __or__(self, other: "Logic") -> "Logic": + if isinstance(self, Or): + return Or(*self.constraints, other) + elif isinstance(other, Or): + return Or(self, *other.constraints) + return Or(self, other) + + def __invert__(self) -> "Logic": + return self.inverse() + + +class OneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + self.fetcher = fetcher + # Need to convert generator expressions so we can use them more than once + self.options = list(options) + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value = self.fetcher(args, returns) + if value in self.options: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to be one of {self.options} (but it was '{value}')"]) + + def __str__(self): + return f"{self.fetcher} is one of {self.options}" + + def inverse(self) -> "Logic": + return NotOneOf(self.fetcher, self.options) + + +class NotOneOf(Logic): + def __init__(self, fetcher: Fetcher, options: Sequence[Any]): + self.fetcher = fetcher + self.options = list(options) + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value = self.fetcher(args, returns) + if value not in self.options: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to not be one of {self.options} (but it was '{value}')"]) + + def __str__(self): + return f"{self.fetcher} is not one of {self.options}" + + def inverse(self) -> "Logic": + return OneOf(self.fetcher, self.options) + + +def get_val_or_call_fetcher( + fetcher_or_value: Any, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None +) -> Any: + if isinstance(fetcher_or_value, Fetcher): + return fetcher_or_value(args, returns) + return fetcher_or_value + + +class Equal(Logic): + def __init__(self, fetcher: Fetcher, fetcher_or_value: Any): + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) + if value1 == value2: + return Result.ok() + + # TODO (pranavm): If fetcher_or_value is a Fetcher, include its value in the error message. + return Result.err([f"'{self.fetcher}' to be equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) + + def __str__(self): + return f"{self.fetcher} == {self.fetcher_or_value}" + + def inverse(self) -> "Logic": + return NotEqual(self.fetcher, self.fetcher_or_value) + + +class NotEqual(Logic): + def __init__(self, fetcher: Fetcher, fetcher_or_value: Fetcher): + self.fetcher = fetcher + self.fetcher_or_value = fetcher_or_value + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + value1 = self.fetcher(args, returns) + value2 = get_val_or_call_fetcher(self.fetcher_or_value, args, returns) + if value1 != value2: + return Result.ok() + + return Result.err([f"'{self.fetcher}' to be not equal to '{self.fetcher_or_value}' (but it was '{value1}')"]) + + def __str__(self): + return f"{self.fetcher} != {self.fetcher_or_value}" + + def inverse(self) -> "Logic": + return Equal(self.fetcher, self.fetcher_or_value) + + +class And(Logic): + def __init__(self, *constraints: Logic): + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + errors = [] + for constraint in self.constraints: + result = constraint(args, returns) + if not result: + errors.extend(([" and "] if errors else []) + result.error_details) + if errors: + return Result.err(errors) + return Result.ok() + + def __str__(self): + return "(" + " and ".join(str(constraint) for constraint in self.constraints) + ")" + + def inverse(self) -> "Logic": + # De Morgan's law: not (A and B) = (not A) or (not B) + return Or(*(constraint.inverse() for constraint in self.constraints)) + + +class Or(Logic): + def __init__(self, *constraints: Logic): + self.constraints = constraints + + def __call__(self, args: List[Tuple[str, Any]], returns: Optional[Tuple[Any]] = None) -> Result: + all_errors = [] + for constraint in self.constraints: + result = constraint(args, returns) + if result: + return Result.ok() + all_errors.extend(([" or "] if all_errors else []) + result.error_details) + return Result.err(all_errors) + + def __str__(self): + return "(" + " or ".join(str(constraint) for constraint in self.constraints) + ")" + + def inverse(self) -> "Logic": + # De Morgan's law: not (A or B) = (not A) and (not B) + return And(*(constraint.inverse() for constraint in self.constraints)) diff --git a/tripy/nvtripy/frontend/module/batchnorm.py b/tripy/nvtripy/frontend/module/batchnorm.py index f7c1380a2..d08a2ff78 100644 --- a/tripy/nvtripy/frontend/module/batchnorm.py +++ b/tripy/nvtripy/frontend/module/batchnorm.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_features"]) +@constant_fields(["num_features"]) class BatchNorm(Module): r""" Applies batch normalization over an N-dimensional input tensor using precomputed statistics: @@ -105,8 +106,8 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": Returns: A tensor of the same shape as the input. """ - from nvtripy.frontend.ops.unary.rsqrt import rsqrt from nvtripy.frontend.ops.reshape import reshape + from nvtripy.frontend.ops.unary.rsqrt import rsqrt x_shape = (1, self.num_features, *([1] * (x.rank - 2))) diff --git a/tripy/nvtripy/frontend/module/conv/base.py b/tripy/nvtripy/frontend/module/conv/base.py index 812e08fef..fcebf79f0 100644 --- a/tripy/nvtripy/frontend/module/conv/base.py +++ b/tripy/nvtripy/frontend/module/conv/base.py @@ -23,10 +23,11 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @dataclass -@utils.wrappers.constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) +@constant_fields(["dtype", "padding", "stride", "groups", "dilation"]) class ConvBase(Module): r"""Base class for sharing common functionality between Conv and ConvTranspose.""" diff --git a/tripy/nvtripy/frontend/module/conv/conv.py b/tripy/nvtripy/frontend/module/conv/conv.py index 2c322fdf8..978cc8394 100644 --- a/tripy/nvtripy/frontend/module/conv/conv.py +++ b/tripy/nvtripy/frontend/module/conv/conv.py @@ -26,7 +26,7 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.convolution import Convolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers # This function is added so that we can do dtype checking. diff --git a/tripy/nvtripy/frontend/module/conv/conv_transpose.py b/tripy/nvtripy/frontend/module/conv/conv_transpose.py index 3c5b5666e..4becc1103 100644 --- a/tripy/nvtripy/frontend/module/conv/conv_transpose.py +++ b/tripy/nvtripy/frontend/module/conv/conv_transpose.py @@ -26,7 +26,7 @@ from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor from nvtripy.trace.ops.deconvolution import Deconvolution -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers # This function is added so that we can do dtype checking. diff --git a/tripy/nvtripy/frontend/module/embedding.py b/tripy/nvtripy/frontend/module/embedding.py index b8e5b72f0..d72e4134c 100644 --- a/tripy/nvtripy/frontend/module/embedding.py +++ b/tripy/nvtripy/frontend/module/embedding.py @@ -22,11 +22,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype"]) +@constant_fields(["dtype"]) class Embedding(Module): """ A lookup table for embedding vectors of a fixed size. diff --git a/tripy/nvtripy/frontend/module/groupnorm.py b/tripy/nvtripy/frontend/module/groupnorm.py index 5f98807ca..390f07fa8 100644 --- a/tripy/nvtripy/frontend/module/groupnorm.py +++ b/tripy/nvtripy/frontend/module/groupnorm.py @@ -25,11 +25,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_groups", "num_channels", "dtype"]) +@constant_fields(["num_groups", "num_channels", "dtype"]) class GroupNorm(Module): r""" Applies group normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/instancenorm.py b/tripy/nvtripy/frontend/module/instancenorm.py index 54b29a250..5a9a41ee4 100644 --- a/tripy/nvtripy/frontend/module/instancenorm.py +++ b/tripy/nvtripy/frontend/module/instancenorm.py @@ -17,15 +17,15 @@ from dataclasses import dataclass -from nvtripy import constants, export, utils +from nvtripy import constants, export from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.instancenorm import InstanceNorm as InstanceNormOp @@ -81,7 +81,7 @@ def instancenorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["num_channels", "dtype", "eps"]) +@constant_fields(["num_channels", "dtype", "eps"]) class InstanceNorm(Module): r""" Applies Instance Normalization over a mini-batch of inputs: diff --git a/tripy/nvtripy/frontend/module/layernorm.py b/tripy/nvtripy/frontend/module/layernorm.py index be2627ed1..1128d41cf 100644 --- a/tripy/nvtripy/frontend/module/layernorm.py +++ b/tripy/nvtripy/frontend/module/layernorm.py @@ -21,12 +21,12 @@ from nvtripy import export, utils from nvtripy.common import datatype from nvtripy.common.exception import raise_error +from nvtripy.frontend import wrappers from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter -from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields from nvtripy.trace.ops.layernorm import LayerNorm as LayerNormOp @@ -70,7 +70,7 @@ def layernorm( @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "normalized_shape"]) +@constant_fields(["dtype", "normalized_shape"]) class LayerNorm(Module): r""" Applies layer normalization over the input tensor: diff --git a/tripy/nvtripy/frontend/module/linear.py b/tripy/nvtripy/frontend/module/linear.py index 6c3c06eba..919c9e1f1 100644 --- a/tripy/nvtripy/frontend/module/linear.py +++ b/tripy/nvtripy/frontend/module/linear.py @@ -23,11 +23,12 @@ from nvtripy.frontend.module.module import Module from nvtripy.frontend.module.parameter import DefaultParameter, OptionalParameter from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields @export.public_api(document_under="operations/modules") @dataclass -@utils.wrappers.constant_fields(["dtype", "quant_dtype"]) +@constant_fields(["dtype", "quant_dtype"]) class Linear(Module): r""" Applies a linear transformation to the input: @@ -117,11 +118,11 @@ def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": A tensor of shape :math:`[*, \text{out_features}]`. """ from nvtripy.common.exception import raise_error - from nvtripy.frontend.tensor import Tensor - from nvtripy.frontend.ops.transpose import transpose - from nvtripy.frontend.ops.unsqueeze import unsqueeze from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize + from nvtripy.frontend.ops.transpose import transpose + from nvtripy.frontend.ops.unsqueeze import unsqueeze + from nvtripy.frontend.tensor import Tensor if self.quant_dtype is not None: if isinstance(self.input_scale, Tensor): diff --git a/tripy/nvtripy/frontend/ops/allclose.py b/tripy/nvtripy/frontend/ops/allclose.py index 28ac163ed..0b4d756e2 100644 --- a/tripy/nvtripy/frontend/ops/allclose.py +++ b/tripy/nvtripy/frontend/ops/allclose.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/arange.py b/tripy/nvtripy/frontend/ops/arange.py index e20be7917..6aa07ca8f 100644 --- a/tripy/nvtripy/frontend/ops/arange.py +++ b/tripy/nvtripy/frontend/ops/arange.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops.cast import cast from nvtripy.frontend.ops.reshape import reshape from nvtripy.trace.ops.linspace import Linspace -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/binary/add.py b/tripy/nvtripy/frontend/ops/binary/add.py index f673c452c..e936129cd 100644 --- a/tripy/nvtripy/frontend/ops/binary/add.py +++ b/tripy/nvtripy/frontend/ops/binary/add.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Add from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__add__") diff --git a/tripy/nvtripy/frontend/ops/binary/div.py b/tripy/nvtripy/frontend/ops/binary/div.py index 125b42b46..482a06a59 100644 --- a/tripy/nvtripy/frontend/ops/binary/div.py +++ b/tripy/nvtripy/frontend/ops/binary/div.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Div from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__truediv__") diff --git a/tripy/nvtripy/frontend/ops/binary/equal.py b/tripy/nvtripy/frontend/ops/binary/equal.py index 55c899b36..699882a74 100644 --- a/tripy/nvtripy/frontend/ops/binary/equal.py +++ b/tripy/nvtripy/frontend/ops/binary/equal.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Equal from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__eq__") diff --git a/tripy/nvtripy/frontend/ops/binary/floor_div.py b/tripy/nvtripy/frontend/ops/binary/floor_div.py index 4ddbe0a5b..e221e0fb8 100644 --- a/tripy/nvtripy/frontend/ops/binary/floor_div.py +++ b/tripy/nvtripy/frontend/ops/binary/floor_div.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import FloorDiv from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__floordiv__") diff --git a/tripy/nvtripy/frontend/ops/binary/greater.py b/tripy/nvtripy/frontend/ops/binary/greater.py index c82ded594..cf5504d7a 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater.py +++ b/tripy/nvtripy/frontend/ops/binary/greater.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Greater from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__gt__") diff --git a/tripy/nvtripy/frontend/ops/binary/greater_equal.py b/tripy/nvtripy/frontend/ops/binary/greater_equal.py index 6be66c9db..c75c3cb12 100644 --- a/tripy/nvtripy/frontend/ops/binary/greater_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/greater_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ge__") diff --git a/tripy/nvtripy/frontend/ops/binary/less.py b/tripy/nvtripy/frontend/ops/binary/less.py index 6495319fa..6b8431f39 100644 --- a/tripy/nvtripy/frontend/ops/binary/less.py +++ b/tripy/nvtripy/frontend/ops/binary/less.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Less from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__lt__") diff --git a/tripy/nvtripy/frontend/ops/binary/less_equal.py b/tripy/nvtripy/frontend/ops/binary/less_equal.py index 14bf35c9b..9694fa4e7 100644 --- a/tripy/nvtripy/frontend/ops/binary/less_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/less_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__le__") diff --git a/tripy/nvtripy/frontend/ops/binary/logical_or.py b/tripy/nvtripy/frontend/ops/binary/logical_or.py index 2302415ce..268f42b73 100644 --- a/tripy/nvtripy/frontend/ops/binary/logical_or.py +++ b/tripy/nvtripy/frontend/ops/binary/logical_or.py @@ -15,7 +15,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import LogicalOr -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__or__") diff --git a/tripy/nvtripy/frontend/ops/binary/maximum.py b/tripy/nvtripy/frontend/ops/binary/maximum.py index 05023378d..c25d8152a 100644 --- a/tripy/nvtripy/frontend/ops/binary/maximum.py +++ b/tripy/nvtripy/frontend/ops/binary/maximum.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Max from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/binary/minimum.py b/tripy/nvtripy/frontend/ops/binary/minimum.py index 0a1954b1b..19a7e5a74 100644 --- a/tripy/nvtripy/frontend/ops/binary/minimum.py +++ b/tripy/nvtripy/frontend/ops/binary/minimum.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Min from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/binary/mod.py b/tripy/nvtripy/frontend/ops/binary/mod.py index f38203f43..f22ee1295 100644 --- a/tripy/nvtripy/frontend/ops/binary/mod.py +++ b/tripy/nvtripy/frontend/ops/binary/mod.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def mod_impl(lhs, rhs): diff --git a/tripy/nvtripy/frontend/ops/binary/mul.py b/tripy/nvtripy/frontend/ops/binary/mul.py index e6bcd9940..0fc2c9811 100644 --- a/tripy/nvtripy/frontend/ops/binary/mul.py +++ b/tripy/nvtripy/frontend/ops/binary/mul.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Mul from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__mul__") diff --git a/tripy/nvtripy/frontend/ops/binary/not_equal.py b/tripy/nvtripy/frontend/ops/binary/not_equal.py index 6e175dc84..d4a8b4fbb 100644 --- a/tripy/nvtripy/frontend/ops/binary/not_equal.py +++ b/tripy/nvtripy/frontend/ops/binary/not_equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__ne__") diff --git a/tripy/nvtripy/frontend/ops/binary/pow.py b/tripy/nvtripy/frontend/ops/binary/pow.py index 285e73141..17b55e1d3 100644 --- a/tripy/nvtripy/frontend/ops/binary/pow.py +++ b/tripy/nvtripy/frontend/ops/binary/pow.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Pow from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__pow__") diff --git a/tripy/nvtripy/frontend/ops/binary/sub.py b/tripy/nvtripy/frontend/ops/binary/sub.py index 94b51ec9e..7b4bc4d6e 100644 --- a/tripy/nvtripy/frontend/ops/binary/sub.py +++ b/tripy/nvtripy/frontend/ops/binary/sub.py @@ -16,7 +16,7 @@ from nvtripy.frontend.ops.binary.create import create_binary_op from nvtripy.trace.ops.binary import Sub from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__sub__") diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index 9c197af57..ecc4e44a1 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -17,19 +17,26 @@ from nvtripy import export -from nvtripy.common.datatype import bool as tp_bool -from nvtripy.common.datatype import float32, int8 +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.frontend.ops.dequantize import dequantize from nvtripy.frontend.ops.quantize import quantize from nvtripy.trace.ops.cast import Cast -from nvtripy.utils import wrappers @register_tensor_method("cast") @export.public_api(document_under="operations/functions") @wrappers.interface( + input_requirements=( + ((GetInput("input").dtype != dt.float8) | ~OneOf(GetInput("dtype"), [dt.int4, dt.int8])) + & ((GetInput("input").dtype != dt.int8) | (GetInput("dtype") != dt.float8)) + & ((GetInput("input").dtype != dt.int4) | ~OneOf(GetInput("dtype"), [dt.float8, dt.int8, dt.int64])) + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), + # TODO (pranavm): Remove old dtype constraints system: dtype_constraints={"input": "T1", "dtype": "T2", wrappers.RETURN_VALUE: "T2"}, dtype_variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], @@ -79,14 +86,14 @@ def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": # If given a quantized input, dequantize before converting. If bool is the target dtype, # we do still need to quantize int8s because it compiles into an MLIR-TRT *comparison* op - if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != int8 or dtype == tp_bool): - dequant_dtype = float32 + if op_utils.is_quantized_dtype(input.dtype) and (input.dtype != dt.int8 or dtype == dt.bool): + dequant_dtype = dt.float32 input = dequantize(input, 1.0, dequant_dtype) if dtype == dequant_dtype: return input - if op_utils.is_quantized_dtype(dtype) and dtype != int8: - if input.dtype != float32: - input = op_utils.create_op(Cast, [input], float32) + if op_utils.is_quantized_dtype(dtype) and dtype != dt.int8: + if input.dtype != dt.float32: + input = op_utils.create_op(Cast, [input], dt.float32) return quantize(input, 1.0, dtype) return op_utils.create_op(Cast, [input], dtype) diff --git a/tripy/nvtripy/frontend/ops/concatenate.py b/tripy/nvtripy/frontend/ops/concatenate.py index de9d95de7..4f1826db6 100644 --- a/tripy/nvtripy/frontend/ops/concatenate.py +++ b/tripy/nvtripy/frontend/ops/concatenate.py @@ -21,7 +21,11 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.concatenate import Concatenate -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +# constraints = OneOf(GetInput("tensors").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.float8, tp.int4, tp.int8, tp.int32, tp.int64, tp.bool]) +# output_guarantees = GetReturn(0).dtype == GetInput("tensors").dtype @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index 4e5c10d6a..488cde55a 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -22,7 +22,7 @@ from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("copy") diff --git a/tripy/nvtripy/frontend/ops/cumsum.py b/tripy/nvtripy/frontend/ops/cumsum.py index c0dd76902..711d717ae 100644 --- a/tripy/nvtripy/frontend/ops/cumsum.py +++ b/tripy/nvtripy/frontend/ops/cumsum.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index 816efa28c..cc0d6a39b 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -22,7 +22,7 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.dequantize import Dequantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/quantization") diff --git a/tripy/nvtripy/frontend/ops/equal.py b/tripy/nvtripy/frontend/ops/equal.py index 7d5fc3a7b..93ac7256f 100644 --- a/tripy/nvtripy/frontend/ops/equal.py +++ b/tripy/nvtripy/frontend/ops/equal.py @@ -14,7 +14,7 @@ # limitations under the License. from nvtripy import export from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/expand.py b/tripy/nvtripy/frontend/ops/expand.py index 6cced90bc..41b975bd3 100644 --- a/tripy/nvtripy/frontend/ops/expand.py +++ b/tripy/nvtripy/frontend/ops/expand.py @@ -21,7 +21,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def process_sizes(input: "nvtripy.Tensor", sizes: ShapeLike): diff --git a/tripy/nvtripy/frontend/ops/flatten.py b/tripy/nvtripy/frontend/ops/flatten.py index a356bf123..85b3b175c 100644 --- a/tripy/nvtripy/frontend/ops/flatten.py +++ b/tripy/nvtripy/frontend/ops/flatten.py @@ -18,7 +18,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("flatten") diff --git a/tripy/nvtripy/frontend/ops/flip.py b/tripy/nvtripy/frontend/ops/flip.py index 16912f97b..c2e3d71d1 100644 --- a/tripy/nvtripy/frontend/ops/flip.py +++ b/tripy/nvtripy/frontend/ops/flip.py @@ -19,7 +19,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/full.py b/tripy/nvtripy/frontend/ops/full.py index 044cfe1e5..e8882250a 100644 --- a/tripy/nvtripy/frontend/ops/full.py +++ b/tripy/nvtripy/frontend/ops/full.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.broadcast import Broadcast from nvtripy.types import ShapeLike, TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index 7af68d070..bced20763 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -19,7 +19,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.gather import Gather -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/iota.py b/tripy/nvtripy/frontend/ops/iota.py index fa5b247f1..6df4d2668 100644 --- a/tripy/nvtripy/frontend/ops/iota.py +++ b/tripy/nvtripy/frontend/ops/iota.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.linspace import Linspace from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype) -> "nvtripy.Tensor": diff --git a/tripy/nvtripy/frontend/ops/masked_fill.py b/tripy/nvtripy/frontend/ops/masked_fill.py index 3bfb54226..540198fc4 100644 --- a/tripy/nvtripy/frontend/ops/masked_fill.py +++ b/tripy/nvtripy/frontend/ops/masked_fill.py @@ -15,7 +15,7 @@ import numbers from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/matmul.py b/tripy/nvtripy/frontend/ops/matmul.py index 1f323987f..43210a79a 100644 --- a/tripy/nvtripy/frontend/ops/matmul.py +++ b/tripy/nvtripy/frontend/ops/matmul.py @@ -19,7 +19,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.matmul import MatrixMultiply -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__matmul__") diff --git a/tripy/nvtripy/frontend/ops/ones.py b/tripy/nvtripy/frontend/ops/ones.py index 2abcd81cc..154d5863c 100644 --- a/tripy/nvtripy/frontend/ops/ones.py +++ b/tripy/nvtripy/frontend/ops/ones.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common import datatype from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/outer.py b/tripy/nvtripy/frontend/ops/outer.py index fae200134..c25652023 100644 --- a/tripy/nvtripy/frontend/ops/outer.py +++ b/tripy/nvtripy/frontend/ops/outer.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/pad.py b/tripy/nvtripy/frontend/ops/pad.py index f58afdfd7..42512cf49 100644 --- a/tripy/nvtripy/frontend/ops/pad.py +++ b/tripy/nvtripy/frontend/ops/pad.py @@ -23,7 +23,7 @@ from nvtripy.trace.ops.shape import Shape from nvtripy.trace.ops.slice import SliceFill, SliceReflect from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index a367ce939..8f798ec15 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -22,7 +22,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.permute import Permute -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("permute") diff --git a/tripy/nvtripy/frontend/ops/pooling/avgpool.py b/tripy/nvtripy/frontend/ops/pooling/avgpool.py index 50eb3a18d..68038cbd5 100644 --- a/tripy/nvtripy/frontend/ops/pooling/avgpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/avgpool.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import AvgPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/pooling/maxpool.py b/tripy/nvtripy/frontend/ops/pooling/maxpool.py index c38ab13f6..dd177e780 100644 --- a/tripy/nvtripy/frontend/ops/pooling/maxpool.py +++ b/tripy/nvtripy/frontend/ops/pooling/maxpool.py @@ -21,7 +21,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops.pooling import utils as pooling_utils from nvtripy.trace.ops.pooling import MaxPooling -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 69c062a0b..3d229789d 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -22,7 +22,7 @@ from nvtripy.common import datatype from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.quantize import Quantize -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/quantization") diff --git a/tripy/nvtripy/frontend/ops/reduce/all.py b/tripy/nvtripy/frontend/ops/reduce/all.py index fc3116287..d7dc507e4 100644 --- a/tripy/nvtripy/frontend/ops/reduce/all.py +++ b/tripy/nvtripy/frontend/ops/reduce/all.py @@ -15,7 +15,7 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/reduce/any.py b/tripy/nvtripy/frontend/ops/reduce/any.py index 01630dd60..c9f4af47b 100644 --- a/tripy/nvtripy/frontend/ops/reduce/any.py +++ b/tripy/nvtripy/frontend/ops/reduce/any.py @@ -15,7 +15,7 @@ from typing import Optional, Sequence, Union from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/reduce/argmax.py b/tripy/nvtripy/frontend/ops/reduce/argmax.py index e3abdb4c0..7d5ee84ab 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmax.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmax.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/argmin.py b/tripy/nvtripy/frontend/ops/reduce/argmin.py index 69d616162..bc09187a5 100644 --- a/tripy/nvtripy/frontend/ops/reduce/argmin.py +++ b/tripy/nvtripy/frontend/ops/reduce/argmin.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import arg_min_max_impl from nvtripy.trace.ops.topk import TopKMin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/max.py b/tripy/nvtripy/frontend/ops/reduce/max.py index 5b28f99a4..b6b569b2d 100644 --- a/tripy/nvtripy/frontend/ops/reduce/max.py +++ b/tripy/nvtripy/frontend/ops/reduce/max.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Max -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/mean.py b/tripy/nvtripy/frontend/ops/reduce/mean.py index 5a85e6357..e9dd4d5a1 100644 --- a/tripy/nvtripy/frontend/ops/reduce/mean.py +++ b/tripy/nvtripy/frontend/ops/reduce/mean.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Avg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/min.py b/tripy/nvtripy/frontend/ops/reduce/min.py index 1bd15785b..b08c5a494 100644 --- a/tripy/nvtripy/frontend/ops/reduce/min.py +++ b/tripy/nvtripy/frontend/ops/reduce/min.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Min -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/prod.py b/tripy/nvtripy/frontend/ops/reduce/prod.py index 75ab67f21..016bf20e8 100644 --- a/tripy/nvtripy/frontend/ops/reduce/prod.py +++ b/tripy/nvtripy/frontend/ops/reduce/prod.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Prod -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/sum.py b/tripy/nvtripy/frontend/ops/reduce/sum.py index e118eb397..9e85e4197 100644 --- a/tripy/nvtripy/frontend/ops/reduce/sum.py +++ b/tripy/nvtripy/frontend/ops/reduce/sum.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import reduce_impl from nvtripy.trace.ops.reduce import Sum -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/topk.py b/tripy/nvtripy/frontend/ops/reduce/topk.py index 57e1e2dfe..e5713c82a 100644 --- a/tripy/nvtripy/frontend/ops/reduce/topk.py +++ b/tripy/nvtripy/frontend/ops/reduce/topk.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,11 @@ from nvtripy import export from nvtripy.frontend.ops.reduce.utils import topk_impl from nvtripy.trace.ops.topk import TopKMax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers + + +# constraints = OneOf(GetInput("input").dtype, [tp.float32, tp.float16, tp.bfloat16, tp.int32, tp.int64]) +# output_guarantees = (GetReturn(0).dtype == GetInput("input").dtype) & (GetReturn(1).dtype == tp.int32)) @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reduce/var.py b/tripy/nvtripy/frontend/ops/reduce/var.py index ad8cd4f8d..585600034 100644 --- a/tripy/nvtripy/frontend/ops/reduce/var.py +++ b/tripy/nvtripy/frontend/ops/reduce/var.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/repeat.py b/tripy/nvtripy/frontend/ops/repeat.py index d7e531b41..9dc498e4c 100644 --- a/tripy/nvtripy/frontend/ops/repeat.py +++ b/tripy/nvtripy/frontend/ops/repeat.py @@ -18,7 +18,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index 262dc78b8..ae0eb52a0 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.reshape import Reshape from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers def infer_dimensions(input: "nvtripy.Tensor", shape: ShapeLike) -> ShapeLike: diff --git a/tripy/nvtripy/frontend/ops/resize.py b/tripy/nvtripy/frontend/ops/resize.py index 477623062..31fba7ad3 100644 --- a/tripy/nvtripy/frontend/ops/resize.py +++ b/tripy/nvtripy/frontend/ops/resize.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.resize import ResizeCubic, ResizeLinear, ResizeNearest from nvtripy.types import ShapeLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers SUPPORTED_MODES = ("cubic", "linear", "nearest") diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 94e3fa19c..08795206d 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -25,7 +25,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.shape import GetDimensionSize, Shape from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("shape") diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index 010e04dee..032fc70e5 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -23,7 +23,7 @@ from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.slice import Slice from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.utils.types import type_str_from_arg from nvtripy.utils.utils import make_list diff --git a/tripy/nvtripy/frontend/ops/softmax.py b/tripy/nvtripy/frontend/ops/softmax.py index f195eac81..002332319 100644 --- a/tripy/nvtripy/frontend/ops/softmax.py +++ b/tripy/nvtripy/frontend/ops/softmax.py @@ -20,7 +20,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.softmax import Softmax -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 76a3d4a8d..829f74b52 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -21,7 +21,7 @@ from nvtripy.common.exception import raise_error from nvtripy.frontend.ops import utils as op_utils from nvtripy.types import IntLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index 5ec260e15..132fd7836 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -17,7 +17,7 @@ from nvtripy import export, utils from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("squeeze") diff --git a/tripy/nvtripy/frontend/ops/stack.py b/tripy/nvtripy/frontend/ops/stack.py index 0f2df667e..e46c8b69f 100644 --- a/tripy/nvtripy/frontend/ops/stack.py +++ b/tripy/nvtripy/frontend/ops/stack.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/transpose.py b/tripy/nvtripy/frontend/ops/transpose.py index 1d4b8cf3d..9119f6937 100644 --- a/tripy/nvtripy/frontend/ops/transpose.py +++ b/tripy/nvtripy/frontend/ops/transpose.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("transpose") diff --git a/tripy/nvtripy/frontend/ops/tril.py b/tripy/nvtripy/frontend/ops/tril.py index 563400e16..a90cc8bb4 100644 --- a/tripy/nvtripy/frontend/ops/tril.py +++ b/tripy/nvtripy/frontend/ops/tril.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +18,7 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.zeros import zeros_like from nvtripy.frontend.ops.where import where -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/triu.py b/tripy/nvtripy/frontend/ops/triu.py index 80dc08103..95a3bc02f 100644 --- a/tripy/nvtripy/frontend/ops/triu.py +++ b/tripy/nvtripy/frontend/ops/triu.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +18,7 @@ from nvtripy.frontend.ops.iota import iota_like from nvtripy.frontend.ops.where import where from nvtripy.frontend.ops.zeros import zeros_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/unary/abs.py b/tripy/nvtripy/frontend/ops/unary/abs.py index 9447cecbb..6ae52b523 100644 --- a/tripy/nvtripy/frontend/ops/unary/abs.py +++ b/tripy/nvtripy/frontend/ops/unary/abs.py @@ -19,7 +19,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Abs -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__abs__") diff --git a/tripy/nvtripy/frontend/ops/unary/cos.py b/tripy/nvtripy/frontend/ops/unary/cos.py index cabe9f697..fbf9ba35d 100644 --- a/tripy/nvtripy/frontend/ops/unary/cos.py +++ b/tripy/nvtripy/frontend/ops/unary/cos.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Cos -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/exp.py b/tripy/nvtripy/frontend/ops/unary/exp.py index 4895a4556..837f89051 100644 --- a/tripy/nvtripy/frontend/ops/unary/exp.py +++ b/tripy/nvtripy/frontend/ops/unary/exp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Exp -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/gelu.py b/tripy/nvtripy/frontend/ops/unary/gelu.py index a99dd240c..44a9d60bd 100644 --- a/tripy/nvtripy/frontend/ops/unary/gelu.py +++ b/tripy/nvtripy/frontend/ops/unary/gelu.py @@ -17,7 +17,7 @@ from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import GeluErf from nvtripy.frontend.ops import utils as op_utils diff --git a/tripy/nvtripy/frontend/ops/unary/invert.py b/tripy/nvtripy/frontend/ops/unary/invert.py index 343f7525f..8fea29d6a 100644 --- a/tripy/nvtripy/frontend/ops/unary/invert.py +++ b/tripy/nvtripy/frontend/ops/unary/invert.py @@ -15,7 +15,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Not -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__invert__") diff --git a/tripy/nvtripy/frontend/ops/unary/log.py b/tripy/nvtripy/frontend/ops/unary/log.py index 74257a948..8d21af689 100644 --- a/tripy/nvtripy/frontend/ops/unary/log.py +++ b/tripy/nvtripy/frontend/ops/unary/log.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Log -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/neg.py b/tripy/nvtripy/frontend/ops/unary/neg.py index 3849364b4..0e08130a0 100644 --- a/tripy/nvtripy/frontend/ops/unary/neg.py +++ b/tripy/nvtripy/frontend/ops/unary/neg.py @@ -17,7 +17,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.unary import Neg -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("__neg__") diff --git a/tripy/nvtripy/frontend/ops/unary/relu.py b/tripy/nvtripy/frontend/ops/unary/relu.py index 2db6aac9f..531cff1dc 100644 --- a/tripy/nvtripy/frontend/ops/unary/relu.py +++ b/tripy/nvtripy/frontend/ops/unary/relu.py @@ -18,7 +18,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Relu -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.common import datatype diff --git a/tripy/nvtripy/frontend/ops/unary/rsqrt.py b/tripy/nvtripy/frontend/ops/unary/rsqrt.py index 5cd215073..9041deef6 100644 --- a/tripy/nvtripy/frontend/ops/unary/rsqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/rsqrt.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Recip -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sigmoid.py b/tripy/nvtripy/frontend/ops/unary/sigmoid.py index be7ea05a7..efee06568 100644 --- a/tripy/nvtripy/frontend/ops/unary/sigmoid.py +++ b/tripy/nvtripy/frontend/ops/unary/sigmoid.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.trace.ops.unary import Sigmoid from nvtripy.frontend.ops import utils as op_utils diff --git a/tripy/nvtripy/frontend/ops/unary/silu.py b/tripy/nvtripy/frontend/ops/unary/silu.py index 3813f06e3..6e359bec5 100644 --- a/tripy/nvtripy/frontend/ops/unary/silu.py +++ b/tripy/nvtripy/frontend/ops/unary/silu.py @@ -16,7 +16,7 @@ # from nvtripy import export -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sin.py b/tripy/nvtripy/frontend/ops/unary/sin.py index 7078a30b7..bf7ae4e1b 100644 --- a/tripy/nvtripy/frontend/ops/unary/sin.py +++ b/tripy/nvtripy/frontend/ops/unary/sin.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sin -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/sqrt.py b/tripy/nvtripy/frontend/ops/unary/sqrt.py index 6a67ed9db..b8681357c 100644 --- a/tripy/nvtripy/frontend/ops/unary/sqrt.py +++ b/tripy/nvtripy/frontend/ops/unary/sqrt.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Sqrt -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unary/tanh.py b/tripy/nvtripy/frontend/ops/unary/tanh.py index fac66c675..67f9bc595 100644 --- a/tripy/nvtripy/frontend/ops/unary/tanh.py +++ b/tripy/nvtripy/frontend/ops/unary/tanh.py @@ -15,7 +15,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary import Tanh -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/unsqueeze.py b/tripy/nvtripy/frontend/ops/unsqueeze.py index fa3e25045..e14612d49 100644 --- a/tripy/nvtripy/frontend/ops/unsqueeze.py +++ b/tripy/nvtripy/frontend/ops/unsqueeze.py @@ -18,7 +18,7 @@ from nvtripy import export from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @register_tensor_method("unsqueeze") diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index 7c41a4ae0..277a61490 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -20,7 +20,7 @@ from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.where import Where from nvtripy.types import TensorLike -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/ops/zeros.py b/tripy/nvtripy/frontend/ops/zeros.py index 177d20cf0..4f3028a21 100644 --- a/tripy/nvtripy/frontend/ops/zeros.py +++ b/tripy/nvtripy/frontend/ops/zeros.py @@ -17,7 +17,7 @@ from nvtripy import export from nvtripy.common import datatype from nvtripy.frontend.ops.full import full, full_like -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/utils/wrappers.py b/tripy/nvtripy/frontend/wrappers.py similarity index 85% rename from tripy/nvtripy/utils/wrappers.py rename to tripy/nvtripy/frontend/wrappers.py index 7f84427c3..a96d3aa7e 100644 --- a/tripy/nvtripy/utils/wrappers.py +++ b/tripy/nvtripy/frontend/wrappers.py @@ -15,17 +15,18 @@ # limitations under the License. # + import functools import inspect -import types from dataclasses import dataclass from textwrap import indent from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from nvtripy import config, utils +from nvtripy.common.datatype import DATA_TYPES from nvtripy.common.exception import raise_error +from nvtripy.frontend.constraints import Constraints from nvtripy.utils import result -from nvtripy.common.datatype import DATA_TYPES @dataclass @@ -232,8 +233,63 @@ def add_arg(arg): return new_args, new_kwargs, new_merged_args +def _doc_str(obj: Any) -> str: + """ + Returns a string representation of an object for use in the public documentation. + """ + from nvtripy.common.datatype import dtype as tp_dtype + from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or + from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn + + if isinstance(obj, tp_dtype): + return f":class:`{obj.name}`" + + if isinstance(obj, GetInput): + return f"``{obj.name}``" + elif isinstance(obj, GetReturn): + return f"``return[{obj.index}]``" + elif isinstance(obj, GetDataType): + # Intentionally do not use _doc_str() on the value_fetcher so we can wrap it in backticks correctly. + return f"``{obj.value_fetcher}.dtype``" + elif isinstance(obj, OneOf): + return f"{_doc_str(obj.fetcher)} is one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" + elif isinstance(obj, NotOneOf): + return f"{_doc_str(obj.fetcher)} is not one of [{', '.join(f'{_doc_str(opt)}' for opt in obj.options)}]" + elif isinstance(obj, Equal): + return f"{_doc_str(obj.fetcher)} == {_doc_str(obj.fetcher_or_value)}" + elif isinstance(obj, NotEqual): + return f"{_doc_str(obj.fetcher)} != {_doc_str(obj.fetcher_or_value)}" + elif isinstance(obj, And): + return ", **and**\n".join("- " + indent(_doc_str(constraint), " ").lstrip() for constraint in obj.constraints) + elif isinstance(obj, Or): + return "(" + " *or* ".join(_doc_str(constraint) for constraint in obj.constraints) + ")" + + assert False, f"Unsupported object type for doc string generation: {type(obj)}. Please add handling here!" + + +# Modify the docstring to include constraints +def _update_docstring(func, input_requirements, output_guarantees): + if not func.__doc__: + return + + indentation = " " * 4 + code_block_index = func.__doc__.find(".. code-block:: python") + assert code_block_index != -1, f"No code example in docstring for {func.__name__}" + + input_requirements_str = f"\nINPUT REQUIREMENTS:\n{indent(_doc_str(input_requirements), indentation)}\n" + output_guarantees_str = f"\nOUTPUT GUARANTEES:\n{indent(_doc_str(output_guarantees), indentation)}\n" + + func.__doc__ = ( + func.__doc__[:code_block_index] + + indent(input_requirements_str + output_guarantees_str, indentation) + + "\n" + + indentation + + func.__doc__[code_block_index:] + ) + + # Modify the docstring to mention data type variables and exceptions -def _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions): +def _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions): if not func.__doc__: return @@ -295,6 +351,9 @@ def sorted_types(dtypes): def interface( + # TODO (pranavm): These should be required arguments eventually. + input_requirements: Constraints = None, + output_guarantees: Constraints = None, dtype_constraints: Dict[str, str] = {}, dtype_variables: Dict[str, List[str]] = {}, dtype_exceptions: List[Dict[str, str]] = [], @@ -367,7 +426,10 @@ def decorator(func): DataTypeConstraints(func, dtype_constraints, dtype_variables, dtype_exceptions) ) - _update_docstring(func, dtype_constraints, dtype_variables, dtype_exceptions) + if input_requirements is not None: + _update_docstring(func, input_requirements, output_guarantees) + elif dtype_constraints or dtype_variables or dtype_exceptions: + _update_docstring_legacy(func, dtype_constraints, dtype_variables, dtype_exceptions) @functools.wraps(func) def wrapper(*args, **kwargs): @@ -386,6 +448,17 @@ def wrapper(*args, **kwargs): shape_likes, ) + if config.enable_input_validation: + if input_requirements is not None: + result = input_requirements(merged_args) + if not result: + raise_error( + f"Invalid inputs for function: '{func.__qualname__}'.", + ["Expected: "] + + result.error_details + + [f".\n\nNote: Requirements are:\n {input_requirements}."], + ) + if config.enable_dtype_checking: from nvtripy.common.datatype import dtype from nvtripy.frontend.tensor import Tensor diff --git a/tripy/nvtripy/utils/stack_info.py b/tripy/nvtripy/utils/stack_info.py index 19fab4871..f8557552e 100644 --- a/tripy/nvtripy/utils/stack_info.py +++ b/tripy/nvtripy/utils/stack_info.py @@ -139,6 +139,6 @@ def get_module_names_to_exclude_from_stack_info(): or trying to retrieve column information from code. """ import nvtripy.utils.function_registry as function_registry - import nvtripy.utils.wrappers as wrappers + import nvtripy.frontend.wrappers as wrappers return {mod.__name__ for mod in [function_registry, wrappers]} diff --git a/tripy/tests/common/test_exception.py b/tripy/tests/common/test_exception.py index 0c9fbdd87..b8a8a0aec 100644 --- a/tripy/tests/common/test_exception.py +++ b/tripy/tests/common/test_exception.py @@ -122,7 +122,7 @@ def test_can_determine_column_range(self): ) def test_wrappers_is_excluded(self): - from nvtripy.utils import wrappers + from nvtripy.frontend import wrappers tensor = tp.ones((2, 3)) diff --git a/tripy/tests/frontend/constraints/__init__.py b/tripy/tests/frontend/constraints/__init__.py new file mode 100644 index 000000000..8bb95d5cb --- /dev/null +++ b/tripy/tests/frontend/constraints/__init__.py @@ -0,0 +1,16 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tripy/tests/frontend/constraints/test_base.py b/tripy/tests/frontend/constraints/test_base.py new file mode 100644 index 000000000..ac3b03cb1 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_base.py @@ -0,0 +1,116 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints import And, Equal, GetDataType, GetInput, OneOf + + +class TestConstraints: + def test_find_exact_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_no_match(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = OneOf(GetInput("a"), [1, 2, 3]) + assert len(constraint.find(pattern)) == 0 + + def test_find_in_nested_and(self): + inner_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(inner_constraint, OneOf(GetInput("c"), [1, 2, 3])) + pattern = Equal(GetInput, GetInput) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is inner_constraint + + def test_find_multiple_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2, OneOf(GetInput("e"), [1, 2, 3])) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_dtype_pattern(self): + constraint = Equal(GetDataType(GetInput("tensor1")), GetDataType(GetInput("tensor2"))) + pattern = Equal(GetDataType(GetInput), GetDataType(GetInput)) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_deeply_nested_matches(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("d"), GetInput("e")) + constraint = And( + And(equal1, OneOf(GetInput("c"), [1, 2, 3])), + And(equal2, OneOf(GetInput("f"), [4, 5, 6])), + ) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 2 and matches[0] is equal1 and matches[1] is equal2 + + def test_find_with_specific_names(self): + match_constraint = Equal(GetInput("a"), GetInput("b")) + constraint = And(match_constraint, Equal(GetInput("c"), GetInput("d"))) + matches = constraint.find(Equal(GetInput("a"), GetInput("b"))) + assert len(matches) == 1 and matches[0] is match_constraint + + def test_find_with_multiple_children(self): + equal1 = Equal(GetInput("a"), GetInput("b")) + equal2 = Equal(GetInput("c"), GetInput("d")) + oneof1 = OneOf(GetInput("e"), [1, 2, 3]) + equal3 = Equal(GetInput("f"), GetInput("g")) + constraint = And(equal1, equal2, oneof1, equal3) + matches = constraint.find(Equal(GetInput, GetInput)) + assert len(matches) == 3 + assert equal1 in matches + assert equal2 in matches + assert equal3 in matches + + def test_find_and_constraint(self): + and1 = And(Equal(GetInput("a"), GetInput("b")), OneOf(GetInput("c"), [1, 2, 3])) + and2 = And(Equal(GetInput("d"), GetInput("e")), OneOf(GetInput("f"), [4, 5, 6])) + constraint = And(and1, and2) + matches = constraint.find(And(Equal, OneOf)) + assert len(matches) == 2 + assert and1 in matches + assert and2 in matches + + def test_find_with_none_wildcard_second_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(GetInput("a"), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_first_arg(self): + constraint = Equal(GetInput("a"), GetInput("b")) + pattern = Equal(None, GetInput("b")) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is constraint + + def test_find_with_none_wildcard_in_nested(self): + equal1 = Equal(GetDataType(GetInput("a")), GetDataType(GetInput("b"))) + equal2 = Equal(GetInput("c"), GetInput("d")) + constraint = And(equal1, equal2) + pattern = Equal(GetDataType(GetInput), None) + matches = constraint.find(pattern) + assert len(matches) == 1 and matches[0] is equal1 + + def test_find_with_none_wildcard_matches_different_types(self): + equal = Equal(GetInput("a"), GetInput("b")) + oneof = OneOf(GetInput("c"), [1, 2, 3]) + constraint = And(equal, oneof) + pattern = None + matches = constraint.find(pattern) + assert len(matches) == 6 + assert constraint in matches diff --git a/tripy/tests/frontend/constraints/test_fetcher.py b/tripy/tests/frontend/constraints/test_fetcher.py new file mode 100644 index 000000000..dbfcc3b40 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_fetcher.py @@ -0,0 +1,106 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import nvtripy as tp +from nvtripy.common.exception import TripyException +from nvtripy.frontend.constraints import Equal, GetDataType, GetInput, GetReturn, NotEqual +from tests import helper + + +class TestFetcher: + def test_eq_operator_returns_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 == fetcher2 + assert isinstance(constraint, Equal) + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 + + def test_ne_operator_returns_not_equal(self): + fetcher1 = GetInput("param1") + fetcher2 = GetInput("param2") + constraint = fetcher1 != fetcher2 + assert isinstance(constraint, NotEqual) + assert constraint.fetcher == fetcher1 + assert constraint.fetcher_or_value == fetcher2 + + +class TestValueFetcher: + def test_dtype_property(self): + fetcher = GetInput("tensor") + dtype_fetcher = fetcher.dtype + assert isinstance(dtype_fetcher, GetDataType) + assert dtype_fetcher.value_fetcher == fetcher + + +class TestGetInput: + def test_call(self): + fetcher = GetInput("data") + args = [("data", 42), ("other", "hello")] + assert fetcher(args) == 42 + + def test_str(self): + fetcher = GetInput("data") + assert str(fetcher) == "data" + + +class TestGetReturn: + def test_init(self): + fetcher = GetReturn(0) + assert fetcher.index == 0 + + def test_call(self): + fetcher = GetReturn(0) + returns = (42, "hello", 3.14) + assert fetcher([], returns) == 42 + + fetcher2 = GetReturn(2) + assert fetcher2([], returns) == 3.14 + + def test_str(self): + fetcher = GetReturn(0) + assert str(fetcher) == "return[0]" + + fetcher2 = GetReturn(2) + assert str(fetcher2) == "return[2]" + + +class TestGetDataType: + def test_call(self): + tensor = tp.ones((2, 3), dtype=tp.float32) + fetcher = GetDataType(GetInput("input_tensor")) + assert fetcher([("input_tensor", tensor)]) == tp.float32 + + def test_call_with_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32)] * 2 + fetcher = GetDataType(GetInput("input_tensors")) + assert fetcher([("input_tensors", tensors)]) == tp.float32 + + def test_call_with_mismatched_dtypes_in_sequence(self): + tensors = [tp.ones((2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.int32)] + fetcher = GetDataType(GetInput("input_tensors")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_tensors", tensors)]) + + def test_call_with_non_tensor_argument(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Expected a tensor or data type argument"): + fetcher([("input_data", 42)]) + + def test_call_with_nested_sequence_error(self): + fetcher = GetDataType(GetInput("input_data")) + with helper.raises(TripyException, match="Could not determine data type"): + fetcher([("input_data", [tp.ones((2, 3), dtype=tp.float32), [42]])]) diff --git a/tripy/tests/frontend/constraints/test_logic.py b/tripy/tests/frontend/constraints/test_logic.py new file mode 100644 index 000000000..c23d500b5 --- /dev/null +++ b/tripy/tests/frontend/constraints/test_logic.py @@ -0,0 +1,219 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from nvtripy.frontend.constraints import And, Equal, GetInput, NotEqual, NotOneOf, OneOf, Or + + +class TestLogic: + def test_operator_and_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 & constraint2 + assert isinstance(combined, And) + assert combined([("param1", 2), ("param2", "b")]) + + def test_operator_and_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 & constraint2 & constraint3 + assert isinstance(combined, And) + assert len(combined.constraints) == 3 + assert combined([("param1", 2), ("param2", "b"), ("param3", True)]) + + def test_operator_or_basic(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + combined = constraint1 | constraint2 + assert isinstance(combined, Or) + assert combined([("param1", 5), ("param2", "b")]) + + def test_operator_or_chaining(self): + constraint1 = OneOf(GetInput("param1"), [1, 2, 3]) + constraint2 = OneOf(GetInput("param2"), ["a", "b", "c"]) + constraint3 = OneOf(GetInput("param3"), [True, False]) + combined = constraint1 | constraint2 | constraint3 + assert isinstance(combined, Or) + assert len(combined.constraints) == 3 + assert combined([("param1", 5), ("param2", "z"), ("param3", True)]) + + def test_operator_not_basic(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + negated = ~constraint + assert isinstance(negated, NotOneOf) + assert negated([("param", 5)]) + assert not negated([("param", 2)]) + + +class TestOneOf: + def test_call(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 2)]) + result = constraint([("param", 5)]) + assert not result + assert "'param' to be one of [1, 2, 3] (but it was '5')" in result.error_details + + def test_str(self): + assert str(OneOf(GetInput("param"), [1, 2, 3])) == "param is one of [1, 2, 3]" + + def test_inverse(self): + constraint = OneOf(GetInput("param"), [1, 2, 3]) + inverse = constraint.inverse() + assert isinstance(inverse, NotOneOf) + assert inverse([("param", 5)]) + assert not inverse([("param", 2)]) + + +class TestNotOneOf: + def test_call(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + assert constraint([("param", 5)]) + result = constraint([("param", 2)]) + assert not result + assert "'param' to not be one of [1, 2, 3] (but it was '2')" in result.error_details + + def test_str(self): + assert str(NotOneOf(GetInput("param"), [1, 2, 3])) == "param is not one of [1, 2, 3]" + + def test_inverse(self): + constraint = NotOneOf(GetInput("param"), [1, 2, 3]) + inverse = constraint.inverse() + assert isinstance(inverse, OneOf) + assert inverse([("param", 2)]) + assert not inverse([("param", 5)]) + + +class TestAnd: + def test_call_all_pass(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert and_constraint([("param1", 2), ("param2", "b")]) + + def test_call_one_fails(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "b")]) + assert not result + assert "'param1' to be one of [1, 2, 3] (but it was '5')" in result.error_details + + def test_call_all_fail(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = and_constraint([("param1", 5), ("param2", "z")]) + assert not result + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') and 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) + + def test_str(self): + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(and_constraint) == "(param1 is one of [1, 2, 3] and param2 is one of ['a', 'b'])" + + def test_inverse(self): + # De Morgan's law: not (A and B) = (not A) or (not B) + and_constraint = And(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = and_constraint.inverse() + assert isinstance(inverse, Or) + assert str(inverse) == "(param1 is not one of [1, 2, 3] or param2 is not one of ['a', 'b'])" + + +class TestOr: + def test_call_first_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "z")]) + + def test_call_second_passes(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 5), ("param2", "b")]) + + def test_call_all_pass(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + assert or_constraint([("param1", 2), ("param2", "b")]) + + def test_call_all_fail(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b", "c"])) + result = or_constraint([("param1", 5), ("param2", "z")]) + assert not result + assert ( + "".join(result.error_details) + == "'param1' to be one of [1, 2, 3] (but it was '5') or 'param2' to be one of ['a', 'b', 'c'] (but it was 'z')" + ) + + def test_str(self): + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + assert str(or_constraint) == "(param1 is one of [1, 2, 3] or param2 is one of ['a', 'b'])" + + def test_call_multiple_constraints(self): + or_constraint = Or( + OneOf(GetInput("param1"), [1, 2, 3]), + OneOf(GetInput("param2"), ["a", "b", "c"]), + OneOf(GetInput("param3"), [True, False]), + ) + assert or_constraint([("param1", 5), ("param2", "z"), ("param3", True)]) + assert not or_constraint([("param1", 5), ("param2", "z"), ("param3", None)]) + + def test_inverse(self): + # De Morgan's law: not (A or B) = (not A) and (not B) + or_constraint = Or(OneOf(GetInput("param1"), [1, 2, 3]), OneOf(GetInput("param2"), ["a", "b"])) + inverse = or_constraint.inverse() + assert isinstance(inverse, And) + assert str(inverse) == "(param1 is not one of [1, 2, 3] and param2 is not one of ['a', 'b'])" + + +class TestEqual: + def test_call(self): + constraint = Equal(GetInput("param1"), GetInput("param2")) + assert constraint([("param1", 5), ("param2", 5)]) + result = constraint([("param1", 5), ("param2", 10)]) + assert not result + assert "'param1' to be equal to 'param2' (but it was '5')" in result.error_details + + def test_str(self): + assert str(Equal(GetInput("param1"), GetInput("param2"))) == "param1 == param2" + assert str(Equal(GetInput("param1"), 5)) == "param1 == 5" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") == GetInput("param2") + assert isinstance(constraint, Equal) + + def test_inverse(self): + constraint = Equal(GetInput("param1"), 5) + inverse = constraint.inverse() + assert isinstance(inverse, NotEqual) + assert inverse([("param1", 10)]) + assert not inverse([("param1", 5)]) + + +class TestNotEqual: + def test_call(self): + constraint = NotEqual(GetInput("param1"), GetInput("param2")) + assert constraint([("param1", 5), ("param2", 10)]) + result = constraint([("param1", 5), ("param2", 5)]) + assert not result + assert "'param1' to be not equal to 'param2' (but it was '5')" in result.error_details + + def test_str(self): + assert str(NotEqual(GetInput("param1"), GetInput("param2"))) == "param1 != param2" + assert str(NotEqual(GetInput("param1"), 5)) == "param1 != 5" + + def test_operator_on_fetcher(self): + constraint = GetInput("param1") != GetInput("param2") + assert isinstance(constraint, NotEqual) + + def test_inverse(self): + constraint = NotEqual(GetInput("param1"), 5) + inverse = constraint.inverse() + assert isinstance(inverse, Equal) + assert inverse([("param1", 5)]) + assert not inverse([("param1", 10)]) diff --git a/tripy/tests/utils/wrappers/test_datatype_constraints.py b/tripy/tests/frontend/wrappers/test_datatype_constraints.py similarity index 98% rename from tripy/tests/utils/wrappers/test_datatype_constraints.py rename to tripy/tests/frontend/wrappers/test_datatype_constraints.py index 548023fe2..cccb14590 100644 --- a/tripy/tests/utils/wrappers/test_datatype_constraints.py +++ b/tripy/tests/frontend/wrappers/test_datatype_constraints.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# TODO (pranavm): Move into integration tests + import contextlib import inspect import itertools @@ -22,10 +25,10 @@ import nvtripy as tp import pytest from nvtripy.common.datatype import DATA_TYPES -from nvtripy.utils import wrappers +from nvtripy.frontend import wrappers from nvtripy.utils.types import str_from_type_annotation from nvtripy.utils.utils import make_list -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS from tests import helper from tests.conftest import skip_if_older_than_sm89 diff --git a/tripy/tests/utils/wrappers/test_interface.py b/tripy/tests/frontend/wrappers/test_wrappers.py old mode 100755 new mode 100644 similarity index 73% rename from tripy/tests/utils/wrappers/test_interface.py rename to tripy/tests/frontend/wrappers/test_wrappers.py index 7a9466afe..602f9ba76 --- a/tripy/tests/utils/wrappers/test_interface.py +++ b/tripy/tests/frontend/wrappers/test_wrappers.py @@ -22,8 +22,10 @@ import nvtripy as tp import pytest from nvtripy.export import PUBLIC_APIS -from nvtripy.utils import wrappers -from nvtripy.utils.wrappers import DATA_TYPE_CONSTRAINTS +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints.fetcher import GetDataType, GetInput, GetReturn +from nvtripy.frontend.constraints.logic import And, Equal, NotEqual, NotOneOf, OneOf, Or +from nvtripy.frontend.wrappers import DATA_TYPE_CONSTRAINTS, _doc_str from tests import helper # Get all functions/methods which have tensors in the type signature @@ -59,6 +61,73 @@ def sequence_func(tensors: List[tp.Tensor]): return +class TestDocStr: + def test_basic_types(self): + assert _doc_str(tp.float32) == ":class:`float32`" + assert _doc_str(GetInput("x")) == "``x``" + assert _doc_str(GetReturn(0)) == "``return[0]``" + + def test_get_datatype(self): + assert _doc_str(GetDataType(GetInput("x"))) == "``x.dtype``" + assert _doc_str(GetDataType(GetReturn(0))) == "``return[0].dtype``" + + def test_one_of_and_not_one_of(self): + input_x = GetInput("x") + + assert ( + _doc_str(OneOf(input_x, [tp.float32, tp.float16])) == "``x`` is one of [:class:`float32`, :class:`float16`]" + ) + assert _doc_str(NotOneOf(input_x, [tp.int8, tp.int32])) == "``x`` is not one of [:class:`int8`, :class:`int32`]" + + def test_equal_and_not_equal(self): + input_a = GetInput("a") + input_b = GetInput("b") + + assert _doc_str(Equal(input_a, input_b)) == "``a`` == ``b``" + assert _doc_str(Equal(input_a, tp.float32)) == "``a`` == :class:`float32`" + assert _doc_str(NotEqual(input_a, input_b)) == "``a`` != ``b``" + + def test_and_constraint(self): + constraint1 = OneOf(GetInput("a"), [tp.float32]) + constraint2 = OneOf(GetInput("b"), [tp.int32]) + + assert ( + _doc_str(And(constraint1, constraint2)) + == "- ``a`` is one of [:class:`float32`]\n- ``b`` is one of [:class:`int32`]" + ) + + def test_or_constraint(self): + input_a = GetInput("a") + or_constraint = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + + assert _doc_str(or_constraint) == "(``a`` == :class:`float32` or ``a`` == :class:`float16`)" + + def test_nested_constraints(self): + input_a = GetInput("a") + input_b = GetInput("b") + + or_part = Or(Equal(input_a, tp.float32), Equal(input_a, tp.float16)) + and_constraint = And(or_part, OneOf(input_b, [tp.int32])) + + assert ( + _doc_str(and_constraint) + == "- (``a`` == :class:`float32` or ``a`` == :class:`float16`)\n- ``b`` is one of [:class:`int32`]" + ) + + def test_complex_real_world_constraint(self): + input_a = GetInput("input") + input_b = GetInput("other") + dtype_a = GetDataType(input_a) + dtype_b = GetDataType(input_b) + + and_constraint = And(Equal(dtype_a, dtype_b), OneOf(dtype_a, [tp.float32, tp.float16])) + + assert ( + _doc_str(and_constraint) + == "- ``input.dtype`` == ``other.dtype``\n- ``input.dtype`` is one of [:class:`float32`, :class:`float16`]" + ) + + class TestDtypes: def test_works_with_sequences(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)]) diff --git a/tripy/tests/utils/test_utils.py b/tripy/tests/utils/test_utils.py index 54725b4e1..568e6d223 100644 --- a/tripy/tests/utils/test_utils.py +++ b/tripy/tests/utils/test_utils.py @@ -20,6 +20,7 @@ import nvtripy as tp import pytest from nvtripy import utils +from nvtripy.frontend.wrappers import constant_fields from tests import helper @@ -46,7 +47,7 @@ def test_hash_equivalence(self, func): def make_with_constant_field(): - @utils.wrappers.constant_fields("field") + @constant_fields("field") class WithConstField: def __init__(self): self.custom_setter_called_count = defaultdict(int)