diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml new file mode 100644 index 000000000..1f408ad17 --- /dev/null +++ b/.github/actionlint.yaml @@ -0,0 +1,5 @@ +# actionlint.yaml + +self-hosted-runner: + labels: + - zkir diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d96b4395..f385b10c3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,21 +18,21 @@ concurrency: jobs: build-and-test: - runs-on: self-hosted + runs-on: [self-hosted, zkir] steps: - name: Checkout ZKIR uses: actions/checkout@v4 - name: Run `bazel build` run: | - bazel build --remote_cache=grpc://127.0.0.1:9092 --test_output=errors //... + bazel build --remote_cache=grpc://127.0.0.1:9092 --remote_download_outputs=all --test_output=errors //... - name: Run `bazel test` run: | bazel test --config ci --remote_cache=grpc://127.0.0.1:9092 --test_output=errors //... pre-commit-style: - runs-on: self-hosted + runs-on: [self-hosted, zkir] steps: - name: Checkout ZKIR uses: actions/checkout@v4 @@ -47,7 +47,7 @@ jobs: - name: Refresh compile commands run: | - bazel run @hedron_compile_commands//:refresh_all + bazel run --remote_cache=grpc://127.0.0.1:9092 --remote_download_outputs=all @hedron_compile_commands//:refresh_all - name: Run pre-commit github action uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 # pin@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f22131bc1..75ed6fb5e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,14 +11,7 @@ repos: - id: check-added-large-files args: ["--maxkb=500000"] # 500MB - # clang-format - - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.7 - hooks: - - id: clang-format - args: ["--style=file"] - - # cpplint + clang-tidy + # cpplint - repo: https://github.com/pocc/pre-commit-hooks rev: v1.3.5 hooks: @@ -27,11 +20,27 @@ repos: [ "--filter=-whitespace/indent_namespace, -legal/copyright, -build/namespaces, -runtime/references", ] + + # clang-tidy + # This hook now relies on a local shell script. + - repo: local + hooks: - id: clang-tidy - args: - [ - "--checks=readability-static-definition-in-anonymous-namespace,modernize-concat-nested-namespaces", - ] + name: clang-tidy + entry: ./run_clang_tidy.sh + language: script + files: '\.(c|cc|cpp|cxx|h|hh|hpp|hxx)$' + exclude: | + (?x) + (^|/)(bazel-.*|bazel-out/|build/|out/|third_party/|external/) + pass_filenames: true + + # clang-format + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v19.1.7 + hooks: + - id: clang-format + args: ["--style=file"] # Starlark - repo: https://github.com/keith/pre-commit-buildifier @@ -60,6 +69,7 @@ repos: hooks: - id: actionlint additional_dependencies: [pyflakes>=3.0.1, shellcheck-py>=0.9.0.5] + args: ["--config-file=.github/actionlint.yaml"] # mdformat - repo: https://github.com/executablebooks/mdformat @@ -78,4 +88,4 @@ repos: - id: pyink language_version: python3 -exclude: patches/.*\.patch$ +exclude: third_party/llvm-project/.*\.patch$ diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index cd1dcfad9..fc0bcb3d7 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -20,9 +20,9 @@ zkir_deps() load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "9ae3bce17543f92ce0237597cc66503d58cce317" +LLVM_COMMIT = "5ed852f7f72855710eeff53179e6a6f2271a3c2a" -LLVM_SHA256 = "19a170bb0931bcf3e698f37fef24b2c52871525987731b5d56be6a4afc705b22" +LLVM_SHA256 = "95792e50d5f84847721545b645a6ca2c2b3b7610d02e3de07d65a6148e68508c" http_archive( name = "llvm-raw", diff --git a/run_clang_tidy.sh b/run_clang_tidy.sh new file mode 100755 index 000000000..1889983ad --- /dev/null +++ b/run_clang_tidy.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -euo pipefail + +# This script runs clang-tidy with OS-specific arguments. + +# Get the path to the directory containing this script. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_ROOT="$SCRIPT_DIR" + +# Define the base clang-tidy command. +CLANG_TIDY_BIN=$(command -v clang-tidy) +if [ -z "$CLANG_TIDY_BIN" ]; then + echo "Error: clang-tidy not found in PATH." >&2 + exit 1 +fi + +# Define the common arguments. +CLANG_TIDY_ARGS=( + "--checks=readability-static-definition-in-anonymous-namespace,modernize-concat-nested-namespaces" + "-fix" +) + +# Add macOS-specific arguments. +if [ "$(uname)" == "Darwin" ]; then + echo "Running clang-tidy with macOS toolchain." + CLANG_TIDY_ARGS+=( + "--extra-arg=-isysroot" + "--extra-arg=$(xcrun --show-sdk-path)" + ) +else + # This branch would be for Linux, Windows, etc. + echo "Running clang-tidy with default toolchain." +fi + +# Pass the compile_commands.json directory. +# This assumes the file is in the project root. +if [ -f "$PROJECT_ROOT/compile_commands.json" ]; then + CLANG_TIDY_ARGS+=("-p=$PROJECT_ROOT") +fi + +# Run clang-tidy on the files passed by pre-commit. +# $@ is the array of all arguments passed to this script (i.e., the filenames). +$CLANG_TIDY_BIN "${CLANG_TIDY_ARGS[@]}" "$@" diff --git a/tests/Dialect/EllipticCurve/bufferization.mlir b/tests/Dialect/EllipticCurve/bufferization.mlir index b2f0ea63e..320e505ba 100644 --- a/tests/Dialect/EllipticCurve/bufferization.mlir +++ b/tests/Dialect/EllipticCurve/bufferization.mlir @@ -27,11 +27,11 @@ func.func @test_bufferization_materialize_in_destination(%tensor : tensor<2x!aff return } -// CHECK-LABEL: @test_bufferization_to_memref +// CHECK-LABEL: @test_bufferization_to_buffer // CHECK-SAME: (%[[TENSOR:.*]]: [[TENSOR_TYPE:.*]]) -> [[T:.*]] { -func.func @test_bufferization_to_memref(%tensor : tensor<2x!affine>) -> memref<2x!affine> { - // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] - %memref = bufferization.to_memref %tensor : tensor<2x!affine> to memref<2x!affine> +func.func @test_bufferization_to_buffer(%tensor : tensor<2x!affine>) -> memref<2x!affine> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] + %memref = bufferization.to_buffer %tensor : tensor<2x!affine> to memref<2x!affine> // CHECK: return %[[MEMREF]] : [[T]] return %memref : memref<2x!affine> } diff --git a/tests/Dialect/EllipticCurve/elliptic_curve_msm_runner.mlir b/tests/Dialect/EllipticCurve/elliptic_curve_msm_runner.mlir index 781708226..fedd10d73 100644 --- a/tests/Dialect/EllipticCurve/elliptic_curve_msm_runner.mlir +++ b/tests/Dialect/EllipticCurve/elliptic_curve_msm_runner.mlir @@ -58,7 +58,7 @@ func.func @test_msm() { %extract_point_reduced = field.from_mont %extract_point : tensor<2x!PF> %extract = field.extract %extract_point_reduced : tensor<2x!PF> -> tensor<2xi256> %trunc = arith.trunci %extract : tensor<2xi256> to tensor<2xi32> - %mem = bufferization.to_memref %trunc : tensor<2xi32> to memref<2xi32> + %mem = bufferization.to_buffer %trunc : tensor<2xi32> to memref<2xi32> %U = memref.cast %mem : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -73,7 +73,7 @@ func.func @test_msm() { %extract_point1_reduced = field.from_mont %extract_point1 : tensor<2x!PF> %extract1 = field.extract %extract_point1_reduced : tensor<2x!PF> -> tensor<2xi256> %trunc1 = arith.trunci %extract1 : tensor<2xi256> to tensor<2xi32> - %mem1 = bufferization.to_memref %trunc1 : tensor<2xi32> to memref<2xi32> + %mem1 = bufferization.to_buffer %trunc1 : tensor<2xi32> to memref<2xi32> %U1 = memref.cast %mem1 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U1) : (memref<*xi32>) -> () return diff --git a/tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir b/tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir index 13fbc12ff..c9d7f6a59 100644 --- a/tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir +++ b/tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir @@ -49,7 +49,7 @@ func.func @test_ops_in_order() { %from_mont1 = field.from_mont %extract_point1 : tensor<3x!PF> %extract1 = field.extract %from_mont1 : tensor<3x!PF> -> tensor<3xi256> %trunc1 = arith.trunci %extract1 : tensor<3xi256> to tensor<3xi32> - %1 = bufferization.to_memref %trunc1 : tensor<3xi32> to memref<3xi32> + %1 = bufferization.to_buffer %trunc1 : tensor<3xi32> to memref<3xi32> %U1 = memref.cast %1 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U1) : (memref<*xi32>) -> () @@ -59,7 +59,7 @@ func.func @test_ops_in_order() { %from_mont2 = field.from_mont %extract_point2 : tensor<3x!PF> %extract2 = field.extract %from_mont2 : tensor<3x!PF> -> tensor<3xi256> %trunc2 = arith.trunci %extract2 : tensor<3xi256> to tensor<3xi32> - %2 = bufferization.to_memref %trunc2 : tensor<3xi32> to memref<3xi32> + %2 = bufferization.to_buffer %trunc2 : tensor<3xi32> to memref<3xi32> %U2 = memref.cast %2 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () @@ -69,7 +69,7 @@ func.func @test_ops_in_order() { %from_mont3 = field.from_mont %extract_point3 : tensor<3x!PF> %extract3 = field.extract %from_mont3 : tensor<3x!PF> -> tensor<3xi256> %trunc3 = arith.trunci %extract3 : tensor<3xi256> to tensor<3xi32> - %3 = bufferization.to_memref %trunc3 : tensor<3xi32> to memref<3xi32> + %3 = bufferization.to_buffer %trunc3 : tensor<3xi32> to memref<3xi32> %U3 = memref.cast %3 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U3) : (memref<*xi32>) -> () @@ -79,7 +79,7 @@ func.func @test_ops_in_order() { %from_mont4 = field.from_mont %extract_point4 : tensor<3x!PF> %extract4 = field.extract %from_mont4 : tensor<3x!PF> -> tensor<3xi256> %trunc4 = arith.trunci %extract4 : tensor<3xi256> to tensor<3xi32> - %4 = bufferization.to_memref %trunc4 : tensor<3xi32> to memref<3xi32> + %4 = bufferization.to_buffer %trunc4 : tensor<3xi32> to memref<3xi32> %U4 = memref.cast %4 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U4) : (memref<*xi32>) -> () @@ -89,7 +89,7 @@ func.func @test_ops_in_order() { %from_mont5 = field.from_mont %extract_point5 : tensor<4x!PF> %extract5 = field.extract %from_mont5 : tensor<4x!PF> -> tensor<4xi256> %trunc5 = arith.trunci %extract5 : tensor<4xi256> to tensor<4xi32> - %5 = bufferization.to_memref %trunc5 : tensor<4xi32> to memref<4xi32> + %5 = bufferization.to_buffer %trunc5 : tensor<4xi32> to memref<4xi32> %U5 = memref.cast %5 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U5) : (memref<*xi32>) -> () @@ -99,7 +99,7 @@ func.func @test_ops_in_order() { %from_mont6 = field.from_mont %extract_point6 : tensor<2x!PF> %extract6 = field.extract %from_mont6 : tensor<2x!PF> -> tensor<2xi256> %trunc6 = arith.trunci %extract6 : tensor<2xi256> to tensor<2xi32> - %6 = bufferization.to_memref %trunc6 : tensor<2xi32> to memref<2xi32> + %6 = bufferization.to_buffer %trunc6 : tensor<2xi32> to memref<2xi32> %U6 = memref.cast %6 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U6) : (memref<*xi32>) -> () @@ -111,7 +111,7 @@ func.func @test_ops_in_order() { %from_mont7 = field.from_mont %extract_point7 : tensor<3x!PF> %extract7 = field.extract %from_mont7 : tensor<3x!PF> -> tensor<3xi256> %trunc7 = arith.trunci %extract7 : tensor<3xi256> to tensor<3xi32> - %7 = bufferization.to_memref %trunc7 : tensor<3xi32> to memref<3xi32> + %7 = bufferization.to_buffer %trunc7 : tensor<3xi32> to memref<3xi32> %U7 = memref.cast %7 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U7) : (memref<*xi32>) -> () @@ -121,7 +121,7 @@ func.func @test_ops_in_order() { %from_mont8 = field.from_mont %extract_point8 : tensor<2x!PF> %extract8 = field.extract %from_mont8 : tensor<2x!PF> -> tensor<2xi256> %trunc8 = arith.trunci %extract8 : tensor<2xi256> to tensor<2xi32> - %8 = bufferization.to_memref %trunc8 : tensor<2xi32> to memref<2xi32> + %8 = bufferization.to_buffer %trunc8 : tensor<2xi32> to memref<2xi32> %U8 = memref.cast %8 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U8) : (memref<*xi32>) -> () @@ -133,7 +133,7 @@ func.func @test_ops_in_order() { %from_mont9 = field.from_mont %extract_point9 : tensor<4x!PF> %extract9 = field.extract %from_mont9 : tensor<4x!PF> -> tensor<4xi256> %trunc9 = arith.trunci %extract9 : tensor<4xi256> to tensor<4xi32> - %9 = bufferization.to_memref %trunc9 : tensor<4xi32> to memref<4xi32> + %9 = bufferization.to_buffer %trunc9 : tensor<4xi32> to memref<4xi32> %U9 = memref.cast %9 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U9) : (memref<*xi32>) -> () diff --git a/tests/Dialect/Field/bufferization.mlir b/tests/Dialect/Field/bufferization.mlir index 139ceacd8..2c90c17a3 100644 --- a/tests/Dialect/Field/bufferization.mlir +++ b/tests/Dialect/Field/bufferization.mlir @@ -19,11 +19,11 @@ func.func @test_bufferization_materialize_in_destination(%tensor : tensor<2x!PF> return } -// CHECK-LABEL: @test_bufferization_to_memref +// CHECK-LABEL: @test_bufferization_to_buffer // CHECK-SAME: (%[[TENSOR:.*]]: [[TENSOR_TYPE:.*]]) -> [[T:.*]] { -func.func @test_bufferization_to_memref(%tensor : tensor<2x!PF>) -> memref<2x!PF> { - // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] - %memref = bufferization.to_memref %tensor : tensor<2x!PF> to memref<2x!PF> +func.func @test_bufferization_to_buffer(%tensor : tensor<2x!PF>) -> memref<2x!PF> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] + %memref = bufferization.to_buffer %tensor : tensor<2x!PF> to memref<2x!PF> // CHECK: return %[[MEMREF]] : [[T]] return %memref : memref<2x!PF> } diff --git a/tests/Dialect/Field/field_runner.mlir b/tests/Dialect/Field/field_runner.mlir index aaa7b71ab..c41751371 100644 --- a/tests/Dialect/Field/field_runner.mlir +++ b/tests/Dialect/Field/field_runner.mlir @@ -23,7 +23,7 @@ func.func @test_power() { %res1 = field.pow %base_pf, %exp : !PF, i64 %1 = field.extract %res1 : !PF -> i32 %2 = tensor.from_elements %1 : tensor<1xi32> - %3 = bufferization.to_memref %2 : tensor<1xi32> to memref<1xi32> + %3 = bufferization.to_buffer %2 : tensor<1xi32> to memref<1xi32> %U1 = memref.cast %3 : memref<1xi32> to memref<*xi32> func.call @printMemrefI32(%U1) : (memref<*xi32>) -> () @@ -32,7 +32,7 @@ func.func @test_power() { %res1_standard = field.from_mont %res1_mont : !PF %4 = field.extract %res1_standard : !PF -> i32 %5 = tensor.from_elements %4 : tensor<1xi32> - %6 = bufferization.to_memref %5 : tensor<1xi32> to memref<1xi32> + %6 = bufferization.to_buffer %5 : tensor<1xi32> to memref<1xi32> %U2 = memref.cast %6 : memref<1xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () @@ -40,7 +40,7 @@ func.func @test_power() { %res2 = field.pow %base_f2, %exp : !QF, i64 %9, %10 = field.extract %res2 : !QF -> i32, i32 %11 = tensor.from_elements %9, %10 : tensor<2xi32> - %12 = bufferization.to_memref %11 : tensor<2xi32> to memref<2xi32> + %12 = bufferization.to_buffer %11 : tensor<2xi32> to memref<2xi32> %U3 = memref.cast %12 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U3) : (memref<*xi32>) -> () @@ -49,7 +49,7 @@ func.func @test_power() { %res2_standard = field.from_mont %res2_mont : !QF %13, %14 = field.extract %res2_standard : !QF -> i32, i32 %15 = tensor.from_elements %13, %14 : tensor<2xi32> - %16 = bufferization.to_memref %15 : tensor<2xi32> to memref<2xi32> + %16 = bufferization.to_buffer %15 : tensor<2xi32> to memref<2xi32> %U4 = memref.cast %16 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U4) : (memref<*xi32>) -> () diff --git a/tests/Dialect/ModArith/bufferization.mlir b/tests/Dialect/ModArith/bufferization.mlir index 6e0027b04..f95aa5a36 100644 --- a/tests/Dialect/ModArith/bufferization.mlir +++ b/tests/Dialect/ModArith/bufferization.mlir @@ -19,11 +19,11 @@ func.func @test_bufferization_materialize_in_destination(%tensor : tensor<2x!Zp> return } -// CHECK-LABEL: @test_bufferization_to_memref +// CHECK-LABEL: @test_bufferization_to_buffer // CHECK-SAME: (%[[TENSOR:.*]]: [[TENSOR_TYPE:.*]]) -> [[T:.*]] { -func.func @test_bufferization_to_memref(%tensor : tensor<2x!Zp>) -> memref<2x!Zp> { - // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] - %memref = bufferization.to_memref %tensor : tensor<2x!Zp> to memref<2x!Zp> +func.func @test_bufferization_to_buffer(%tensor : tensor<2x!Zp>) -> memref<2x!Zp> { + // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[TENSOR]] : [[TENSOR_TYPE]] to [[T]] + %memref = bufferization.to_buffer %tensor : tensor<2x!Zp> to memref<2x!Zp> // CHECK: return %[[MEMREF]] : [[T]] return %memref : memref<2x!Zp> } diff --git a/tests/Dialect/ModArith/mod_arith_runner.mlir b/tests/Dialect/ModArith/mod_arith_runner.mlir index 7230be8d6..e8e99ef6e 100644 --- a/tests/Dialect/ModArith/mod_arith_runner.mlir +++ b/tests/Dialect/ModArith/mod_arith_runner.mlir @@ -41,7 +41,7 @@ func.func @test_lower_inverse() { %2 = mod_arith.extract %from_mont : !Fq -> i256 %3 = arith.trunci %2 : i256 to i32 %4 = tensor.from_elements %3 : tensor<1xi32> - %5 = bufferization.to_memref %4 : tensor<1xi32> to memref<1xi32> + %5 = bufferization.to_buffer %4 : tensor<1xi32> to memref<1xi32> %U = memref.cast %5 : memref<1xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -50,7 +50,7 @@ func.func @test_lower_inverse() { %6 = mod_arith.extract %mul2 : !Fq -> i256 %7 = arith.trunci %6 : i256 to i32 %8 = tensor.from_elements %7 : tensor<1xi32> - %9 = bufferization.to_memref %8 : tensor<1xi32> to memref<1xi32> + %9 = bufferization.to_buffer %8 : tensor<1xi32> to memref<1xi32> %10 = memref.cast %9 : memref<1xi32> to memref<*xi32> func.call @printMemrefI32(%10) : (memref<*xi32>) -> () @@ -61,7 +61,7 @@ func.func @test_lower_inverse() { %from_mont_r = mod_arith.from_mont %mul_r : !Fr %r_ext = mod_arith.extract %from_mont_r : !Fr -> i32 %r_tensor = tensor.from_elements %r_ext : tensor<1xi32> - %r_mem = bufferization.to_memref %r_tensor : tensor<1xi32> to memref<1xi32> + %r_mem = bufferization.to_buffer %r_tensor : tensor<1xi32> to memref<1xi32> %r_mem_cast = memref.cast %r_mem : memref<1xi32> to memref<*xi32> func.call @printMemrefI32(%r_mem_cast) : (memref<*xi32>) -> () return @@ -84,7 +84,7 @@ func.func @test_lower_inverse_tensor() { %mul = mod_arith.mul %inv, %tensor2 : tensor<3x!Fq> %ext = mod_arith.extract %mul : tensor<3x!Fq> -> tensor<3xi256> %trunc = arith.trunci %ext : tensor<3xi256> to tensor<3xi32> - %1 = bufferization.to_memref %trunc : tensor<3xi32> to memref<3xi32> + %1 = bufferization.to_buffer %trunc : tensor<3xi32> to memref<3xi32> %U1 = memref.cast %1 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U1) : (memref<*xi32>) -> () @@ -94,7 +94,7 @@ func.func @test_lower_inverse_tensor() { %from_mont = mod_arith.from_mont %mul_mont : tensor<3x!Fq> %ext2 = mod_arith.extract %from_mont : tensor<3x!Fq> -> tensor<3xi256> %trunc2 = arith.trunci %ext2 : tensor<3xi256> to tensor<3xi32> - %2 = bufferization.to_memref %trunc2 : tensor<3xi32> to memref<3xi32> + %2 = bufferization.to_buffer %trunc2 : tensor<3xi32> to memref<3xi32> %U2 = memref.cast %2 : memref<3xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () return diff --git a/tests/Dialect/Poly/poly_ntt_runner.mlir b/tests/Dialect/Poly/poly_ntt_runner.mlir index 9f8770a50..ae3b8b3bb 100644 --- a/tests/Dialect/Poly/poly_ntt_runner.mlir +++ b/tests/Dialect/Poly/poly_ntt_runner.mlir @@ -34,7 +34,7 @@ func.func @test_poly_ntt() { %res_standard = field.from_mont %res : tensor<4x!coeff_ty> %extract = field.extract %res_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32> + %1 = bufferization.to_buffer %extract : tensor<4xi32> to memref<4xi32> %U = memref.cast %1 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -43,7 +43,7 @@ func.func @test_poly_ntt() { %res2 = poly.to_tensor %poly : !poly_ty -> tensor<4x!coeff_ty_mont> %res2_standard = field.from_mont %res2 : tensor<4x!coeff_ty> %extract2 = field.extract %res2_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32> + %2= bufferization.to_buffer %extract2 : tensor<4xi32> to memref<4xi32> %U2 = memref.cast %2 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () return @@ -61,7 +61,7 @@ func.func @test_poly_ntt_with_twiddles() { %res_standard = field.from_mont %res : tensor<4x!coeff_ty> %extract = field.extract %res_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32> + %1 = bufferization.to_buffer %extract : tensor<4xi32> to memref<4xi32> %U = memref.cast %1 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -72,7 +72,7 @@ func.func @test_poly_ntt_with_twiddles() { %res2 = poly.to_tensor %poly : !poly_ty -> tensor<4x!coeff_ty_mont> %res2_standard = field.from_mont %res2 : tensor<4x!coeff_ty> %extract2 = field.extract %res2_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32> + %2= bufferization.to_buffer %extract2 : tensor<4xi32> to memref<4xi32> %U2 = memref.cast %2 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () return @@ -90,7 +90,7 @@ func.func @test_poly_ntt_out_of_place() { %res_standard = field.from_mont %res : tensor<4x!coeff_ty> %extract = field.extract %res_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32> + %1 = bufferization.to_buffer %extract : tensor<4xi32> to memref<4xi32> %U = memref.cast %1 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -100,7 +100,7 @@ func.func @test_poly_ntt_out_of_place() { %res2 = poly.to_tensor %poly : !poly_ty -> tensor<4x!coeff_ty_mont> %res2_standard = field.from_mont %res2 : tensor<4x!coeff_ty> %extract2 = field.extract %res2_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32> + %2= bufferization.to_buffer %extract2 : tensor<4xi32> to memref<4xi32> %U2 = memref.cast %2 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () return @@ -118,7 +118,7 @@ func.func @test_poly_ntt_out_of_place_no_bit_reversal() { %res_standard = field.from_mont %res : tensor<4x!coeff_ty> %extract = field.extract %res_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32> + %1 = bufferization.to_buffer %extract : tensor<4xi32> to memref<4xi32> %U = memref.cast %1 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U) : (memref<*xi32>) -> () @@ -128,7 +128,7 @@ func.func @test_poly_ntt_out_of_place_no_bit_reversal() { %res2 = poly.to_tensor %poly : !poly_ty -> tensor<4x!coeff_ty_mont> %res2_standard = field.from_mont %res2 : tensor<4x!coeff_ty> %extract2 = field.extract %res2_standard : tensor<4x!coeff_ty> -> tensor<4xi32> - %2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32> + %2= bufferization.to_buffer %extract2 : tensor<4xi32> to memref<4xi32> %U2 = memref.cast %2 : memref<4xi32> to memref<*xi32> func.call @printMemrefI32(%U2) : (memref<*xi32>) -> () return diff --git a/third_party/llvm-project/linalg_type_support.patch b/third_party/llvm-project/linalg_type_support.patch index 9b8cfa412..2fc62836f 100644 --- a/third_party/llvm-project/linalg_type_support.patch +++ b/third_party/llvm-project/linalg_type_support.patch @@ -1,5 +1,5 @@ diff --git mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp w/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp -index 7d1844df4219..53ee42ae95b1 100644 +index ca7f31dd6b51..5b0e32f98e8e 100644 --- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -25,12 +25,18 @@ @@ -21,7 +21,7 @@ index 7d1844df4219..53ee42ae95b1 100644 /// Include the definitions of the copy operation interface. #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" -@@ -489,6 +495,9 @@ mlir::linalg::detail::isContractionInterfaceImpl( +@@ -525,6 +531,9 @@ mlir::linalg::detail::isContractionInterfaceImpl( // TODO: more fields than add/mul. // clang-format off if (!::isContractionBody< @@ -32,7 +32,7 @@ index 7d1844df4219..53ee42ae95b1 100644 arith::MulIOp, arith::AddIOp, complex::MulOp, complex::AddOp, diff --git mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp w/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp -index 96106cf7ae12..abda2b01ddfb 100644 +index 8075df730ccc..eed50cf6fb63 100644 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -51,6 +51,14 @@ @@ -50,7 +50,7 @@ index 96106cf7ae12..abda2b01ddfb 100644 #include #include -@@ -429,8 +437,16 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, +@@ -439,8 +447,16 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, namespace { @@ -67,11 +67,11 @@ index 96106cf7ae12..abda2b01ddfb 100644 RegionBuilderHelper(OpBuilder &builder, Block &block) : builder(builder), block(block) {} -@@ -482,8 +498,16 @@ public: +@@ -506,7 +522,14 @@ public: bool allInteger = isInteger(arg0) && isInteger(arg1); bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && arg1.getType().getIntOrFloatBitWidth() == 1; -- if (!allComplex && !allFloatingPoint && !allInteger) +- if (!allComplex && !allFloatingPoint && !allInteger) { + + bool allModular = isModular(arg0) && isModular(arg1); + bool allField = isField(arg0) && isField(arg1); @@ -79,13 +79,19 @@ index 96106cf7ae12..abda2b01ddfb 100644 + bool allECPoint = isECPoint(arg0) && isECPoint(arg1); + + if (!allComplex && !allFloatingPoint && !allInteger && !allModular && -+ !allField && !scalarByPoint && !allECPoint) ++ !allField && !scalarByPoint && !allECPoint) { + if (emitError) { + emitError() + << "Cannot build binary Linalg operation: expects allComplex, " +@@ -516,6 +539,7 @@ public: + } llvm_unreachable("unsupported non numeric type"); + } + OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); switch (binaryFn) { -@@ -494,6 +518,15 @@ public: +@@ -526,6 +550,15 @@ public: return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); @@ -101,10 +107,10 @@ index 96106cf7ae12..abda2b01ddfb 100644 return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) -@@ -502,6 +535,15 @@ public: - return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) +@@ -539,6 +572,15 @@ public: + } llvm_unreachable("unsupported operation: sub with bools"); + } + if (allModular) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allField) @@ -117,7 +123,7 @@ index 96106cf7ae12..abda2b01ddfb 100644 return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) -@@ -510,6 +552,15 @@ public: +@@ -547,6 +589,15 @@ public: return builder.create(arg0.getLoc(), arg0, arg1); if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); @@ -133,10 +139,10 @@ index 96106cf7ae12..abda2b01ddfb 100644 return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::div: if (allComplex) -@@ -518,30 +569,96 @@ public: - return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) +@@ -560,6 +611,14 @@ public: + } llvm_unreachable("unsupported operation: div with bools"); + } + if (allModular) + llvm_unreachable("unsupported operation: div with modular int"); + if (allField) @@ -147,8 +153,11 @@ index 96106cf7ae12..abda2b01ddfb 100644 + llvm_unreachable("unsupported operation: div with two EC points"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::div_unsigned: - if (!allInteger || allBool) + if (!allInteger || allBool) { +@@ -569,26 +628,84 @@ public: + } llvm_unreachable("unsupported operation: unsigned div not on uint"); + } + if (allModular) + llvm_unreachable( + "unsupported operation: div_unsigned with modular int"); @@ -230,17 +239,17 @@ index 96106cf7ae12..abda2b01ddfb 100644 return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::powf: assert(allFloatingPoint); -@@ -571,6 +688,9 @@ public: - +@@ -627,6 +744,9 @@ public: // Build the type functions defined by OpDSL. - Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand, + function_ref emitError = {}) { + if (isField(operand) || isECPoint(operand) || isModular(operand)) { + return operand; // Do not cast ZKIR types. + } switch (typeFn) { case TypeFn::cast_signed: return cast(toType, operand, false); -@@ -629,6 +749,21 @@ private: +@@ -696,6 +816,21 @@ private: bool isInteger(Value value) { return llvm::isa(value.getType()); } @@ -263,10 +272,10 @@ index 96106cf7ae12..abda2b01ddfb 100644 OpBuilder &builder; Block █ diff --git utils/bazel/llvm-project-overlay/mlir/BUILD.bazel w/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -index 05c2fb481980..b5c63c72ac1e 100644 +index cc266c2fe3a7..13b7e256ea64 100644 --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -10349,5 +10349,8 @@ cc_library( +@@ -10544,5 +10544,8 @@ cc_library( ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", diff --git a/third_party/llvm-project/memref_folding.patch b/third_party/llvm-project/memref_folding.patch index 0b08e86d4..4d0091276 100644 --- a/third_party/llvm-project/memref_folding.patch +++ b/third_party/llvm-project/memref_folding.patch @@ -1,11 +1,13 @@ -diff --git i/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp w/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp -index 4fce9be390bd..984bbd37436c 100644 +diff --git mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +index 9bd87d66c7d3..0d0b41f51926 100644 --- mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp -@@ -827,5 +827,5 @@ struct LoadOfToMemref : public OpRewritePattern { +@@ -835,7 +835,7 @@ struct LoadOfToBuffer : public OpRewritePattern { LogicalResult matchAndRewrite(memref::LoadOp load, PatternRewriter &rewriter) const override { - auto toMemref = load.getMemref().getDefiningOp(); -- if (!toMemref) -+ if (!toMemref || !toMemref.getReadOnly()) + auto toBuffer = load.getMemref().getDefiningOp(); +- if (!toBuffer) ++ if (!toBuffer || !toBuffer.getReadOnly()) return failure(); + + rewriter.replaceOpWithNewOp(load, toBuffer.getTensor(), diff --git a/third_party/llvm-project/tensor_type_support.patch b/third_party/llvm-project/tensor_type_support.patch index e246cd0e8..9a6149be5 100644 --- a/third_party/llvm-project/tensor_type_support.patch +++ b/third_party/llvm-project/tensor_type_support.patch @@ -1,17 +1,17 @@ diff --git mlir/lib/Dialect/Tensor/IR/TensorOps.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp -index 815806f06b47..ec3a4a333f45 100644 +index 22a25fd1a5af..bbfab789d58a 100644 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp -@@ -37,6 +37,8 @@ - #include +@@ -41,6 +41,8 @@ #include + #include +#include "zkir/Dialect/ModArith/IR/ModArithTypes.h" + using namespace mlir; using namespace mlir::tensor; -@@ -1275,8 +1277,15 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result, +@@ -1465,8 +1467,16 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result, } OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { @@ -21,7 +21,8 @@ index 815806f06b47..ec3a4a333f45 100644 + getType().getElementType())) { + auto st = cast(getType()); + return DenseElementsAttr::get( -+ st.cloneWith(std::nullopt, modArithType.getModulus().getType()), adaptor.getElements()); ++ st.cloneWith(std::nullopt, modArithType.getModulus().getType()), ++ adaptor.getElements()); + } return DenseElementsAttr::get(getType(), adaptor.getElements()); + } @@ -29,10 +30,10 @@ index 815806f06b47..ec3a4a333f45 100644 } diff --git utils/bazel/llvm-project-overlay/mlir/BUILD.bazel utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -index b5c63c72ac1e..017915a8fe08 100644 +index cc266c2fe3a7..9ef338d525f1 100644 --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -7016,5 +7016,6 @@ cc_library( +@@ -7134,5 +7134,6 @@ cc_library( ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp index aa8e0a5f3..52074a2fd 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp @@ -380,7 +380,7 @@ struct ConvertAdd : public OpConversionPattern { sum = jacobianAdd(p1Coords, p2Coords, jacobianType.getCurve(), b); } else { - assert(false && "Unsupported point types for addition"); + llvm_unreachable("Unsupported point types for addition"); } b.create(loc, sum); }); @@ -411,7 +411,7 @@ struct ConvertDouble : public OpConversionPattern { } else if (auto jacobianType = dyn_cast(outputType)) { doubled = jacobianDouble(coords, jacobianType.getCurve(), b); } else { - assert(false && "Unsupported point type for doubling"); + llvm_unreachable("Unsupported point type for doubling"); } rewriter.replaceOpWithMultiple(op, {doubled}); @@ -517,8 +517,8 @@ struct ConvertScalarMul : public OpConversionPattern { ? b.create(outputType, point) : point; - auto arithOne = b.create(1, scalarIntType); - auto arithZero = b.create(0, scalarIntType); + auto arithOne = b.create(scalarIntType, 1); + auto arithZero = b.create(scalarIntType, 0); auto result = zeroPoint; auto ifOp = b.create( b.create(arith::CmpIPredicate::ne, @@ -655,7 +655,7 @@ void EllipticCurveToField::runOnOperation() { ConvertSub, ConvertAny, ConvertAny, - ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, @@ -674,7 +674,7 @@ void EllipticCurveToField::runOnOperation() { // clang-format off bufferization::AllocTensorOp, bufferization::MaterializeInDestinationOp, - bufferization::ToMemrefOp, + bufferization::ToBufferOp, bufferization::ToTensorOp, memref::AllocOp, memref::AllocaOp, diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.cpp index 4696db064..428427827 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.cpp @@ -59,8 +59,8 @@ void PippengersGeneric::scalarIsNotOneBranch(Value scalar, Value point, ImplicitLocOpBuilder &b) { auto windowBitIntType = IntegerType::get(b.getContext(), bitsPerWindow_); - Value zeroInt = b.create(0, windowBitIntType); - Value oneInt = b.create(1, windowBitIntType); + Value zeroInt = b.create(windowBitIntType, 0); + Value oneInt = b.create(windowBitIntType, 1); Value scalarForWindow = scalarDecomposition(scalar, windowOffset, b); // If the scalar is non-zero, we update the corresponding bucket. (Recall that diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.h b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.h index 1a302b9bb..d762a71f3 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.h +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/MSM/Pippengers/Generic.h @@ -33,8 +33,8 @@ class PippengersGeneric : public Pippengers { // - scalarDecomposition(): calculate scalar slice @ window // - populate bucket // } - // bucketReduction(): reduce buckets to one window sum per window // } + // bucketReduction(): reduce buckets to one window sum per window // } // } // windowReduction(): reduce window sums to MSM result diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Add.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Add.cpp index f4ca619a2..764795a1f 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Add.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Add.cpp @@ -235,7 +235,7 @@ SmallVector jacobianAdd(ValueRange p1, ValueRange p2, } else if (p1.size() == 3 && p2.size() == 3) { return jacobianAndJacobian(p1, p2, curve, b); } else { - assert(false && "Unsupported point types for Jacobian addition"); + llvm_unreachable("Unsupported point types for jacobian addition"); return {}; } } diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Double.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Double.cpp index 0483a5b53..64ea5be67 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Double.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/Jacobian/Double.cpp @@ -112,7 +112,7 @@ SmallVector jacobianDouble(ValueRange point, ShortWeierstrassAttr curve, } else if (point.size() == 3) { return jacobianToJacobian(point, a, b); } else { - assert(false && "Unsupported point type for jacobian doubling"); + llvm_unreachable("Unsupported point type for jacobian doubling"); return {}; } } diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Add.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Add.cpp index 8e4ecd00f..398e44941 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Add.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Add.cpp @@ -210,7 +210,7 @@ SmallVector xyzzAdd(ValueRange p1, ValueRange p2, } else if (p1.size() == 4 && p2.size() == 4) { return xyzzAndXyzz(p1, p2, curve, b); } else { - assert(false && "Unsupported point types for XYZZ addition"); + llvm_unreachable("Unsupported point types for xyzz addition"); return {}; } } diff --git a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Double.cpp b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Double.cpp index 219184e7e..cde2db559 100644 --- a/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Double.cpp +++ b/zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/PointOperations/XYZZ/Double.cpp @@ -94,7 +94,7 @@ SmallVector xyzzDouble(ValueRange point, ShortWeierstrassAttr curve, } else if (point.size() == 4) { return xyzzToXyzz(point, curve, b); } else { - assert(false && "Unsupported point type for xyzz doubling"); + llvm_unreachable("Unsupported point type for xyzz doubling"); return {}; } } diff --git a/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.cpp b/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.cpp index ceb468739..80811a0d6 100644 --- a/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.cpp +++ b/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.cpp @@ -76,9 +76,15 @@ LogicalResult MSMOp::verify() { if (scalarsType.getRank() != pointsType.getRank()) { return emitError() << "scalars and points must have the same rank"; } + Type inputType = pointsType.getElementType(); Type outputType = getOutput().getType(); - if (isa(outputType)) { - return emitError() << "output type cannot be affine"; + if (isa(inputType) && isa(outputType)) { + return emitError() << "jacobian input points require a jacobian or xyzz " + "output type for msm"; + } else if (isa(inputType) && + (isa(outputType) || isa(outputType))) { + return emitError() + << "xyzz input points require an xyzz output type for msm"; } int32_t degree = getDegree(); diff --git a/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td b/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td index fb48583ef..99ececd1a 100644 --- a/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td +++ b/zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td @@ -213,6 +213,10 @@ def EllipticCurve_MSMOp : EllipticCurve_Op<"msm", []> { When parallel is enabled, window-based parallelization is applied. If term-based parallelization is needed, parallel should be disabled and the caller should implement it. + affine -> affine, jacobian, xyzz + jacobian -> jacobian, xyzz + xyzz -> xyzz + Example: If you want to run window-based parallelization, use the parallel attribute as follows: diff --git a/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp b/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp index 2aa07998e..b46cc6a67 100644 --- a/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp +++ b/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp @@ -539,7 +539,7 @@ struct ConvertPow : public OpConversionPattern { auto fieldType = getElementTypeOrSelf(base); auto isNonNegative = b.create( arith::CmpIPredicate::sge, exp, - b.create(IntegerAttr::get(exp.getType(), 0))); + b.create(exp.getType(), 0)); b.create( isNonNegative, StringAttr::get( @@ -589,8 +589,7 @@ struct ConvertPow : public OpConversionPattern { // x^n ≡ x^(n mod (p²-1)) mod p² if (isa(fieldType)) { exp = b.create( - exp, - b.create(IntegerAttr::get(intType, modulus - 1))); + exp, b.create(intType, modulus - 1)); } else if (isa(fieldType)) { modulus = modulus.zext(modBitWidth * 2); modulus = modulus * modulus - 1; @@ -598,11 +597,11 @@ struct ConvertPow : public OpConversionPattern { IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp); intType = IntegerType::get(exp.getContext(), modulus.getBitWidth()); exp = b.create( - exp, b.create(IntegerAttr::get(intType, modulus))); + exp, b.create(intType, modulus)); } - Value zero = b.create(IntegerAttr::get(intType, 0)); - Value one = b.create(IntegerAttr::get(intType, 1)); + Value zero = b.create(intType, 0); + Value one = b.create(intType, 1); Value powerOfP = base; auto ifOp = b.create( b.create(arith::CmpIPredicate::ne, @@ -757,7 +756,7 @@ void FieldToModArith::runOnOperation() { ConvertAny, ConvertAny, ConvertAny, - ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, @@ -798,7 +797,7 @@ void FieldToModArith::runOnOperation() { affine::AffineYieldOp, bufferization::AllocTensorOp, bufferization::MaterializeInDestinationOp, - bufferization::ToMemrefOp, + bufferization::ToBufferOp, bufferization::ToTensorOp, linalg::BroadcastOp, linalg::GenericOp, diff --git a/zkir/Dialect/ModArith/Conversions/ModArithToArith/Inverter/BYInverter.cpp b/zkir/Dialect/ModArith/Conversions/ModArithToArith/Inverter/BYInverter.cpp index 60f91b84d..43ad8210a 100644 --- a/zkir/Dialect/ModArith/Conversions/ModArithToArith/Inverter/BYInverter.cpp +++ b/zkir/Dialect/ModArith/Conversions/ModArithToArith/Inverter/BYInverter.cpp @@ -31,23 +31,21 @@ BYInverter::BYInverter(ImplicitLocOpBuilder &b, Type inputType) extIntType_ = IntegerType::get(b.getContext(), extModBitWidth); limbType_ = IntegerType::get(b.getContext(), limbBitWidth); - maskN_ = b.create(IntegerAttr::get( - limbType_, APInt::getAllOnes(n).zextOrTrunc(limbBitWidth))); + maskN_ = b.create( + limbType_, APInt::getAllOnes(n).zextOrTrunc(limbBitWidth)); APInt m = modulus.getValue().zext(extModBitWidth); APInt mInv = byAttr.getMInv().getValue(); - m_ = b.create(IntegerAttr::get(extIntType_, m)); - mInv_ = b.create( - IntegerAttr::get(limbType_, mInv.zextOrTrunc(limbBitWidth))); - - limbTypeOne_ = b.create(IntegerAttr::get(limbType_, 1)); - limbTypeZero_ = b.create(IntegerAttr::get(limbType_, 0)); - extIntTypeOne_ = - b.create(IntegerAttr::get(extIntType_, 1)); - extIntTypeZero_ = - b.create(IntegerAttr::get(extIntType_, 0)); - - extIntTypeN_ = b.create(IntegerAttr::get(extIntType_, n)); - limbTypeN_ = b.create(IntegerAttr::get(limbType_, n)); + m_ = b.create(extIntType_, m); + mInv_ = + b.create(limbType_, mInv.zextOrTrunc(limbBitWidth)); + + limbTypeOne_ = b.create(limbType_, 1); + limbTypeZero_ = b.create(limbType_, 0); + extIntTypeOne_ = b.create(extIntType_, 1); + extIntTypeZero_ = b.create(extIntType_, 0); + + extIntTypeN_ = b.create(extIntType_, n); + limbTypeN_ = b.create(limbType_, n); } BYInverter::JumpResult BYInverter::GenerateJump(Value f, Value g, Value eta) { @@ -108,8 +106,7 @@ BYInverter::JumpResult BYInverter::GenerateJump(Value f, Value g, Value eta) { f = b.create(deltaPos, g, f); g = b.create(deltaPos, negF, g); - Value five = - b.create(IntegerAttr::get(limbType_, 5)); + Value five = b.create(limbType_, 5); Value oneMinusEta = b.create(limbTypeOne_, eta); Value shift = b.create( b.create(steps, oneMinusEta), five); @@ -117,9 +114,8 @@ BYInverter::JumpResult BYInverter::GenerateJump(Value f, Value g, Value eta) { b.create(limbTypeOne_, shift), limbTypeOne_); Value threeF = b.create( - b.create(IntegerAttr::get(limbType_, 3)), f); - Value twentyEight = - b.create(IntegerAttr::get(limbType_, 28)); + b.create(limbType_, 3), f); + Value twentyEight = b.create(limbType_, 28); Value w = b.create( b.create( g, b.create(threeF, twentyEight)), @@ -275,8 +271,8 @@ Value BYInverter::Generate(Value input, bool isMont) { f = whileOp.getResult(0); d = whileOp.getResult(2); - Value minusOneIntType = b_.create( - IntegerAttr::get(extIntType_, APInt::getAllOnes(extIntType_.getWidth()))); + Value minusOneIntType = b_.create( + extIntType_, APInt::getAllOnes(extIntType_.getWidth())); Value antiunit = b_.create(arith::CmpIPredicate::eq, f, minusOneIntType); diff --git a/zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp b/zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp index e0639e94a..864ab1736 100644 --- a/zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp +++ b/zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp @@ -155,7 +155,7 @@ struct ConvertNegate : public OpConversionPattern { cast(intType), IntegerAttr::get(getElementTypeOrSelf(intType), 0))); } else { - zero = b.create(IntegerAttr::get(intType, 0)); + zero = b.create(intType, 0); } auto cmp = b.create(arith::CmpIPredicate::eq, adaptor.getInput(), zero); @@ -462,7 +462,7 @@ struct ConvertDouble : public OpConversionPattern { auto intType = modType.getModulus().getType(); Value cmod = b.create(modulusAttr(op)); - Value one = b.create(IntegerAttr::get(intType, 1)); + Value one = b.create(intType, 1); auto shifted = b.create(adaptor.getInput(), one); auto ifge = b.create(arith::CmpIPredicate::uge, shifted, cmod); @@ -629,7 +629,7 @@ MulExtendedResult squareExtended(ImplicitLocOpBuilder &b, Op op, Value input) { const unsigned numLimbs = (modBitWidth + limbWidth - 1) / limbWidth; Type limbType = IntegerType::get(b.getContext(), limbWidth); - Value zeroLimb = b.create(IntegerAttr::get(limbType, 0)); + Value zeroLimb = b.create(limbType, 0); auto decomposeToLimbs = [&b, limbType, limbWidth, numLimbs]( SmallVector &limbs, Value input, @@ -640,8 +640,7 @@ MulExtendedResult squareExtended(ImplicitLocOpBuilder &b, Op op, Value input) { } limbs[0] = b.create(limbType, input); Value remaining = input; - Value shift = - b.create(IntegerAttr::get(type, limbWidth)); + Value shift = b.create(type, limbWidth); for (unsigned i = 1; i < limbs.size(); ++i) { remaining = b.create(remaining, shift); limbs[i] = b.create(limbType, remaining); @@ -687,19 +686,17 @@ MulExtendedResult squareExtended(ImplicitLocOpBuilder &b, Op op, Value input) { } // Reconstruct a single integer value by combining all limbs - Value result = b.create(IntegerAttr::get(resultType, 0)); + Value result = b.create(resultType, 0); for (unsigned i = 0; i < 2 * numLimbs; ++i) { Value rAtI = b.create(resultType, resultVec[i]); Value shifted = b.create( - rAtI, b.create( - IntegerAttr::get(resultType, i * limbWidth))); + rAtI, b.create(resultType, i * limbWidth)); result = b.create(result, shifted); } // Multiply result by 2. It's safe to assume no overflow result = b.create( - result, b.create(IntegerAttr::get(resultType, 1)), - noOverflow); + result, b.create(resultType, 1), noOverflow); decomposeToLimbs(resultVec, result, resultType); @@ -719,7 +716,7 @@ MulExtendedResult squareExtended(ImplicitLocOpBuilder &b, Op op, Value input) { } // Reconstruct `lo` and `hi` values by composing individual limbs - Value zero = b.create(IntegerAttr::get(intType, 0)); + Value zero = b.create(intType, 0); Value resultLow = zero; Value resultHigh = zero; for (unsigned i = 0; i < 2 * numLimbs; ++i) { @@ -728,13 +725,12 @@ MulExtendedResult squareExtended(ImplicitLocOpBuilder &b, Op op, Value input) { : b.create(intType, resultVec[i]); if (i < numLimbs) { auto shifted = b.create( - rAtI, b.create( - IntegerAttr::get(intType, i * limbWidth))); + rAtI, b.create(intType, i * limbWidth)); resultLow = b.create(resultLow, shifted); } else { auto shifted = b.create( - rAtI, b.create( - IntegerAttr::get(intType, (i - numLimbs) * limbWidth))); + rAtI, + b.create(intType, (i - numLimbs) * limbWidth)); resultHigh = b.create(resultHigh, shifted); } } @@ -766,8 +762,8 @@ struct ConvertSquare : public OpConversionPattern { MulExtendedResult result = squareExtended(b, op, adaptor.getInput()); Value lowExt = b.create(mulResultType, result.lo); Value highExt = b.create(mulResultType, result.hi); - Value shift = b.create(IntegerAttr::get( - mulResultType, resultType.getModulus().getValue().getBitWidth())); + Value shift = b.create( + mulResultType, resultType.getModulus().getValue().getBitWidth()); highExt = b.create(highExt, shift); Value squared = b.create(lowExt, highExt); @@ -888,7 +884,7 @@ void ModArithToArith::runOnOperation() { ConvertAny, ConvertAny, ConvertAny, - ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, @@ -931,7 +927,7 @@ void ModArithToArith::runOnOperation() { arith::SelectOp, bufferization::AllocTensorOp, bufferization::MaterializeInDestinationOp, - bufferization::ToMemrefOp, + bufferization::ToBufferOp, bufferization::ToTensorOp, linalg::BroadcastOp, linalg::GenericOp, diff --git a/zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp b/zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp index 8dace1b89..027689acf 100644 --- a/zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp +++ b/zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp @@ -265,9 +265,9 @@ static Value fastNTT(ImplicitLocOpBuilder &b, NTTOpAdaptor adaptor, // Create a memref buffer for in-place updates auto memrefType = MemRefType::get(intTensorType.getShape(), coeffType); - Value srcMemref = b.create(memrefType, source, + Value srcMemref = b.create(memrefType, source, /*read_only=*/true); - Value destMemref = b.create(memrefType, dest); + Value destMemref = b.create(memrefType, dest); // Begin the outer loop over the stages of the NTT. // The iterative loop carries three values: @@ -439,7 +439,7 @@ struct ConvertNTT : public OpConversionPattern { } // NOTE(batzor): We should not use `dest` operand for the destination - // here. Otherwise, writable `ToMemrefOp` will be called twice on the same + // here. Otherwise, writable `ToBufferOp` will be called twice on the same // `dest` SSA Value causing conflict and force memory copy. nttResult = fastNTT(b, adaptor, bitReversed.getResult(), bitReversed.getResult(), nttMappingAttr); diff --git a/zkir/Dialect/TensorExt/Conversions/TensorExtToTensor/TensorExtToTensor.cpp b/zkir/Dialect/TensorExt/Conversions/TensorExtToTensor/TensorExtToTensor.cpp index 6c3bbc67f..e9698957b 100644 --- a/zkir/Dialect/TensorExt/Conversions/TensorExtToTensor/TensorExtToTensor.cpp +++ b/zkir/Dialect/TensorExt/Conversions/TensorExtToTensor/TensorExtToTensor.cpp @@ -38,10 +38,10 @@ struct ConvertBitReverse : public OpConversionPattern { auto c1 = b.create(1); auto cN = b.create(numCoeffs); auto sourceMemref = - b.create(memrefType, adaptor.getSource(), + b.create(memrefType, adaptor.getSource(), /*read_only=*/true); auto destMemref = - b.create(memrefType, adaptor.getDest()); + b.create(memrefType, adaptor.getDest()); auto parallelOp = b.create( /*lowerBound=*/ValueRange{c0}, /*lowerBound=*/ValueRange{cN}, diff --git a/zkir/Utils/ShapedTypeConverter.cpp b/zkir/Utils/ShapedTypeConverter.cpp index 5238da2f3..57a5ddf2f 100644 --- a/zkir/Utils/ShapedTypeConverter.cpp +++ b/zkir/Utils/ShapedTypeConverter.cpp @@ -36,7 +36,7 @@ Type ShapedTypeConverter::convertShapedType(ShapedType oldType, } else if (auto tensorType = dyn_cast(oldType)) { return tensorType.cloneWith(shape, elementType); } - assert(false && "Unsupported shaped type"); + llvm_unreachable("Unsupported shaped type"); return oldType; }