Skip to content

Commit 7c148a7

Browse files
authored
Add constraints for split_copy test
Differential Revision: D84104833 Pull Request resolved: #14870
1 parent 5af73eb commit 7c148a7

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,34 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
222222
cp.Value.Le(lambda deps, dtype, struct: 2),
223223
]
224224
)
225+
case "transpose_copy.int":
226+
tensor_constraints.extend(
227+
[
228+
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
229+
]
230+
)
231+
case "permute_copy.default":
232+
tensor_constraints.extend(
233+
[
234+
cp.Dtype.In(lambda deps: [torch.float32, torch.int8, torch.uint8]),
235+
cp.Rank.Le(
236+
lambda deps: 5
237+
), # xa_nn_transpose only supports up to 5D
238+
cp.Rank.Ge(lambda deps: 1), # Must have at least 1 dimension
239+
]
240+
)
241+
case "sqrt.default":
242+
tensor_constraints.extend(
243+
[
244+
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
245+
]
246+
)
247+
case "clamp.default":
248+
tensor_constraints.extend(
249+
[
250+
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
251+
]
252+
)
225253
case "rsqrt.default":
226254
tensor_constraints.extend(
227255
[
@@ -232,6 +260,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
232260
cp.Value.Le(lambda deps, dtype, struct: 2**2),
233261
]
234262
)
263+
case "relu.default":
264+
tensor_constraints.extend(
265+
[
266+
cp.Dtype.In(lambda deps: [torch.float32]),
267+
]
268+
)
235269
case "mean.dim":
236270
tensor_constraints.extend(
237271
[
@@ -241,10 +275,17 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
241275
case "exp.default":
242276
tensor_constraints.extend(
243277
[
278+
cp.Dtype.In(lambda deps: [torch.float32]),
244279
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
245280
cp.Value.Le(lambda deps, dtype, struct: 2**2),
246281
]
247282
)
283+
case "tanh.default":
284+
tensor_constraints.extend(
285+
[
286+
cp.Dtype.In(lambda deps: [torch.float32]),
287+
]
288+
)
248289
case "slice_copy.Tensor":
249290
tensor_constraints.extend(
250291
[
@@ -253,6 +294,34 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
253294
cp.Value.Le(lambda deps, dtype, struct: 2),
254295
]
255296
)
297+
case "div.Scalar" | "add.Tensor" | "mul.Tensor" | "sub.Tensor":
298+
tensor_constraints.extend(
299+
[
300+
cp.Dtype.In(
301+
lambda deps: [
302+
torch.int32,
303+
torch.int64,
304+
torch.float32,
305+
]
306+
),
307+
]
308+
)
309+
case "split_copy.Tensor":
310+
tensor_constraints.extend(
311+
[
312+
cp.Dtype.In(
313+
lambda deps: [
314+
torch.int32,
315+
torch.int64,
316+
torch.float32,
317+
]
318+
),
319+
cp.Value.Ge(lambda deps, dtype, struct: 1),
320+
cp.Value.Le(lambda deps, dtype, struct: 2**3),
321+
cp.Rank.Le(lambda deps: 3),
322+
cp.Size.Le(lambda deps, r, d: 2**2),
323+
]
324+
)
256325
case "constant_pad_nd.default":
257326
tensor_constraints.extend(
258327
[
@@ -283,6 +352,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
283352
cp.Rank.Le(lambda deps: 2**2),
284353
]
285354
)
355+
case "pow.Tensor_Scalar":
356+
tensor_constraints.extend(
357+
[
358+
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
359+
]
360+
)
286361
case "div.Tensor_mode" | "minimum.default":
287362
if index == 0:
288363
tensor_constraints = [

0 commit comments

Comments
 (0)