2222
2323
2424def apply_tensor_contraints (op_name : str , tensor_constraints : list [object ]) -> None :
25+ additional_tensor_constraints = [
26+ cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
27+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
28+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
29+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
30+ cp .Rank .Ge (lambda deps : 1 ),
31+ cp .Size .Ge (lambda deps , r , d : 1 ),
32+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
33+ ]
34+
2535 match op_name :
36+ case "where.self" :
37+ additional_tensor_constraints = [
38+ cp .Dtype .In (lambda deps : [torch .float , torch .int , torch .bool ]),
39+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
40+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
41+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
42+ cp .Rank .Ge (lambda deps : 1 ),
43+ cp .Size .Ge (lambda deps , r , d : 1 ),
44+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
45+ ]
2646 case "sigmoid.default" | "rsqrt.default" :
27- tensor_constraints .extend (
47+ additional_tensor_constraints .extend (
2848 [
2949 cp .Dtype .In (lambda deps : [torch .float ]),
3050 cp .Rank .Le (lambda deps : 2 ** 2 ),
@@ -33,45 +53,35 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
3353 ]
3454 )
3555 case "mean.dim" :
36- tensor_constraints .extend (
56+ additional_tensor_constraints .extend (
3757 [
3858 cp .Dtype .In (lambda deps : [torch .float ]),
3959 cp .Rank .Le (lambda deps : 2 ** 2 ),
4060 ]
4161 )
4262 case "exp.default" :
43- tensor_constraints .extend (
63+ additional_tensor_constraints .extend (
4464 [
4565 cp .Rank .Le (lambda deps : 2 ** 3 ),
4666 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
4767 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
4868 ]
4969 )
5070 case "slice_copy.Tensor" :
51- tensor_constraints .extend (
71+ additional_tensor_constraints .extend (
5272 [
5373 cp .Rank .Le (lambda deps : 2 ),
5474 cp .Value .Ge (lambda deps , dtype , struct : 1 ),
5575 cp .Value .Le (lambda deps , dtype , struct : 2 ),
5676 ]
5777 )
5878 case _:
59- tensor_constraints .extend (
79+ additional_tensor_constraints .extend (
6080 [
6181 cp .Rank .Le (lambda deps : 2 ** 2 ),
6282 ]
6383 )
64- tensor_constraints .extend (
65- [
66- cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
67- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
68- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
69- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
70- cp .Rank .Ge (lambda deps : 1 ),
71- cp .Size .Ge (lambda deps , r , d : 1 ),
72- cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
73- ]
74- )
84+ tensor_constraints .extend (additional_tensor_constraints )
7585
7686
7787def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
0 commit comments