@@ -540,3 +540,37 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
540
540
return %0 : !torch.vtensor <[?,?],i1 >
541
541
}
542
542
543
+ // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc(
544
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
545
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
546
+ // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
547
+ // CHECK: %[[STR:.*]] = torch.constant.str "trunc"
548
+ // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
549
+ // CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
550
+ // CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
551
+ // CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
552
+ // CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
553
+ // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
554
+ // CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
555
+ func.func @torch.aten.div.Tensor_mode$trunc (%arg0: !torch.vtensor <[?,?,?,?],f32 >, %arg1: !torch.vtensor <[?,?,?,?],f32 >) -> !torch.vtensor <[?,?,?,?],f32 > {
556
+ %str = torch.constant.str " trunc"
557
+ %0 = torch.aten.div.Tensor_mode %arg0 , %arg1 , %str : !torch.vtensor <[?,?,?,?],f32 >, !torch.vtensor <[?,?,?,?],f32 >, !torch.str -> !torch.vtensor <[?,?,?,?],f32 >
558
+ return %0 : !torch.vtensor <[?,?,?,?],f32 >
559
+ }
560
+
561
+ // -----
562
+
563
+ // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor(
564
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
565
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
566
+ // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
567
+ // CHECK: %[[STR:.*]] = torch.constant.str "floor"
568
+ // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
569
+ // CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
570
+ // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
571
+ // CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
572
+ func.func @torch.aten.div.Tensor_mode$floor (%arg0: !torch.vtensor <[?,?,?,?],f32 >, %arg1: !torch.vtensor <[?,?,?,?],f32 >) -> !torch.vtensor <[?,?,?,?],f32 > {
573
+ %str = torch.constant.str " floor"
574
+ %0 = torch.aten.div.Tensor_mode %arg0 , %arg1 , %str : !torch.vtensor <[?,?,?,?],f32 >, !torch.vtensor <[?,?,?,?],f32 >, !torch.str -> !torch.vtensor <[?,?,?,?],f32 >
575
+ return %0 : !torch.vtensor <[?,?,?,?],f32 >
576
+ }
0 commit comments