diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index a0162189b9..822a3f72ac 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -1,6 +1,8 @@ // RUN: xdsl-opt %s -p "lower-csl-stencil" | filecheck %s builtin.module { +// CHECK-NEXT: builtin.module { + "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "gauss_seidel_func"}> ({ ^0(%0 : i16, %1 : i16, %2 : i16, %3 : i16, %4 : i16, %5 : i16, %6 : i16, %7 : i16, %8 : i16): %9 = arith.constant 0 : i16 @@ -59,10 +61,7 @@ builtin.module { } "csl_wrapper.yield"() <{"fields" = []}> : () -> () }) : () -> () -} - -// CHECK-NEXT: builtin.module { // CHECK-NEXT: "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "gauss_seidel_func"}> ({ // CHECK-NEXT: ^0(%0 : i16, %1 : i16, %2 : i16, %3 : i16, %4 : i16, %5 : i16, %6 : i16, %7 : i16, %8 : i16): // CHECK-NEXT: %9 = arith.constant 0 : i16 @@ -138,4 +137,244 @@ builtin.module { // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () // CHECK-NEXT: }) : () -> () + + + "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "loop", "width" = 1024 : i16}> ({ + ^0(%arg0 : i16, %arg1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): + %0 = arith.constant 0 : i16 + %1 = "csl.get_color"(%0) : (i16) -> !csl.color + %2 = "csl_wrapper.import"(%arg2, %arg3, %1) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module + %3 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module + %4 = "csl.member_call"(%3, %arg0, %arg1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct + %5 = "csl.member_call"(%2, %arg0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct + %6 = arith.constant 1 : i16 + %7 = arith.subi %arg5, %6 : i16 + %8 = arith.subi %arg2, %arg0 : i16 + %9 = arith.subi %arg3, %arg1 : i16 + %10 = arith.cmpi slt, %arg0, %7 : i16 + %11 = arith.cmpi slt, %arg1, %7 : i16 + %12 = arith.cmpi slt, %8, %arg5 : i16 + %13 = arith.cmpi slt, %9, %arg5 : i16 + %14 = arith.ori %10, %11 : i1 + %15 = arith.ori %14, %12 : i1 + %16 = arith.ori %15, %13 : i1 + "csl_wrapper.yield"(%5, %4, %16) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () + }, { + ^1(%arg0_1 : i16, %arg1_1 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): + %17 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module + %18 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module + %19 = memref.alloc() : memref<512xf32> + %20 = memref.alloc() : memref<512xf32> + %21 = "csl.addressof"(%19) : (memref<512xf32>) -> !csl.ptr, #csl> + %22 = "csl.addressof"(%20) : (memref<512xf32>) -> !csl.ptr, #csl> + "csl.export"(%21) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () + "csl.export"(%22) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () + "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel_func}> : () -> () + %23 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var + %24 = "csl.variable"() : () -> !csl.var> + %25 = "csl.variable"() : () -> !csl.var> + csl.func @loop() { + %26 = arith.constant 0 : index + %27 = arith.constant 1000 : index + %28 = arith.constant 1 : index + "csl.store_var"(%24, %19) : (!csl.var>, memref<512xf32>) -> () + "csl.store_var"(%25, %20) : (!csl.var>, memref<512xf32>) -> () + csl.activate local, 1 : i32 + csl.return + } + csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ + %29 = arith.constant 1000 : i16 + %30 = "csl.load_var"(%23) : (!csl.var) -> i16 + %31 = arith.cmpi slt, %30, %29 : i16 + scf.if %31 { + "csl.call"() <{"callee" = @for_body0}> : () -> () + } else { + "csl.call"() <{"callee" = @for_post0}> : () -> () + } + csl.return + } + csl.func @for_body0() { + %arg10 = "csl.load_var"(%23) : (!csl.var) -> i16 + %arg11 = "csl.load_var"(%24) : (!csl.var>) -> memref<512xf32> + %arg12 = "csl.load_var"(%25) : (!csl.var>) -> memref<512xf32> + %32 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> + csl_stencil.apply(%arg11 : memref<512xf32>, %32 : memref<510xf32>, %arg12 : memref<512xf32>, %arg9 : i1) outs (%arg12 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + ^2(%arg13 : memref<4x510xf32>, %arg14 : index, %arg15 : memref<510xf32>): + %33 = csl_stencil.access %arg13[1, 0] : memref<4x510xf32> + %34 = csl_stencil.access %arg13[-1, 0] : memref<4x510xf32> + %35 = csl_stencil.access %arg13[0, 1] : memref<4x510xf32> + %36 = csl_stencil.access %arg13[0, -1] : memref<4x510xf32> + %37 = memref.subview %arg15[%arg14] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> + "csl.fadds"(%37, %36, %35) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>, memref<510xf32>) -> () + "csl.fadds"(%37, %37, %34) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () + "csl.fadds"(%37, %37, %33) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () + "memref.copy"(%37, %37) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () + csl_stencil.yield %arg15 : memref<510xf32> + }, { + ^3(%arg13_1 : memref<512xf32>, %arg14_1 : memref<510xf32>, %38 : memref<512xf32>, %39 : i1): + scf.if %39 { + } else { + %40 = memref.subview %arg13_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> + %41 = memref.subview %arg13_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> + "csl.fadds"(%arg14_1, %arg14_1, %41) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () + "csl.fadds"(%arg14_1, %arg14_1, %40) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () + %42 = arith.constant 1.666600e-01 : f32 + "csl.fmuls"(%arg14_1, %arg14_1, %42) : (memref<510xf32>, memref<510xf32>, f32) -> () + %43 = memref.subview %38[1] [510] [1] : memref<512xf32> to memref<510xf32> + "memref.copy"(%arg14_1, %43) : (memref<510xf32>, memref<510xf32>) -> () + } + "csl.call"() <{"callee" = @for_inc0}> : () -> () + csl_stencil.yield + }) to <[0, 0], [1, 1]> + csl.return + } + csl.func @for_inc0() { + %33 = arith.constant 1 : i16 + %34 = "csl.load_var"(%23) : (!csl.var) -> i16 + %35 = arith.addi %34, %33 : i16 + "csl.store_var"(%23, %35) : (!csl.var, i16) -> () + %36 = "csl.load_var"(%24) : (!csl.var>) -> memref<512xf32> + %37 = "csl.load_var"(%25) : (!csl.var>) -> memref<512xf32> + "csl.store_var"(%24, %37) : (!csl.var>, memref<512xf32>) -> () + "csl.store_var"(%25, %36) : (!csl.var>, memref<512xf32>) -> () + csl.activate local, 1 : i32 + csl.return + } + csl.func @for_post0() { + "csl.member_call"(%17) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () + csl.return + } + "csl_wrapper.yield"() <{"fields" = []}> : () -> () + }) : () -> () + +// CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "loop", "width" = 1024 : i16}> ({ +// CHECK-NEXT: ^2(%arg0_1 : i16, %arg1_1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): +// CHECK-NEXT: %61 = arith.constant 0 : i16 +// CHECK-NEXT: %62 = "csl.get_color"(%61) : (i16) -> !csl.color +// CHECK-NEXT: %63 = "csl_wrapper.import"(%arg2, %arg3, %62) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %64 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %65 = "csl.member_call"(%64, %arg0_1, %arg1_1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %66 = "csl.member_call"(%63, %arg0_1) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %67 = arith.constant 1 : i16 +// CHECK-NEXT: %68 = arith.subi %arg5, %67 : i16 +// CHECK-NEXT: %69 = arith.subi %arg2, %arg0_1 : i16 +// CHECK-NEXT: %70 = arith.subi %arg3, %arg1_1 : i16 +// CHECK-NEXT: %71 = arith.cmpi slt, %arg0_1, %68 : i16 +// CHECK-NEXT: %72 = arith.cmpi slt, %arg1_1, %68 : i16 +// CHECK-NEXT: %73 = arith.cmpi slt, %69, %arg5 : i16 +// CHECK-NEXT: %74 = arith.cmpi slt, %70, %arg5 : i16 +// CHECK-NEXT: %75 = arith.ori %71, %72 : i1 +// CHECK-NEXT: %76 = arith.ori %75, %73 : i1 +// CHECK-NEXT: %77 = arith.ori %76, %74 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%66, %65, %77) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: }, { +// CHECK-NEXT: ^3(%arg0_2 : i16, %arg1_2 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): +// CHECK-NEXT: %78 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %79 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %80 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %81 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %82 = "csl.addressof"(%80) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %83 = "csl.addressof"(%81) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%82) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%83) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel_func}> : () -> () +// CHECK-NEXT: %84 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var +// CHECK-NEXT: %85 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %86 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: csl.func @loop() { +// CHECK-NEXT: %87 = arith.constant 0 : index +// CHECK-NEXT: %88 = arith.constant 1000 : index +// CHECK-NEXT: %89 = arith.constant 1 : index +// CHECK-NEXT: "csl.store_var"(%85, %80) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%86, %81) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: csl.activate local, 1 : i32 +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ +// CHECK-NEXT: %90 = arith.constant 1000 : i16 +// CHECK-NEXT: %91 = "csl.load_var"(%84) : (!csl.var) -> i16 +// CHECK-NEXT: %92 = arith.cmpi slt, %91, %90 : i16 +// CHECK-NEXT: scf.if %92 { +// CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: "csl.call"() <{"callee" = @for_post0}> : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_body0() { +// CHECK-NEXT: %arg10 = "csl.load_var"(%84) : (!csl.var) -> i16 +// CHECK-NEXT: %arg11 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg12 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %accumulator_1 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> +// CHECK-NEXT: %93 = arith.constant 1 : i16 +// CHECK-NEXT: %94 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %95 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %96 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%79, %96, %93, %94, %95) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_2 : i16) { +// CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index +// CHECK-NEXT: %97 = arith.constant 1 : i16 +// CHECK-NEXT: %98 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %99 = "csl.member_call"(%79, %98, %97) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %100 = builtin.unrealized_conversion_cast %99 : !csl to memref<510xf32> +// CHECK-NEXT: %101 = arith.constant 1 : i16 +// CHECK-NEXT: %102 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %103 = "csl.member_call"(%79, %102, %101) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %104 = builtin.unrealized_conversion_cast %103 : !csl to memref<510xf32> +// CHECK-NEXT: %105 = arith.constant 1 : i16 +// CHECK-NEXT: %106 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %107 = "csl.member_call"(%79, %106, %105) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %108 = builtin.unrealized_conversion_cast %107 : !csl to memref<510xf32> +// CHECK-NEXT: %109 = arith.constant 1 : i16 +// CHECK-NEXT: %110 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %111 = "csl.member_call"(%79, %110, %109) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %112 = builtin.unrealized_conversion_cast %111 : !csl to memref<510xf32> +// CHECK-NEXT: %113 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%113, %112, %108) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%113, %113, %104) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%113, %113, %100) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: "memref.copy"(%113, %113) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @done_exchange_cb1() { +// CHECK-NEXT: %arg12_1 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg11_1 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: scf.if %arg9 { +// CHECK-NEXT: } else { +// CHECK-NEXT: %114 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %115 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %115) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %114) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %116 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %116) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %117 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_1, %117) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: } +// CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_inc0() { +// CHECK-NEXT: %118 = arith.constant 1 : i16 +// CHECK-NEXT: %119 = "csl.load_var"(%84) : (!csl.var) -> i16 +// CHECK-NEXT: %120 = arith.addi %119, %118 : i16 +// CHECK-NEXT: "csl.store_var"(%84, %120) : (!csl.var, i16) -> () +// CHECK-NEXT: %121 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %122 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: "csl.store_var"(%85, %122) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%86, %121) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: csl.activate local, 1 : i32 +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_post0() { +// CHECK-NEXT: "csl.member_call"(%78) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () +// CHECK-NEXT: }) : () -> () + + + +} // CHECK-NEXT: } diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 51233f6c3e..42ee90a17e 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -11,7 +11,7 @@ i16, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper -from xdsl.ir import Attribute, Block, Operation, Region +from xdsl.ir import Attribute, Block, Operation, OpResult, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -21,6 +21,7 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint +from xdsl.traits import is_side_effect_free from xdsl.utils.hints import isa @@ -221,6 +222,36 @@ def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, rewriter.replace_matched_op(csl.ReturnOp()) +@dataclass(frozen=True) +class InlineApplyOpArgs(RewritePattern): + """ + Inlines apply op args into the callbacks. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): + arg_mapping = zip( + op.done_exchange.block.args[2:], + op.args[-(len(op.done_exchange.block.args) - 2) :], + ) + for block_arg, arg in [ + (op.done_exchange.block.args[0], op.field), + *arg_mapping, + ]: + if isinstance(arg, OpResult) and arg.op.parent == op.parent: + if not ( + isinstance(arg.op, csl.LoadVarOp) or is_side_effect_free(arg.op) + ): + raise ValueError( + "Can only promote csl.LoadVarOp or side_effect_free op" + ) + rewriter.insert_op( + new_arg := arg.op.clone(), + InsertPoint.at_start(op.done_exchange.block), + ) + block_arg.replace_by(SSAValue.get(new_arg)) + + @dataclass(frozen=True) class LowerCslStencil(ModulePass): """ @@ -239,7 +270,13 @@ class LowerCslStencil(ModulePass): def apply(self, ctx: MLContext, op: ModuleOp) -> None: PatternRewriteWalker( - LowerYieldOp(), + GreedyRewritePatternApplier( + [ + LowerYieldOp(), + InlineApplyOpArgs(), + ] + ), + apply_recursively=False, ).rewrite_module(op) module_pass = PatternRewriteWalker( GreedyRewritePatternApplier(