@@ -222,6 +222,34 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
222222 cp .Value .Le (lambda deps , dtype , struct : 2 ),
223223 ]
224224 )
225+ case "transpose_copy.int" :
226+ tensor_constraints .extend (
227+ [
228+ cp .Dtype .In (lambda deps : [torch .float32 , torch .int32 ]),
229+ ]
230+ )
231+ case "permute_copy.default" :
232+ tensor_constraints .extend (
233+ [
234+ cp .Dtype .In (lambda deps : [torch .float32 , torch .int8 , torch .uint8 ]),
235+ cp .Rank .Le (
236+ lambda deps : 5
237+ ), # xa_nn_transpose only supports up to 5D
238+ cp .Rank .Ge (lambda deps : 1 ), # Must have at least 1 dimension
239+ ]
240+ )
241+ case "sqrt.default" :
242+ tensor_constraints .extend (
243+ [
244+ cp .Dtype .In (lambda deps : [torch .float32 , torch .int32 ]),
245+ ]
246+ )
247+ case "clamp.default" :
248+ tensor_constraints .extend (
249+ [
250+ cp .Dtype .In (lambda deps : [torch .float32 , torch .int32 ]),
251+ ]
252+ )
225253 case "rsqrt.default" :
226254 tensor_constraints .extend (
227255 [
@@ -232,6 +260,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
232260 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
233261 ]
234262 )
263+ case "relu.default" :
264+ tensor_constraints .extend (
265+ [
266+ cp .Dtype .In (lambda deps : [torch .float32 ]),
267+ ]
268+ )
235269 case "mean.dim" :
236270 tensor_constraints .extend (
237271 [
@@ -241,10 +275,17 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
241275 case "exp.default" :
242276 tensor_constraints .extend (
243277 [
278+ cp .Dtype .In (lambda deps : [torch .float32 ]),
244279 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
245280 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
246281 ]
247282 )
283+ case "tanh.default" :
284+ tensor_constraints .extend (
285+ [
286+ cp .Dtype .In (lambda deps : [torch .float32 ]),
287+ ]
288+ )
248289 case "slice_copy.Tensor" :
249290 tensor_constraints .extend (
250291 [
@@ -253,6 +294,34 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
253294 cp .Value .Le (lambda deps , dtype , struct : 2 ),
254295 ]
255296 )
297+ case "div.Scalar" | "add.Tensor" | "mul.Tensor" | "sub.Tensor" :
298+ tensor_constraints .extend (
299+ [
300+ cp .Dtype .In (
301+ lambda deps : [
302+ torch .int32 ,
303+ torch .int64 ,
304+ torch .float32 ,
305+ ]
306+ ),
307+ ]
308+ )
309+ case "split_copy.Tensor" :
310+ tensor_constraints .extend (
311+ [
312+ cp .Dtype .In (
313+ lambda deps : [
314+ torch .int32 ,
315+ torch .int64 ,
316+ torch .float32 ,
317+ ]
318+ ),
319+ cp .Value .Ge (lambda deps , dtype , struct : 1 ),
320+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 3 ),
321+ cp .Rank .Le (lambda deps : 3 ),
322+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
323+ ]
324+ )
256325 case "constant_pad_nd.default" :
257326 tensor_constraints .extend (
258327 [
@@ -283,6 +352,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
283352 cp .Rank .Le (lambda deps : 2 ** 2 ),
284353 ]
285354 )
355+ case "pow.Tensor_Scalar" :
356+ tensor_constraints .extend (
357+ [
358+ cp .Dtype .In (lambda deps : [torch .float32 , torch .int32 ]),
359+ ]
360+ )
286361 case "div.Tensor_mode" | "minimum.default" :
287362 if index == 0 :
288363 tensor_constraints = [
0 commit comments