diff --git a/.noir-sync-commit b/.noir-sync-commit index d995cb8c0aae..4ebd701fd95e 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -b7ace682af1ab8a43308457302f08b151af342db +49a095ded5cd33795bcdac60cbd98ce7c5ab9198 diff --git a/noir/noir-repo/.github/benchmark_projects.yml b/noir/noir-repo/.github/benchmark_projects.yml new file mode 100644 index 000000000000..17eb7aa96a67 --- /dev/null +++ b/noir/noir-repo/.github/benchmark_projects.yml @@ -0,0 +1,99 @@ +define: &AZ_COMMIT 6c0b83d4b73408f87acfa080d52a81c411e47336 +projects: + private-kernel-inner: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/private-kernel-inner + num_runs: 5 + compilation-timeout: 2.5 + execution-timeout: 0.08 + compilation-memory-limit: 300 + execution-memory-limit: 250 + private-kernel-tail: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/private-kernel-tail + num_runs: 5 + timeout: 4 + compilation-timeout: 1.2 + execution-timeout: 0.02 + compilation-memory-limit: 250 + execution-memory-limit: 200 + private-kernel-reset: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/private-kernel-reset + num_runs: 5 + timeout: 250 + compilation-timeout: 8 + execution-timeout: 0.35 + compilation-memory-limit: 750 + execution-memory-limit: 300 + rollup-base-private: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-base-private + num_runs: 5 + timeout: 15 + compilation-timeout: 10 + execution-timeout: 0.5 + compilation-memory-limit: 1100 + execution-memory-limit: 500 + rollup-base-public: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-base-public + num_runs: 5 + timeout: 15 + compilation-timeout: 8 + execution-timeout: 0.4 + compilation-memory-limit: 1000 + execution-memory-limit: 500 + rollup-block-root-empty: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-empty + cannot_execute: true + num_runs: 5 + timeout: 60 + compilation-timeout: 1.5 + compilation-memory-limit: 400 + rollup-block-root-single-tx: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-single-tx + cannot_execute: true + num_runs: 1 + timeout: 60 + compilation-timeout: 100 + compilation-memory-limit: 7000 + rollup-block-root: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-block-root + num_runs: 1 + timeout: 60 + compilation-timeout: 100 + execution-timeout: 40 + compilation-memory-limit: 7000 + execution-memory-limit: 1500 + rollup-merge: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-merge + num_runs: 5 + timeout: 300 + compilation-timeout: 1.5 + execution-timeout: 0.01 + compilation-memory-limit: 400 + execution-memory-limit: 400 + rollup-root: + repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT + path: noir-projects/noir-protocol-circuits/crates/rollup-root + num_runs: 5 + timeout: 300 + compilation-timeout: 2 + execution-timeout: 0.6 + compilation-memory-limit: 450 + execution-memory-limit: 400 diff --git a/noir/noir-repo/.github/workflows/reports.yml b/noir/noir-repo/.github/workflows/reports.yml index 74967b901c8b..4e49b2507527 100644 --- a/noir/noir-repo/.github/workflows/reports.yml +++ b/noir/noir-repo/.github/workflows/reports.yml @@ -7,6 +7,23 @@ on: pull_request: jobs: + benchmark-projects-list: + name: Load benchmark projects list + runs-on: ubuntu-22.04 + outputs: + projects: ${{ steps.get_bench_projects.outputs.projects }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Build list of projects + id: get_bench_projects + run: | + PROJECTS=$(yq ./.github/benchmark_projects.yml -o json | jq -c '.projects | map(.)') + echo "projects=$PROJECTS" + echo "projects=$PROJECTS" >> $GITHUB_OUTPUT + build-nargo: runs-on: ubuntu-22.04 @@ -98,10 +115,10 @@ jobs: ./gates_report_brillig.sh 9223372036854775807 jq '.programs |= map(.package_name |= (. + "_inliner_max"))' gates_report_brillig.json > ./reports/gates_report_brillig_inliner_max.json - + ./gates_report_brillig.sh 0 jq '.programs |= map(.package_name |= (. + "_inliner_zero"))' gates_report_brillig.json > ./reports/gates_report_brillig_inliner_zero.json - + ./gates_report_brillig.sh -9223372036854775808 jq '.programs |= map(.package_name |= (. + "_inliner_min"))' gates_report_brillig.json > ./reports/gates_report_brillig_inliner_min.json @@ -143,14 +160,14 @@ jobs: - name: Generate Brillig execution report working-directory: ./test_programs run: | - mkdir ./reports - + mkdir ./reports + ./gates_report_brillig_execution.sh 9223372036854775807 jq '.programs |= map(.package_name |= (. + "_inliner_max"))' gates_report_brillig_execution.json > ./reports/gates_report_brillig_execution_inliner_max.json ./gates_report_brillig_execution.sh 0 jq '.programs |= map(.package_name |= (. + "_inliner_zero"))' gates_report_brillig_execution.json > ./reports/gates_report_brillig_execution_inliner_zero.json - + ./gates_report_brillig_execution.sh -9223372036854775808 jq '.programs |= map(.package_name |= (. + "_inliner_min"))' gates_report_brillig_execution.json > ./reports/gates_report_brillig_execution_inliner_min.json @@ -241,7 +258,7 @@ jobs: run: | ./execution_report.sh 0 1 mv execution_report.json ../execution_report.json - + - name: Upload compilation report uses: actions/upload-artifact@v4 with: @@ -249,7 +266,7 @@ jobs: path: compilation_report.json retention-days: 3 overwrite: true - + - name: Upload execution report uses: actions/upload-artifact@v4 with: @@ -259,25 +276,13 @@ jobs: overwrite: true external_repo_compilation_and_execution_report: - needs: [build-nargo] + needs: [build-nargo, benchmark-projects-list] runs-on: ubuntu-22.04 timeout-minutes: 15 strategy: fail-fast: false matrix: - include: - # - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-contracts, cannot_execute: true } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-inner, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-tail, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-reset, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-base-private, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-base-public, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-merge, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-empty, num_runs: 5, cannot_execute: true } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-single-tx, num_runs: 1, flags: "--skip-brillig-constraints-check --skip-underconstrained-check", cannot_execute: true } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root, num_runs: 1, flags: "--skip-brillig-constraints-check --skip-underconstrained-check" } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-merge, num_runs: 5 } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-root, num_runs: 5 } + project: ${{ fromJson( needs.benchmark-projects-list.outputs.projects )}} name: External repo compilation and execution reports - ${{ matrix.project.repo }}/${{ matrix.project.path }} steps: @@ -300,7 +305,7 @@ jobs: repository: ${{ matrix.project.repo }} path: test-repo ref: ${{ matrix.project.ref }} - + - name: Fetch noir dependencies working-directory: ./test-repo/${{ matrix.project.path }} run: | @@ -318,11 +323,11 @@ jobs: ./compilation_report.sh 1 ${{ matrix.project.num_runs }} env: FLAGS: ${{ matrix.project.flags }} - + - name: Check compilation time limit run: | TIME=$(jq '.[0].value' ./test-repo/${{ matrix.project.path }}/compilation_report.json) - TIME_LIMIT=80 + TIME_LIMIT=${{ matrix.project.compilation-timeout }} if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$TIME" "$TIME_LIMIT"; then # Don't bump this timeout without understanding why this has happened and confirming that you're not the cause. echo "Failing due to compilation exceeding timeout..." @@ -338,12 +343,12 @@ jobs: mv /home/runner/work/noir/noir/scripts/test_programs/execution_report.sh ./execution_report.sh mv /home/runner/work/noir/noir/scripts/test_programs/parse_time.sh ./parse_time.sh ./execution_report.sh 1 ${{ matrix.project.num_runs }} - + - name: Check execution time limit if: ${{ !matrix.project.cannot_execute }} run: | TIME=$(jq '.[0].value' ./test-repo/${{ matrix.project.path }}/execution_report.json) - TIME_LIMIT=60 + TIME_LIMIT=${{ matrix.project.execution-timeout }} if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$TIME" "$TIME_LIMIT"; then # Don't bump this timeout without understanding why this has happened and confirming that you're not the cause. echo "Failing due to execution exceeding timeout..." @@ -370,7 +375,7 @@ jobs: PACKAGE_NAME=$(basename $PACKAGE_NAME) mv ./test-repo/${{ matrix.project.path }}/execution_report.json ./execution_report_$PACKAGE_NAME.json echo "execution_report_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT - + - name: Upload compilation report uses: actions/upload-artifact@v4 with: @@ -388,28 +393,15 @@ jobs: overwrite: true external_repo_memory_report: - needs: [build-nargo] + needs: [build-nargo, benchmark-projects-list] runs-on: ubuntu-22.04 timeout-minutes: 30 strategy: fail-fast: false matrix: - include: - # TODO: Bring this report back under a flag. The `noir-contracts` report takes just under 30 minutes. - # - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-contracts } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-inner } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-tail } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/private-kernel-reset } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-base-private } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-base-public } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-merge } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-empty, cannot_execute: true } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root-single-tx, flags: "--skip-brillig-constraints-check --skip-underconstrained-check", cannot_execute: true } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-block-root, flags: "--skip-brillig-constraints-check --skip-underconstrained-check" } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-merge } - - project: { repo: AztecProtocol/aztec-packages, path: noir-projects/noir-protocol-circuits/crates/rollup-root } - - name: External repo memory report - ${{ matrix.project.repo }}/${{ matrix.project.path }} + include: ${{ fromJson( needs.benchmark-projects-list.outputs.projects )}} + + name: External repo memory report - ${{ matrix.repo }}/${{ matrix.path }} steps: - uses: actions/checkout@v4 with: @@ -419,19 +411,19 @@ jobs: test_programs/memory_report.sh test_programs/parse_memory.sh sparse-checkout-cone-mode: false - + - name: Download nargo binary uses: ./scripts/.github/actions/download-nargo - name: Checkout uses: actions/checkout@v4 with: - repository: ${{ matrix.project.repo }} + repository: ${{ matrix.repo }} path: test-repo - ref: ${{ matrix.project.ref }} + ref: ${{ matrix.ref }} - name: Generate compilation memory report - working-directory: ./test-repo/${{ matrix.project.path }} + working-directory: ./test-repo/${{ matrix.path }} run: | mv /home/runner/work/noir/noir/scripts/test_programs/memory_report.sh ./memory_report.sh mv /home/runner/work/noir/noir/scripts/test_programs/parse_memory.sh ./parse_memory.sh @@ -439,7 +431,31 @@ jobs: # Rename the memory report as the execution report is about to write to the same file cp memory_report.json compilation_memory_report.json env: - FLAGS: ${{ matrix.project.flags }} + FLAGS: ${{ matrix.flags }} + + - name: Check compilation memory limit + run: | + MEMORY=$(jq '.[0].value' ./test-repo/${{ matrix.path }}/compilation_memory_report.json) + MEMORY_LIMIT=${{ matrix.compilation-memory-limit }} + if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$MEMORY" "$MEMORY_LIMIT"; then + # Don't bump this limit without understanding why this has happened and confirming that you're not the cause. + echo "Failing due to compilation exceeding memory limit..." + echo "Limit: "$MEMORY_LIMIT"MB" + echo "Compilation took: "$MEMORY"MB". + exit 1 + fi + + - name: Check compilation memory limit + run: | + MEMORY=$(jq '.[0].value' ./test-repo/${{ matrix.project.path }}/compilation_memory_report.json) + MEMORY_LIMIT=6000 + if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$MEMORY" "$MEMORY_LIMIT"; then + # Don't bump this limit without understanding why this has happened and confirming that you're not the cause. + echo "Failing due to compilation exceeding memory limit..." + echo "Limit: "$MEMORY_LIMIT"MB" + echo "Compilation took: "$MEMORY"MB". + exit 1 + fi - name: Check compilation memory limit run: | @@ -454,16 +470,16 @@ jobs: fi - name: Generate execution memory report - working-directory: ./test-repo/${{ matrix.project.path }} - if: ${{ !matrix.project.cannot_execute }} + working-directory: ./test-repo/${{ matrix.path }} + if: ${{ !matrix.cannot_execute }} run: | ./memory_report.sh 1 1 - name: Check execution memory limit - if: ${{ !matrix.project.cannot_execute }} + if: ${{ !matrix.cannot_execute }} run: | - MEMORY=$(jq '.[0].value' ./test-repo/${{ matrix.project.path }}/memory_report.json) - MEMORY_LIMIT=1300 + MEMORY=$(jq '.[0].value' ./test-repo/${{ matrix.path }}/memory_report.json) + MEMORY_LIMIT=${{ matrix.execution-memory-limit }} if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$MEMORY" "$MEMORY_LIMIT"; then # Don't bump this limit without understanding why this has happened and confirming that you're not the cause. echo "Failing due to execution exceeding memory limit..." @@ -476,9 +492,9 @@ jobs: id: compilation_mem_report shell: bash run: | - PACKAGE_NAME=${{ matrix.project.path }} + PACKAGE_NAME=${{ matrix.path }} PACKAGE_NAME=$(basename $PACKAGE_NAME) - mv ./test-repo/${{ matrix.project.path }}/compilation_memory_report.json ./memory_report_$PACKAGE_NAME.json + mv ./test-repo/${{ matrix.path }}/compilation_memory_report.json ./memory_report_$PACKAGE_NAME.json echo "memory_report_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT - name: Upload compilation memory report @@ -491,11 +507,11 @@ jobs: - name: Move execution report id: execution_mem_report - if: ${{ !matrix.project.cannot_execute }} + if: ${{ !matrix.cannot_execute }} run: | - PACKAGE_NAME=${{ matrix.project.path }} + PACKAGE_NAME=${{ matrix.path }} PACKAGE_NAME=$(basename $PACKAGE_NAME) - mv ./test-repo/${{ matrix.project.path }}/memory_report.json ./memory_report_$PACKAGE_NAME.json + mv ./test-repo/${{ matrix.path }}/memory_report.json ./memory_report_$PACKAGE_NAME.json echo "memory_report_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT - name: Upload execution memory report @@ -507,10 +523,10 @@ jobs: overwrite: true upload_compilation_report: - name: Upload compilation report + name: Upload compilation report needs: [generate_compilation_and_execution_report, external_repo_compilation_and_execution_report] # We want this job to run even if one variation of the matrix in `external_repo_compilation_and_execution_report` fails - if: always() + if: always() runs-on: ubuntu-22.04 permissions: pull-requests: write @@ -558,10 +574,10 @@ jobs: max-items-in-chart: 50 upload_compilation_memory_report: - name: Upload compilation memory report + name: Upload compilation memory report needs: [generate_memory_report, external_repo_memory_report] # We want this job to run even if one variation of the matrix in `external_repo_memory_report` fails - if: always() + if: always() runs-on: ubuntu-22.04 permissions: pull-requests: write @@ -608,10 +624,10 @@ jobs: max-items-in-chart: 50 upload_execution_memory_report: - name: Upload execution memory report + name: Upload execution memory report needs: [generate_memory_report, external_repo_memory_report] # We want this job to run even if one variation of the matrix in `external_repo_memory_report` fails - if: always() + if: always() runs-on: ubuntu-22.04 permissions: pull-requests: write @@ -659,10 +675,10 @@ jobs: upload_execution_report: - name: Upload execution report + name: Upload execution report needs: [generate_compilation_and_execution_report, external_repo_compilation_and_execution_report] # We want this job to run even if one variation of the matrix in `external_repo_compilation_and_execution_report` fails - if: always() + if: always() runs-on: ubuntu-22.04 permissions: pull-requests: write @@ -720,7 +736,7 @@ jobs: - upload_compilation_memory_report - upload_execution_report - upload_execution_memory_report - + steps: - name: Report overall success run: | diff --git a/noir/noir-repo/.github/workflows/test-js-packages.yml b/noir/noir-repo/.github/workflows/test-js-packages.yml index 2e0f226b0a29..fe573b608893 100644 --- a/noir/noir-repo/.github/workflows/test-js-packages.yml +++ b/noir/noir-repo/.github/workflows/test-js-packages.yml @@ -492,9 +492,9 @@ jobs: strategy: fail-fast: false matrix: - project: ${{ fromJson( needs.critical-library-list.outputs.libraries )}} - - name: Check external repo - ${{ matrix.project.repo }}/${{ matrix.project.path }} + include: ${{ fromJson( needs.critical-library-list.outputs.libraries )}} + + name: Check external repo - ${{ matrix.repo }}/${{ matrix.path }} steps: - name: Checkout uses: actions/checkout@v4 @@ -504,9 +504,9 @@ jobs: - name: Checkout uses: actions/checkout@v4 with: - repository: ${{ matrix.project.repo }} + repository: ${{ matrix.repo }} path: test-repo - ref: ${{ matrix.project.ref }} + ref: ${{ matrix.ref }} - name: Download nargo binary uses: ./noir-repo/.github/actions/download-nargo @@ -520,14 +520,14 @@ jobs: - name: Run nargo test id: test_report - working-directory: ./test-repo/${{ matrix.project.path }} + working-directory: ./test-repo/${{ matrix.path }} run: | - output_file=${{ github.workspace }}/noir-repo/.github/critical_libraries_status/${{ matrix.project.repo }}/${{ matrix.project.path }}.actual.jsonl + output_file=${{ github.workspace }}/noir-repo/.github/critical_libraries_status/${{ matrix.repo }}/${{ matrix.path }}.actual.jsonl BEFORE=$SECONDS - nargo test --silence-warnings --skip-brillig-constraints-check --format json ${{ matrix.project.nargo_args }} | tee $output_file + nargo test --silence-warnings --skip-brillig-constraints-check --format json ${{ matrix.nargo_args }} | tee $output_file TIME=$(($SECONDS-$BEFORE)) - NAME=${{ matrix.project.repo }}/${{ matrix.project.path }} + NAME=${{ matrix.repo }}/${{ matrix.path }} # Replace any slashes with underscores NAME=${NAME//\//_} TEST_REPORT_NAME=test_report_$NAME @@ -541,28 +541,28 @@ jobs: fi env: NARGO_IGNORE_TEST_FAILURES_FROM_FOREIGN_CALLS: true - + - name: Check test time limit run: | - TIME=$(jq '.[0].value' ./test-repo/${{ matrix.project.path }}/${{ steps.test_report.outputs.test_report_name }}.json) - if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$TIME" "${{ matrix.project.timeout }}"; then + TIME=$(jq '.[0].value' ./test-repo/${{ matrix.path }}/${{ steps.test_report.outputs.test_report_name }}.json) + if awk 'BEGIN{exit !(ARGV[1]>ARGV[2])}' "$TIME" "${{ matrix.timeout }}"; then # Don't bump this timeout without understanding why this has happened and confirming that you're not the cause. echo "Failing due to test suite exceeding timeout..." - echo "Timeout: ${{ matrix.project.timeout }}" + echo "Timeout: ${{ matrix.timeout }}" echo "Test suite took: $TIME". exit 1 fi - name: Compare test results working-directory: ./noir-repo - run: .github/scripts/check_test_results.sh .github/critical_libraries_status/${{ matrix.project.repo }}/${{ matrix.project.path }}.failures.jsonl .github/critical_libraries_status/${{ matrix.project.repo }}/${{ matrix.project.path }}.actual.jsonl + run: .github/scripts/check_test_results.sh .github/critical_libraries_status/${{ matrix.repo }}/${{ matrix.path }}.failures.jsonl .github/critical_libraries_status/${{ matrix.repo }}/${{ matrix.path }}.actual.jsonl - name: Upload test report - if: ${{ matrix.project.timeout > 10 }} # We want to avoid recording benchmarking for a ton of tiny libraries, these should be covered with aggressive timeouts + if: ${{ matrix.timeout > 10 }} # We want to avoid recording benchmarking for a ton of tiny libraries, these should be covered with aggressive timeouts uses: actions/upload-artifact@v4 with: name: ${{ steps.test_report.outputs.test_report_name }} - path: ./test-repo/${{ matrix.project.path }}/${{ steps.test_report.outputs.test_report_name }}.json + path: ./test-repo/${{ matrix.path }}/${{ steps.test_report.outputs.test_report_name }}.json retention-days: 3 overwrite: true @@ -576,7 +576,7 @@ jobs: uses: actions/checkout@v4 with: path: noir-repo - + - name: Checkout uses: actions/checkout@v4 with: @@ -595,13 +595,13 @@ jobs: - name: Run nargo compile working-directory: ./test-repo/noir-projects/noir-contracts - run: nargo compile --inliner-aggressiveness 0 + run: nargo compile --inliner-aggressiveness 0 upload_critical_library_report: - name: Upload critical library report + name: Upload critical library report needs: [external-repo-checks] # We want this job to run even if one variation of the matrix in `external-repo-checks` fails - if: always() + if: always() runs-on: ubuntu-22.04 permissions: pull-requests: write diff --git a/noir/noir-repo/Cargo.lock b/noir/noir-repo/Cargo.lock index f1162d2bcf68..a522ccb3bed5 100644 --- a/noir/noir-repo/Cargo.lock +++ b/noir/noir-repo/Cargo.lock @@ -3221,6 +3221,7 @@ dependencies = [ "noirc_driver", "noirc_errors", "noirc_frontend", + "num-bigint", "rayon", "serde", "serde_json", diff --git a/noir/noir-repo/EXTERNAL_NOIR_LIBRARIES.yml b/noir/noir-repo/EXTERNAL_NOIR_LIBRARIES.yml index 2d41e8c5ec5b..0ea374928031 100644 --- a/noir/noir-repo/EXTERNAL_NOIR_LIBRARIES.yml +++ b/noir/noir-repo/EXTERNAL_NOIR_LIBRARIES.yml @@ -1,4 +1,4 @@ - +define: &AZ_COMMIT 6c0b83d4b73408f87acfa080d52a81c411e47336 libraries: noir_check_shuffle: repo: noir-lang/noir_check_shuffle @@ -47,34 +47,42 @@ libraries: timeout: 3 aztec_nr: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/aztec-nr timeout: 60 noir_contracts: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-contracts timeout: 80 blob: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/blob timeout: 70 protocol_circuits_parity_lib: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/parity-lib timeout: 4 protocol_circuits_private_kernel_lib: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/private-kernel-lib timeout: 250 protocol_circuits_reset_kernel_lib: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/reset-kernel-lib timeout: 15 protocol_circuits_types: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/types timeout: 60 protocol_circuits_rollup_lib: repo: AztecProtocol/aztec-packages + ref: *AZ_COMMIT path: noir-projects/noir-protocol-circuits/crates/rollup-lib timeout: 300 # Use 1 test threads for rollup-lib because each test requires a lot of memory, and multiple ones in parallel exceed the maximum memory limit. diff --git a/noir/noir-repo/acvm-repo/acir_field/src/field_element.rs b/noir/noir-repo/acvm-repo/acir_field/src/field_element.rs index 0249b410aa74..8afc76da9d89 100644 --- a/noir/noir-repo/acvm-repo/acir_field/src/field_element.rs +++ b/noir/noir-repo/acvm-repo/acir_field/src/field_element.rs @@ -274,12 +274,15 @@ impl AcirField for FieldElement { } fn to_be_bytes(self) -> Vec { - // to_be_bytes! uses little endian which is why we reverse the output - // TODO: Add a little endian equivalent, so the caller can use whichever one - // TODO they desire + let mut bytes = self.to_le_bytes(); + bytes.reverse(); + bytes + } + + /// Converts the field element to a vector of bytes in little-endian order + fn to_le_bytes(self) -> Vec { let mut bytes = Vec::new(); self.0.serialize_uncompressed(&mut bytes).unwrap(); - bytes.reverse(); bytes } @@ -289,6 +292,12 @@ impl AcirField for FieldElement { FieldElement(F::from_be_bytes_mod_order(bytes)) } + /// Converts bytes in little-endian order into a FieldElement and applies a + /// reduction if needed. + fn from_le_bytes_reduce(bytes: &[u8]) -> FieldElement { + FieldElement(F::from_le_bytes_mod_order(bytes)) + } + /// Returns the closest number of bytes to the bits specified /// This method truncates fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec { @@ -405,6 +414,50 @@ mod tests { assert_eq!(max_num_bits_bn254, 254); } + proptest! { + #[test] + fn test_endianness_prop(value in any::()) { + let field = FieldElement::::from(value); + // Test serialization consistency + let le_bytes = field.to_le_bytes(); + let be_bytes = field.to_be_bytes(); + + let mut reversed_le = le_bytes.clone(); + reversed_le.reverse(); + prop_assert_eq!(&be_bytes, &reversed_le, "BE bytes should be reverse of LE bytes"); + + // Test deserialization consistency + let from_le = FieldElement::from_le_bytes_reduce(&le_bytes); + let from_be = FieldElement::from_be_bytes_reduce(&be_bytes); + prop_assert_eq!(from_le, from_be, "Deserialization should be consistent between LE and BE"); + prop_assert_eq!(from_le, field, "Deserialized value should match original"); + } + } + + #[test] + fn test_endianness() { + let field = FieldElement::::from(0x1234_5678_u32); + let le_bytes = field.to_le_bytes(); + let be_bytes = field.to_be_bytes(); + + // Check that the bytes are reversed between BE and LE + let mut reversed_le = le_bytes.clone(); + reversed_le.reverse(); + assert_eq!(&be_bytes, &reversed_le); + + // Verify we can reconstruct the same field element from either byte order + let from_le = FieldElement::from_le_bytes_reduce(&le_bytes); + let from_be = FieldElement::from_be_bytes_reduce(&be_bytes); + assert_eq!(from_le, from_be); + assert_eq!(from_le, field); + + // Additional test with a larger number to ensure proper byte handling + let large_field = FieldElement::::from(0x0123_4567_89AB_CDEF_u64); + let large_le = large_field.to_le_bytes(); + let reconstructed = FieldElement::from_le_bytes_reduce(&large_le); + assert_eq!(reconstructed, large_field); + } + proptest! { // This currently panics due to the fact that we allow inputs which are greater than the field modulus, // automatically reducing them to fit within the canonical range. diff --git a/noir/noir-repo/acvm-repo/acir_field/src/generic_ark.rs b/noir/noir-repo/acvm-repo/acir_field/src/generic_ark.rs index 74927c07a363..04761dd1ed0f 100644 --- a/noir/noir-repo/acvm-repo/acir_field/src/generic_ark.rs +++ b/noir/noir-repo/acvm-repo/acir_field/src/generic_ark.rs @@ -75,6 +75,12 @@ pub trait AcirField: /// Converts bytes into a FieldElement and applies a reduction if needed. fn from_be_bytes_reduce(bytes: &[u8]) -> Self; + /// Converts bytes in little-endian order into a FieldElement and applies a reduction if needed. + fn from_le_bytes_reduce(bytes: &[u8]) -> Self; + + /// Converts the field element to a vector of bytes in little-endian order + fn to_le_bytes(self) -> Vec; + /// Returns the closest number of bytes to the bits specified /// This method truncates fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index a2cfb6be5adf..957ebc2b0695 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -4,7 +4,7 @@ pub(crate) mod brillig_block_variables; pub(crate) mod brillig_fn; pub(crate) mod brillig_globals; pub(crate) mod brillig_slice_ops; -mod constant_allocation; +pub(crate) mod constant_allocation; mod variable_liveness; use acvm::FieldElement; @@ -20,7 +20,7 @@ use super::{ }; use crate::{ errors::InternalError, - ssa::ir::{call_stack::CallStack, function::Function}, + ssa::ir::{call_stack::CallStack, function::Function, types::NumericType}, }; /// Converting an SSA function into Brillig bytecode. @@ -28,6 +28,7 @@ pub(crate) fn convert_ssa_function( func: &Function, options: &BrilligOptions, globals: &HashMap, + hoisted_global_constants: &HashMap<(FieldElement, NumericType), BrilligVariable>, ) -> BrilligArtifact { let mut brillig_context = BrilligContext::new(options); @@ -44,6 +45,7 @@ pub(crate) fn convert_ssa_function( block, &func.dfg, globals, + hoisted_global_constants, ); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index fea7f57d3f88..40dd825be35d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -25,11 +25,13 @@ use acvm::{acir::AcirField, FieldElement}; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use iter_extended::vecmap; use num_bigint::BigUint; +use std::collections::BTreeSet; use std::sync::Arc; use super::brillig_black_box::convert_black_box_call; -use super::brillig_block_variables::BlockVariables; +use super::brillig_block_variables::{allocate_value_with_type, BlockVariables}; use super::brillig_fn::FunctionContext; +use super::brillig_globals::HoistedConstantsToBrilligGlobals; use super::constant_allocation::InstructionLocation; /// Generate the compilation artifacts for compiling a function into brillig bytecode. @@ -45,6 +47,7 @@ pub(crate) struct BrilligBlock<'block, Registers: RegisterAllocator> { pub(crate) last_uses: HashMap>, pub(crate) globals: &'block HashMap, + pub(crate) hoisted_global_constants: &'block HoistedConstantsToBrilligGlobals, pub(crate) building_globals: bool, } @@ -57,11 +60,17 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { block_id: BasicBlockId, dfg: &DataFlowGraph, globals: &'block HashMap, + hoisted_global_constants: &'block HoistedConstantsToBrilligGlobals, ) { let live_in = function_context.liveness.get_live_in(&block_id); let mut live_in_no_globals = HashSet::default(); for value in live_in { + if let Value::NumericConstant { constant, typ } = dfg[*value] { + if hoisted_global_constants.contains_key(&(constant, typ)) { + continue; + } + } if !dfg.is_global(*value) { live_in_no_globals.insert(*value); } @@ -85,6 +94,7 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { variables, last_uses, globals, + hoisted_global_constants, building_globals: false, }; @@ -95,7 +105,8 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { &mut self, globals: &DataFlowGraph, used_globals: &HashSet, - ) { + hoisted_global_constants: &BTreeSet<(FieldElement, NumericType)>, + ) -> HashMap<(FieldElement, NumericType), BrilligVariable> { for (id, value) in globals.values_iter() { if !used_globals.contains(&id) { continue; @@ -114,6 +125,16 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { } } } + + let mut new_hoisted_constants = HashMap::default(); + for (constant, typ) in hoisted_global_constants.iter().copied() { + let new_variable = allocate_value_with_type(self.brillig_context, Type::Numeric(typ)); + self.brillig_context.const_instruction(new_variable.extract_single_addr(), constant); + if new_hoisted_constants.insert((constant, typ), new_variable).is_some() { + unreachable!("ICE: ({constant:?}, {typ:?}) was already in cache"); + } + } + new_hoisted_constants } fn convert_block(&mut self, dfg: &DataFlowGraph) { @@ -746,7 +767,10 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { let index_variable = self.convert_ssa_single_addr_value(*index, dfg); - if !dfg.is_safe_index(*index, *array) { + // Slice access checks are generated separately against the slice's dynamic length field. + if matches!(dfg.type_of_value(*array), Type::Array(..)) + && !dfg.is_safe_index(*index, *array) + { self.validate_array_index(array_variable, index_variable); } @@ -774,7 +798,10 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { dfg, ); - if !dfg.is_safe_index(*index, *array) { + // Slice access checks are generated separately against the slice's dynamic length field. + if matches!(dfg.type_of_value(*array), Type::Array(..)) + && !dfg.is_safe_index(*index, *array) + { self.validate_array_index(source_variable, index_register); } @@ -950,7 +977,8 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { for dead_variable in dead_variables { // Globals are reserved throughout the entirety of the program - if !dfg.is_global(*dead_variable) { + let not_hoisted_global = self.get_hoisted_global(dfg, *dead_variable).is_none(); + if !dfg.is_global(*dead_variable) && not_hoisted_global { self.variables.remove_variable( dead_variable, self.function_context, @@ -1668,6 +1696,10 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { let value_id = dfg.resolve(value_id); let value = &dfg[value_id]; + if let Some(variable) = self.get_hoisted_global(dfg, value_id) { + return variable; + } + match value { Value::Global(_) => { unreachable!("Expected global value to be resolve to its inner value"); @@ -2008,6 +2040,19 @@ impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { } } } + + fn get_hoisted_global( + &self, + dfg: &DataFlowGraph, + value_id: ValueId, + ) -> Option { + if let Value::NumericConstant { constant, typ } = &dfg[value_id] { + if let Some(variable) = self.hoisted_global_constants.get(&(*constant, *typ)) { + return Some(*variable); + } + } + None + } } /// Returns the type of the operation considering the types of the operands diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs index 04c43df9d7e6..66f0980ac639 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use acvm::FieldElement; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; @@ -6,9 +6,11 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use super::{ BrilligArtifact, BrilligBlock, BrilligVariable, Function, FunctionContext, Label, ValueId, }; -use crate::brillig::{ - brillig_ir::BrilligContext, called_functions_vec, Brillig, BrilligOptions, DataFlowGraph, - FunctionId, Instruction, Value, +use crate::brillig::{Brillig, BrilligOptions, FunctionId}; +use crate::{ + brillig::{brillig_ir::BrilligContext, ConstantAllocation, DataFlowGraph}, + ssa::ir::types::NumericType, + ssa::opt::brillig_entry_points::{build_inner_call_to_entry_points, get_brillig_entry_points}, }; /// Context structure for generating Brillig globals @@ -24,124 +26,101 @@ pub(crate) struct BrilligGlobals { /// Maps a Brillig entry point to all functions called in that entry point. /// This includes any nested calls as well, as we want to be able to associate /// any Brillig function with the appropriate global allocations. - brillig_entry_points: HashMap>, + brillig_entry_points: BTreeMap>, + /// Maps a Brillig entry point to constants shared across the entry point and its nested calls. + hoisted_global_constants: HashMap, /// Maps an inner call to its Brillig entry point /// This is simply used to simplify fetching global allocations when compiling /// individual Brillig functions. - inner_call_to_entry_point: HashMap>, + inner_call_to_entry_point: HashMap>, /// Final map that associated an entry point with its Brillig global allocations entry_point_globals_map: HashMap, + /// Final map that associates an entry point with any local function constants + /// that are shared and were hoisted to the global space. + /// This map is kept separate from `entry_point_globals_map` to clearly distinguish + /// the two types of globals. + entry_point_hoisted_globals_map: HashMap, } /// Mapping of SSA value ids to their Brillig allocations pub(crate) type SsaToBrilligGlobals = HashMap; +pub(crate) type HoistedConstantsToBrilligGlobals = + HashMap<(FieldElement, NumericType), BrilligVariable>; +/// Mapping of a constant value and the number of functions in which it occurs +pub(crate) type ConstantCounterMap = HashMap<(FieldElement, NumericType), usize>; + impl BrilligGlobals { pub(crate) fn new( functions: &BTreeMap, mut used_globals: HashMap>, main_id: FunctionId, ) -> Self { - let mut brillig_entry_points = HashMap::default(); - let acir_functions = functions.iter().filter(|(_, func)| func.runtime().is_acir()); - for (_, function) in acir_functions { - for block_id in function.reachable_blocks() { - for instruction_id in function.dfg[block_id].instructions() { - let instruction = &function.dfg[*instruction_id]; - let Instruction::Call { func: func_id, arguments: _ } = instruction else { - continue; - }; - - let func_value = &function.dfg[*func_id]; - let Value::Function(func_id) = func_value else { continue }; - - let called_function = &functions[func_id]; - if called_function.runtime().is_acir() { - continue; - } - - // We have now found a Brillig entry point. - // Let's recursively build a call graph to determine any functions - // whose parent is this entry point and any globals used in those internal calls. - brillig_entry_points.insert(*func_id, HashSet::default()); - Self::mark_entry_points_calls_recursive( - functions, - *func_id, - called_function, - &mut used_globals, - &mut brillig_entry_points, - im::HashSet::new(), - ); - } - } - } + let brillig_entry_points = get_brillig_entry_points(functions, main_id); - // If main has been marked as Brillig, it is itself an entry point. - // Run the same analysis from above on main. - let main_func = &functions[&main_id]; - if main_func.runtime().is_brillig() { - brillig_entry_points.insert(main_id, HashSet::default()); - Self::mark_entry_points_calls_recursive( - functions, - main_id, - main_func, - &mut used_globals, - &mut brillig_entry_points, - im::HashSet::new(), + let mut hoisted_global_constants: HashMap = + HashMap::default(); + // Mark any globals used in a Brillig entry point. + // Using the information collected we can determine which globals + // an entry point must initialize. + for (entry_point, entry_point_inner_calls) in brillig_entry_points.iter() { + Self::mark_globals_for_hoisting( + &mut hoisted_global_constants, + *entry_point, + &functions[entry_point], ); - } - // NB: Temporary fix to override entry point analysis - let merged_set = - used_globals.values().flat_map(|set| set.iter().copied()).collect::>(); - for set in used_globals.values_mut() { - *set = merged_set.clone(); + for inner_call in entry_point_inner_calls.iter() { + Self::mark_globals_for_hoisting( + &mut hoisted_global_constants, + *entry_point, + &functions[inner_call], + ); + + let inner_globals = used_globals + .get(inner_call) + .expect("Should have a slot for each function") + .clone(); + used_globals + .get_mut(entry_point) + .expect("ICE: should have func") + .extend(inner_globals); + } } - Self { used_globals, brillig_entry_points, ..Default::default() } + let inner_call_to_entry_point = build_inner_call_to_entry_points(&brillig_entry_points); + + Self { + used_globals, + brillig_entry_points, + inner_call_to_entry_point, + hoisted_global_constants, + ..Default::default() + } } - /// Recursively mark any functions called in an entry point as well as - /// any globals used in those functions. - /// Using the information collected we can determine which globals - /// an entry point must initialize. - fn mark_entry_points_calls_recursive( - functions: &BTreeMap, + /// Helper for marking that a constant was instantiated in a given function. + /// For a given entry point, we want to determine which constants are shared across multiple functions. + fn mark_globals_for_hoisting( + hoisted_global_constants: &mut HashMap, entry_point: FunctionId, - called_function: &Function, - used_globals: &mut HashMap>, - brillig_entry_points: &mut HashMap>, - mut explored_functions: im::HashSet, + function: &Function, ) { - if explored_functions.insert(called_function.id()).is_some() { - return; - } - - let inner_calls = called_functions_vec(called_function).into_iter().collect::>(); - - for inner_call in inner_calls { - let inner_globals = used_globals - .get(&inner_call) - .expect("Should have a slot for each function") - .clone(); - used_globals - .get_mut(&entry_point) - .expect("ICE: should have func") - .extend(inner_globals); - - if let Some(inner_calls) = brillig_entry_points.get_mut(&entry_point) { - inner_calls.insert(inner_call); + // We can potentially have multiple local constants with the same value and type + let constants = ConstantAllocation::from_function(function); + for constant in constants.get_constants() { + let value = function.dfg.get_numeric_constant(constant); + let value = value.unwrap(); + let typ = function.dfg.type_of_value(constant); + if !function.dfg.is_global(constant) { + hoisted_global_constants + .entry(entry_point) + .or_default() + .entry((value, typ.unwrap_numeric())) + .and_modify(|counter| *counter += 1) + .or_insert(1); } - - Self::mark_entry_points_calls_recursive( - functions, - entry_point, - &functions[&inner_call], - used_globals, - brillig_entry_points, - explored_functions.clone(), - ); } } @@ -151,77 +130,122 @@ impl BrilligGlobals { brillig: &mut Brillig, options: &BrilligOptions, ) { - // Map for fetching the correct entry point globals when compiling any function - let mut inner_call_to_entry_point: HashMap> = - HashMap::default(); let mut entry_point_globals_map = HashMap::default(); + let mut entry_point_hoisted_globals_map = HashMap::default(); + // We only need to generate globals for entry points - for (entry_point, entry_point_inner_calls) in self.brillig_entry_points.iter() { + for (entry_point, _) in self.brillig_entry_points.iter() { let entry_point = *entry_point; - for inner_call in entry_point_inner_calls { - inner_call_to_entry_point.entry(*inner_call).or_default().push(entry_point); - } - let used_globals = self.used_globals.remove(&entry_point).unwrap_or_default(); - let (artifact, brillig_globals, globals_size) = - convert_ssa_globals(options, globals_dfg, &used_globals, entry_point); + // Select set of constants which can be hoisted from function's to the global memory space + // for a given entry point. + let hoisted_global_constants = self + .hoisted_global_constants + .remove(&entry_point) + .unwrap_or_default() + .iter() + .filter_map( + |(&value, &num_occurrences)| { + if num_occurrences > 1 { + Some(value) + } else { + None + } + }, + ) + .collect(); + let (artifact, brillig_globals, globals_size, hoisted_global_constants) = + convert_ssa_globals( + options, + globals_dfg, + &used_globals, + &hoisted_global_constants, + entry_point, + ); entry_point_globals_map.insert(entry_point, brillig_globals); + entry_point_hoisted_globals_map.insert(entry_point, hoisted_global_constants); brillig.globals.insert(entry_point, artifact); brillig.globals_memory_size.insert(entry_point, globals_size); } - self.inner_call_to_entry_point = inner_call_to_entry_point; self.entry_point_globals_map = entry_point_globals_map; + self.entry_point_hoisted_globals_map = entry_point_hoisted_globals_map; } /// Fetch the global allocations that can possibly be accessed /// by any given Brillig function (non-entry point or entry point). /// The allocations available to a function are determined by its entry point. /// For a given function id input, this function will search for that function's - /// entry point (or multiple entry points) and fetch the global allocations - /// associated with those entry points. + /// entry point and fetch the global allocations associated with that entry point. /// These allocations can then be used when compiling the Brillig function /// and resolving global variables. pub(crate) fn get_brillig_globals( &self, brillig_function_id: FunctionId, - ) -> SsaToBrilligGlobals { - let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id); + ) -> Option<(&SsaToBrilligGlobals, &HoistedConstantsToBrilligGlobals)> { + // Check whether `brillig_function_id` is itself an entry point. + // If so, return the global allocations directly. + let entry_point_globals = self.get_entry_point_globals(&brillig_function_id); + if entry_point_globals.is_some() { + return entry_point_globals; + } - let mut globals_allocations = HashMap::default(); - if let Some(entry_points) = entry_points { - // A Brillig function is used by multiple entry points. Fetch both globals allocations - // in case one is used by the internal call. - let entry_point_allocations = entry_points - .iter() - .flat_map(|entry_point| self.entry_point_globals_map.get(entry_point)) - .collect::>(); - for map in entry_point_allocations { - globals_allocations.extend(map); - } - } else if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) { - // If there is no mapping from an inner call to an entry point, that means `brillig_function_id` - // is itself an entry point and we can fetch the global allocations directly from `self.entry_point_globals_map`. - // vec![globals] - globals_allocations.extend(globals); - } else { + let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id); + let Some(entry_points) = entry_points else { unreachable!( "ICE: Expected global allocation to be set for function {brillig_function_id}" ); + }; + + // Sanity check: We should have guaranteed earlier that an inner call has only a single entry point + assert_eq!(entry_points.len(), 1, "{brillig_function_id} has multiple entry points"); + let entry_point = entry_points.first().expect("ICE: Inner call should have an entry point"); + + self.get_entry_point_globals(entry_point) + } + + /// Fetch the global allocations for a given entry point. + /// This contains both the user specified globals, as well as any constants shared + /// across functions that have been hoisted into the global space. + fn get_entry_point_globals( + &self, + entry_point: &FunctionId, + ) -> Option<(&SsaToBrilligGlobals, &HoistedConstantsToBrilligGlobals)> { + if let (Some(globals), Some(hoisted_constants)) = ( + self.entry_point_globals_map.get(entry_point), + self.entry_point_hoisted_globals_map.get(entry_point), + ) { + Some((globals, hoisted_constants)) + } else { + None } - globals_allocations } } +/// A globals artifact containing all information necessary for utilizing +/// globals from SSA during Brillig code generation. +pub(crate) type BrilligGlobalsArtifact = ( + // The actual bytecode declaring globals and any metadata needing for linking + BrilligArtifact, + // The SSA value -> Brillig global allocations + // This will be used for fetching global values when compiling functions to Brillig. + HashMap, + // The size of the global memory + usize, + // Duplicate SSA constants local to a function -> Brillig global allocations + HashMap<(FieldElement, NumericType), BrilligVariable>, +); + pub(crate) fn convert_ssa_globals( options: &BrilligOptions, globals_dfg: &DataFlowGraph, used_globals: &HashSet, + hoisted_global_constants: &BTreeSet<(FieldElement, NumericType)>, entry_point: FunctionId, -) -> (BrilligArtifact, HashMap, usize) { +) -> BrilligGlobalsArtifact { let mut brillig_context = BrilligContext::new_for_global_init(options, entry_point); // The global space does not have globals itself let empty_globals = HashMap::default(); @@ -238,23 +262,25 @@ pub(crate) fn convert_ssa_globals( variables: Default::default(), last_uses: HashMap::default(), globals: &empty_globals, + hoisted_global_constants: &HashMap::default(), building_globals: true, }; - brillig_block.compile_globals(globals_dfg, used_globals); + let hoisted_global_constants = + brillig_block.compile_globals(globals_dfg, used_globals, hoisted_global_constants); let globals_size = brillig_context.global_space_size(); brillig_context.return_instruction(); let artifact = brillig_context.artifact(); - (artifact, function_context.ssa_value_allocations, globals_size) + (artifact, function_context.ssa_value_allocations, globals_size, hoisted_global_constants) } #[cfg(test)] mod tests { use acvm::{ - acir::brillig::{BitSize, Opcode}, + acir::brillig::{BitSize, IntegerBitSize, Opcode}, FieldElement, }; @@ -262,6 +288,8 @@ mod tests { brillig_ir::registers::RegisterAllocator, BrilligOptions, GlobalSpace, LabelType, Ssa, }; + use super::ConstantAllocation; + #[test] fn entry_points_different_globals() { let src = " @@ -312,11 +340,10 @@ mod tests { if func_id.to_u32() == 1 { assert_eq!( artifact.byte_code.len(), - 2, + 1, "Expected just a `Return`, but got more than a single opcode" ); - // TODO: Bring this back (https://github.com/noir-lang/noir/issues/7306) - // assert!(matches!(&artifact.byte_code[0], Opcode::Return)); + assert!(matches!(&artifact.byte_code[0], Opcode::Return)); } else if func_id.to_u32() == 2 { assert_eq!( artifact.byte_code.len(), @@ -430,17 +457,16 @@ mod tests { if func_id.to_u32() == 1 { assert_eq!( artifact.byte_code.len(), - 30, + 2, "Expected enough opcodes to initialize the globals" ); - // TODO: Bring this back (https://github.com/noir-lang/noir/issues/7306) - // let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else { - // panic!("First opcode is expected to be `Const`"); - // }; - // assert_eq!(destination.unwrap_direct(), GlobalSpace::start()); - // assert!(matches!(bit_size, BitSize::Field)); - // assert_eq!(*value, FieldElement::from(1u128)); - // assert!(matches!(&artifact.byte_code[1], Opcode::Return)); + let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else { + panic!("First opcode is expected to be `Const`"); + }; + assert_eq!(destination.unwrap_direct(), GlobalSpace::start()); + assert!(matches!(bit_size, BitSize::Field)); + assert_eq!(*value, FieldElement::from(1u128)); + assert!(matches!(&artifact.byte_code[1], Opcode::Return)); } else if func_id.to_u32() == 2 || func_id.to_u32() == 3 { // We want the entry point which uses globals (f2) and the entry point which calls f2 function internally (f3 through f4) // to have the same globals initialized. @@ -460,4 +486,154 @@ mod tests { } } } + + #[test] + fn hoist_shared_constants() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: Field, v1: Field): + call f1(v0, v1) + return + } + brillig(inline) predicate_pure fn entry_point f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + v4 = add v2, Field 1 + v6 = eq v4, Field 5 + constrain v6 == u1 0 + call f2(v0, v1) + return + } + brillig(inline) predicate_pure fn inner_func f2 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 20 + constrain v3 == u1 0 + v5 = add v0, v1 + v7 = add v5, Field 10 + v9 = add v7, Field 1 + v11 = eq v9, Field 20 + constrain v11 == u1 0 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + // Need to run DIE to generate the used globals map, which is necessary for Brillig globals generation. + let mut ssa = ssa.dead_instruction_elimination(); + + // Show that the constants in each function have different SSA value IDs + for (func_id, function) in &ssa.functions { + let constant_allocation = ConstantAllocation::from_function(function); + let mut constants = constant_allocation.get_constants().into_iter().collect::>(); + // We want to order the constants by ID + constants.sort(); + if func_id.to_u32() == 1 { + assert_eq!(constants.len(), 3); + let one = function.dfg.get_numeric_constant(constants[0]).unwrap(); + assert_eq!(one, FieldElement::from(1u128)); + let five = function.dfg.get_numeric_constant(constants[1]).unwrap(); + assert_eq!(five, FieldElement::from(5u128)); + let zero = function.dfg.get_numeric_constant(constants[2]).unwrap(); + assert_eq!(zero, FieldElement::from(0u128)); + } else if func_id.to_u32() == 2 { + assert_eq!(constants.len(), 4); + let twenty = function.dfg.get_numeric_constant(constants[0]).unwrap(); + assert_eq!(twenty, FieldElement::from(20u128)); + let zero = function.dfg.get_numeric_constant(constants[1]).unwrap(); + assert_eq!(zero, FieldElement::from(0u128)); + let ten = function.dfg.get_numeric_constant(constants[2]).unwrap(); + assert_eq!(ten, FieldElement::from(10u128)); + let one = function.dfg.get_numeric_constant(constants[3]).unwrap(); + assert_eq!(one, FieldElement::from(1u128)); + } + } + + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(&BrilligOptions::default(), used_globals_map); + + assert_eq!(brillig.globals.len(), 1, "Should have a single entry point"); + for (func_id, artifact) in brillig.globals { + assert_eq!(func_id.to_u32(), 1); + assert_eq!( + artifact.byte_code.len(), + 3, + "Expected enough opcodes to initialize the hoisted constants" + ); + let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else { + panic!("First opcode is expected to be `Const`"); + }; + assert_eq!(destination.unwrap_direct(), GlobalSpace::start()); + assert!(matches!(bit_size, BitSize::Integer(IntegerBitSize::U1))); + assert_eq!(*value, FieldElement::from(0u128)); + + let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[1] else { + panic!("First opcode is expected to be `Const`"); + }; + assert_eq!(destination.unwrap_direct(), GlobalSpace::start() + 1); + assert!(matches!(bit_size, BitSize::Field)); + assert_eq!(*value, FieldElement::from(1u128)); + + assert!(matches!(&artifact.byte_code[2], Opcode::Return)); + } + } + + #[test] + fn do_not_hoist_shared_constants_different_entry_points() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: Field, v1: Field): + call f1(v0, v1) + call f2(v0, v1) + return + } + brillig(inline) predicate_pure fn entry_point f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + v4 = add v2, Field 1 + v6 = eq v4, Field 5 + constrain v6 == u1 0 + return + } + brillig(inline) predicate_pure fn entry_point_two f2 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 20 + constrain v3 == u1 0 + v5 = add v0, v1 + v7 = add v5, Field 10 + v9 = add v7, Field 1 + v10 = eq v9, Field 20 + constrain v10 == u1 0 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + // Need to run DIE to generate the used globals map, which is necessary for Brillig globals generation. + let mut ssa = ssa.dead_instruction_elimination(); + + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(&BrilligOptions::default(), used_globals_map); + + assert_eq!( + brillig.globals.len(), + 2, + "Should have a globals artifact associated with each entry point" + ); + for (func_id, mut artifact) in brillig.globals { + let labels = artifact.take_labels(); + // When entering a context two labels are created. + // One is a context label and another is a section label. + assert_eq!(labels.len(), 2); + for (label, position) in labels { + assert_eq!(label.label_type, LabelType::GlobalInit(func_id)); + assert_eq!(position, 0); + } + assert_eq!( + artifact.byte_code.len(), + 1, + "Expected enough opcodes to initialize the hoisted constants" + ); + assert!(matches!(&artifact.byte_code[0], Opcode::Return)); + } + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index 1ec2d165b121..99645f84ed3d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -179,6 +179,7 @@ mod tests { use crate::ssa::function_builder::FunctionBuilder; use crate::ssa::ir::function::RuntimeType; use crate::ssa::ir::map::Id; + use crate::ssa::ir::types::NumericType; use crate::ssa::ssa_gen::Ssa; fn create_test_environment() -> (Ssa, FunctionContext, BrilligContext) { @@ -197,6 +198,7 @@ mod tests { function_context: &'a mut FunctionContext, brillig_context: &'a mut BrilligContext, globals: &'a HashMap, + hoisted_global_constants: &'a HashMap<(FieldElement, NumericType), BrilligVariable>, ) -> BrilligBlock<'a, Stack> { let variables = BlockVariables::default(); BrilligBlock { @@ -206,6 +208,7 @@ mod tests { variables, last_uses: Default::default(), globals, + hoisted_global_constants, building_globals: false, } } @@ -249,8 +252,13 @@ mod tests { let target_vector = BrilligVector { pointer: context.allocate_register() }; let brillig_globals = HashMap::default(); - let mut block = - create_brillig_block(&mut function_context, &mut context, &brillig_globals); + let hoisted_globals = HashMap::default(); + let mut block = create_brillig_block( + &mut function_context, + &mut context, + &brillig_globals, + &hoisted_globals, + ); if push_back { block.slice_push_back_operation( @@ -367,8 +375,13 @@ mod tests { }; let brillig_globals = HashMap::default(); - let mut block = - create_brillig_block(&mut function_context, &mut context, &brillig_globals); + let hoisted_globals = HashMap::default(); + let mut block = create_brillig_block( + &mut function_context, + &mut context, + &brillig_globals, + &hoisted_globals, + ); if pop_back { block.slice_pop_back_operation( @@ -475,8 +488,13 @@ mod tests { let target_vector = BrilligVector { pointer: context.allocate_register() }; let brillig_globals = HashMap::default(); - let mut block = - create_brillig_block(&mut function_context, &mut context, &brillig_globals); + let hoisted_globals = HashMap::default(); + let mut block = create_brillig_block( + &mut function_context, + &mut context, + &brillig_globals, + &hoisted_globals, + ); block.slice_insert_operation( target_vector, @@ -617,8 +635,13 @@ mod tests { }; let brillig_globals = HashMap::default(); - let mut block = - create_brillig_block(&mut function_context, &mut context, &brillig_globals); + let hoisted_globals = HashMap::default(); + let mut block = create_brillig_block( + &mut function_context, + &mut context, + &brillig_globals, + &hoisted_globals, + ); block.slice_remove_operation( target_vector, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs index 64741393dd70..d49876ef1d92 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs @@ -163,6 +163,10 @@ impl ConstantAllocation { } current_block } + + pub(crate) fn get_constants(&self) -> HashSet { + self.constant_usage.keys().copied().collect() + } } pub(crate) fn is_constant_value(id: ValueId, dfg: &DataFlowGraph) -> bool { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs index a7e188602ce6..901ab07e263f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod brillig_ir; use acvm::FieldElement; use brillig_gen::brillig_globals::BrilligGlobals; +use brillig_gen::constant_allocation::ConstantAllocation; use brillig_ir::{artifact::LabelType, brillig_variable::BrilligVariable, registers::GlobalSpace}; use self::{ @@ -17,10 +18,9 @@ use crate::ssa::{ ir::{ dfg::DataFlowGraph, function::{Function, FunctionId}, - instruction::Instruction, - value::{Value, ValueId}, + types::NumericType, + value::ValueId, }, - opt::inlining::called_functions_vec, ssa_gen::Ssa, }; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; @@ -52,8 +52,9 @@ impl Brillig { func: &Function, options: &BrilligOptions, globals: &HashMap, + hoisted_global_constants: &HashMap<(FieldElement, NumericType), BrilligVariable>, ) { - let obj = convert_ssa_function(func, options, globals); + let obj = convert_ssa_function(func, options, globals, hoisted_global_constants); self.ssa_function_to_brillig.insert(func.id(), obj); } @@ -122,10 +123,14 @@ impl Ssa { brillig_globals.declare_globals(&globals_dfg, &mut brillig, options); for brillig_function_id in brillig_reachable_function_ids { - let globals_allocations = brillig_globals.get_brillig_globals(brillig_function_id); + let empty_allocations = HashMap::default(); + let empty_const_allocations = HashMap::default(); + let (globals_allocations, hoisted_constant_allocations) = brillig_globals + .get_brillig_globals(brillig_function_id) + .unwrap_or((&empty_allocations, &empty_const_allocations)); let func = &self.functions[&brillig_function_id]; - brillig.compile(func, options, &globals_allocations); + brillig.compile(func, options, globals_allocations, hoisted_constant_allocations); } brillig diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs index 6d051da45503..74196d9a7666 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs @@ -131,8 +131,8 @@ pub(crate) fn optimize_into_acir( .run_pass(|ssa| ssa.fold_constants_with_brillig(&brillig), "Inlining Brillig Calls Inlining") // It could happen that we inlined all calls to a given brillig function. // In that case it's unused so we can remove it. This is what we check next. - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (3rd)") - .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (4th)") + .run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (3rd)") .finish(); if !options.skip_underconstrained_check { @@ -217,6 +217,12 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Ssa { + if self.main().runtime().is_brillig() { + return self; + } + + let brillig_entry_points = get_brillig_entry_points(&self.functions, self.main_id); + + let functions_to_clone_map = build_functions_to_clone(&brillig_entry_points); + + let calls_to_update = build_calls_to_update(&mut self, functions_to_clone_map); + + let mut new_functions_map = HashMap::default(); + for (entry_point, inner_calls) in brillig_entry_points { + let new_entry_point = + new_functions_map.get(&entry_point).copied().unwrap_or(entry_point); + + let function = + self.functions.get_mut(&new_entry_point).expect("ICE: Function does not exist"); + update_function_calls(function, entry_point, &mut new_functions_map, &calls_to_update); + + for inner_call in inner_calls { + let new_inner_call = + new_functions_map.get(&inner_call).copied().unwrap_or(inner_call); + + let function = + self.functions.get_mut(&new_inner_call).expect("ICE: Function does not exist"); + update_function_calls( + function, + entry_point, + &mut new_functions_map, + &calls_to_update, + ); + } + } + + self + } +} + +/// For every call site, we can determine the entry point for a given callee. +/// Once we know that we can determine which functions are in need of duplication. +/// We duplicate when the following occurs: +/// 1. A function is called from two different entry points +/// 2. An entry point function is called from another entry point +fn build_functions_to_clone( + brillig_entry_points: &BTreeMap>, +) -> HashMap> { + let inner_call_to_entry_point = build_inner_call_to_entry_points(brillig_entry_points); + let entry_points = brillig_entry_points.keys().copied().collect::>(); + + let mut functions_to_clone_map: HashMap> = HashMap::default(); + + for (inner_call, inner_call_entry_points) in inner_call_to_entry_point { + let should_clone = inner_call_entry_points.len() > 1 || entry_points.contains(&inner_call); + if should_clone { + for entry_point in inner_call_entry_points { + functions_to_clone_map.entry(entry_point).or_default().push(inner_call); + } + } + } + + functions_to_clone_map +} + +/// Clones new functions and returns a mapping representing the calls to update. +/// +/// Returns a map of (entry point, callee function) -> new callee function id. +fn build_calls_to_update( + ssa: &mut Ssa, + functions_to_clone_map: HashMap>, +) -> HashMap<(FunctionId, FunctionId), FunctionId> { + let mut calls_to_update: HashMap<(FunctionId, FunctionId), FunctionId> = HashMap::default(); + + for (entry_point, functions_to_clone) in functions_to_clone_map { + for old_id in functions_to_clone { + let function = ssa.functions[&old_id].clone(); + ssa.add_fn(|id| { + calls_to_update.insert((entry_point, old_id), id); + Function::clone_with_id(id, &function) + }); + } + } + + calls_to_update +} + +fn update_function_calls( + function: &mut Function, + entry_point: FunctionId, + new_functions_map: &mut HashMap, + // Maps (entry point, callee function) -> new callee function id + calls_to_update: &HashMap<(FunctionId, FunctionId), FunctionId>, +) { + for block_id in function.reachable_blocks() { + #[allow(clippy::unnecessary_to_owned)] // clippy is wrong here + for instruction_id in function.dfg[block_id].instructions().to_vec() { + let instruction = function.dfg[instruction_id].clone(); + let Instruction::Call { func: func_id, arguments } = instruction else { + continue; + }; + + let func_value = &function.dfg[func_id]; + let Value::Function(func_id) = func_value else { continue }; + let Some(new_id) = calls_to_update.get(&(entry_point, *func_id)) else { + continue; + }; + + new_functions_map.insert(*func_id, *new_id); + let new_function_value_id = function.dfg.import_function(*new_id); + function.dfg[instruction_id] = + Instruction::Call { func: new_function_value_id, arguments }; + } + } +} + +/// Returns a map of Brillig entry points to all functions called in that entry point. +/// This includes any nested calls as well, as we want to be able to associate +/// any Brillig function with the appropriate global allocations. +pub(crate) fn get_brillig_entry_points( + functions: &BTreeMap, + main_id: FunctionId, +) -> BTreeMap> { + let mut brillig_entry_points = BTreeMap::default(); + let acir_functions = functions.iter().filter(|(_, func)| func.runtime().is_acir()); + for (_, function) in acir_functions { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + let called_function = &functions[func_id]; + if called_function.runtime().is_acir() { + continue; + } + + // We have now found a Brillig entry point. + brillig_entry_points.insert(*func_id, BTreeSet::default()); + build_entry_points_map_recursive( + functions, + *func_id, + *func_id, + &mut brillig_entry_points, + im::HashSet::new(), + ); + } + } + } + + // If main has been marked as Brillig, it is itself an entry point. + // Run the same analysis from above on main. + let main_func = &functions[&main_id]; + if main_func.runtime().is_brillig() { + brillig_entry_points.insert(main_id, BTreeSet::default()); + build_entry_points_map_recursive( + functions, + main_id, + main_id, + &mut brillig_entry_points, + im::HashSet::new(), + ); + } + + brillig_entry_points +} + +/// Recursively mark any functions called in an entry point +fn build_entry_points_map_recursive( + functions: &BTreeMap, + entry_point: FunctionId, + called_function: FunctionId, + brillig_entry_points: &mut BTreeMap>, + mut explored_functions: im::HashSet, +) { + if explored_functions.insert(called_function).is_some() { + return; + } + + let inner_calls: HashSet = + called_functions_vec(&functions[&called_function]).into_iter().collect(); + + for inner_call in inner_calls { + if let Some(inner_calls) = brillig_entry_points.get_mut(&entry_point) { + inner_calls.insert(inner_call); + } + + build_entry_points_map_recursive( + functions, + entry_point, + inner_call, + brillig_entry_points, + explored_functions.clone(), + ); + } +} + +/// Builds a mapping from a [`FunctionId`] to the set of [`FunctionId`s][`FunctionId`] of all the brillig entrypoints +/// from which this function is reachable. +pub(crate) fn build_inner_call_to_entry_points( + brillig_entry_points: &BTreeMap>, +) -> HashMap> { + // Map for fetching the correct entry point globals when compiling any function + let mut inner_call_to_entry_point: HashMap> = + HashMap::default(); + + // We only need to generate globals for entry points + for (entry_point, entry_point_inner_calls) in brillig_entry_points.iter() { + for inner_call in entry_point_inner_calls { + inner_call_to_entry_point.entry(*inner_call).or_default().insert(*entry_point); + } + } + + inner_call_to_entry_point +} + +#[cfg(test)] +mod tests { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn duplicate_inner_call_with_multiple_entry_points() { + let src = " + g0 = Field 1 + g1 = Field 2 + g2 = Field 3 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v3: Field, v4: Field): + v5 = add g0, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v3: Field, v4: Field): + v5 = add g1, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + call f3(v3, v4) + return + } + brillig(inline) fn inner_func f3 { + b0(v3: Field, v4: Field): + v5 = add g2, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.brillig_entry_point_analysis(); + let ssa = ssa.remove_unreachable_functions(); + + // We expect `inner_func` to be duplicated + let expected = " + g0 = Field 1 + g1 = Field 2 + g2 = Field 3 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v3: Field, v4: Field): + v5 = add Field 1, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v3: Field, v4: Field): + v5 = add Field 2, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + call f4(v3, v4) + return + } + brillig(inline) fn inner_func f3 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + brillig(inline) fn inner_func f4 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn duplicate_inner_call_with_multiple_entry_points_nested() { + let src = " + g0 = Field 2 + g1 = Field 3 + + acir(inline) fn main f0 { + b0(v2: Field, v3: Field): + call f1(v2, v3) + call f2(v2, v3) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v2: Field, v3: Field): + v4 = add g0, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f3(v2, v3) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v2: Field, v3: Field): + v4 = add g0, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f3(v2, v3) + return + } + brillig(inline) fn inner_func f3 { + b0(v2: Field, v3: Field): + v4 = add g0, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f4(v2, v3) + return + } + brillig(inline) fn nested_inner_func f4 { + b0(v2: Field, v3: Field): + v4 = add g1, v2 + v5 = add v4, v3 + constrain v5 == Field 4 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.brillig_entry_point_analysis(); + let ssa = ssa.remove_unreachable_functions(); + + // We expect both `inner_func` and `nested_inner_func` to be duplicated + let expected = " + g0 = Field 2 + g1 = Field 3 + + acir(inline) fn main f0 { + b0(v2: Field, v3: Field): + call f1(v2, v3) + call f2(v2, v3) + return + } + brillig(inline) fn entry_point_one f1 { + b0(v2: Field, v3: Field): + v4 = add Field 2, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f4(v2, v3) + return + } + brillig(inline) fn entry_point_two f2 { + b0(v2: Field, v3: Field): + v4 = add Field 2, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f6(v2, v3) + return + } + brillig(inline) fn nested_inner_func f3 { + b0(v2: Field, v3: Field): + v4 = add Field 3, v2 + v5 = add v4, v3 + constrain v5 == Field 4 + return + } + brillig(inline) fn inner_func f4 { + b0(v2: Field, v3: Field): + v4 = add Field 2, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f3(v2, v3) + return + } + brillig(inline) fn nested_inner_func f5 { + b0(v2: Field, v3: Field): + v4 = add Field 3, v2 + v5 = add v4, v3 + constrain v5 == Field 4 + return + } + brillig(inline) fn inner_func f6 { + b0(v2: Field, v3: Field): + v4 = add Field 2, v2 + v5 = add v4, v3 + constrain v5 == Field 3 + call f5(v2, v3) + return + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn duplicate_entry_point_called_from_entry_points() { + // Check that we duplicate entry points that are also called from another entry point. + // In this test the entry points used in other entry points are f2 and f3. + // These functions are also called within the wrapper function f4, as we also want to make sure + // that we duplicate entry points called from another entry point's inner calls. + let src = " + g0 = Field 2 + g1 = Field 3 + g2 = Field 1 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_inner_func_globals f1 { + b0(v3: Field, v4: Field): + call f4(v3, v4) + return + } + brillig(inline) fn entry_point_one_global f2 { + b0(v3: Field, v4: Field): + v5 = add g0, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + return + } + brillig(inline) fn entry_point_one_diff_global f3 { + b0(v3: Field, v4: Field): + v5 = add g1, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + brillig(inline) fn wrapper f4 { + b0(v3: Field, v4: Field): + v5 = add g2, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f2(v3, v4) + call f3(v4, v3) + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.brillig_entry_point_analysis(); + + // We expect `entry_point_one_global` and `entry_point_one_diff_global` to be duplicated + let expected = " + g0 = Field 2 + g1 = Field 3 + g2 = Field 1 + + acir(inline) fn main f0 { + b0(v3: Field, v4: Field): + call f1(v3, v4) + call f2(v3, v4) + call f3(v3, v4) + return + } + brillig(inline) fn entry_point_inner_func_globals f1 { + b0(v3: Field, v4: Field): + call f4(v3, v4) + return + } + brillig(inline) fn entry_point_one_global f2 { + b0(v3: Field, v4: Field): + v5 = add Field 2, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + return + } + brillig(inline) fn entry_point_one_diff_global f3 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + brillig(inline) fn wrapper f4 { + b0(v3: Field, v4: Field): + v5 = add Field 1, v3 + v6 = add v5, v4 + constrain v6 == Field 2 + call f5(v3, v4) + call f6(v4, v3) + return + } + brillig(inline) fn entry_point_one_global f5 { + b0(v3: Field, v4: Field): + v5 = add Field 2, v3 + v6 = add v5, v4 + constrain v6 == Field 3 + return + } + brillig(inline) fn entry_point_one_diff_global f6 { + b0(v3: Field, v4: Field): + v5 = add Field 3, v3 + v6 = add v5, v4 + constrain v6 == Field 4 + return + } + "; + assert_normalized_ssa_equals(ssa, expected); + } +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 9adc73b64b84..19fc6a7f5a20 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -228,7 +228,6 @@ impl<'f> LoopInvariantContext<'f> { let can_be_deduplicated = instruction.can_be_deduplicated(self.inserter.function, false) || matches!(instruction, Instruction::MakeArray { .. }) - || matches!(instruction, Instruction::Binary(_)) || self.can_be_deduplicated_from_loop_bound(&instruction); is_loop_invariant && can_be_deduplicated @@ -313,13 +312,6 @@ impl<'f> LoopInvariantContext<'f> { binary: &Binary, induction_vars: &HashMap, ) -> bool { - if !matches!( - binary.operator, - BinaryOp::Add { .. } | BinaryOp::Mul { .. } | BinaryOp::Sub { .. } - ) { - return false; - } - let operand_type = self.inserter.function.dfg.type_of_value(binary.lhs).unwrap_numeric(); let lhs_const = self.inserter.function.dfg.get_numeric_constant_with_type(binary.lhs); @@ -330,7 +322,15 @@ impl<'f> LoopInvariantContext<'f> { induction_vars.get(&binary.lhs), induction_vars.get(&binary.rhs), ) { - (Some((lhs, _)), None, None, Some((_, upper_bound))) => (lhs, *upper_bound), + (Some((lhs, _)), None, None, Some((lower_bound, upper_bound))) => { + if matches!(binary.operator, BinaryOp::Div | BinaryOp::Mod) { + // If we have a Div/Mod operation we want to make sure that the + // lower bound is not zero. + (lhs, *lower_bound) + } else { + (lhs, *upper_bound) + } + } (None, Some((rhs, _)), Some((lower_bound, upper_bound)), None) => { if matches!(binary.operator, BinaryOp::Sub { .. }) { // If we are subtracting and the induction variable is on the lhs, @@ -343,7 +343,8 @@ impl<'f> LoopInvariantContext<'f> { _ => return false, }; - // We evaluate this expression using the upper bounds of its inputs to check whether it will ever overflow. + // We evaluate this expression using the upper bounds (or lower in the case of div/mod) + // of its inputs to check whether it will ever overflow. // If so, this will cause `eval_constant_binary_op` to return `None`. // Therefore a `Some` value shows that this operation is safe. eval_constant_binary_op(lhs, rhs, binary.operator, operand_type).is_some() @@ -870,8 +871,6 @@ mod test { b2(): return b3(): - v6 = mul v0, v1 - constrain v6 == u32 6 v8 = sub v2, u32 1 jmp b1(v8) } @@ -883,21 +882,116 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: u32, v1: u32): - v3 = mul v0, v1 jmp b1(u32 1) b1(v2: u32): - v6 = lt v2, u32 4 - jmpif v6 then: b3, else: b2 + v5 = lt v2, u32 4 + jmpif v5 then: b3, else: b2 b2(): return b3(): - constrain v3 == u32 6 - v8 = unchecked_sub v2, u32 1 - jmp b1(v8) + v6 = unchecked_sub v2, u32 1 + jmp b1(v6) } "; let ssa = ssa.loop_invariant_code_motion(); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn do_not_hoist_unsafe_div() { + // This test is similar to `nested_loop_invariant_code_motion`, the operation + // in question we are trying to hoist is `v9 = div i32 10, v0`. + // Check that the lower bound of the outer loop it checked and that we not + // hoist an operation that can potentially error with a division by zero. + let src = " + brillig(inline) fn main f0 { + b0(): + jmp b1(i32 0) + b1(v0: i32): + v4 = lt v0, i32 4 + jmpif v4 then: b3, else: b2 + b2(): + return + b3(): + jmp b4(i32 0) + b4(v1: i32): + v5 = lt v1, i32 4 + jmpif v5 then: b6, else: b5 + b5(): + v7 = unchecked_add v0, i32 1 + jmp b1(v7) + b6(): + v9 = div i32 10, v0 + constrain v9 == i32 6 + v11 = unchecked_add v1, i32 1 + jmp b4(v11) + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.loop_invariant_code_motion(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn hoist_safe_div() { + // This test is identical to `do_not_hoist_unsafe_div`, except the loop + // in this test starts with a lower bound of `1`. + let src = " + brillig(inline) fn main f0 { + b0(): + jmp b1(i32 1) + b1(v0: i32): + v4 = lt v0, i32 4 + jmpif v4 then: b3, else: b2 + b2(): + return + b3(): + jmp b4(i32 0) + b4(v1: i32): + v5 = lt v1, i32 4 + jmpif v5 then: b6, else: b5 + b5(): + v7 = unchecked_add v0, i32 1 + jmp b1(v7) + b6(): + v9 = div i32 10, v0 + constrain v9 == i32 6 + v11 = unchecked_add v1, i32 1 + jmp b4(v11) + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.loop_invariant_code_motion(); + let expected = " + brillig(inline) fn main f0 { + b0(): + jmp b1(i32 1) + b1(v0: i32): + v4 = lt v0, i32 4 + jmpif v4 then: b3, else: b2 + b2(): + return + b3(): + v6 = div i32 10, v0 + jmp b4(i32 0) + b4(v1: i32): + v8 = lt v1, i32 4 + jmpif v8 then: b6, else: b5 + b5(): + v9 = unchecked_add v0, i32 1 + jmp b1(v9) + b6(): + constrain v6 == i32 6 + v11 = unchecked_add v1, i32 1 + jmp b4(v11) + } + "; + + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 7ec419890c0f..4d8a652b94d9 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -7,6 +7,7 @@ mod array_set; mod as_slice_length; mod assert_constant; +pub(crate) mod brillig_entry_points; mod constant_folding; mod defunctionalize; mod die; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index c65bc9ba7cf5..91aa8b2914d2 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -3,6 +3,7 @@ mod program; mod value; use acvm::AcirField; +use noirc_frontend::hir_def::expr::Constructor; use noirc_frontend::token::FmtStrFragment; pub(crate) use program::Ssa; @@ -11,7 +12,7 @@ use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use noirc_frontend::ast::{UnaryOp, Visibility}; use noirc_frontend::hir_def::types::Type as HirType; -use noirc_frontend::monomorphization::ast::{self, Expression, Program}; +use noirc_frontend::monomorphization::ast::{self, Expression, MatchCase, Program, While}; use crate::{ errors::RuntimeError, @@ -153,7 +154,9 @@ impl<'a> FunctionContext<'a> { Expression::Cast(cast) => self.codegen_cast(cast), Expression::For(for_expr) => self.codegen_for(for_expr), Expression::Loop(block) => self.codegen_loop(block), + Expression::While(while_) => self.codegen_while(while_), Expression::If(if_expr) => self.codegen_if(if_expr), + Expression::Match(match_expr) => self.codegen_match(match_expr), Expression::Tuple(tuple) => self.codegen_tuple(tuple), Expression::ExtractTupleField(tuple, index) => { self.codegen_extract_tuple_field(tuple, *index) @@ -588,7 +591,7 @@ impl<'a> FunctionContext<'a> { Ok(Self::unit_value()) } - /// Codegens a loop, creating three new blocks in the process. + /// Codegens a loop, creating two new blocks in the process. /// The return value of a loop is always a unit literal. /// /// For example, the loop `loop { body }` is codegen'd as: @@ -620,6 +623,47 @@ impl<'a> FunctionContext<'a> { Ok(Self::unit_value()) } + /// Codegens a while loop, creating three new blocks in the process. + /// The return value of a while is always a unit literal. + /// + /// For example, the loop `while cond { body }` is codegen'd as: + /// + /// ```text + /// jmp while_entry() + /// while_entry: + /// v0 = ... codegen cond ... + /// jmpif v0, then: while_body, else: while_end + /// while_body(): + /// v3 = ... codegen body ... + /// jmp while_entry() + /// while_end(): + /// ... This is the current insert point after codegen_while finishes ... + /// ``` + fn codegen_while(&mut self, while_: &While) -> Result { + let while_entry = self.builder.insert_block(); + let while_body = self.builder.insert_block(); + let while_end = self.builder.insert_block(); + + self.builder.terminate_with_jmp(while_entry, vec![]); + + // Codegen the entry (where the condition is) + self.builder.switch_to_block(while_entry); + let condition = self.codegen_non_tuple_expression(&while_.condition)?; + self.builder.terminate_with_jmpif(condition, while_body, while_end); + + self.enter_loop(Loop { loop_entry: while_entry, loop_index: None, loop_end: while_end }); + + // Codegen the body + self.builder.switch_to_block(while_body); + self.codegen_expression(&while_.body)?; + self.builder.terminate_with_jmp(while_entry, vec![]); + + // Finish by switching to the end of the while + self.builder.switch_to_block(while_end); + self.exit_loop(); + Ok(Self::unit_value()) + } + /// Codegens an if expression, handling the case of what to do if there is no 'else'. /// /// For example, the expression `if cond { a } else { b }` is codegen'd as: @@ -710,6 +754,157 @@ impl<'a> FunctionContext<'a> { }) } + fn codegen_match(&mut self, match_expr: &ast::Match) -> Result { + let variable = self.lookup(match_expr.variable_to_match); + + // Any matches with only a single case we don't need to check the tag at all. + // Note that this includes all matches on struct / tuple values. + if match_expr.cases.len() == 1 && match_expr.default_case.is_none() { + return self.no_match(variable, &match_expr.cases[0]); + } + + // From here on we can assume `variable` is an enum, int, or bool value (not a struct/tuple) + let tag = self.enum_tag(&variable); + let tag_type = self.builder.type_of_value(tag).unwrap_numeric(); + + let end_block = self.builder.insert_block(); + + // Optimization: if there is no default case we can jump directly to the last case + // when finished with the previous case instead of using a jmpif with an unreachable + // else block. + let last_case = if match_expr.default_case.is_some() { + match_expr.cases.len() + } else { + match_expr.cases.len() - 1 + }; + + for i in 0..last_case { + let case = &match_expr.cases[i]; + let variant_tag = self.variant_index_value(&case.constructor, tag_type)?; + let eq = self.builder.insert_binary(tag, BinaryOp::Eq, variant_tag); + + let case_block = self.builder.insert_block(); + let else_block = self.builder.insert_block(); + self.builder.terminate_with_jmpif(eq, case_block, else_block); + + self.builder.switch_to_block(case_block); + self.bind_case_arguments(variable.clone(), case); + let results = self.codegen_expression(&case.branch)?.into_value_list(self); + self.builder.terminate_with_jmp(end_block, results); + + self.builder.switch_to_block(else_block); + } + + if let Some(branch) = &match_expr.default_case { + let results = self.codegen_expression(branch)?.into_value_list(self); + self.builder.terminate_with_jmp(end_block, results); + } else { + // If there is no default case, assume we saved the last case from the + // last_case optimization above + let case = match_expr.cases.last().unwrap(); + self.bind_case_arguments(variable, case); + let results = self.codegen_expression(&case.branch)?.into_value_list(self); + self.builder.terminate_with_jmp(end_block, results); + } + + self.builder.switch_to_block(end_block); + let result = Self::map_type(&match_expr.typ, |typ| { + self.builder.add_block_parameter(end_block, typ).into() + }); + Ok(result) + } + + fn variant_index_value( + &mut self, + constructor: &Constructor, + typ: NumericType, + ) -> Result { + match constructor { + Constructor::Int(value) => { + self.checked_numeric_constant(value.field, value.is_negative, typ) + } + other => Ok(self.builder.numeric_constant(other.variant_index(), typ)), + } + } + + fn no_match(&mut self, variable: Values, case: &MatchCase) -> Result { + if !case.arguments.is_empty() { + self.bind_case_arguments(variable, case); + } + self.codegen_expression(&case.branch) + } + + /// Extracts the tag value from an enum. Assumes enums are represented as a tuple + /// where the tag is always the first field of the tuple. + /// + /// If the enum is only a single Leaf value, this expects the enum to consist only of the tag value. + fn enum_tag(&mut self, enum_value: &Values) -> ValueId { + match enum_value { + Tree::Branch(values) => self.enum_tag(&values[0]), + Tree::Leaf(value) => value.clone().eval(self), + } + } + + /// Bind the given variable ids to each argument of the given enum, using the + /// variant at the given variant index. Note that this function makes assumptions that the + /// representation of an enum is: + /// + /// ( + /// tag_value, + /// (field0_0, .. field0_N), // fields of variant 0, + /// (field1_0, .. field1_N), // fields of variant 1, + /// .., + /// (fieldM_0, .. fieldM_N), // fields of variant N, + /// ) + fn bind_case_arguments(&mut self, enum_value: Values, case: &MatchCase) { + if !case.arguments.is_empty() { + if case.constructor.is_enum() { + self.bind_enum_case_arguments(enum_value, case); + } else if case.constructor.is_tuple_or_struct() { + self.bind_tuple_or_struct_case_arguments(enum_value, case); + } + } + } + + fn bind_enum_case_arguments(&mut self, enum_value: Values, case: &MatchCase) { + let Tree::Branch(mut variants) = enum_value else { + unreachable!("Expected enum value to contain each variant"); + }; + + let variant_index = case.constructor.variant_index(); + + // variant_index + 1 to account for the extra tag value + let Tree::Branch(variant) = variants.swap_remove(variant_index + 1) else { + unreachable!("Expected enum variant to contain a tag and each variant's arguments"); + }; + + assert_eq!( + variant.len(), + case.arguments.len(), + "Expected enum variant to contain a value for each variant argument" + ); + + for (value, arg) in variant.into_iter().zip(&case.arguments) { + self.define(*arg, value); + } + } + + fn bind_tuple_or_struct_case_arguments(&mut self, struct_value: Values, case: &MatchCase) { + let Tree::Branch(fields) = struct_value else { + unreachable!("Expected struct value to contain each field"); + }; + + assert_eq!( + fields.len(), + case.arguments.len(), + "Expected field length to match constructor argument count" + ); + + for (value, arg) in fields.into_iter().zip(&case.arguments) { + self.define(*arg, value); + } + } + fn codegen_tuple(&mut self, tuple: &[Expression]) -> Result { Ok(Tree::Branch(try_vecmap(tuple, |expr| self.codegen_expression(expr))?)) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs index d7eeb64cc1be..145cff0a3415 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs @@ -46,6 +46,7 @@ pub enum StatementKind { Assign(AssignStatement), For(ForLoopStatement), Loop(Expression, Span /* loop keyword span */), + While(WhileStatement), Break, Continue, /// This statement should be executed at compile-time @@ -103,11 +104,8 @@ impl StatementKind { statement.add_semicolon(semi, span, last_statement_in_block, emit_error); StatementKind::Comptime(statement) } - // A semicolon on a for loop is optional and does nothing - StatementKind::For(_) => self, - - // A semicolon on a loop is optional and does nothing - StatementKind::Loop(..) => self, + // A semicolon on a for loop, loop or while is optional and does nothing + StatementKind::For(_) | StatementKind::Loop(..) | StatementKind::While(..) => self, // No semicolon needed for a resolved statement StatementKind::Interned(_) => self, @@ -119,6 +117,7 @@ impl StatementKind { | (ExpressionKind::Unsafe(..), semi, _) | (ExpressionKind::Interned(..), semi, _) | (ExpressionKind::InternedStatement(..), semi, _) + | (ExpressionKind::Match(..), semi, _) | (ExpressionKind::If(_), semi, _) => { if semi.is_some() { StatementKind::Semi(expr) @@ -856,6 +855,13 @@ pub struct ForLoopStatement { pub span: Span, } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct WhileStatement { + pub condition: Expression, + pub body: Expression, + pub while_keyword_span: Span, +} + impl Display for StatementKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -864,6 +870,9 @@ impl Display for StatementKind { StatementKind::Assign(assign) => assign.fmt(f), StatementKind::For(for_loop) => for_loop.fmt(f), StatementKind::Loop(block, _) => write!(f, "loop {}", block), + StatementKind::While(while_) => { + write!(f, "while {} {}", while_.condition, while_.body) + } StatementKind::Break => write!(f, "break"), StatementKind::Continue => write!(f, "continue"), StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs index 475e3ff1be9a..e5040463d174 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs @@ -69,9 +69,7 @@ pub struct TypeImpl { pub struct NoirTraitImpl { pub impl_generics: UnresolvedGenerics, - pub trait_name: Path, - - pub trait_generics: GenericTypeArgs, + pub r#trait: UnresolvedType, pub object_type: UnresolvedType, @@ -247,7 +245,7 @@ impl Display for NoirTraitImpl { )?; } - write!(f, " {}{} for {}", self.trait_name, self.trait_generics, self.object_type)?; + write!(f, " {} for {}", self.r#trait, self.object_type)?; if !self.where_clause.is_empty() { write!( f, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs index e40c534c3b96..7fe89489c3cf 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs @@ -162,25 +162,25 @@ pub trait Visitor { true } - fn visit_literal_array(&mut self, _: &ArrayLiteral) -> bool { + fn visit_literal_array(&mut self, _: &ArrayLiteral, _: Span) -> bool { true } - fn visit_literal_slice(&mut self, _: &ArrayLiteral) -> bool { + fn visit_literal_slice(&mut self, _: &ArrayLiteral, _: Span) -> bool { true } - fn visit_literal_bool(&mut self, _: bool) {} + fn visit_literal_bool(&mut self, _: bool, _: Span) {} - fn visit_literal_integer(&mut self, _value: FieldElement, _negative: bool) {} + fn visit_literal_integer(&mut self, _value: FieldElement, _negative: bool, _: Span) {} - fn visit_literal_str(&mut self, _: &str) {} + fn visit_literal_str(&mut self, _: &str, _: Span) {} - fn visit_literal_raw_str(&mut self, _: &str, _: u8) {} + fn visit_literal_raw_str(&mut self, _: &str, _: u8, _: Span) {} - fn visit_literal_fmt_str(&mut self, _: &[FmtStrFragment], _length: u32) {} + fn visit_literal_fmt_str(&mut self, _: &[FmtStrFragment], _length: u32, _: Span) {} - fn visit_literal_unit(&mut self) {} + fn visit_literal_unit(&mut self, _: Span) {} fn visit_block_expression(&mut self, _: &BlockExpression, _: Option) -> bool { true @@ -262,11 +262,11 @@ pub trait Visitor { true } - fn visit_array_literal(&mut self, _: &ArrayLiteral) -> bool { + fn visit_array_literal(&mut self, _: &ArrayLiteral, _: Span) -> bool { true } - fn visit_array_literal_standard(&mut self, _: &[Expression]) -> bool { + fn visit_array_literal_standard(&mut self, _: &[Expression], _: Span) -> bool { true } @@ -274,6 +274,7 @@ pub trait Visitor { &mut self, _repeated_element: &Expression, _length: &Expression, + _: Span, ) -> bool { true } @@ -310,6 +311,10 @@ pub trait Visitor { true } + fn visit_while_statement(&mut self, _condition: &Expression, _body: &Expression) -> bool { + true + } + fn visit_comptime_statement(&mut self, _: &Statement) -> bool { true } @@ -598,7 +603,7 @@ impl NoirTraitImpl { } pub fn accept_children(&self, visitor: &mut impl Visitor) { - self.trait_name.accept(visitor); + self.r#trait.accept(visitor); self.object_type.accept(visitor); for item in &self.items { @@ -923,28 +928,32 @@ impl Expression { impl Literal { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { if visitor.visit_literal(self, span) { - self.accept_children(visitor); + self.accept_children(span, visitor); } } - pub fn accept_children(&self, visitor: &mut impl Visitor) { + pub fn accept_children(&self, span: Span, visitor: &mut impl Visitor) { match self { Literal::Array(array_literal) => { - if visitor.visit_literal_array(array_literal) { - array_literal.accept(visitor); + if visitor.visit_literal_array(array_literal, span) { + array_literal.accept(span, visitor); } } Literal::Slice(array_literal) => { - if visitor.visit_literal_slice(array_literal) { - array_literal.accept(visitor); + if visitor.visit_literal_slice(array_literal, span) { + array_literal.accept(span, visitor); } } - Literal::Bool(value) => visitor.visit_literal_bool(*value), - Literal::Integer(value, negative) => visitor.visit_literal_integer(*value, *negative), - Literal::Str(str) => visitor.visit_literal_str(str), - Literal::RawStr(str, length) => visitor.visit_literal_raw_str(str, *length), - Literal::FmtStr(fragments, length) => visitor.visit_literal_fmt_str(fragments, *length), - Literal::Unit => visitor.visit_literal_unit(), + Literal::Bool(value) => visitor.visit_literal_bool(*value, span), + Literal::Integer(value, negative) => { + visitor.visit_literal_integer(*value, *negative, span); + } + Literal::Str(str) => visitor.visit_literal_str(str, span), + Literal::RawStr(str, length) => visitor.visit_literal_raw_str(str, *length, span), + Literal::FmtStr(fragments, length) => { + visitor.visit_literal_fmt_str(fragments, *length, span); + } + Literal::Unit => visitor.visit_literal_unit(span), } } } @@ -1116,21 +1125,21 @@ impl Lambda { } impl ArrayLiteral { - pub fn accept(&self, visitor: &mut impl Visitor) { - if visitor.visit_array_literal(self) { - self.accept_children(visitor); + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_array_literal(self, span) { + self.accept_children(span, visitor); } } - pub fn accept_children(&self, visitor: &mut impl Visitor) { + pub fn accept_children(&self, span: Span, visitor: &mut impl Visitor) { match self { ArrayLiteral::Standard(expressions) => { - if visitor.visit_array_literal_standard(expressions) { + if visitor.visit_array_literal_standard(expressions, span) { visit_expressions(expressions, visitor); } } ArrayLiteral::Repeated { repeated_element, length } => { - if visitor.visit_array_literal_repeated(repeated_element, length) { + if visitor.visit_array_literal_repeated(repeated_element, length, span) { repeated_element.accept(visitor); length.accept(visitor); } @@ -1165,6 +1174,12 @@ impl Statement { block.accept(visitor); } } + StatementKind::While(while_) => { + if visitor.visit_while_statement(&while_.condition, &while_.body) { + while_.condition.accept(visitor); + while_.body.accept(visitor); + } + } StatementKind::Comptime(statement) => { if visitor.visit_comptime_statement(statement) { statement.accept(visitor); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs index c13c74f44cb3..b78787249cb7 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -397,8 +397,7 @@ impl<'context> Elaborator<'context> { generated_items.trait_impls.push(UnresolvedTraitImpl { file_id: self.file, module_id: self.local_module, - trait_generics: trait_impl.trait_generics, - trait_path: trait_impl.trait_name, + r#trait: trait_impl.r#trait, object_type: trait_impl.object_type, methods, generics: trait_impl.impl_generics, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs index 5153845a57c1..ffc2f0221209 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs @@ -1,16 +1,25 @@ -use iter_extended::vecmap; -use noirc_errors::Location; +use fxhash::FxHashMap as HashMap; +use iter_extended::{try_vecmap, vecmap}; +use noirc_errors::{Location, Span}; use crate::{ - ast::{EnumVariant, FunctionKind, NoirEnumeration, UnresolvedType, Visibility}, + ast::{ + EnumVariant, Expression, ExpressionKind, FunctionKind, Literal, NoirEnumeration, + StatementKind, UnresolvedType, Visibility, + }, + elaborator::path_resolution::PathResolutionItem, + hir::{comptime::Value, resolution::errors::ResolverError, type_check::TypeCheckError}, hir_def::{ - expr::{HirEnumConstructorExpression, HirExpression, HirIdent}, + expr::{ + Case, Constructor, HirBlockExpression, HirEnumConstructorExpression, HirExpression, + HirIdent, HirMatch, SignedField, + }, function::{FuncMeta, FunctionBody, HirFunction, Parameters}, - stmt::HirPattern, + stmt::{HirLetStatement, HirPattern, HirStatement}, }, - node_interner::{DefinitionKind, ExprId, FunctionModifiers, GlobalValue, TypeId}, + node_interner::{DefinitionId, DefinitionKind, ExprId, FunctionModifiers, GlobalValue, TypeId}, token::Attributes, - DataType, Shared, Type, + DataType, Kind, Shared, Type, }; use super::Elaborator; @@ -243,4 +252,676 @@ impl Elaborator<'_> { (pattern, parameter_type, Visibility::Private) })) } + + /// To elaborate the rules of a match we need to go through the pattern first to define all + /// the variables within, then compile the corresponding branch. For each branch we do this + /// way we'll need to keep a distinct scope so that branches cannot access the pattern + /// variables from other branches. + /// + /// Returns (rows, result type) where rows is a pattern matrix used to compile the + /// match into a decision tree. + pub(super) fn elaborate_match_rules( + &mut self, + variable_to_match: DefinitionId, + rules: Vec<(Expression, Expression)>, + ) -> (Vec, Type) { + let result_type = self.interner.next_type_variable(); + let expected_pattern_type = self.interner.definition_type(variable_to_match); + + let rows = vecmap(rules, |(pattern, branch)| { + self.push_scope(); + let pattern = self.expression_to_pattern(pattern, &expected_pattern_type); + let columns = vec![Column::new(variable_to_match, pattern)]; + + let guard = None; + let body_span = branch.span; + let (body, body_type) = self.elaborate_expression(branch); + + self.unify(&body_type, &result_type, || TypeCheckError::TypeMismatch { + expected_typ: result_type.to_string(), + expr_typ: body_type.to_string(), + expr_span: body_span, + }); + + self.pop_scope(); + Row::new(columns, guard, body) + }); + (rows, result_type) + } + + /// Convert an expression into a Pattern, defining any variables within. + fn expression_to_pattern(&mut self, expression: Expression, expected_type: &Type) -> Pattern { + let expr_span = expression.span; + let unify_with_expected_type = |this: &mut Self, actual| { + this.unify(actual, expected_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: actual.to_string(), + expr_span, + }); + }; + + match expression.kind { + ExpressionKind::Literal(Literal::Integer(value, negative)) => { + let actual = self.interner.next_type_variable_with_kind(Kind::IntegerOrField); + unify_with_expected_type(self, &actual); + Pattern::Int(SignedField::new(value, negative)) + } + ExpressionKind::Literal(Literal::Bool(value)) => { + unify_with_expected_type(self, &Type::Bool); + let constructor = if value { Constructor::True } else { Constructor::False }; + Pattern::Constructor(constructor, Vec::new()) + } + ExpressionKind::Variable(path) => { + // A variable can be free or bound if it refers to an enum constant: + // - in `(a, b)`, both variables may be free and should be defined, or + // may refer to an enum variant named `a` or `b` in scope. + // - Possible diagnostics improvement: warn if `a` is defined as a variable + // when there is a matching enum variant with name `Foo::a` which can + // be imported. The user likely intended to reference the enum variant. + let path_len = path.segments.len(); + let location = Location::new(path.span(), self.file); + let last_ident = path.last_ident(); + + match self.resolve_path_or_error(path) { + Ok(resolution) => self.path_resolution_to_constructor( + resolution, + Vec::new(), + expected_type, + location.span, + ), + Err(_) if path_len == 1 => { + // Define the variable + let kind = DefinitionKind::Local(None); + // TODO: `allow_shadowing` is false while I'm too lazy to add a check that we + // don't define the same name multiple times in one pattern. + let id = self.add_variable_decl(last_ident, false, false, true, kind).id; + self.interner.push_definition_type(id, expected_type.clone()); + Pattern::Binding(id) + } + Err(error) => { + self.push_err(error); + // Default to defining a variable of the same name although this could + // cause further match warnings/errors (e.g. redundant cases). + let id = self.fresh_match_variable(expected_type.clone(), location); + Pattern::Binding(id) + } + } + } + ExpressionKind::Call(call) => { + self.expression_to_constructor(*call.func, call.arguments, expected_type) + } + ExpressionKind::Constructor(_) => todo!("handle constructors"), + ExpressionKind::Tuple(fields) => { + let field_types = vecmap(0..fields.len(), |_| self.interner.next_type_variable()); + let actual = Type::Tuple(field_types.clone()); + unify_with_expected_type(self, &actual); + + let fields = vecmap(fields.into_iter().enumerate(), |(i, field)| { + let expected = field_types.get(i).unwrap_or(&Type::Error); + self.expression_to_pattern(field, expected) + }); + + Pattern::Constructor(Constructor::Tuple(field_types.clone()), fields) + } + + ExpressionKind::Parenthesized(expr) => self.expression_to_pattern(*expr, expected_type), + ExpressionKind::Interned(id) => { + let kind = self.interner.get_expression_kind(id); + let expr = Expression::new(kind.clone(), expression.span); + self.expression_to_pattern(expr, expected_type) + } + ExpressionKind::InternedStatement(id) => { + if let StatementKind::Expression(expr) = self.interner.get_statement_kind(id) { + self.expression_to_pattern(expr.clone(), expected_type) + } else { + panic!("Invalid expr kind {expression}") + } + } + + ExpressionKind::Literal(_) + | ExpressionKind::Block(_) + | ExpressionKind::Prefix(_) + | ExpressionKind::Index(_) + | ExpressionKind::MethodCall(_) + | ExpressionKind::MemberAccess(_) + | ExpressionKind::Cast(_) + | ExpressionKind::Infix(_) + | ExpressionKind::If(_) + | ExpressionKind::Match(_) + | ExpressionKind::Constrain(_) + | ExpressionKind::Lambda(_) + | ExpressionKind::Quote(_) + | ExpressionKind::Unquote(_) + | ExpressionKind::Comptime(_, _) + | ExpressionKind::Unsafe(_, _) + | ExpressionKind::AsTraitPath(_) + | ExpressionKind::TypePath(_) + | ExpressionKind::Resolved(_) + | ExpressionKind::Error => { + panic!("Invalid expr kind {expression}") + } + } + } + + fn expression_to_constructor( + &mut self, + name: Expression, + args: Vec, + expected_type: &Type, + ) -> Pattern { + match name.kind { + ExpressionKind::Variable(path) => { + let span = path.span(); + let location = Location::new(span, self.file); + + match self.resolve_path_or_error(path) { + Ok(resolution) => { + self.path_resolution_to_constructor(resolution, args, expected_type, span) + } + Err(error) => { + self.push_err(error); + let id = self.fresh_match_variable(expected_type.clone(), location); + Pattern::Binding(id) + } + } + } + ExpressionKind::Parenthesized(expr) => { + self.expression_to_constructor(*expr, args, expected_type) + } + ExpressionKind::Interned(id) => { + let kind = self.interner.get_expression_kind(id); + let expr = Expression::new(kind.clone(), name.span); + self.expression_to_constructor(expr, args, expected_type) + } + ExpressionKind::InternedStatement(id) => { + if let StatementKind::Expression(expr) = self.interner.get_statement_kind(id) { + self.expression_to_constructor(expr.clone(), args, expected_type) + } else { + panic!("Invalid expr kind {name}") + } + } + other => todo!("invalid constructor `{other}`"), + } + } + + fn path_resolution_to_constructor( + &mut self, + name: PathResolutionItem, + args: Vec, + expected_type: &Type, + span: Span, + ) -> Pattern { + let (actual_type, expected_arg_types, variant_index) = match name { + PathResolutionItem::Global(id) => { + // variant constant + let global = self.interner.get_global(id); + let variant_index = match global.value { + GlobalValue::Resolved(Value::Enum(tag, ..)) => tag, + _ => todo!("Value is not an enum constant"), + }; + + let global_type = self.interner.definition_type(global.definition_id); + let actual_type = global_type.instantiate(self.interner).0; + (actual_type, Vec::new(), variant_index) + } + PathResolutionItem::Method(_type_id, _type_turbofish, func_id) => { + // TODO(#7430): Take type_turbofish into account when instantiating the function's type + let meta = self.interner.function_meta(&func_id); + let Some(variant_index) = meta.enum_variant_index else { todo!("not a variant") }; + + let (actual_type, expected_arg_types) = match meta.typ.instantiate(self.interner).0 + { + Type::Function(args, ret, _env, _) => (*ret, args), + other => unreachable!("Not a function! Found {other}"), + }; + + (actual_type, expected_arg_types, variant_index) + } + PathResolutionItem::Module(_) => todo!("path_resolution_to_constructor {name:?}"), + PathResolutionItem::Type(_) => todo!("path_resolution_to_constructor {name:?}"), + PathResolutionItem::TypeAlias(_) => todo!("path_resolution_to_constructor {name:?}"), + PathResolutionItem::Trait(_) => todo!("path_resolution_to_constructor {name:?}"), + PathResolutionItem::ModuleFunction(_) => { + todo!("path_resolution_to_constructor {name:?}") + } + PathResolutionItem::TypeAliasFunction(_, _, _) => { + todo!("path_resolution_to_constructor {name:?}") + } + PathResolutionItem::TraitFunction(_, _, _) => { + todo!("path_resolution_to_constructor {name:?}") + } + }; + + // We must unify the actual type before `expected_arg_types` are used since those + // are instantiated and rely on this already being unified. + self.unify(&actual_type, expected_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: actual_type.to_string(), + expr_span: span, + }); + + if args.len() != expected_arg_types.len() { + // error expected N args, found M? + } + + let args = args.into_iter().zip(expected_arg_types); + let args = vecmap(args, |(arg, expected_arg_type)| { + self.expression_to_pattern(arg, &expected_arg_type) + }); + let constructor = Constructor::Variant(actual_type, variant_index); + Pattern::Constructor(constructor, args) + } + + /// Compiles the rows of a match expression, outputting a decision tree for the match. + /// + /// This is an adaptation of https://github.com/yorickpeterse/pattern-matching-in-rust/tree/main/jacobs2021 + /// which is an implementation of https://julesjacobs.com/notes/patternmatching/patternmatching.pdf + pub(super) fn elaborate_match_rows(&mut self, rows: Vec) -> HirMatch { + self.compile_rows(rows).unwrap_or_else(|error| { + self.push_err(error); + HirMatch::Failure + }) + } + + fn compile_rows(&mut self, mut rows: Vec) -> Result { + if rows.is_empty() { + eprintln!("Warning: missing case"); + return Ok(HirMatch::Failure); + } + + self.push_tests_against_bare_variables(&mut rows); + + // If the first row is a match-all we match it and the remaining rows are ignored. + if rows.first().map_or(false, |row| row.columns.is_empty()) { + let row = rows.remove(0); + + return Ok(match row.guard { + None => HirMatch::Success(row.body), + Some(cond) => { + let remaining = self.compile_rows(rows)?; + HirMatch::Guard { cond, body: row.body, otherwise: Box::new(remaining) } + } + }); + } + + let branch_var = self.branch_variable(&rows); + let location = self.interner.definition(branch_var).location; + + match self.interner.definition_type(branch_var).follow_bindings_shallow().into_owned() { + Type::FieldElement | Type::Integer(_, _) => { + let (cases, fallback) = self.compile_int_cases(rows, branch_var)?; + Ok(HirMatch::Switch(branch_var, cases, Some(fallback))) + } + Type::TypeVariable(typevar) if typevar.is_integer_or_field() => { + let (cases, fallback) = self.compile_int_cases(rows, branch_var)?; + Ok(HirMatch::Switch(branch_var, cases, Some(fallback))) + } + + Type::Array(_, _) => todo!(), + Type::Slice(_) => todo!(), + Type::Bool => { + let cases = vec![ + (Constructor::False, Vec::new(), Vec::new()), + (Constructor::True, Vec::new(), Vec::new()), + ]; + + let (cases, fallback) = self.compile_constructor_cases(rows, branch_var, cases)?; + Ok(HirMatch::Switch(branch_var, cases, fallback)) + } + Type::Unit => { + let cases = vec![(Constructor::Unit, Vec::new(), Vec::new())]; + let (cases, fallback) = self.compile_constructor_cases(rows, branch_var, cases)?; + Ok(HirMatch::Switch(branch_var, cases, fallback)) + } + Type::Tuple(fields) => { + let field_variables = self.fresh_match_variables(fields.clone(), location); + let cases = vec![(Constructor::Tuple(fields), field_variables, Vec::new())]; + let (cases, fallback) = self.compile_constructor_cases(rows, branch_var, cases)?; + Ok(HirMatch::Switch(branch_var, cases, fallback)) + } + Type::DataType(type_def, generics) => { + let def = type_def.borrow(); + if let Some(variants) = def.get_variants(&generics) { + drop(def); + let typ = Type::DataType(type_def, generics); + + let cases = vecmap(variants.iter().enumerate(), |(idx, (_name, args))| { + let constructor = Constructor::Variant(typ.clone(), idx); + let args = self.fresh_match_variables(args.clone(), location); + (constructor, args, Vec::new()) + }); + + let (cases, fallback) = + self.compile_constructor_cases(rows, branch_var, cases)?; + Ok(HirMatch::Switch(branch_var, cases, fallback)) + } else if let Some(fields) = def.get_fields(&generics) { + drop(def); + let typ = Type::DataType(type_def, generics); + + // Just treat structs as a single-variant type + let fields = vecmap(fields, |(_name, typ)| typ); + let constructor = Constructor::Variant(typ, 0); + let field_variables = self.fresh_match_variables(fields, location); + let cases = vec![(constructor, field_variables, Vec::new())]; + let (cases, fallback) = + self.compile_constructor_cases(rows, branch_var, cases)?; + Ok(HirMatch::Switch(branch_var, cases, fallback)) + } else { + drop(def); + let typ = Type::DataType(type_def, generics); + todo!("Cannot match on type {typ}") + } + } + typ @ (Type::Alias(_, _) + | Type::TypeVariable(_) + | Type::String(_) + | Type::FmtString(_, _) + | Type::TraitAsType(_, _, _) + | Type::NamedGeneric(_, _) + | Type::CheckedCast { .. } + | Type::Function(_, _, _, _) + | Type::MutableReference(_) + | Type::Forall(_, _) + | Type::Constant(_, _) + | Type::Quoted(_) + | Type::InfixExpr(_, _, _, _) + | Type::Error) => todo!("Cannot match on type {typ:?}"), + } + } + + fn fresh_match_variables( + &mut self, + variable_types: Vec, + location: Location, + ) -> Vec { + vecmap(variable_types, |typ| self.fresh_match_variable(typ, location)) + } + + fn fresh_match_variable(&mut self, variable_type: Type, location: Location) -> DefinitionId { + let name = "internal_match_variable".to_string(); + let kind = DefinitionKind::Local(None); + let id = self.interner.push_definition(name, false, false, kind, location); + self.interner.push_definition_type(id, variable_type); + id + } + + /// Compiles the cases and fallback cases for integer and range patterns. + /// + /// Integers have an infinite number of constructors, so we specialise the + /// compilation of integer and range patterns. + fn compile_int_cases( + &mut self, + rows: Vec, + branch_var: DefinitionId, + ) -> Result<(Vec, Box), ResolverError> { + let mut raw_cases: Vec<(Constructor, Vec, Vec)> = Vec::new(); + let mut fallback_rows = Vec::new(); + let mut tested: HashMap<(SignedField, SignedField), usize> = HashMap::default(); + + for mut row in rows { + if let Some(col) = row.remove_column(branch_var) { + let (key, cons) = match col.pattern { + Pattern::Int(val) => ((val, val), Constructor::Int(val)), + Pattern::Range(start, stop) => ((start, stop), Constructor::Range(start, stop)), + pattern => { + eprintln!("Unexpected pattern for integer type: {pattern:?}"); + continue; + } + }; + + if let Some(index) = tested.get(&key) { + raw_cases[*index].2.push(row); + continue; + } + + tested.insert(key, raw_cases.len()); + + let mut rows = fallback_rows.clone(); + + rows.push(row); + raw_cases.push((cons, Vec::new(), rows)); + } else { + for (_, _, rows) in &mut raw_cases { + rows.push(row.clone()); + } + + fallback_rows.push(row); + } + } + + let cases = try_vecmap(raw_cases, |(cons, vars, rows)| { + let rows = self.compile_rows(rows)?; + Ok::<_, ResolverError>(Case::new(cons, vars, rows)) + })?; + + Ok((cases, Box::new(self.compile_rows(fallback_rows)?))) + } + + /// Compiles the cases and sub cases for the constructor located at the + /// column of the branching variable. + /// + /// What exactly this method does may be a bit hard to understand from the + /// code, as there's simply quite a bit going on. Roughly speaking, it does + /// the following: + /// + /// 1. It takes the column we're branching on (based on the branching + /// variable) and removes it from every row. + /// 2. We add additional columns to this row, if the constructor takes any + /// arguments (which we'll handle in a nested match). + /// 3. We turn the resulting list of rows into a list of cases, then compile + /// those into decision (sub) trees. + /// + /// If a row didn't include the branching variable, we simply copy that row + /// into the list of rows for every constructor to test. + /// + /// For this to work, the `cases` variable must be prepared such that it has + /// a triple for every constructor we need to handle. For an ADT with 10 + /// constructors, that means 10 triples. This is needed so this method can + /// assign the correct sub matches to these constructors. + /// + /// Types with infinite constructors (e.g. int and string) are handled + /// separately; they don't need most of this work anyway. + fn compile_constructor_cases( + &mut self, + rows: Vec, + branch_var: DefinitionId, + mut cases: Vec<(Constructor, Vec, Vec)>, + ) -> Result<(Vec, Option>), ResolverError> { + for mut row in rows { + if let Some(col) = row.remove_column(branch_var) { + if let Pattern::Constructor(cons, args) = col.pattern { + let idx = cons.variant_index(); + let mut cols = row.columns; + + for (var, pat) in cases[idx].1.iter().zip(args.into_iter()) { + cols.push(Column::new(*var, pat)); + } + + cases[idx].2.push(Row::new(cols, row.guard, row.body)); + } + } else { + for (_, _, rows) in &mut cases { + rows.push(row.clone()); + } + } + } + + let cases = try_vecmap(cases, |(cons, vars, rows)| { + let rows = self.compile_rows(rows)?; + Ok::<_, ResolverError>(Case::new(cons, vars, rows)) + })?; + + Ok(Self::deduplicate_cases(cases)) + } + + /// Move any cases with duplicate branches into a shared 'else' branch + fn deduplicate_cases(mut cases: Vec) -> (Vec, Option>) { + let mut else_case = None; + let mut ending_cases = Vec::with_capacity(cases.len()); + let mut previous_case: Option = None; + + // Go through each of the cases, looking for duplicates. + // This is simplified such that the first (consecutive) duplicates + // we find we move to an else case. Each case afterward is then compared + // to the else case. This could be improved in a couple ways: + // - Instead of the the first consecutive duplicates we find, we could + // expand the check to find non-consecutive duplicates as well. + // - We should also ideally move the most duplicated case to the else + // case, not just the first duplicated case we find. I suspect in most + // actual code snippets these are the same but it could still be nice to guarantee. + while let Some(case) = cases.pop() { + if let Some(else_case) = &else_case { + if case.body == *else_case { + // Delete the current case by not pushing it to `ending_cases` + continue; + } else { + ending_cases.push(case); + } + } else if let Some(previous) = previous_case { + if case.body == previous.body { + // else_case is known to be None here + else_case = Some(previous.body); + + // Delete both previous_case and case + previous_case = None; + continue; + } else { + previous_case = Some(case); + ending_cases.push(previous); + } + } else { + previous_case = Some(case); + } + } + + if let Some(case) = previous_case { + ending_cases.push(case); + } + + ending_cases.reverse(); + (ending_cases, else_case.map(Box::new)) + } + + /// Return the variable that was referred to the most in `rows` + fn branch_variable(&mut self, rows: &[Row]) -> DefinitionId { + let mut counts = HashMap::default(); + + for row in rows { + for col in &row.columns { + *counts.entry(&col.variable_to_match).or_insert(0_usize) += 1; + } + } + + rows[0] + .columns + .iter() + .map(|col| col.variable_to_match) + .max_by_key(|var| counts[var]) + .unwrap() + } + + fn push_tests_against_bare_variables(&mut self, rows: &mut Vec) { + for row in rows { + row.columns.retain(|col| { + if let Pattern::Binding(variable) = col.pattern { + row.body = self.let_binding(variable, col.variable_to_match, row.body); + false + } else { + true + } + }); + } + } + + /// Creates: + /// `{ let = ; }` + fn let_binding(&mut self, variable: DefinitionId, rhs: DefinitionId, body: ExprId) -> ExprId { + let location = self.interner.definition(rhs).location; + + let r#type = self.interner.definition_type(variable); + let rhs_type = self.interner.definition_type(rhs); + let variable = HirIdent::non_trait_method(variable, location); + + let rhs = HirExpression::Ident(HirIdent::non_trait_method(rhs, location), None); + let rhs = self.interner.push_expr(rhs); + self.interner.push_expr_type(rhs, rhs_type); + self.interner.push_expr_location(rhs, location.span, location.file); + + let let_ = HirStatement::Let(HirLetStatement { + pattern: HirPattern::Identifier(variable), + r#type, + expression: rhs, + attributes: Vec::new(), + comptime: false, + is_global_let: false, + }); + + let body_type = self.interner.id_type(body); + let let_ = self.interner.push_stmt(let_); + let body = self.interner.push_stmt(HirStatement::Expression(body)); + + self.interner.push_stmt_location(let_, location.span, location.file); + self.interner.push_stmt_location(body, location.span, location.file); + + let block = HirExpression::Block(HirBlockExpression { statements: vec![let_, body] }); + let block = self.interner.push_expr(block); + self.interner.push_expr_type(block, body_type); + self.interner.push_expr_location(block, location.span, location.file); + block + } +} + +/// A Pattern is anything that can appear before the `=>` in a match rule. +#[derive(Debug, Clone)] +enum Pattern { + /// A pattern checking for a tag and possibly binding variables such as `Some(42)` + Constructor(Constructor, Vec), + /// An integer literal pattern such as `4`, `12345`, or `-56` + Int(SignedField), + /// A pattern binding a variable such as `a` or `_` + Binding(DefinitionId), + + /// Multiple patterns combined with `|` where we should match this pattern if any + /// constituent pattern matches. e.g. `Some(3) | None` or `Some(1) | Some(2) | None` + #[allow(unused)] + Or(Vec), + + /// An integer range pattern such as `1..20` which will match any integer n such that + /// 1 <= n < 20. + #[allow(unused)] + Range(SignedField, SignedField), +} + +#[derive(Clone)] +struct Column { + variable_to_match: DefinitionId, + pattern: Pattern, +} + +impl Column { + fn new(variable_to_match: DefinitionId, pattern: Pattern) -> Self { + Column { variable_to_match, pattern } + } +} + +#[derive(Clone)] +pub(super) struct Row { + columns: Vec, + guard: Option, + body: ExprId, +} + +impl Row { + fn new(columns: Vec, guard: Option, body: ExprId) -> Row { + Row { columns, guard, body } + } +} + +impl Row { + fn remove_column(&mut self, variable: DefinitionId) -> Option { + self.columns + .iter() + .position(|c| c.variable_to_match == variable) + .map(|idx| self.columns.remove(idx)) + } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs index 0ca1d99f4d2f..18d5e3be82ed 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -25,10 +25,12 @@ use crate::{ HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }, - stmt::HirStatement, + stmt::{HirLetStatement, HirPattern, HirStatement}, traits::{ResolvedTraitBound, TraitConstraint}, }, - node_interner::{DefinitionKind, ExprId, FuncId, InternedStatementKind, TraitMethodId}, + node_interner::{ + DefinitionId, DefinitionKind, ExprId, FuncId, InternedStatementKind, StmtId, TraitMethodId, + }, token::{FmtStrFragment, Tokens}, DataType, Kind, QuotedType, Shared, Type, }; @@ -60,7 +62,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), ExpressionKind::If(if_) => self.elaborate_if(*if_, target_type), - ExpressionKind::Match(match_) => self.elaborate_match(*match_), + ExpressionKind::Match(match_) => self.elaborate_match(*match_, expr.span), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple, target_type), ExpressionKind::Lambda(lambda) => { @@ -1015,8 +1017,39 @@ impl<'context> Elaborator<'context> { (HirExpression::If(if_expr), ret_type) } - fn elaborate_match(&mut self, _match_expr: MatchExpression) -> (HirExpression, Type) { - (HirExpression::Error, Type::Error) + fn elaborate_match( + &mut self, + match_expr: MatchExpression, + span: Span, + ) -> (HirExpression, Type) { + let (expression, typ) = self.elaborate_expression(match_expr.expression); + let (let_, variable) = self.wrap_in_let(expression, typ); + + let (rows, result_type) = self.elaborate_match_rules(variable, match_expr.rules); + let tree = HirExpression::Match(self.elaborate_match_rows(rows)); + let tree = self.interner.push_expr(tree); + self.interner.push_expr_type(tree, result_type.clone()); + self.interner.push_expr_location(tree, span, self.file); + + let tree = self.interner.push_stmt(HirStatement::Expression(tree)); + self.interner.push_stmt_location(tree, span, self.file); + + let block = HirExpression::Block(HirBlockExpression { statements: vec![let_, tree] }); + (block, result_type) + } + + fn wrap_in_let(&mut self, expr_id: ExprId, typ: Type) -> (StmtId, DefinitionId) { + let location = self.interner.expr_location(&expr_id); + let name = "internal variable".to_string(); + let definition = DefinitionKind::Local(None); + let variable = self.interner.push_definition(name, false, false, definition, location); + self.interner.push_definition_type(variable, typ.clone()); + + let pattern = HirPattern::Identifier(HirIdent::non_trait_method(variable, location)); + let let_ = HirStatement::Let(HirLetStatement::basic(pattern, typ, expr_id)); + let let_ = self.interner.push_stmt(let_); + self.interner.push_stmt_location(let_, location.span, location.file); + (let_, variable) } fn elaborate_tuple( diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs index 7910d8cebdb2..3a5844ab21d2 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs @@ -6,7 +6,7 @@ use crate::{ type_check::TypeCheckError, }, hir_def::{ - expr::{HirBlockExpression, HirExpression, HirIdent, HirLiteral}, + expr::{HirBlockExpression, HirExpression, HirIdent, HirLiteral, HirMatch}, function::FuncMeta, stmt::HirStatement, }, @@ -283,6 +283,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). HirStatement::For(e) => check(e.start_range) && check(e.end_range), HirStatement::Loop(e) => check(e), + HirStatement::While(condition, block) => check(condition) && check(block), HirStatement::Comptime(_) | HirStatement::Break | HirStatement::Continue @@ -314,6 +315,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i HirExpression::If(e) => { check(e.condition) && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) } + HirExpression::Match(e) => can_return_without_recursing_match(interner, func_id, &e), HirExpression::Tuple(e) => e.iter().cloned().all(check), HirExpression::Unsafe(b) => check_block(b), // Rust doesn't check the lambda body (it might not be called). @@ -327,3 +329,22 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i | HirExpression::Error => true, } } + +fn can_return_without_recursing_match( + interner: &NodeInterner, + func_id: FuncId, + match_expr: &HirMatch, +) -> bool { + let check_match = |e| can_return_without_recursing_match(interner, func_id, e); + let check = |e| can_return_without_recursing(interner, func_id, e); + + match match_expr { + HirMatch::Success(expr) => check(*expr), + HirMatch::Failure => true, + HirMatch::Guard { cond: _, body, otherwise } => check(*body) && check_match(otherwise), + HirMatch::Switch(_, cases, otherwise) => { + cases.iter().all(|case| check_match(&case.body)) + && otherwise.as_ref().map_or(true, |case| check_match(case)) + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs index a8e722a92056..d007bee0d8d3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs @@ -11,14 +11,15 @@ use crate::{ }, graph::CrateId, hir::{ - def_collector::dc_crate::{ - filter_literal_globals, CollectedItems, CompilationError, ImplMap, UnresolvedEnum, - UnresolvedFunctions, UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, - UnresolvedTypeAlias, + def_collector::{ + dc_crate::{ + filter_literal_globals, CollectedItems, CompilationError, ImplMap, UnresolvedEnum, + UnresolvedFunctions, UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, + UnresolvedTypeAlias, + }, + errors::DefCollectorErrorKind, }, - def_collector::errors::DefCollectorErrorKind, - def_map::{DefMaps, ModuleData}, - def_map::{LocalModuleId, ModuleId, MAIN_FUNCTION}, + def_map::{DefMaps, LocalModuleId, ModuleData, ModuleId, MAIN_FUNCTION}, resolution::errors::ResolverError, scope::ScopeForest as GenericScopeForest, type_check::{generics::TraitGenerics, TypeCheckError}, @@ -1458,9 +1459,26 @@ impl<'context> Elaborator<'context> { } let trait_generics = trait_impl.resolved_trait_generics.clone(); + let ident = match &trait_impl.r#trait.typ { + UnresolvedTypeData::Named(trait_path, _, _) => trait_path.last_ident(), + UnresolvedTypeData::Resolved(quoted_type_id) => { + let typ = self.interner.get_quoted_type(*quoted_type_id); + let name = if let Type::TraitAsType(_, name, _) = typ { + name.to_string() + } else { + typ.to_string() + }; + Ident::new(name, trait_impl.r#trait.span) + } + _ => { + // We don't error in this case because an error will be produced later on when + // solving the trait impl trait type + Ident::new(trait_impl.r#trait.to_string(), trait_impl.r#trait.span) + } + }; let resolved_trait_impl = Shared::new(TraitImpl { - ident: trait_impl.trait_path.last_ident(), + ident, typ: self_type.clone(), trait_id, trait_generics, @@ -1983,7 +2001,49 @@ impl<'context> Elaborator<'context> { self.file = trait_impl.file_id; self.local_module = trait_impl.module_id; - let trait_id = self.resolve_trait_by_path(trait_impl.trait_path.clone()); + let (trait_id, mut trait_generics, path_span) = match &trait_impl.r#trait.typ { + UnresolvedTypeData::Named(trait_path, trait_generics, _) => { + let trait_id = self.resolve_trait_by_path(trait_path.clone()); + (trait_id, trait_generics.clone(), trait_path.span) + } + UnresolvedTypeData::Resolved(quoted_type_id) => { + let typ = self.interner.get_quoted_type(*quoted_type_id); + let span = trait_impl.r#trait.span; + let Type::TraitAsType(trait_id, _, trait_generics) = typ else { + let found = typ.to_string(); + self.push_err(ResolverError::ExpectedTrait { span, found }); + continue; + }; + + // In order to take associated types into account we turn these resolved generics + // into unresolved ones, but ones that point to solved types. + let trait_id = *trait_id; + let trait_generics = trait_generics.clone(); + let trait_generics = GenericTypeArgs { + ordered_args: vecmap(&trait_generics.ordered, |typ| { + let quoted_type_id = self.interner.push_quoted_type(typ.clone()); + let typ = UnresolvedTypeData::Resolved(quoted_type_id); + UnresolvedType { typ, span } + }), + named_args: vecmap(&trait_generics.named, |named_type| { + let quoted_type_id = + self.interner.push_quoted_type(named_type.typ.clone()); + let typ = UnresolvedTypeData::Resolved(quoted_type_id); + (named_type.name.clone(), UnresolvedType { typ, span }) + }), + kinds: Vec::new(), + }; + + (Some(trait_id), trait_generics, span) + } + _ => { + let span = trait_impl.r#trait.span; + let found = trait_impl.r#trait.typ.to_string(); + self.push_err(ResolverError::ExpectedTrait { span, found }); + continue; + } + }; + trait_impl.trait_id = trait_id; let unresolved_type = trait_impl.object_type.clone(); @@ -2006,14 +2066,12 @@ impl<'context> Elaborator<'context> { method.def.where_clause.append(&mut trait_impl.where_clause.clone()); } - // Add each associated type to the list of named type arguments - let mut trait_generics = trait_impl.trait_generics.clone(); - trait_generics.named_args.extend(self.take_unresolved_associated_types(trait_impl)); - let impl_id = self.interner.next_trait_impl_id(); self.current_trait_impl = Some(impl_id); - let path_span = trait_impl.trait_path.span; + // Add each associated type to the list of named type arguments + trait_generics.named_args.extend(self.take_unresolved_associated_types(trait_impl)); + let (ordered_generics, named_generics) = trait_impl .trait_id .map(|trait_id| { @@ -2038,12 +2096,15 @@ impl<'context> Elaborator<'context> { self.generics.clear(); if let Some(trait_id) = trait_id { - let trait_name = trait_impl.trait_path.last_ident(); - self.interner.add_trait_reference( - trait_id, - Location::new(trait_name.span(), trait_impl.file_id), - trait_name.is_self_type_name(), - ); + let (span, is_self_type_name) = match &trait_impl.r#trait.typ { + UnresolvedTypeData::Named(trait_path, _, _) => { + let trait_name = trait_path.last_ident(); + (trait_name.span(), trait_name.is_self_type_name()) + } + _ => (trait_impl.r#trait.span, false), + }; + let location = Location::new(span, trait_impl.file_id); + self.interner.add_trait_reference(trait_id, location, is_self_type_name); } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs index c401646332f6..8b60a660a161 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs @@ -3,7 +3,7 @@ use noirc_errors::{Location, Span}; use crate::{ ast::{ AssignStatement, Expression, ForLoopStatement, ForRange, Ident, ItemVisibility, LValue, - LetStatement, Path, Statement, StatementKind, + LetStatement, Path, Statement, StatementKind, WhileStatement, }, hir::{ resolution::{ @@ -37,6 +37,7 @@ impl<'context> Elaborator<'context> { StatementKind::Assign(assign) => self.elaborate_assign(assign), StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), StatementKind::Loop(block, span) => self.elaborate_loop(block, span), + StatementKind::While(while_) => self.elaborate_while(while_), StatementKind::Break => self.elaborate_jump(true, statement.span), StatementKind::Continue => self.elaborate_jump(false, statement.span), StatementKind::Comptime(statement) => self.elaborate_comptime_statement(*statement), @@ -258,6 +259,35 @@ impl<'context> Elaborator<'context> { (statement, Type::Unit) } + pub(super) fn elaborate_while(&mut self, while_: WhileStatement) -> (HirStatement, Type) { + let in_constrained_function = self.in_constrained_function(); + if in_constrained_function { + self.push_err(ResolverError::WhileInConstrainedFn { span: while_.while_keyword_span }); + } + + let old_loop = std::mem::take(&mut self.current_loop); + self.current_loop = Some(Loop { is_for: false, has_break: false }); + self.push_scope(); + + let condition_span = while_.condition.span; + let (condition, cond_type) = self.elaborate_expression(while_.condition); + let (block, _block_type) = self.elaborate_expression(while_.body); + + self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expected_typ: Type::Bool.to_string(), + expr_typ: cond_type.to_string(), + expr_span: condition_span, + }); + + self.pop_scope(); + + std::mem::replace(&mut self.current_loop, old_loop).expect("Expected a loop"); + + let statement = HirStatement::While(condition, block); + + (statement, Type::Unit) + } + fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { let in_constrained_function = self.in_constrained_function(); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs index 0ccfaf59494c..aa2f7d99fab9 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -10,7 +10,7 @@ use crate::{ ForBounds, ForLoopStatement, ForRange, GenericTypeArgs, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, Pattern, PrefixExpression, Statement, - StatementKind, UnresolvedType, UnresolvedTypeData, + StatementKind, UnresolvedType, UnresolvedTypeData, WhileStatement, }, hir_def::traits::TraitConstraint, node_interner::{InternedStatementKind, NodeInterner}, @@ -766,6 +766,11 @@ fn remove_interned_in_statement_kind( StatementKind::Loop(block, span) => { StatementKind::Loop(remove_interned_in_expression(interner, block), span) } + StatementKind::While(while_) => StatementKind::While(WhileStatement { + condition: remove_interned_in_expression(interner, while_.condition), + body: remove_interned_in_expression(interner, while_.body), + while_keyword_span: while_.while_keyword_span, + }), StatementKind::Comptime(statement) => { StatementKind::Comptime(Box::new(remove_interned_in_statement(interner, *statement))) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs index db16a579c103..64297a240621 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -50,6 +50,10 @@ pub enum InterpreterError { typ: Type, location: Location, }, + NonBoolUsedInWhile { + typ: Type, + location: Location, + }, NonBoolUsedInConstrain { typ: Type, location: Location, @@ -285,6 +289,7 @@ impl InterpreterError { | InterpreterError::ErrorNodeEncountered { location, .. } | InterpreterError::NonFunctionCalled { location, .. } | InterpreterError::NonBoolUsedInIf { location, .. } + | InterpreterError::NonBoolUsedInWhile { location, .. } | InterpreterError::NonBoolUsedInConstrain { location, .. } | InterpreterError::FailingConstraint { location, .. } | InterpreterError::NoMethodFound { location, .. } @@ -413,6 +418,11 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = "If conditions must be a boolean value".to_string(); CustomDiagnostic::simple_error(msg, secondary, location.span) } + InterpreterError::NonBoolUsedInWhile { typ, location } => { + let msg = format!("Expected a `bool` but found `{typ}`"); + let secondary = "While conditions must be a boolean value".to_string(); + CustomDiagnostic::simple_error(msg, secondary, location.span) + } InterpreterError::NonBoolUsedInConstrain { typ, location } => { let msg = format!("Expected a `bool` but found `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) @@ -648,7 +658,7 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { } InterpreterError::GenericNameShouldBeAnIdent { name, location } => { let msg = - "Generic name needs to be a valid identifer (one word beginning with a letter)" + "Generic name needs to be a valid identifier (one word beginning with a letter)" .to_string(); let secondary = format!("`{name}` is not a valid identifier"); CustomDiagnostic::simple_error(msg, secondary, location.span) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index 3ba7ae429502..05a0435b4503 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -4,17 +4,17 @@ use noirc_errors::{Span, Spanned}; use crate::ast::{ ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainKind, ConstructorExpression, ExpressionKind, ForLoopStatement, ForRange, GenericTypeArgs, Ident, - IfExpression, IndexExpression, InfixExpression, LValue, Lambda, Literal, + IfExpression, IndexExpression, InfixExpression, LValue, Lambda, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, Path, PathKind, PathSegment, Pattern, - PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, WhileStatement, }; use crate::ast::{ConstrainExpression, Expression, Statement, StatementKind}; use crate::hir_def::expr::{ - HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent, HirLiteral, + Constructor, HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent, HirLiteral, HirMatch, }; use crate::hir_def::stmt::{HirLValue, HirPattern, HirStatement}; use crate::hir_def::types::{Type, TypeBinding}; -use crate::node_interner::{ExprId, NodeInterner, StmtId}; +use crate::node_interner::{DefinitionId, ExprId, NodeInterner, StmtId}; // TODO: // - Full path for idents & types @@ -46,6 +46,11 @@ impl HirStatement { span, }), HirStatement::Loop(block) => StatementKind::Loop(block.to_display_ast(interner), span), + HirStatement::While(condition, block) => StatementKind::While(WhileStatement { + condition: condition.to_display_ast(interner), + body: block.to_display_ast(interner), + while_keyword_span: span, + }), HirStatement::Break => StatementKind::Break, HirStatement::Continue => StatementKind::Continue, HirStatement::Expression(expr) => { @@ -77,19 +82,7 @@ impl HirExpression { pub fn to_display_ast(&self, interner: &NodeInterner, span: Span) -> Expression { let kind = match self { HirExpression::Ident(ident, generics) => { - let ident = ident.to_display_ast(interner); - let segment = PathSegment { - ident, - generics: generics.as_ref().map(|option| { - option.iter().map(|generic| generic.to_display_ast()).collect() - }), - span, - }; - - let path = - Path { segments: vec![segment], kind: crate::ast::PathKind::Plain, span }; - - ExpressionKind::Variable(path) + ident.to_display_expr(interner, generics, span) } HirExpression::Literal(HirLiteral::Array(array)) => { let array = array.to_display_ast(interner, span); @@ -190,6 +183,7 @@ impl HirExpression { consequence: if_expr.consequence.to_display_ast(interner), alternative: if_expr.alternative.map(|expr| expr.to_display_ast(interner)), })), + HirExpression::Match(match_expr) => match_expr.to_display_ast(interner, span), HirExpression::Tuple(fields) => { ExpressionKind::Tuple(vecmap(fields, |field| field.to_display_ast(interner))) } @@ -231,6 +225,95 @@ impl HirExpression { } } +impl HirMatch { + fn to_display_ast(&self, interner: &NodeInterner, span: Span) -> ExpressionKind { + match self { + HirMatch::Success(expr) => expr.to_display_ast(interner).kind, + HirMatch::Failure => ExpressionKind::Error, + HirMatch::Guard { cond, body, otherwise } => { + let condition = cond.to_display_ast(interner); + let consequence = body.to_display_ast(interner); + let alternative = + Some(Expression::new(otherwise.to_display_ast(interner, span), span)); + + ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative })) + } + HirMatch::Switch(variable, cases, default) => { + let location = interner.definition(*variable).location; + let ident = HirIdent::non_trait_method(*variable, location); + let expression = ident.to_display_expr(interner, &None, location.span); + let expression = Expression::new(expression, location.span); + + let mut rules = vecmap(cases, |case| { + let args = vecmap(&case.arguments, |arg| arg.to_display_ast(interner)); + let constructor = case.constructor.to_display_ast(args); + let constructor = Expression::new(constructor, span); + let branch = case.body.to_display_ast(interner, span); + (constructor, Expression::new(branch, span)) + }); + + if let Some(case) = default { + let kind = ExpressionKind::Variable(Path::from_single("_".to_string(), span)); + let pattern = Expression::new(kind, span); + let branch = Expression::new(case.to_display_ast(interner, span), span); + rules.push((pattern, branch)); + } + + ExpressionKind::Match(Box::new(MatchExpression { expression, rules })) + } + } + } +} + +impl DefinitionId { + fn to_display_ast(self, interner: &NodeInterner) -> Expression { + let location = interner.definition(self).location; + let kind = HirIdent::non_trait_method(self, location).to_display_expr( + interner, + &None, + location.span, + ); + Expression::new(kind, location.span) + } +} + +impl Constructor { + fn to_display_ast(&self, arguments: Vec) -> ExpressionKind { + match self { + Constructor::True => ExpressionKind::Literal(Literal::Bool(true)), + Constructor::False => ExpressionKind::Literal(Literal::Bool(false)), + Constructor::Unit => ExpressionKind::Literal(Literal::Unit), + Constructor::Int(value) => { + ExpressionKind::Literal(Literal::Integer(value.field, value.is_negative)) + } + Constructor::Tuple(_) => ExpressionKind::Tuple(arguments), + Constructor::Variant(typ, index) => { + let typ = typ.follow_bindings_shallow(); + let Type::DataType(def, _) = typ.as_ref() else { + return ExpressionKind::Error; + }; + + let Some(variants) = def.borrow().get_variants_as_written() else { + return ExpressionKind::Error; + }; + + let Some(name) = variants.get(*index).map(|variant| variant.name.clone()) else { + return ExpressionKind::Error; + }; + + let span = name.span(); + let name = ExpressionKind::Variable(Path::from_ident(name)); + let func = Box::new(Expression::new(name, span)); + let is_macro_call = false; + ExpressionKind::Call(Box::new(CallExpression { func, arguments, is_macro_call })) + } + Constructor::Range(_start, _end) => { + unreachable!("Range is unimplemented") + } + } + } +} + impl ExprId { /// Convert to AST for display (some details lost) pub fn to_display_ast(self, interner: &NodeInterner) -> Expression { @@ -282,6 +365,26 @@ impl HirIdent { let name = interner.definition_name(self.id).to_owned(); Ident(Spanned::from(self.location.span, name)) } + + fn to_display_expr( + &self, + interner: &NodeInterner, + generics: &Option>, + span: Span, + ) -> ExpressionKind { + let ident = self.to_display_ast(interner); + let segment = PathSegment { + ident, + generics: generics + .as_ref() + .map(|option| option.iter().map(|generic| generic.to_display_ast()).collect()), + span, + }; + + let path = Path { segments: vec![segment], kind: crate::ast::PathKind::Plain, span }; + + ExpressionKind::Variable(path) + } } impl Type { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 1096835ae5e9..1679c069bd56 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -535,6 +535,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { HirExpression::Constrain(constrain) => self.evaluate_constrain(constrain), HirExpression::Cast(cast) => self.evaluate_cast(&cast, id), HirExpression::If(if_) => self.evaluate_if(if_, id), + HirExpression::Match(match_) => todo!("Evaluate match in comptime code"), HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple), HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id), HirExpression::Quote(tokens) => self.evaluate_quote(tokens, id), @@ -1516,7 +1517,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let condition = match self.evaluate(if_.condition)? { Value::Bool(value) => value, value => { - let location = self.elaborator.interner.expr_location(&id); + let location = self.elaborator.interner.expr_location(&if_.condition); let typ = value.get_type().into_owned(); return Err(InterpreterError::NonBoolUsedInIf { typ, location }); } @@ -1571,6 +1572,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { HirStatement::Assign(assign) => self.evaluate_assign(assign), HirStatement::For(for_) => self.evaluate_for(for_), HirStatement::Loop(expression) => self.evaluate_loop(expression), + HirStatement::While(condition, block) => self.evaluate_while(condition, block), HirStatement::Break => self.evaluate_break(statement), HirStatement::Continue => self.evaluate_continue(statement), HirStatement::Expression(expression) => self.evaluate(expression), @@ -1808,6 +1810,55 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { result } + fn evaluate_while(&mut self, condition: ExprId, block: ExprId) -> IResult { + let was_in_loop = std::mem::replace(&mut self.in_loop, true); + let in_lsp = self.elaborator.interner.is_in_lsp_mode(); + let mut counter = 0; + let mut result = Ok(Value::Unit); + + loop { + let condition = match self.evaluate(condition)? { + Value::Bool(value) => value, + value => { + let location = self.elaborator.interner.expr_location(&condition); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonBoolUsedInWhile { typ, location }); + } + }; + if !condition { + break; + } + + self.push_scope(); + + let must_break = match self.evaluate(block) { + Ok(_) => false, + Err(InterpreterError::Break) => true, + Err(InterpreterError::Continue) => false, + Err(error) => { + result = Err(error); + true + } + }; + + self.pop_scope(); + + if must_break { + break; + } + + counter += 1; + if in_lsp && counter == 10_000 { + let location = self.elaborator.interner.expr_location(&block); + result = Err(InterpreterError::LoopHaltedForUiResponsiveness { location }); + break; + } + } + + self.in_loop = was_in_loop; + result + } + fn evaluate_break(&mut self, id: StmtId) -> IResult { if self.in_loop { Err(InterpreterError::Break) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs index f947cd761fcf..a6668eae1b0a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -519,6 +519,11 @@ impl Value { Value::UnresolvedType(typ) => { Token::InternedUnresolvedTypeData(interner.push_unresolved_type_data(typ)) } + Value::TraitConstraint(trait_id, generics) => { + let name = Rc::new(interner.get_trait(trait_id).name.0.contents.clone()); + let typ = Type::TraitAsType(trait_id, name, generics); + Token::QuotedType(interner.push_quoted_type(typ)) + } Value::TypedExpr(TypedExpr::ExprId(expr_id)) => Token::UnquoteMarker(expr_id), Value::U1(bool) => Token::Bool(bool), Value::U8(value) => Token::Int((value as u128).into()), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 73c6c5a5dd22..9d8c32fbc12d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -3,7 +3,7 @@ use super::errors::{DefCollectorErrorKind, DuplicateType}; use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir::comptime::InterpreterError; -use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId}; +use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; use crate::hir::resolution::errors::ResolverError; use crate::hir::type_check::TypeCheckError; use crate::locations::ReferencesTracker; @@ -21,8 +21,8 @@ use crate::node_interner::{ }; use crate::ast::{ - ExpressionKind, GenericTypeArgs, Ident, ItemVisibility, LetStatement, Literal, NoirFunction, - NoirStruct, NoirTrait, NoirTypeAlias, Path, PathKind, PathSegment, UnresolvedGenerics, + ExpressionKind, Ident, ItemVisibility, LetStatement, Literal, NoirFunction, NoirStruct, + NoirTrait, NoirTypeAlias, Path, PathKind, PathSegment, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnsupportedNumericGenericType, }; @@ -83,8 +83,7 @@ pub struct UnresolvedTrait { pub struct UnresolvedTraitImpl { pub file_id: FileId, pub module_id: LocalModuleId, - pub trait_generics: GenericTypeArgs, - pub trait_path: Path, + pub r#trait: UnresolvedType, pub object_type: UnresolvedType, pub methods: UnresolvedFunctions, pub generics: UnresolvedGenerics, @@ -431,14 +430,12 @@ impl DefCollector { Some(defining_module), ); - if let ModuleDefId::TraitId(trait_id) = module_def_id { - context.def_interner.add_trait_reexport( - trait_id, - defining_module, - name.clone(), - visibility, - ); - } + context.def_interner.add_reexport( + module_def_id, + defining_module, + name.clone(), + visibility, + ); } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index f6f31638557c..59e1f2f6e329 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -198,8 +198,6 @@ impl<'a> ModCollector<'a> { let mut errors = Vec::new(); for mut trait_impl in impls { - let trait_name = trait_impl.trait_name.clone(); - let (mut unresolved_functions, associated_types, associated_constants) = collect_trait_impl_items( &mut context.def_interner, @@ -233,12 +231,11 @@ impl<'a> ModCollector<'a> { let unresolved_trait_impl = UnresolvedTraitImpl { file_id: self.file_id, module_id: self.module_id, - trait_path: trait_name, + r#trait: trait_impl.r#trait, methods: unresolved_functions, object_type: trait_impl.object_type, generics: trait_impl.impl_generics, where_clause: trait_impl.where_clause, - trait_generics: trait_impl.trait_generics, associated_constants, associated_types, @@ -961,7 +958,14 @@ fn push_child_module( ); if interner.is_in_lsp_mode() { - interner.register_module(mod_id, visibility, mod_name.0.contents.clone()); + let parent_module_id = ModuleId { krate: def_map.krate, local_id: parent }; + interner.register_module( + mod_id, + location, + visibility, + mod_name.0.contents.clone(), + parent_module_id, + ); } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs index 68ac84d42c63..13c855c6fe7d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -102,6 +102,8 @@ pub enum ResolverError { LoopInConstrainedFn { span: Span }, #[error("`loop` must have at least one `break` in it")] LoopWithoutBreak { span: Span }, + #[error("`while` is only allowed in unconstrained functions")] + WhileInConstrainedFn { span: Span }, #[error("break/continue are only allowed within loops")] JumpOutsideLoop { is_break: bool, span: Span }, #[error("Only `comptime` globals can be mutable")] @@ -178,6 +180,8 @@ pub enum ResolverError { }, #[error("`loop` statements are not yet implemented")] LoopNotYetSupported { span: Span }, + #[error("Expected a trait but found {found}")] + ExpectedTrait { found: String, span: Span }, } impl ResolverError { @@ -442,7 +446,7 @@ impl<'a> From<&'a ResolverError> for Diagnostic { }, ResolverError::LoopInConstrainedFn { span } => { Diagnostic::simple_error( - "loop is only allowed in unconstrained functions".into(), + "`loop` is only allowed in unconstrained functions".into(), "Constrained code must always have a known number of loop iterations".into(), *span, ) @@ -454,6 +458,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) }, + ResolverError::WhileInConstrainedFn { span } => { + Diagnostic::simple_error( + "`while` is only allowed in unconstrained functions".into(), + "Constrained code must always have a known number of loop iterations".into(), + *span, + ) + }, ResolverError::JumpOutsideLoop { is_break, span } => { let item = if *is_break { "break" } else { "continue" }; Diagnostic::simple_error( @@ -672,8 +683,12 @@ impl<'a> From<&'a ResolverError> for Diagnostic { diagnostic }, ResolverError::LoopNotYetSupported { span } => { + let msg = "`loop` statements are not yet implemented".to_string(); + Diagnostic::simple_error(msg, String::new(), *span) + } + ResolverError::ExpectedTrait { found, span } => { Diagnostic::simple_error( - "`loop` statements are not yet implemented".to_string(), + format!("Expected a trait, found {found}"), String::new(), *span) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs index 543c13fac9cd..74028fa38093 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs @@ -37,6 +37,7 @@ pub enum HirExpression { Constrain(HirConstrainExpression), Cast(HirCastExpression), If(HirIfExpression), + Match(HirMatch), Tuple(Vec), Lambda(HirLambda), Quote(Tokens), @@ -354,3 +355,160 @@ pub struct HirLambda { pub body: ExprId, pub captures: Vec, } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HirMatch { + /// Jump directly to ExprId + Success(ExprId), + + Failure, + + /// Run `body` if the given expression is true. + /// Otherwise continue with the given decision tree. + Guard { + cond: ExprId, + body: ExprId, + otherwise: Box, + }, + + /// Switch on the given variable with the given cases to test. + /// The final argument is an optional match-all case to take if + /// none of the cases matched. + Switch(DefinitionId, Vec, Option>), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Case { + pub constructor: Constructor, + pub arguments: Vec, + pub body: HirMatch, +} + +impl Case { + pub fn new(constructor: Constructor, arguments: Vec, body: HirMatch) -> Self { + Self { constructor, arguments, body } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct SignedField { + pub field: FieldElement, + pub is_negative: bool, +} + +impl SignedField { + pub fn new(field: FieldElement, is_negative: bool) -> Self { + Self { field, is_negative } + } +} + +impl std::ops::Neg for SignedField { + type Output = Self; + + fn neg(mut self) -> Self::Output { + self.is_negative = !self.is_negative; + self + } +} + +impl std::cmp::PartialOrd for SignedField { + fn partial_cmp(&self, other: &Self) -> Option { + if self.is_negative != other.is_negative { + if self.is_negative { + return Some(std::cmp::Ordering::Less); + } else { + return Some(std::cmp::Ordering::Greater); + } + } + self.field.partial_cmp(&other.field) + } +} + +impl std::fmt::Display for SignedField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_negative { + write!(f, "-")?; + } + write!(f, "{}", self.field) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum Constructor { + True, + False, + Unit, + Int(SignedField), + Tuple(Vec), + Variant(Type, usize), + Range(SignedField, SignedField), +} + +impl Constructor { + pub fn variant_index(&self) -> usize { + match self { + Constructor::False + | Constructor::Int(_) + | Constructor::Unit + | Constructor::Tuple(_) + | Constructor::Range(_, _) => 0, + Constructor::True => 1, + Constructor::Variant(_, index) => *index, + } + } + + /// True if this constructor constructs an enum value. + /// Enums contain a tag value and often have values to + /// unpack for each different variant index. + pub fn is_enum(&self) -> bool { + match self { + Constructor::Variant(typ, _) => match typ.follow_bindings_shallow().as_ref() { + Type::DataType(def, _) => def.borrow().is_enum(), + _ => false, + }, + _ => false, + } + } + + /// True if this constructor constructs a tuple or struct value. + /// Tuples or structs will still have values to unpack but do not + /// store a tag value internally. + pub fn is_tuple_or_struct(&self) -> bool { + match self { + Constructor::Tuple(_) => true, + Constructor::Variant(typ, _) => match typ.follow_bindings_shallow().as_ref() { + Type::DataType(def, _) => def.borrow().is_struct(), + _ => false, + }, + _ => false, + } + } +} + +impl std::fmt::Display for Constructor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Constructor::True => write!(f, "true"), + Constructor::False => write!(f, "false"), + Constructor::Unit => write!(f, "()"), + Constructor::Int(x) => write!(f, "{x}"), + // We already print the arguments of a constructor after this in the format of `(x, y)`. + // In that case it is already in the format of a tuple so there's nothing more we need + // to do here. This is implicitly assuming we never display a constructor without also + // displaying its arguments though. + Constructor::Tuple(_) => Ok(()), + Constructor::Variant(typ, variant_index) => { + if let Type::DataType(def, _) = typ { + let def = def.borrow(); + if let Some(variant) = def.get_variant_as_written(*variant_index) { + write!(f, "{}", variant.name)?; + } else if def.is_struct() { + write!(f, "{}", def.name)?; + } + } + Ok(()) + } + Constructor::Range(start, end) => write!(f, "{start} .. {end}"), + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs index 96ef7161341a..b0e004349032 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -15,6 +15,7 @@ pub enum HirStatement { Assign(HirAssignStatement), For(HirForStatement), Loop(ExprId), + While(ExprId, ExprId), Break, Continue, Expression(ExprId), @@ -45,6 +46,11 @@ impl HirLetStatement { Self { pattern, r#type, expression, attributes, comptime, is_global_let } } + /// Creates a new 'basic' let statement with no attributes and is not comptime nor global. + pub fn basic(pattern: HirPattern, r#type: Type, expression: ExprId) -> HirLetStatement { + Self::new(pattern, r#type, expression, Vec::new(), false, false) + } + pub fn ident(&self) -> HirIdent { match &self.pattern { HirPattern::Identifier(ident) => ident.clone(), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index a79af9a76305..2afaced32c71 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -473,6 +473,10 @@ impl DataType { matches!(&self.body, TypeBody::Struct(_)) } + pub fn is_enum(&self) -> bool { + matches!(&self.body, TypeBody::Enum(_)) + } + /// Retrieve the fields of this type with no modifications. /// Returns None if this is not a struct type. pub fn fields_raw(&self) -> Option<&[StructField]> { @@ -555,6 +559,21 @@ impl DataType { })) } + /// Retrieve the given variant at the given variant index of this type. + /// Returns None if this is not an enum type or `variant_index` is out of bounds. + pub fn get_variant( + &self, + variant_index: usize, + generic_args: &[Type], + ) -> Option<(String, Vec)> { + let substitutions = self.get_fields_substitutions(generic_args); + let variant = self.variants_raw()?.get(variant_index)?; + + let name = variant.name.to_string(); + let args = vecmap(&variant.params, |param| param.substitute(&substitutions)); + Some((name, args)) + } + fn get_fields_substitutions( &self, generic_args: &[Type], @@ -591,6 +610,15 @@ impl DataType { Some(self.variants_raw()?.to_vec()) } + /// Returns the name and raw parameters of the variant at the given variant index. + /// This will not substitute any generic arguments so a generic variant like `X` + /// in `enum Foo { X(T) }` will return a `("X", Vec)` pair. + /// + /// Returns None if this is not an enum type or the given variant index is out of bounds. + pub fn get_variant_as_written(&self, variant_index: usize) -> Option<&EnumVariant> { + self.variants_raw()?.get(variant_index) + } + /// Returns the field at the given index. Panics if no field exists at the given index or this /// is not a struct type. pub fn field_at(&self, index: usize) -> &StructField { @@ -2582,6 +2610,10 @@ impl Type { } Cow::Borrowed(self) } + Type::Alias(alias_def, generics) => { + let typ = alias_def.borrow().get_type(generics); + Cow::Owned(typ.follow_bindings_shallow().into_owned()) + } other => Cow::Borrowed(other), } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/locations.rs b/noir/noir-repo/compiler/noirc_frontend/src/locations.rs index 08100a3c3517..e68a5d8c5d88 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/locations.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/locations.rs @@ -55,6 +55,30 @@ impl<'a> ReferencesTracker<'a> { } } +/// A `ModuleDefId` captured to be offered in LSP's auto-import feature. +/// +/// The name of the item is stored in the key of the `auto_import_names` map in the `NodeInterner`. +#[derive(Debug, Copy, Clone)] +pub struct AutoImportEntry { + /// The item to import. + pub module_def_id: ModuleDefId, + /// The item's visibility. + pub visibility: ItemVisibility, + /// If the item is available via a re-export, this contains the module where it's defined. + /// For example: + /// + /// ```noir + /// mod foo { // <- this is the defining module + /// mod bar { + /// pub struct Baz {} // This is the item + /// } + /// + /// pub use bar::Baz; // Here's the visibility + /// } + /// ``` + pub defining_module: Option, +} + impl NodeInterner { pub fn reference_location(&self, reference: ReferenceId) -> Location { match reference { @@ -309,9 +333,12 @@ impl NodeInterner { pub(crate) fn register_module( &mut self, id: ModuleId, + location: Location, visibility: ItemVisibility, name: String, + parent_module_id: ModuleId, ) { + self.add_definition_location(ReferenceId::Module(id), location, Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::ModuleId(id), visibility, None); } @@ -381,13 +408,11 @@ impl NodeInterner { } let entry = self.auto_import_names.entry(name).or_default(); - entry.push((module_def_id, visibility, defining_module)); + entry.push(AutoImportEntry { module_def_id, visibility, defining_module }); } #[allow(clippy::type_complexity)] - pub fn get_auto_import_names( - &self, - ) -> &HashMap)>> { + pub fn get_auto_import_names(&self) -> &HashMap> { &self.auto_import_names } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs index 621eb30e4f88..6cd5073aadb9 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -9,6 +9,7 @@ use noirc_errors::{ use crate::{ ast::{BinaryOpKind, IntegerBitSize, Signedness, Visibility}, + hir_def::expr::Constructor, token::{Attributes, FunctionAttribute}, }; use crate::{hir_def::function::FunctionSignature, token::FmtStrFragment}; @@ -37,7 +38,9 @@ pub enum Expression { Cast(Cast), For(For), Loop(Box), + While(While), If(If), + Match(Match), Tuple(Vec), ExtractTupleField(Box, usize), Call(Call), @@ -110,6 +113,12 @@ pub struct For { pub end_range_location: Location, } +#[derive(Debug, Clone, Hash)] +pub struct While { + pub condition: Box, + pub body: Box, +} + #[derive(Debug, Clone, Hash)] pub enum Literal { Array(ArrayLiteral), @@ -153,6 +162,21 @@ pub struct If { pub typ: Type, } +#[derive(Debug, Clone, Hash)] +pub struct Match { + pub variable_to_match: LocalId, + pub cases: Vec, + pub default_case: Option>, + pub typ: Type, +} + +#[derive(Debug, Clone, Hash)] +pub struct MatchCase { + pub constructor: Constructor, + pub arguments: Vec, + pub branch: Expression, +} + #[derive(Debug, Clone, Hash)] pub struct Cast { pub lhs: Box, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs index 5d81913f4ecb..d30229ce97d5 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -25,7 +25,7 @@ use crate::{ Kind, Type, TypeBinding, TypeBindings, }; use acvm::{acir::AcirField, FieldElement}; -use ast::GlobalId; +use ast::{GlobalId, While}; use fxhash::FxHashMap as HashMap; use iter_extended::{btree_map, try_vecmap, vecmap}; use noirc_errors::Location; @@ -591,6 +591,8 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::If(ast::If { condition, consequence, alternative: else_, typ }) } + HirExpression::Match(match_expr) => self.match_expr(match_expr, expr)?, + HirExpression::Tuple(fields) => { let fields = try_vecmap(fields, |id| self.expr(id))?; ast::Expression::Tuple(fields) @@ -703,6 +705,11 @@ impl<'interner> Monomorphizer<'interner> { let block = Box::new(self.expr(block)?); Ok(ast::Expression::Loop(block)) } + HirStatement::While(condition, body) => { + let condition = Box::new(self.expr(condition)?); + let body = Box::new(self.expr(body)?); + Ok(ast::Expression::While(While { condition, body })) + } HirStatement::Expression(expr) => self.expr(expr), HirStatement::Semi(expr) => { self.expr(expr).map(|expr| ast::Expression::Semi(Box::new(expr))) @@ -1982,6 +1989,68 @@ impl<'interner> Monomorphizer<'interner> { Ok((block_let_stmt, closure_ident)) } + fn match_expr( + &mut self, + match_expr: HirMatch, + expr_id: ExprId, + ) -> Result { + match match_expr { + HirMatch::Success(id) => self.expr(id), + HirMatch::Failure => { + let false_ = Box::new(ast::Expression::Literal(ast::Literal::Bool(false))); + let msg = "match failure"; + let msg_expr = ast::Expression::Literal(ast::Literal::Str(msg.to_string())); + + let u32_type = HirType::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo); + let length = (msg.len() as u128).into(); + let length = HirType::Constant(length, Kind::Numeric(Box::new(u32_type))); + let msg_type = HirType::String(Box::new(length)); + + let msg = Some(Box::new((msg_expr, msg_type))); + let location = self.interner.expr_location(&expr_id); + Ok(ast::Expression::Constrain(false_, location, msg)) + } + HirMatch::Guard { cond, body, otherwise } => { + let condition = Box::new(self.expr(cond)?); + let consequence = Box::new(self.expr(body)?); + let alternative = Some(Box::new(self.match_expr(*otherwise, expr_id)?)); + let location = self.interner.expr_location(&expr_id); + let typ = Self::convert_type(&self.interner.id_type(expr_id), location)?; + Ok(ast::Expression::If(ast::If { condition, consequence, alternative, typ })) + } + HirMatch::Switch(variable_to_match, cases, default) => { + let variable_to_match = match self.lookup_local(variable_to_match) { + Some(Definition::Local(id)) => id, + other => unreachable!("Expected match variable to be defined. Found {other:?}"), + }; + + let cases = try_vecmap(cases, |case| { + let arguments = vecmap(case.arguments, |arg| { + let new_id = self.next_local_id(); + self.define_local(arg, new_id); + new_id + }); + let branch = self.match_expr(case.body, expr_id)?; + Ok(ast::MatchCase { constructor: case.constructor, arguments, branch }) + })?; + + let default_case = match default { + Some(case) => Some(Box::new(self.match_expr(*case, expr_id)?)), + None => None, + }; + + let location = self.interner.expr_location(&expr_id); + let typ = Self::convert_type(&self.interner.id_type(expr_id), location)?; + Ok(ast::Expression::Match(ast::Match { + variable_to_match, + cases, + default_case, + typ, + })) + } + } + } + /// Implements std::unsafe_func::zeroed by returning an appropriate zeroed /// ast literal or collection node for the given type. Note that for functions /// there is no obvious zeroed value so this should be considered unsafe to use. diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs index 665f4dcd371b..df4340b4e0d3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs @@ -1,6 +1,6 @@ //! This module implements printing of the monomorphized AST, for debugging purposes. -use super::ast::{Definition, Expression, Function, LValue}; +use super::ast::{Definition, Expression, Function, LValue, While}; use iter_extended::vecmap; use std::fmt::{Display, Formatter}; @@ -50,7 +50,9 @@ impl AstPrinter { } Expression::For(for_expr) => self.print_for(for_expr, f), Expression::Loop(block) => self.print_loop(block, f), + Expression::While(while_) => self.print_while(while_, f), Expression::If(if_expr) => self.print_if(if_expr, f), + Expression::Match(match_expr) => self.print_match(match_expr, f), Expression::Tuple(tuple) => self.print_tuple(tuple, f), Expression::ExtractTupleField(expr, index) => { self.print_expr(expr, f)?; @@ -105,7 +107,7 @@ impl AstPrinter { } super::ast::Literal::Integer(x, _, _, _) => x.fmt(f), super::ast::Literal::Bool(x) => x.fmt(f), - super::ast::Literal::Str(s) => s.fmt(f), + super::ast::Literal::Str(s) => write!(f, "\"{s}\""), super::ast::Literal::FmtStr(fragments, _, _) => { write!(f, "f\"")?; for fragment in fragments { @@ -219,6 +221,17 @@ impl AstPrinter { write!(f, "}}") } + fn print_while(&mut self, while_: &While, f: &mut Formatter) -> Result<(), std::fmt::Error> { + write!(f, "while ")?; + self.print_expr(&while_.condition, f)?; + write!(f, " {{")?; + self.indent_level += 1; + self.print_expr_expect_block(&while_.body, f)?; + self.indent_level -= 1; + self.next_line(f)?; + write!(f, "}}") + } + fn print_if( &mut self, if_expr: &super::ast::If, @@ -243,6 +256,44 @@ impl AstPrinter { write!(f, "}}") } + fn print_match( + &mut self, + match_expr: &super::ast::Match, + f: &mut Formatter, + ) -> Result<(), std::fmt::Error> { + write!(f, "match ${} {{", match_expr.variable_to_match.0)?; + self.indent_level += 1; + self.next_line(f)?; + + for (i, case) in match_expr.cases.iter().enumerate() { + write!(f, "{}", case.constructor)?; + let args = vecmap(&case.arguments, |arg| format!("${}", arg.0)).join(", "); + if !args.is_empty() { + write!(f, "({args})")?; + } + write!(f, " => ")?; + self.print_expr(&case.branch, f)?; + write!(f, ",")?; + + if i != match_expr.cases.len() - 1 { + self.next_line(f)?; + } + } + self.indent_level -= 1; + + if let Some(default) = &match_expr.default_case { + self.indent_level += 1; + self.next_line(f)?; + write!(f, "_ => ")?; + self.print_expr(default, f)?; + write!(f, ",")?; + self.indent_level -= 1; + } + + self.next_line(f)?; + write!(f, "}}") + } + fn print_comma_separated( &mut self, exprs: &[Expression], diff --git a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs index 1ebcb6aff967..f995b20bc1a7 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs @@ -24,6 +24,7 @@ use crate::hir::def_map::{LocalModuleId, ModuleDefId, ModuleId}; use crate::hir::type_check::generics::TraitGenerics; use crate::hir_def::traits::NamedType; use crate::hir_def::traits::ResolvedTraitBound; +use crate::locations::AutoImportEntry; use crate::QuotedType; use crate::ast::{BinaryOpKind, FunctionDefinition, ItemVisibility}; @@ -254,8 +255,7 @@ pub struct NodeInterner { // The third value in the tuple is the module where the definition is (only for pub use). // These include top-level functions, global variables and types, but excludes // impl and trait-impl methods. - pub(crate) auto_import_names: - HashMap)>>, + pub(crate) auto_import_names: HashMap>, /// Each value currently in scope in the comptime interpreter. /// Each element of the Vec represents a scope with every scope together making @@ -268,10 +268,10 @@ pub struct NodeInterner { /// Captures the documentation comments for each module, struct, trait, function, etc. pub(crate) doc_comments: HashMap>, - /// Only for LSP: a map of trait ID to each module that pub or pub(crate) exports it. - /// In LSP this is used to offer importing the trait via one of these exports if - /// the trait is not visible where it's defined. - trait_reexports: HashMap>, + /// Only for LSP: a map of ModuleDefId to each module that pub or pub(crate) exports it. + /// In LSP this is used to offer importing the item via one of these exports if + /// the item is not visible where it's defined. + reexports: HashMap>, } /// A dependency in the dependency graph may be a type or a definition. @@ -637,6 +637,24 @@ pub struct InternedUnresolvedTypeData(noirc_arena::Index); #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct InternedPattern(noirc_arena::Index); +/// Captures a reexport that happens inside a module. For example: +/// +/// ```noir +/// mod moo { +/// // ^^^ module_id +/// +/// pub use foo::bar as baz; +/// //^^^ visibility ^^^ name +/// } +/// ``` +/// +#[derive(Debug, Clone)] +pub struct Reexport { + pub module_id: ModuleId, + pub name: Ident, + pub visibility: ItemVisibility, +} + impl Default for NodeInterner { fn default() -> Self { NodeInterner { @@ -686,7 +704,7 @@ impl Default for NodeInterner { comptime_scopes: vec![HashMap::default()], trait_impl_associated_types: HashMap::default(), doc_comments: HashMap::default(), - trait_reexports: HashMap::default(), + reexports: HashMap::default(), } } } @@ -2268,18 +2286,26 @@ impl NodeInterner { } } - pub fn add_trait_reexport( + pub fn add_reexport( &mut self, - trait_id: TraitId, + module_def_id: ModuleDefId, module_id: ModuleId, name: Ident, visibility: ItemVisibility, ) { - self.trait_reexports.entry(trait_id).or_default().push((module_id, name, visibility)); + self.reexports.entry(module_def_id).or_default().push(Reexport { + module_id, + name, + visibility, + }); + } + + pub fn get_reexports(&self, module_def_id: ModuleDefId) -> &[Reexport] { + self.reexports.get(&module_def_id).map_or(&[], |reexport| reexport) } - pub fn get_trait_reexports(&self, trait_id: TraitId) -> &[(ModuleId, Ident, ItemVisibility)] { - self.trait_reexports.get(&trait_id).map_or(&[], |exports| exports) + pub fn get_trait_reexports(&self, trait_id: TraitId) -> &[Reexport] { + self.get_reexports(ModuleDefId::TraitId(trait_id)) } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs index 189b880d45ea..8345b75dbabf 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs @@ -23,8 +23,6 @@ pub enum ParserErrorReason { ExpectedMutAfterAmpersand { found: Token }, #[error("Invalid left-hand side of assignment")] InvalidLeftHandSideOfAssignment, - #[error("Expected trait, found {found}")] - ExpectedTrait { found: String }, #[error("Visibility `{visibility}` is not followed by an item")] VisibilityNotFollowedByAnItem { visibility: ItemVisibility }, #[error("`unconstrained` is not followed by an item")] diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs index 278c20e1e27f..4b5054984d4f 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs @@ -2,9 +2,9 @@ use noirc_errors::Span; use crate::{ ast::{ - Documented, Expression, ExpressionKind, GenericTypeArgs, Ident, ItemVisibility, - NoirFunction, NoirTraitImpl, Path, TraitImplItem, TraitImplItemKind, TypeImpl, - UnresolvedGeneric, UnresolvedType, UnresolvedTypeData, + Documented, Expression, ExpressionKind, Ident, ItemVisibility, NoirFunction, NoirTraitImpl, + TraitImplItem, TraitImplItemKind, TypeImpl, UnresolvedGeneric, UnresolvedType, + UnresolvedTypeData, }, parser::{labels::ParsingRuleLabel, ParserErrorReason}, token::{Keyword, Token}, @@ -29,24 +29,10 @@ impl<'a> Parser<'a> { let type_span = self.span_since(type_span_start); if self.eat_keyword(Keyword::For) { - if let UnresolvedTypeData::Named(trait_name, trait_generics, _) = object_type.typ { - return Impl::TraitImpl(self.parse_trait_impl( - generics, - trait_generics, - trait_name, - )); - } else { - self.push_error( - ParserErrorReason::ExpectedTrait { found: object_type.typ.to_string() }, - self.current_token_span, - ); - - // Error, but we continue parsing the type and assume this is going to be a regular type impl - self.parse_type(); - }; + Impl::TraitImpl(self.parse_trait_impl(generics, object_type)) + } else { + Impl::Impl(self.parse_type_impl(object_type, type_span, generics)) } - - self.parse_type_impl(object_type, type_span, generics) } /// TypeImpl = 'impl' Generics Type TypeImplBody @@ -55,11 +41,10 @@ impl<'a> Parser<'a> { object_type: UnresolvedType, type_span: Span, generics: Vec, - ) -> Impl { + ) -> TypeImpl { let where_clause = self.parse_where_clause(); let methods = self.parse_type_impl_body(); - - Impl::Impl(TypeImpl { object_type, type_span, generics, where_clause, methods }) + TypeImpl { object_type, type_span, generics, where_clause, methods } } /// TypeImplBody = '{' TypeImplItem* '}' @@ -103,27 +88,18 @@ impl<'a> Parser<'a> { }) } - /// TraitImpl = 'impl' Generics Path GenericTypeArgs 'for' Type TraitImplBody + /// TraitImpl = 'impl' Generics Type 'for' Type TraitImplBody fn parse_trait_impl( &mut self, impl_generics: Vec, - trait_generics: GenericTypeArgs, - trait_name: Path, + r#trait: UnresolvedType, ) -> NoirTraitImpl { let object_type = self.parse_type_or_error(); let where_clause = self.parse_where_clause(); let items = self.parse_trait_impl_body(); let is_synthetic = false; - NoirTraitImpl { - impl_generics, - trait_name, - trait_generics, - object_type, - where_clause, - items, - is_synthetic, - } + NoirTraitImpl { impl_generics, r#trait, object_type, where_clause, items, is_synthetic } } /// TraitImplBody = '{' TraitImplItem* '}' @@ -454,7 +430,12 @@ mod tests { fn parse_empty_trait_impl() { let src = "impl Foo for Field {}"; let trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); + + let UnresolvedTypeData::Named(trait_name, _, _) = trait_impl.r#trait.typ else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert!(matches!(trait_impl.object_type.typ, UnresolvedTypeData::FieldElement)); assert!(trait_impl.items.is_empty()); assert!(trait_impl.impl_generics.is_empty()); @@ -464,7 +445,12 @@ mod tests { fn parse_empty_trait_impl_with_generics() { let src = "impl Foo for Field {}"; let trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); + + let UnresolvedTypeData::Named(trait_name, _, _) = trait_impl.r#trait.typ else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert!(matches!(trait_impl.object_type.typ, UnresolvedTypeData::FieldElement)); assert!(trait_impl.items.is_empty()); assert_eq!(trait_impl.impl_generics.len(), 1); @@ -474,7 +460,12 @@ mod tests { fn parse_trait_impl_with_function() { let src = "impl Foo for Field { fn foo() {} }"; let mut trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); + + let UnresolvedTypeData::Named(trait_name, _, _) = trait_impl.r#trait.typ else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(trait_impl.items.len(), 1); let item = trait_impl.items.remove(0).item; @@ -489,15 +480,26 @@ mod tests { fn parse_trait_impl_with_generic_type_args() { let src = "impl Foo for Field { }"; let trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); - assert!(!trait_impl.trait_generics.is_empty()); + + let UnresolvedTypeData::Named(trait_name, trait_generics, _) = trait_impl.r#trait.typ + else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); + assert!(!trait_generics.is_empty()); } #[test] fn parse_trait_impl_with_type() { let src = "impl Foo for Field { type Foo = i32; }"; let mut trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); + + let UnresolvedTypeData::Named(trait_name, _, _) = trait_impl.r#trait.typ else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(trait_impl.items.len(), 1); let item = trait_impl.items.remove(0).item; @@ -512,7 +514,12 @@ mod tests { fn parse_trait_impl_with_let() { let src = "impl Foo for Field { let x: Field = 1; }"; let mut trait_impl = parse_trait_impl_no_errors(src); - assert_eq!(trait_impl.trait_name.to_string(), "Foo"); + + let UnresolvedTypeData::Named(trait_name, _, _) = trait_impl.r#trait.typ else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(trait_impl.items.len(), 1); let item = trait_impl.items.remove(0).item; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs index d20f8c29e1ed..d8679da6ba88 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs @@ -4,7 +4,7 @@ use crate::{ ast::{ AssignStatement, BinaryOp, BinaryOpKind, Expression, ExpressionKind, ForBounds, ForLoopStatement, ForRange, Ident, InfixExpression, LValue, LetStatement, Statement, - StatementKind, + StatementKind, WhileStatement, }, parser::{labels::ParsingRuleLabel, ParserErrorReason}, token::{Attribute, Keyword, Token, TokenKind}, @@ -81,6 +81,7 @@ impl<'a> Parser<'a> { /// | ComptimeStatement /// | ForStatement /// | LoopStatement + /// | WhileStatement /// | IfStatement /// | BlockStatement /// | AssignStatement @@ -145,6 +146,10 @@ impl<'a> Parser<'a> { return Some(StatementKind::Loop(block, span)); } + if let Some(while_) = self.parse_while() { + return Some(StatementKind::While(while_)); + } + if let Some(kind) = self.parse_if_expr() { let span = self.span_since(start_span); return Some(StatementKind::Expression(Expression { kind, span })); @@ -302,6 +307,31 @@ impl<'a> Parser<'a> { Some((block, start_span)) } + /// WhileStatement = 'while' ExpressionExceptConstructor Block + fn parse_while(&mut self) -> Option { + let start_span = self.current_token_span; + if !self.eat_keyword(Keyword::While) { + return None; + } + + self.push_error(ParserErrorReason::ExperimentalFeature("while loops"), start_span); + + let condition = self.parse_expression_except_constructor_or_error(); + + let block_start_span = self.current_token_span; + let block = if let Some(block) = self.parse_block() { + Expression { + kind: ExpressionKind::Block(block), + span: self.span_since(block_start_span), + } + } else { + self.expected_token(Token::LeftBrace); + Expression { kind: ExpressionKind::Error, span: self.span_since(block_start_span) } + }; + + Some(WhileStatement { condition, body: block, while_keyword_span: start_span }) + } + /// ForRange /// = ExpressionExceptConstructor /// | ExpressionExceptConstructor '..' ExpressionExceptConstructor @@ -771,4 +801,36 @@ mod tests { }; assert!(matches!(let_statement.expression.kind, ExpressionKind::Constrain(..))); } + + #[test] + fn parses_empty_while() { + let src = "while true { }"; + let mut parser = Parser::for_str(src); + let statement = parser.parse_statement_or_error(); + let StatementKind::While(while_) = statement.kind else { + panic!("Expected while"); + }; + let ExpressionKind::Block(block) = while_.body.kind else { + panic!("Expected block"); + }; + assert!(block.statements.is_empty()); + assert_eq!(while_.while_keyword_span.start(), 0); + assert_eq!(while_.while_keyword_span.end(), 5); + + assert_eq!(while_.condition.to_string(), "true"); + } + + #[test] + fn parses_while_with_statements() { + let src = "while true { 1; 2 }"; + let mut parser = Parser::for_str(src); + let statement = parser.parse_statement_or_error(); + let StatementKind::While(while_) = statement.kind else { + panic!("Expected while"); + }; + let ExpressionKind::Block(block) = while_.body.kind else { + panic!("Expected block"); + }; + assert_eq!(block.statements.len(), 2); + } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs index 6f6a9bab9603..53df5cd00b1b 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -92,6 +92,11 @@ impl<'a> Parser<'a> { }) .into(); + let r#trait = UnresolvedType { + typ: UnresolvedTypeData::Named(trait_name, trait_generics, false), + span, + }; + // bounds from trait let mut where_clause = where_clause.clone(); for bound in bounds.clone() { @@ -104,15 +109,7 @@ impl<'a> Parser<'a> { let items = vec![]; let is_synthetic = true; - NoirTraitImpl { - impl_generics, - trait_name, - trait_generics, - object_type, - where_clause, - items, - is_synthetic, - } + NoirTraitImpl { impl_generics, r#trait, object_type, where_clause, items, is_synthetic } }); let noir_trait = NoirTrait { @@ -287,7 +284,7 @@ fn empty_trait( #[cfg(test)] mod tests { use crate::{ - ast::{NoirTrait, NoirTraitImpl, TraitItem}, + ast::{NoirTrait, NoirTraitImpl, TraitItem, UnresolvedTypeData}, parser::{ parser::{ parse_program, @@ -374,9 +371,14 @@ mod tests { assert!(noir_trait_alias.items.is_empty()); assert!(noir_trait_alias.is_alias); - assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + let UnresolvedTypeData::Named(trait_name, trait_generics, _) = noir_trait_impl.r#trait.typ + else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(noir_trait_impl.impl_generics.len(), 3); - assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(trait_generics.ordered_args.len(), 2); assert_eq!(noir_trait_impl.where_clause.len(), 2); assert_eq!(noir_trait_alias.bounds.len(), 2); assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); @@ -428,9 +430,14 @@ mod tests { assert!(noir_trait_alias.items.is_empty()); assert!(noir_trait_alias.is_alias); - assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + let UnresolvedTypeData::Named(trait_name, trait_generics, _) = noir_trait_impl.r#trait.typ + else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(noir_trait_impl.impl_generics.len(), 3); - assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(trait_generics.ordered_args.len(), 2); assert_eq!(noir_trait_impl.where_clause.len(), 3); assert_eq!(noir_trait_impl.where_clause[0].to_string(), "A: Z"); assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Bar"); @@ -557,9 +564,14 @@ mod tests { assert_eq!(noir_trait_alias.to_string(), "trait Foo = Bar + Baz;"); assert!(noir_trait_alias.is_alias); - assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + let UnresolvedTypeData::Named(trait_name, trait_generics, _) = noir_trait_impl.r#trait.typ + else { + panic!("Expected name type"); + }; + + assert_eq!(trait_name.to_string(), "Foo"); assert_eq!(noir_trait_impl.impl_generics.len(), 1); - assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 0); + assert_eq!(trait_generics.ordered_args.len(), 0); assert_eq!(noir_trait_impl.where_clause.len(), 2); assert_eq!(noir_trait_impl.where_clause[0].to_string(), "#T: Bar"); assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Baz"); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index af1948f012e7..1b3a19a5cfc3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -903,6 +903,7 @@ fn find_lambda_captures(stmts: &[StmtId], interner: &NodeInterner, result: &mut HirStatement::Semi(semi_expr) => semi_expr, HirStatement::For(for_loop) => for_loop.block, HirStatement::Loop(block) => block, + HirStatement::While(_, block) => block, HirStatement::Error => panic!("Invalid HirStatement!"), HirStatement::Break => panic!("Unexpected break"), HirStatement::Continue => panic!("Unexpected continue"), diff --git a/noir/noir-repo/docs/docs/how_to/using-devcontainers.mdx b/noir/noir-repo/docs/docs/how_to/using-devcontainers.mdx index 4442f94dc4c7..dcd30d898015 100644 --- a/noir/noir-repo/docs/docs/how_to/using-devcontainers.mdx +++ b/noir/noir-repo/docs/docs/how_to/using-devcontainers.mdx @@ -68,7 +68,7 @@ Github comes with a default codespace and you can use it to code your own devcon "image": "mcr.microsoft.com/devcontainers/base:ubuntu", "features": { "ghcr.io/noir-lang/features/noir:latest": { - "version": "1.0.0-beta.1" + "version": "1.0.0-beta.2" } } } diff --git a/noir/noir-repo/docs/docs/noir/concepts/assert.md b/noir/noir-repo/docs/docs/noir/concepts/assert.md index 2132de42072d..c7bc42fa3228 100644 --- a/noir/noir-repo/docs/docs/noir/concepts/assert.md +++ b/noir/noir-repo/docs/docs/noir/concepts/assert.md @@ -10,7 +10,9 @@ sidebar_position: 4 Noir includes a special `assert` function which will explicitly constrain the predicate/comparison expression that follows to be true. If this expression is false at runtime, the program will fail to -be proven. Example: +be proven. As of v1.0.0-beta.2, assert statements are expressions and can be used in value contexts. + +Example: ```rust fn main(x : Field, y : Field) { @@ -75,4 +77,3 @@ fn main(x : Field, y : Field) { static_assert(example_slice.len() == 0, error_message); } ``` - diff --git a/noir/noir-repo/docs/docs/noir/concepts/control_flow.md b/noir/noir-repo/docs/docs/noir/concepts/control_flow.md index 3e2d913ec964..57816c38c575 100644 --- a/noir/noir-repo/docs/docs/noir/concepts/control_flow.md +++ b/noir/noir-repo/docs/docs/noir/concepts/control_flow.md @@ -79,13 +79,13 @@ The iteration variable `i` is still increased by one as normal when `continue` i ## Loops -In unconstrained code, `loop` is allowed for loops that end with a `break`. -A `loop` must have at least one `break` in it. +In unconstrained code, `loop` is allowed for loops that end with a `break`. +A `loop` must contain at least one `break` statement that is reachable during execution. This is only allowed in unconstrained code since normal constrained code requires that Noir knows exactly how many iterations a loop may have. ```rust -let mut i = 10 +let mut i = 10; loop { println(i); i -= 1; @@ -96,3 +96,16 @@ loop { } ``` +## While loops + +In unconstrained code, `while` is allowed for loops that end when a given condition is met. +This is only allowed in unconstrained code since normal constrained code requires that Noir knows exactly how many iterations +a loop may have. + +```rust +let mut i = 0 +while i < 10 { + println(i); + i += 2; +} +``` diff --git a/noir/noir-repo/docs/docs/noir/standard_library/bigint.md b/noir/noir-repo/docs/docs/noir/standard_library/bigint.md deleted file mode 100644 index cc7d6e1c8de0..000000000000 --- a/noir/noir-repo/docs/docs/noir/standard_library/bigint.md +++ /dev/null @@ -1,110 +0,0 @@ ---- -title: Big Integers -description: How to use big integers from Noir standard library -keywords: - [ - Big Integer, - Noir programming language, - Noir libraries, - ] ---- - -The BigInt module in the standard library exposes some class of integers which do not fit (well) into a Noir native field. It implements modulo arithmetic, modulo a 'big' prime number. - -:::note - -The module can currently be considered as `Field`s with fixed modulo sizes used by a set of elliptic curves, in addition to just the native curve. [More work](https://github.com/noir-lang/noir/issues/510) is needed to achieve arbitrarily sized big integers. - -:::note - -`nargo` can be built with `--profile release-pedantic` to enable extra overflow checks which may affect `BigInt` results in some cases. -Consider the [`noir-bignum`](https://github.com/noir-lang/noir-bignum) library for an optimized alternative approach. - -::: - -Currently 6 classes of integers (i.e 'big' prime numbers) are available in the module, namely: - -- BN254 Fq: Bn254Fq -- BN254 Fr: Bn254Fr -- Secp256k1 Fq: Secpk1Fq -- Secp256k1 Fr: Secpk1Fr -- Secp256r1 Fr: Secpr1Fr -- Secp256r1 Fq: Secpr1Fq - -Where XXX Fq and XXX Fr denote respectively the order of the base and scalar field of the (usual) elliptic curve XXX. -For instance the big integer 'Secpk1Fq' in the standard library refers to integers modulo $2^{256}-2^{32}-977$. - -Feel free to explore the source code for the other primes: - -#include_code big_int_definition noir_stdlib/src/bigint.nr rust - -## Example usage - -A common use-case is when constructing a big integer from its bytes representation, and performing arithmetic operations on it: - -#include_code big_int_example test_programs/execution_success/bigint/src/main.nr rust - -## Methods - -The available operations for each big integer are: - -### from_le_bytes - -Construct a big integer from its little-endian bytes representation. Example: - -```rust - // Construct a big integer from a slice of bytes - let a = Secpk1Fq::from_le_bytes(&[x, y, 0, 45, 2]); - // Construct a big integer from an array of 32 bytes - let a = Secpk1Fq::from_le_bytes_32([1;32]); - ``` - -Sure, here's the formatted version of the remaining methods: - -### to_le_bytes - -Return the little-endian bytes representation of a big integer. Example: - -```rust -let bytes = a.to_le_bytes(); -``` - -### add - -Add two big integers. Example: - -```rust -let sum = a + b; -``` - -### sub - -Subtract two big integers. Example: - -```rust -let difference = a - b; -``` - -### mul - -Multiply two big integers. Example: - -```rust -let product = a * b; -``` - -### div - -Divide two big integers. Note that division is field division and not euclidean division. Example: - -```rust -let quotient = a / b; -``` - -### eq - -Compare two big integers. Example: - -```rust -let are_equal = a == b; -``` diff --git a/noir/noir-repo/docs/docs/tutorials/noirjs_app.md b/noir/noir-repo/docs/docs/tutorials/noirjs_app.md index d98cdf1ef56b..40e5d94180de 100644 --- a/noir/noir-repo/docs/docs/tutorials/noirjs_app.md +++ b/noir/noir-repo/docs/docs/tutorials/noirjs_app.md @@ -25,7 +25,7 @@ Let's go barebones. Doing the bare minimum is not only simple, but also allows y Barebones means we can immediately start with the dependencies even on an empty folder 😈: ```bash -bun i @noir-lang/noir_wasm@1.0.0-beta.1 @noir-lang/noir_js@1.0.0-beta.1 @aztec/bb.js@0.63.1 +bun i @noir-lang/noir_wasm@1.0.0-beta.2 @noir-lang/noir_js@1.0.0-beta.2 @aztec/bb.js@0.72.1 ``` Wait, what are these dependencies? @@ -36,7 +36,7 @@ Wait, what are these dependencies? :::info -In this guide, we will install versions pinned to 1.0.0-beta.1. These work with Barretenberg version 0.63.1, so we are using that one version too. Feel free to try with older or later versions, though! +In this guide, we will install versions pinned to 1.0.0-beta.2. These work with Barretenberg version 0.72.1, so we are using that one version too. Feel free to try with older or later versions, though! ::: @@ -50,7 +50,7 @@ It's not just you. We also enjoy syntax highlighting. [Check out the Language Se ::: -All you need is a `main.nr` and a `Nargo.toml` file. You can follow the [noirup](../getting_started/noir_installation.md) installation and just run `noirup -v 1.0.0-beta.1`, or just create them by hand: +All you need is a `main.nr` and a `Nargo.toml` file. You can follow the [noirup](../getting_started/noir_installation.md) installation and just run `noirup -v 1.0.0-beta.2`, or just create them by hand: ```bash mkdir -p circuit/src diff --git a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/how_to/using-devcontainers.mdx b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/how_to/using-devcontainers.mdx index 4442f94dc4c7..79f0c5119510 100644 --- a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/how_to/using-devcontainers.mdx +++ b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/how_to/using-devcontainers.mdx @@ -22,9 +22,9 @@ Enter Codespaces. ## Codespaces -If a devcontainer is just a Docker image, then what stops you from provisioning a `p3dn.24xlarge` AWS EC2 instance with 92 vCPUs and 768 GiB RAM and using it to prove your 10-gate SNARK proof? +If a devcontainer is just a Docker image, then what stops you from provisioning a `p3dn.24xlarge` AWS EC2 instance with 92 vCPUs and 768 GiB RAM and using it to prove your 10-gate SNARK proof? -Nothing! Except perhaps the 30-40$ per hour it will cost you. +Nothing! Except perhaps the 30-40$ per hour it will cost you. The problem is that provisioning takes time, and I bet you don't want to see the AWS console every time you want to code something real quick. @@ -68,7 +68,7 @@ Github comes with a default codespace and you can use it to code your own devcon "image": "mcr.microsoft.com/devcontainers/base:ubuntu", "features": { "ghcr.io/noir-lang/features/noir:latest": { - "version": "1.0.0-beta.1" + "version": "1.0.0-beta.2" } } } @@ -77,15 +77,15 @@ Github comes with a default codespace and you can use it to code your own devcon This will pull the new image and build it, so it could take a minute or so -#### 8. Done! +#### 8. Done! Just wait for the build to finish, and there's your easy Noir environment. Some examples of how to use it can be found in the [awesome-noir](https://github.com/noir-lang/awesome-noir?tab=readme-ov-file#boilerplates) repository. ## How do I use it? -Using the codespace is obviously much easier than setting it up. +Using the codespace is obviously much easier than setting it up. Just navigate to your repository and click "Code" -> "Open with Codespaces". It should take a few seconds to load, and you're ready to go. :::info -If you really like the experience, you can add a badge to your readme, links to existing codespaces, and more. +If you really like the experience, you can add a badge to your readme, links to existing codespaces, and more. Check out the [official docs](https://docs.github.com/en/codespaces/setting-up-your-project-for-codespaces/setting-up-your-repository/facilitating-quick-creation-and-resumption-of-codespaces) for more info. diff --git a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/assert.md b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/assert.md index 2132de42072d..c7bc42fa3228 100644 --- a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/assert.md +++ b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/assert.md @@ -10,7 +10,9 @@ sidebar_position: 4 Noir includes a special `assert` function which will explicitly constrain the predicate/comparison expression that follows to be true. If this expression is false at runtime, the program will fail to -be proven. Example: +be proven. As of v1.0.0-beta.2, assert statements are expressions and can be used in value contexts. + +Example: ```rust fn main(x : Field, y : Field) { @@ -75,4 +77,3 @@ fn main(x : Field, y : Field) { static_assert(example_slice.len() == 0, error_message); } ``` - diff --git a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/control_flow.md b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/control_flow.md index 3e2d913ec964..fa9f17b6c8ec 100644 --- a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/control_flow.md +++ b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/noir/concepts/control_flow.md @@ -79,13 +79,13 @@ The iteration variable `i` is still increased by one as normal when `continue` i ## Loops -In unconstrained code, `loop` is allowed for loops that end with a `break`. -A `loop` must have at least one `break` in it. +In unconstrained code, `loop` is allowed for loops that end with a `break`. +A `loop` must contain at least one `break` statement that is reachable during execution. This is only allowed in unconstrained code since normal constrained code requires that Noir knows exactly how many iterations a loop may have. ```rust -let mut i = 10 +let mut i = 10; loop { println(i); i -= 1; @@ -95,4 +95,3 @@ loop { } } ``` - diff --git a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/tutorials/noirjs_app.md b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/tutorials/noirjs_app.md index d98cdf1ef56b..40e5d94180de 100644 --- a/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/tutorials/noirjs_app.md +++ b/noir/noir-repo/docs/versioned_docs/version-v1.0.0-beta.2/tutorials/noirjs_app.md @@ -25,7 +25,7 @@ Let's go barebones. Doing the bare minimum is not only simple, but also allows y Barebones means we can immediately start with the dependencies even on an empty folder 😈: ```bash -bun i @noir-lang/noir_wasm@1.0.0-beta.1 @noir-lang/noir_js@1.0.0-beta.1 @aztec/bb.js@0.63.1 +bun i @noir-lang/noir_wasm@1.0.0-beta.2 @noir-lang/noir_js@1.0.0-beta.2 @aztec/bb.js@0.72.1 ``` Wait, what are these dependencies? @@ -36,7 +36,7 @@ Wait, what are these dependencies? :::info -In this guide, we will install versions pinned to 1.0.0-beta.1. These work with Barretenberg version 0.63.1, so we are using that one version too. Feel free to try with older or later versions, though! +In this guide, we will install versions pinned to 1.0.0-beta.2. These work with Barretenberg version 0.72.1, so we are using that one version too. Feel free to try with older or later versions, though! ::: @@ -50,7 +50,7 @@ It's not just you. We also enjoy syntax highlighting. [Check out the Language Se ::: -All you need is a `main.nr` and a `Nargo.toml` file. You can follow the [noirup](../getting_started/noir_installation.md) installation and just run `noirup -v 1.0.0-beta.1`, or just create them by hand: +All you need is a `main.nr` and a `Nargo.toml` file. You can follow the [noirup](../getting_started/noir_installation.md) installation and just run `noirup -v 1.0.0-beta.2`, or just create them by hand: ```bash mkdir -p circuit/src diff --git a/noir/noir-repo/noir_stdlib/src/bigint.nr b/noir/noir-repo/noir_stdlib/src/bigint.nr deleted file mode 100644 index 6ce3ea7d72a4..000000000000 --- a/noir/noir-repo/noir_stdlib/src/bigint.nr +++ /dev/null @@ -1,406 +0,0 @@ -use crate::cmp::Eq; -use crate::ops::{Add, Div, Mul, Sub}; - -global bn254_fq: [u8] = &[ - 0x47, 0xFD, 0x7C, 0xD8, 0x16, 0x8C, 0x20, 0x3C, 0x8d, 0xca, 0x71, 0x68, 0x91, 0x6a, 0x81, 0x97, - 0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, 0x64, 0x30, -]; -global bn254_fr: [u8] = &[ - 1, 0, 0, 240, 147, 245, 225, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, 182, 69, - 80, 184, 41, 160, 49, 225, 114, 78, 100, 48, -]; -global secpk1_fr: [u8] = &[ - 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA, - 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, -]; -global secpk1_fq: [u8] = &[ - 0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, -]; -global secpr1_fq: [u8] = &[ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, -]; -global secpr1_fr: [u8] = &[ - 81, 37, 99, 252, 194, 202, 185, 243, 132, 158, 23, 167, 173, 250, 230, 188, 255, 255, 255, 255, - 255, 255, 255, 255, 0, 0, 0, 0, 255, 255, 255, 255, -]; -// docs:start:big_int_definition -pub struct BigInt { - pointer: u32, - modulus: u32, -} -// docs:end:big_int_definition - -impl BigInt { - #[foreign(bigint_add)] - fn bigint_add(self, other: BigInt) -> BigInt {} - #[foreign(bigint_sub)] - fn bigint_sub(self, other: BigInt) -> BigInt {} - #[foreign(bigint_mul)] - fn bigint_mul(self, other: BigInt) -> BigInt {} - #[foreign(bigint_div)] - fn bigint_div(self, other: BigInt) -> BigInt {} - #[foreign(bigint_from_le_bytes)] - fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt {} - #[foreign(bigint_to_le_bytes)] - fn to_le_bytes(self) -> [u8; 32] {} - - fn check_32_bytes(self: Self, other: BigInt) -> bool { - let bytes = self.to_le_bytes(); - let o_bytes = other.to_le_bytes(); - let mut result = true; - for i in 0..32 { - result = result & (bytes[i] == o_bytes[i]); - } - result - } -} - -pub trait BigField { - fn from_le_bytes(bytes: [u8]) -> Self; - fn from_le_bytes_32(bytes: [u8; 32]) -> Self; - fn to_le_bytes(self) -> [u8]; -} - -pub struct Secpk1Fq { - array: [u8; 32], -} - -impl BigField for Secpk1Fq { - fn from_le_bytes(bytes: [u8]) -> Secpk1Fq { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Secpk1Fq { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Secpk1Fq { - Secpk1Fq { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Secpk1Fq { - fn add(self: Self, other: Secpk1Fq) -> Secpk1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fq); - Secpk1Fq { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Secpk1Fq { - fn sub(self: Self, other: Secpk1Fq) -> Secpk1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fq); - Secpk1Fq { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Secpk1Fq { - fn mul(self: Self, other: Secpk1Fq) -> Secpk1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fq); - Secpk1Fq { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Secpk1Fq { - fn div(self: Self, other: Secpk1Fq) -> Secpk1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fq); - Secpk1Fq { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Secpk1Fq { - fn eq(self: Self, other: Secpk1Fq) -> bool { - self.array == other.array - } -} - -pub struct Secpk1Fr { - array: [u8; 32], -} - -impl BigField for Secpk1Fr { - fn from_le_bytes(bytes: [u8]) -> Secpk1Fr { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Secpk1Fr { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Secpk1Fr { - Secpk1Fr { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Secpk1Fr { - fn add(self: Self, other: Secpk1Fr) -> Secpk1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fr); - Secpk1Fr { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Secpk1Fr { - fn sub(self: Self, other: Secpk1Fr) -> Secpk1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fr); - Secpk1Fr { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Secpk1Fr { - fn mul(self: Self, other: Secpk1Fr) -> Secpk1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fr); - Secpk1Fr { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Secpk1Fr { - fn div(self: Self, other: Secpk1Fr) -> Secpk1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpk1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpk1_fr); - Secpk1Fr { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Secpk1Fr { - fn eq(self: Self, other: Secpk1Fr) -> bool { - self.array == other.array - } -} - -pub struct Bn254Fr { - array: [u8; 32], -} - -impl BigField for Bn254Fr { - fn from_le_bytes(bytes: [u8]) -> Bn254Fr { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Bn254Fr { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Bn254Fr { - Bn254Fr { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Bn254Fr { - fn add(self: Self, other: Bn254Fr) -> Bn254Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fr); - Bn254Fr { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Bn254Fr { - fn sub(self: Self, other: Bn254Fr) -> Bn254Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fr); - Bn254Fr { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Bn254Fr { - fn mul(self: Self, other: Bn254Fr) -> Bn254Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fr); - Bn254Fr { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Bn254Fr { - fn div(self: Self, other: Bn254Fr) -> Bn254Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fr); - Bn254Fr { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Bn254Fr { - fn eq(self: Self, other: Bn254Fr) -> bool { - self.array == other.array - } -} - -pub struct Bn254Fq { - array: [u8; 32], -} - -impl BigField for Bn254Fq { - fn from_le_bytes(bytes: [u8]) -> Bn254Fq { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Bn254Fq { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Bn254Fq { - Bn254Fq { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Bn254Fq { - fn add(self: Self, other: Bn254Fq) -> Bn254Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fq); - Bn254Fq { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Bn254Fq { - fn sub(self: Self, other: Bn254Fq) -> Bn254Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fq); - Bn254Fq { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Bn254Fq { - fn mul(self: Self, other: Bn254Fq) -> Bn254Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fq); - Bn254Fq { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Bn254Fq { - fn div(self: Self, other: Bn254Fq) -> Bn254Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), bn254_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), bn254_fq); - Bn254Fq { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Bn254Fq { - fn eq(self: Self, other: Bn254Fq) -> bool { - self.array == other.array - } -} - -pub struct Secpr1Fq { - array: [u8; 32], -} - -impl BigField for Secpr1Fq { - fn from_le_bytes(bytes: [u8]) -> Secpr1Fq { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Secpr1Fq { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Secpr1Fq { - Secpr1Fq { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Secpr1Fq { - fn add(self: Self, other: Secpr1Fq) -> Secpr1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fq); - Secpr1Fq { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Secpr1Fq { - fn sub(self: Self, other: Secpr1Fq) -> Secpr1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fq); - Secpr1Fq { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Secpr1Fq { - fn mul(self: Self, other: Secpr1Fq) -> Secpr1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fq); - Secpr1Fq { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Secpr1Fq { - fn div(self: Self, other: Secpr1Fq) -> Secpr1Fq { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fq); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fq); - Secpr1Fq { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Secpr1Fq { - fn eq(self: Self, other: Secpr1Fq) -> bool { - self.array == other.array - } -} - -pub struct Secpr1Fr { - array: [u8; 32], -} - -impl BigField for Secpr1Fr { - fn from_le_bytes(bytes: [u8]) -> Secpr1Fr { - assert(bytes.len() <= 32); - let mut array = [0; 32]; - for i in 0..bytes.len() { - array[i] = bytes[i]; - } - Secpr1Fr { array } - } - - fn from_le_bytes_32(bytes: [u8; 32]) -> Secpr1Fr { - Secpr1Fr { array: bytes } - } - - fn to_le_bytes(self) -> [u8] { - self.array - } -} - -impl Add for Secpr1Fr { - fn add(self: Self, other: Secpr1Fr) -> Secpr1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fr); - Secpr1Fr { array: a.bigint_add(b).to_le_bytes() } - } -} -impl Sub for Secpr1Fr { - fn sub(self: Self, other: Secpr1Fr) -> Secpr1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fr); - Secpr1Fr { array: a.bigint_sub(b).to_le_bytes() } - } -} -impl Mul for Secpr1Fr { - fn mul(self: Self, other: Secpr1Fr) -> Secpr1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fr); - Secpr1Fr { array: a.bigint_mul(b).to_le_bytes() } - } -} -impl Div for Secpr1Fr { - fn div(self: Self, other: Secpr1Fr) -> Secpr1Fr { - let a = BigInt::from_le_bytes(self.array.as_slice(), secpr1_fr); - let b = BigInt::from_le_bytes(other.array.as_slice(), secpr1_fr); - Secpr1Fr { array: a.bigint_div(b).to_le_bytes() } - } -} -impl Eq for Secpr1Fr { - fn eq(self: Self, other: Secpr1Fr) -> bool { - self.array == other.array - } -} diff --git a/noir/noir-repo/noir_stdlib/src/lib.nr b/noir/noir-repo/noir_stdlib/src/lib.nr index 4074e5e6920f..d5c360792d91 100644 --- a/noir/noir-repo/noir_stdlib/src/lib.nr +++ b/noir/noir-repo/noir_stdlib/src/lib.nr @@ -20,7 +20,6 @@ pub mod ops; pub mod default; pub mod prelude; pub mod uint128; -pub mod bigint; pub mod runtime; pub mod meta; pub mod append; diff --git a/noir/noir-repo/noir_stdlib/src/merkle.nr b/noir/noir-repo/noir_stdlib/src/merkle.nr index f6806874444a..34cfcdb17877 100644 --- a/noir/noir-repo/noir_stdlib/src/merkle.nr +++ b/noir/noir-repo/noir_stdlib/src/merkle.nr @@ -2,6 +2,7 @@ // Currently we assume that it is a binary tree, so depth k implies a width of 2^k // XXX: In the future we can add an arity parameter // Returns the merkle root of the tree from the provided leaf, its hashpath, using a pedersen hash function. +#[deprecated("This function will be removed from the stdlib in version 1.0.0-beta.4")] pub fn compute_merkle_root(leaf: Field, index: Field, hash_path: [Field; N]) -> Field { let index_bits: [u1; N] = index.to_le_bits(); let mut current = leaf; diff --git a/noir/noir-repo/scripts/bump-aztec-packges-commit.sh b/noir/noir-repo/scripts/bump-aztec-packges-commit.sh new file mode 100755 index 000000000000..f3a9f5652d26 --- /dev/null +++ b/noir/noir-repo/scripts/bump-aztec-packges-commit.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +AZTEC_COMMIT=$(git ls-remote https://github.com/AztecProtocol/aztec-packages.git HEAD | grep -oE '^\b[0-9a-f]{40}\b') + +function bump_commit() { + FILE=$1 + AZTEC_COMMIT=$AZTEC_COMMIT yq -i '.define = env(AZTEC_COMMIT)' $FILE + +} + +bump_commit ./EXTERNAL_NOIR_LIBRARIES.yml +bump_commit ./.github/benchmark_projects.yml diff --git a/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr index 03a64d57dcf6..03fc10a1c35a 100644 --- a/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr @@ -1,4 +1,35 @@ fn main() { + primitive_tests(); + foo_tests(); + option_tests(); + abc_tests(); +} + +fn primitive_tests() { + let x: i32 = -2; + match x { + -3 => fail(), + -2 => (), + 0 => fail(), + 2 => fail(), + _ => fail(), + } + + match true { + false => fail(), + true => (), + } +} + +enum Foo { + A(Field, Field), + B(u32), + C(T), + D(), + E, +} + +fn foo_tests() { let _a = Foo::A::(1, 2); let _b: Foo = Foo::B(3); let _c = Foo::C(4); @@ -10,12 +41,76 @@ fn main() { // Enum variants are functions and can be passed around as such let _many_cs = [1, 2, 3].map(Foo::C); + + match _b { + Foo::C(_) => fail(), + Foo::B(x) => { assert_eq(x, 3); }, + _ => fail(), + } + + match _c { + Foo::A(1, _) => fail(), + Foo::E => fail(), + Foo::C(4) => (), + Foo::C(_) => fail(), + _ => fail(), + } } -enum Foo { - A(Field, Field), - B(u32), - C(T), - D(), +fn fail() { + assert(false); +} + +enum MyOption { + None, + Maybe, + Some(T), +} + +fn option_tests() { + let opt = MyOption::Some(ABC::C); + match opt { + MyOption::Some(ABC::D) => fail(), + MyOption::Some(x) => { assert_eq(x, ABC::C); }, + _ => (), + } +} + +enum ABC { + A, + B, + C, + D, E, + F, +} + +impl Eq for ABC { + fn eq(self, other: ABC) -> bool { + match (self, other) { + (ABC::A, ABC::A) => true, + (ABC::B, ABC::B) => true, + (ABC::C, ABC::C) => true, + (ABC::D, ABC::D) => true, + (ABC::E, ABC::E) => true, + (ABC::F, ABC::F) => true, + _ => false, + } + } +} + +fn abc_tests() { + // Mut is only to throw the optimizer off a bit so we can see + // the `eq`s that get generated before they're removed because each of these are constant + let mut tuple = (ABC::A, ABC::B); + match tuple { + (ABC::A, _) => 1, + (_, ABC::A) => 2, + (_, ABC::B) => 3, + (_, ABC::C) => 4, + (_, ABC::D) => 5, + (ABC::B, ABC::E) => 6, + (ABC::C, ABC::F) => 7, + _ => 0, + }; } diff --git a/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Nargo.toml new file mode 100644 index 000000000000..8651c5577a25 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "unquote_trait_constraint_in_trait_impl_position" +version = "0.1.0" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Prover.toml b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Prover.toml new file mode 100644 index 000000000000..9bff601c75a3 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/Prover.toml @@ -0,0 +1,3 @@ +x = "3" +y = "4" +z = "429981696" diff --git a/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/src/main.nr new file mode 100644 index 000000000000..92ddf99f1c0b --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/unquote_trait_constraint_in_trait_impl_position/src/main.nr @@ -0,0 +1,24 @@ +struct Struct {} + +trait Trait { + fn method(self) -> T; +} + +fn main() { + let st = Struct {}; + assert_eq(st.method(), 1); +} + +#[foo] +comptime fn foo(_: FunctionDefinition) -> Quoted { + let tr = quote { Trait }.as_trait_constraint(); + let st = quote { Struct }.as_type(); + quote { + impl $tr for $st { + fn method(self) -> Field { + 1 + } + } + } +} + diff --git a/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml b/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml deleted file mode 100644 index cbdfc2d83d91..000000000000 --- a/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "bigint_from_too_many_le_bytes" -type = "bin" -authors = [""] -compiler_version = ">=0.31.0" - -[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr b/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr deleted file mode 100644 index 2d4587ee3d9b..000000000000 --- a/noir/noir-repo/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr +++ /dev/null @@ -1,22 +0,0 @@ -use std::bigint::{bn254_fq, BigInt}; - -// TODO(https://github.com/noir-lang/noir/issues/5580): decide whether this is desired behavior -// -// Fails at execution time: -// -// error: Assertion failed: 'Index out of bounds' -// ┌─ std/cmp.nr:35:34 -// │ -// 35 │ result &= self[i].eq(other[i]); -// │ -------- -// │ -// = Call stack: -// 1. /Users/michaelklein/Coding/rust/noir/test_programs/compile_failure/bigint_from_too_many_le_bytes/src/main.nr:7:12 -// 2. std/cmp.nr:35:34 -// Failed assertion -fn main() { - let bytes: [u8] = bn254_fq.push_front(0x00); - let bigint = BigInt::from_le_bytes(bytes, bn254_fq); - let result_bytes = bigint.to_le_bytes(); - assert(bytes == result_bytes.as_slice()); -} diff --git a/noir/noir-repo/test_programs/execution_success/bigint/Prover.toml b/noir/noir-repo/test_programs/execution_success/bigint/Prover.toml deleted file mode 100644 index c50874a86134..000000000000 --- a/noir/noir-repo/test_programs/execution_success/bigint/Prover.toml +++ /dev/null @@ -1,2 +0,0 @@ -x = [34,3,5,8,4] -y = [44,7,1,8,8] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/bigint/src/main.nr b/noir/noir-repo/test_programs/execution_success/bigint/src/main.nr deleted file mode 100644 index 2ccb446b8b4f..000000000000 --- a/noir/noir-repo/test_programs/execution_success/bigint/src/main.nr +++ /dev/null @@ -1,83 +0,0 @@ -use std::bigint; -use std::{bigint::Secpk1Fq, println}; - -fn main(mut x: [u8; 5], y: [u8; 5]) { - let a = bigint::Secpk1Fq::from_le_bytes(&[x[0], x[1], x[2], x[3], x[4]]); - let b = bigint::Secpk1Fq::from_le_bytes(&[y[0], y[1], y[2], y[3], y[4]]); - let mut a_be_bytes = [0; 32]; - let mut b_be_bytes = [0; 32]; - for i in 0..5 { - a_be_bytes[31 - i] = x[i]; - b_be_bytes[31 - i] = y[i]; - } - let a_field = std::field::bytes32_to_field(a_be_bytes); - let b_field = std::field::bytes32_to_field(b_be_bytes); - - // Regression for issue #4682 - let c = if x[0] != 0 { - test_unconstrained1(a, b) - } else { - // Safety: testing context - unsafe { - test_unconstrained2(a, b) - } - }; - assert(c.array[0] == std::wrapping_mul(x[0], y[0])); - - let a_bytes = a.to_le_bytes(); - let b_bytes = b.to_le_bytes(); - for i in 0..5 { - assert(a_bytes[i] == x[i]); - assert(b_bytes[i] == y[i]); - } - // Regression for issue #4578 - let d = a * b; - assert(d / b == a); - - let d = d - b; - let mut result = [0; 32]; - let result_slice: [u8; 32] = (a_field * b_field - b_field).to_le_bytes(); - for i in 0..32 { - result[i] = result_slice[i]; - } - let d1 = bigint::Secpk1Fq::from_le_bytes_32(result); - assert(d1 == d); - big_int_example(x[0], x[1]); - - // Regression for issue #4882 - let num_b: [u8; 32] = [ - 0, 0, 0, 240, 147, 245, 225, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, 182, - 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48, - ]; - let num2_b: [u8; 7] = [126, 193, 45, 39, 188, 84, 11]; - let num = bigint::Bn254Fr::from_le_bytes(num_b.as_slice()); - let num2 = bigint::Bn254Fr::from_le_bytes(num2_b.as_slice()); - - let ret_b: [u8; 32] = [ - 131, 62, 210, 200, 215, 160, 214, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, - 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48, - ]; - let ret = bigint::Bn254Fr::from_le_bytes(ret_b.as_slice()); - assert(ret == num.mul(num2)); - let div = num.div(num2); - assert(div.mul(num2) == num); -} - -fn test_unconstrained1(a: Secpk1Fq, b: Secpk1Fq) -> Secpk1Fq { - let c = a * b; - c -} -unconstrained fn test_unconstrained2(a: Secpk1Fq, b: Secpk1Fq) -> Secpk1Fq { - let c = a + b; - test_unconstrained1(a, c) -} - -// docs:start:big_int_example -fn big_int_example(x: u8, y: u8) { - let a = Secpk1Fq::from_le_bytes(&[x, y, 0, 45, 2]); - let b = Secpk1Fq::from_le_bytes(&[y, x, 9]); - let c = (a + b) * b / a; - let d = c.to_le_bytes(); - println(d[0]); -} -// docs:end:big_int_example diff --git a/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Nargo.toml b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Nargo.toml new file mode 100644 index 000000000000..2327a6e10970 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "global_var_entry_point_used_in_another_entry" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Prover.toml b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Prover.toml new file mode 100644 index 000000000000..4c144083f005 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/Prover.toml @@ -0,0 +1,2 @@ +x = 0 +y = 1 diff --git a/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/src/main.nr b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/src/main.nr new file mode 100644 index 000000000000..364c27cc4ab7 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_entry_point_used_in_another_entry/src/main.nr @@ -0,0 +1,44 @@ +global ONE: Field = 1; +global TWO: Field = 2; +global THREE: Field = 3; + +fn main(x: Field, y: pub Field) { + // Safety: testing context + unsafe { + entry_point_no_global(x, y); + entry_point_inner_func_globals(x, y); + entry_point_one_global(x, y); + entry_point_one_diff_global(x, y); + } +} + +unconstrained fn entry_point_no_global(x: Field, y: Field) { + assert(x + y != 100); +} + +unconstrained fn entry_point_one_global(x: Field, y: Field) { + let z = TWO + x + y; + assert(z == 3); +} + +unconstrained fn entry_point_inner_func_globals(x: Field, y: Field) { + wrapper(x, y); +} + +// Test that we duplicate Brillig entry points called within +// another entry point's inner calls +unconstrained fn wrapper(x: Field, y: Field) { + let z = ONE + x + y; + assert(z == 2); + entry_point_one_global(x, y); + // Test that we handle repeated entry point calls + // `entry_point_one_diff_global` should be duplicated and the duplicated function + // should use the globals from `entry_point_inner_func_globals` + entry_point_one_diff_global(y, x); +} + +unconstrained fn entry_point_one_diff_global(x: Field, y: Field) { + let z = THREE + x + y; + assert(z == 4); +} + diff --git a/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Nargo.toml b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Nargo.toml new file mode 100644 index 000000000000..4d2bd864e8bb --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "global_var_func_with_multiple_entry_points" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Prover.toml b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Prover.toml new file mode 100644 index 000000000000..4c144083f005 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/Prover.toml @@ -0,0 +1,2 @@ +x = 0 +y = 1 diff --git a/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/src/main.nr b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/src/main.nr new file mode 100644 index 000000000000..6ae68e21e793 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_func_with_multiple_entry_points/src/main.nr @@ -0,0 +1,28 @@ +global ONE: Field = 1; +global TWO: Field = 2; +global THREE: Field = 3; + +fn main(x: Field, y: pub Field) { + // Safety: testing context + unsafe { + entry_point_one(x, y); + entry_point_two(x, y); + } +} + +unconstrained fn entry_point_one(x: Field, y: Field) { + let z = ONE + x + y; + assert(z == 2); + inner_func(x, y); +} + +unconstrained fn entry_point_two(x: Field, y: Field) { + let z = TWO + x + y; + assert(z == 3); + inner_func(x, y); +} + +unconstrained fn inner_func(x: Field, y: Field) { + let z = THREE + x + y; + assert(z == 4); +} diff --git a/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Nargo.toml b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Nargo.toml new file mode 100644 index 000000000000..e7ef4a6f36d5 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "global_var_multiple_entry_points_nested" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Prover.toml b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Prover.toml new file mode 100644 index 000000000000..4c144083f005 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/Prover.toml @@ -0,0 +1,2 @@ +x = 0 +y = 1 diff --git a/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/src/main.nr b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/src/main.nr new file mode 100644 index 000000000000..57e6b872ac32 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/global_var_multiple_entry_points_nested/src/main.nr @@ -0,0 +1,34 @@ +global TWO: Field = 2; +global THREE: Field = 3; + +fn main(x: Field, y: pub Field) { + // Safety: testing context + unsafe { + entry_point_one(x, y); + entry_point_two(x, y); + } +} + +unconstrained fn entry_point_one(x: Field, y: Field) { + let z = TWO + x + y; + assert(z == 3); + inner_func(x, y); +} + +// Identical to `entry_point_one` +unconstrained fn entry_point_two(x: Field, y: Field) { + let z = TWO + x + y; + assert(z == 3); + inner_func(x, y); +} + +unconstrained fn inner_func(x: Field, y: Field) { + let z = TWO + x + y; + assert(z == 3); + nested_inner_func(x, y); +} + +unconstrained fn nested_inner_func(x: Field, y: Field) { + let z = THREE + x + y; + assert(z == 4); +} diff --git a/noir/noir-repo/test_programs/execution_success/global_var_regression_entry_points/src/consts.nr b/noir/noir-repo/test_programs/execution_success/global_var_regression_entry_points/src/consts.nr index 7ad6a4a54d13..43b580cf7817 100644 --- a/noir/noir-repo/test_programs/execution_success/global_var_regression_entry_points/src/consts.nr +++ b/noir/noir-repo/test_programs/execution_success/global_var_regression_entry_points/src/consts.nr @@ -1,4 +1,4 @@ -global EXPONENTIATE: [[Field; 257]; 257] = [ +pub(crate) global EXPONENTIATE: [[Field; 257]; 257] = [ [ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/noir/noir-repo/test_programs/execution_success/global_var_regression_simple/src/main.nr b/noir/noir-repo/test_programs/execution_success/global_var_regression_simple/src/main.nr index b1bf753a73c2..8af1879b0425 100644 --- a/noir/noir-repo/test_programs/execution_success/global_var_regression_simple/src/main.nr +++ b/noir/noir-repo/test_programs/execution_success/global_var_regression_simple/src/main.nr @@ -20,6 +20,6 @@ fn dummy_again(x: Field, y: Field) { acc += EXPONENTIATE[i][j]; } } - assert(!acc.lt(x)); assert(x != y); + assert(!acc.lt(x)); } diff --git a/noir/noir-repo/test_programs/execution_success/loop_invariant_regression/src/main.nr b/noir/noir-repo/test_programs/execution_success/loop_invariant_regression/src/main.nr index 61f8b1bedba4..d126b7729c12 100644 --- a/noir/noir-repo/test_programs/execution_success/loop_invariant_regression/src/main.nr +++ b/noir/noir-repo/test_programs/execution_success/loop_invariant_regression/src/main.nr @@ -1,11 +1,14 @@ // Tests a simple loop where we expect loop invariant instructions // to be hoisted to the loop's pre-header block. +global U32_MAX: u32 = 4294967295; + fn main(x: u32, y: u32) { - loop_(4, x, y); + simple_loop(4, x, y); + loop_with_predicate(4, x, y); array_read_loop(4, x); } -fn loop_(upper_bound: u32, x: u32, y: u32) { +fn simple_loop(upper_bound: u32, x: u32, y: u32) { for _ in 0..upper_bound { let mut z = x * y; z = z * x; @@ -13,6 +16,15 @@ fn loop_(upper_bound: u32, x: u32, y: u32) { } } +fn loop_with_predicate(upper_bound: u32, x: u32, y: u32) { + for _ in 0..upper_bound { + if x == 5 { + let mut z = U32_MAX * y; + assert_eq(z, 12); + } + } +} + fn array_read_loop(upper_bound: u32, x: u32) { let arr = [2; 5]; for i in 0..upper_bound { diff --git a/noir/noir-repo/test_programs/execution_success/bigint/Nargo.toml b/noir/noir-repo/test_programs/execution_success/while_keyword/Nargo.toml similarity index 69% rename from noir/noir-repo/test_programs/execution_success/bigint/Nargo.toml rename to noir/noir-repo/test_programs/execution_success/while_keyword/Nargo.toml index eee0920f1884..9331fbb4c346 100644 --- a/noir/noir-repo/test_programs/execution_success/bigint/Nargo.toml +++ b/noir/noir-repo/test_programs/execution_success/while_keyword/Nargo.toml @@ -1,6 +1,5 @@ [package] -name = "bigint" +name = "while_keyword" type = "bin" authors = [""] - [dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/while_keyword/src/main.nr b/noir/noir-repo/test_programs/execution_success/while_keyword/src/main.nr new file mode 100644 index 000000000000..92c2cc4f0e0b --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/while_keyword/src/main.nr @@ -0,0 +1,44 @@ +fn main() { + // Safety: test code + unsafe { + check_while(); + } + + check_comptime_while(); +} + +unconstrained fn check_while() { + let mut i = 0; + let mut sum = 0; + + while i < 4 { + if i == 2 { + i += 1; + continue; + } + + sum += i; + i += 1; + } + + assert_eq(sum, 1 + 3); +} + +fn check_comptime_while() { + comptime { + let mut i = 0; + let mut sum = 0; + + while i < 4 { + if i == 2 { + i += 1; + continue; + } + + sum += i; + i += 1; + } + + assert_eq(sum, 1 + 3); + } +} diff --git a/noir/noir-repo/test_programs/noir_test_success/comptime_blackbox/src/main.nr b/noir/noir-repo/test_programs/noir_test_success/comptime_blackbox/src/main.nr index c5ca59c7afce..446692c485bb 100644 --- a/noir/noir-repo/test_programs/noir_test_success/comptime_blackbox/src/main.nr +++ b/noir/noir-repo/test_programs/noir_test_success/comptime_blackbox/src/main.nr @@ -1,33 +1,6 @@ //! Tests to show that the comptime interpreter implement blackbox functions. -use std::bigint; use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; -/// Test that all bigint operations work in comptime. -#[test] -fn test_bigint() { - let result: [u8] = comptime { - let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); - let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); - let c = (a + b) * b / a - a; - c.to_le_bytes() - }; - // Do the same calculation outside comptime. - let a = bigint::Secpk1Fq::from_le_bytes(&[0, 1, 2, 3, 4]); - let b = bigint::Secpk1Fq::from_le_bytes(&[5, 6, 7, 8, 9]); - let c = bigint::Secpk1Fq::from_le_bytes(result); - assert_eq(c, (a + b) * b / a - a); -} - -/// Test that to_le_radix returns an array. -#[test] -fn test_to_le_radix() { - comptime { - let field = 2; - let bytes: [u8; 8] = field.to_le_radix(256); - let _num = bigint::BigInt::from_le_bytes(bytes, bigint::bn254_fq); - }; -} - #[test] fn test_bitshift() { let c = comptime { diff --git a/noir/noir-repo/tooling/lsp/Cargo.toml b/noir/noir-repo/tooling/lsp/Cargo.toml index 65b59552b32f..a055a9a6bcec 100644 --- a/noir/noir-repo/tooling/lsp/Cargo.toml +++ b/noir/noir-repo/tooling/lsp/Cargo.toml @@ -32,6 +32,7 @@ fm.workspace = true rayon.workspace = true fxhash.workspace = true convert_case = "0.6.0" +num-bigint.workspace = true [target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dependencies] wasm-bindgen.workspace = true diff --git a/noir/noir-repo/tooling/lsp/src/modules.rs b/noir/noir-repo/tooling/lsp/src/modules.rs index 758322fa4bcc..56529949b3ea 100644 --- a/noir/noir-repo/tooling/lsp/src/modules.rs +++ b/noir/noir-repo/tooling/lsp/src/modules.rs @@ -1,9 +1,14 @@ +use std::collections::BTreeMap; + use noirc_frontend::{ + ast::{Ident, ItemVisibility}, graph::{CrateId, Dependency}, - hir::def_map::{ModuleDefId, ModuleId}, - node_interner::{NodeInterner, ReferenceId}, + hir::def_map::{CrateDefMap, ModuleDefId, ModuleId}, + node_interner::{NodeInterner, Reexport, ReferenceId}, }; +use crate::visibility::module_def_id_is_visible; + pub(crate) fn get_parent_module( interner: &NodeInterner, module_def_id: ModuleDefId, @@ -36,7 +41,7 @@ pub(crate) fn relative_module_full_path( if let ModuleDefId::ModuleId(module_id) = module_def_id { full_path = relative_module_id_path( module_id, - ¤t_module_id, + current_module_id, current_module_parent_id, interner, ); @@ -45,7 +50,7 @@ pub(crate) fn relative_module_full_path( full_path = relative_module_id_path( parent_module, - ¤t_module_id, + current_module_id, current_module_parent_id, interner, ); @@ -57,7 +62,7 @@ pub(crate) fn relative_module_full_path( /// Returns a relative path if possible. pub(crate) fn relative_module_id_path( target_module_id: ModuleId, - current_module_id: &ModuleId, + current_module_id: ModuleId, current_module_parent_id: Option, interner: &NodeInterner, ) -> String { @@ -80,7 +85,7 @@ pub(crate) fn relative_module_id_path( let parent_module_id = &ModuleId { krate: target_module_id.krate, local_id: parent_local_id }; - if current_module_id == parent_module_id { + if current_module_id == *parent_module_id { is_relative = true; break; } @@ -160,3 +165,109 @@ pub(crate) fn module_full_path( segments.reverse(); segments.join("::") } + +/// Finds a visible reexport for any ancestor module of the given ModuleDefId, +pub(crate) fn get_ancestor_module_reexport( + module_def_id: ModuleDefId, + visibility: ItemVisibility, + current_module_id: ModuleId, + interner: &NodeInterner, + def_maps: &BTreeMap, + dependencies: &[Dependency], +) -> Option { + let parent_module = get_parent_module(interner, module_def_id)?; + let reexport = + interner.get_reexports(ModuleDefId::ModuleId(parent_module)).iter().find(|reexport| { + module_def_id_is_visible( + ModuleDefId::ModuleId(reexport.module_id), + current_module_id, + reexport.visibility, + None, + interner, + def_maps, + dependencies, + ) + }); + if let Some(reexport) = reexport { + return Some(reexport.clone()); + } + + // Try searching in the parent's parent module. + let mut grandparent_module_reexport = get_ancestor_module_reexport( + ModuleDefId::ModuleId(parent_module), + visibility, + current_module_id, + interner, + def_maps, + dependencies, + )?; + + // If we can find one, we need to check if ModuleDefId is actually visible from the grandparent module + if !module_def_id_is_visible( + module_def_id, + current_module_id, + visibility, + Some(grandparent_module_reexport.module_id), + interner, + def_maps, + dependencies, + ) { + return None; + } + + // If we can find one we need to adjust the exported name a bit. + let parent_module_name = &interner.try_module_attributes(&parent_module)?.name; + grandparent_module_reexport.name.0.contents = + format!("{}::{}", grandparent_module_reexport.name.0.contents, parent_module_name); + + Some(grandparent_module_reexport) +} + +/// Returns the relative path to reach `module_def_id` named `name` starting from `current_module_id`. +/// +/// - `defining_module` might be `Some` if the item is reexported from another module +/// - `intermediate_name` might be `Some` if the item's parent module is reexport from another module +/// (this will be the name of the reexport) +/// +/// Returns `None` if `module_def_id` isn't visible from the current module, neither directly, neither via +/// any of its reexports (or parent module reexports). +pub(crate) fn module_def_id_relative_path( + module_def_id: ModuleDefId, + name: &str, + current_module_id: ModuleId, + current_module_parent_id: Option, + defining_module: Option, + intermediate_name: &Option, + interner: &NodeInterner, +) -> Option { + let module_path = if let Some(defining_module) = defining_module { + relative_module_id_path( + defining_module, + current_module_id, + current_module_parent_id, + interner, + ) + } else { + let Some(module_full_path) = relative_module_full_path( + module_def_id, + current_module_id, + current_module_parent_id, + interner, + ) else { + return None; + }; + module_full_path + }; + + let path = if defining_module.is_some() || !matches!(module_def_id, ModuleDefId::ModuleId(..)) { + if let Some(reexport_name) = &intermediate_name { + format!("{}::{}::{}", module_path, reexport_name, name) + } else { + format!("{}::{}", module_path, name) + } + } else { + module_path.clone() + }; + + Some(path) +} diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_action.rs b/noir/noir-repo/tooling/lsp/src/requests/code_action.rs index 24ed327393de..d5512855b3be 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_action.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_action.rs @@ -16,9 +16,9 @@ use noirc_frontend::{ CallExpression, ConstructorExpression, ItemVisibility, MethodCallExpression, NoirTraitImpl, Path, UseTree, Visitor, }, - graph::CrateId, - hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}, - node_interner::NodeInterner, + graph::{CrateId, Dependency}, + hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId}, + node_interner::{NodeInterner, Reexport}, usage_tracker::UsageTracker, }; use noirc_frontend::{ @@ -26,7 +26,10 @@ use noirc_frontend::{ ParsedModule, }; -use crate::{use_segment_positions::UseSegmentPositions, utils, LspState}; +use crate::{ + modules::get_ancestor_module_reexport, use_segment_positions::UseSegmentPositions, utils, + visibility::module_def_id_is_visible, LspState, +}; use super::{process_request, to_lsp_location}; @@ -63,6 +66,7 @@ pub(crate) fn on_code_action_request( byte_range, args.crate_id, args.def_maps, + args.dependencies, args.interner, args.usage_tracker, ); @@ -84,6 +88,7 @@ struct CodeActionFinder<'a> { /// if we are analyzing something inside an inline module declaration. module_id: ModuleId, def_maps: &'a BTreeMap, + dependencies: &'a Vec, interner: &'a NodeInterner, usage_tracker: &'a UsageTracker, /// How many nested `mod` we are in deep @@ -106,6 +111,7 @@ impl<'a> CodeActionFinder<'a> { byte_range: Range, krate: CrateId, def_maps: &'a BTreeMap, + dependencies: &'a Vec, interner: &'a NodeInterner, usage_tracker: &'a UsageTracker, ) -> Self { @@ -128,6 +134,7 @@ impl<'a> CodeActionFinder<'a> { byte_range, module_id, def_maps, + dependencies, interner, usage_tracker, nesting: 0, @@ -188,6 +195,38 @@ impl<'a> CodeActionFinder<'a> { } } + fn module_def_id_is_visible( + &self, + module_def_id: ModuleDefId, + visibility: ItemVisibility, + defining_module: Option, + ) -> bool { + module_def_id_is_visible( + module_def_id, + self.module_id, + visibility, + defining_module, + self.interner, + self.def_maps, + self.dependencies, + ) + } + + fn get_ancestor_module_reexport( + &self, + module_def_id: ModuleDefId, + visibility: ItemVisibility, + ) -> Option { + get_ancestor_module_reexport( + module_def_id, + visibility, + self.module_id, + self.interner, + self.def_maps, + self.dependencies, + ) + } + fn includes_span(&self, span: Span) -> bool { let byte_range_span = Span::from(self.byte_range.start as u32..self.byte_range.end as u32); span.intersects(&byte_range_span) diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_action/implement_missing_members.rs b/noir/noir-repo/tooling/lsp/src/requests/code_action/implement_missing_members.rs index 1cd181966a23..c29caf79848a 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_action/implement_missing_members.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_action/implement_missing_members.rs @@ -21,7 +21,11 @@ impl<'a> CodeActionFinder<'a> { return; } - let location = Location::new(noir_trait_impl.trait_name.span(), self.file); + let UnresolvedTypeData::Named(trait_name, _, _) = &noir_trait_impl.r#trait.typ else { + return; + }; + + let location = Location::new(trait_name.span(), self.file); let Some(ReferenceId::Trait(trait_id)) = self.interner.find_referenced(location) else { return; }; diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_action/import_or_qualify.rs b/noir/noir-repo/tooling/lsp/src/requests/code_action/import_or_qualify.rs index 609a81bdfe7b..1141aca23d2d 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_action/import_or_qualify.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_action/import_or_qualify.rs @@ -1,17 +1,13 @@ use lsp_types::TextEdit; use noirc_errors::Location; -use noirc_frontend::{ - ast::{Ident, Path}, - hir::def_map::ModuleDefId, -}; +use noirc_frontend::ast::{Ident, Path}; use crate::{ byte_span_to_range, - modules::{relative_module_full_path, relative_module_id_path}, + modules::module_def_id_relative_path, use_segment_positions::{ use_completion_item_additional_text_edits, UseCompletionItemAdditionTextEditsRequest, }, - visibility::module_def_id_is_visible, }; use super::CodeActionFinder; @@ -41,55 +37,41 @@ impl<'a> CodeActionFinder<'a> { continue; } - for (module_def_id, visibility, defining_module) in entries { - if !module_def_id_is_visible( - *module_def_id, + for entry in entries { + let module_def_id = entry.module_def_id; + let visibility = entry.visibility; + let mut defining_module = entry.defining_module.as_ref().cloned(); + + // If the item is offered via a re-export of it's parent module, this holds the name of the reexport. + let mut intermediate_name = None; + + let is_visible = + self.module_def_id_is_visible(module_def_id, visibility, defining_module); + if !is_visible { + if let Some(reexport) = + self.get_ancestor_module_reexport(module_def_id, visibility) + { + defining_module = Some(reexport.module_id); + intermediate_name = Some(reexport.name); + } else { + continue; + } + } + + let Some(full_path) = module_def_id_relative_path( + module_def_id, + name, self.module_id, - *visibility, - *defining_module, + current_module_parent_id, + defining_module, + &intermediate_name, self.interner, - self.def_maps, - ) { + ) else { continue; - } - - let module_full_path = if let Some(defining_module) = defining_module { - relative_module_id_path( - *defining_module, - &self.module_id, - current_module_parent_id, - self.interner, - ) - } else { - let Some(module_full_path) = relative_module_full_path( - *module_def_id, - self.module_id, - current_module_parent_id, - self.interner, - ) else { - continue; - }; - module_full_path - }; - - let full_path = if defining_module.is_some() - || !matches!(module_def_id, ModuleDefId::ModuleId(..)) - { - format!("{}::{}", module_full_path, name) - } else { - module_full_path.clone() - }; - - let qualify_prefix = if let ModuleDefId::ModuleId(..) = module_def_id { - let mut segments: Vec<_> = module_full_path.split("::").collect(); - segments.pop(); - segments.join("::") - } else { - module_full_path }; self.push_import_code_action(&full_path); - self.push_qualify_code_action(ident, &qualify_prefix, &full_path); + self.push_qualify_code_action(ident, &full_path); } } } @@ -113,7 +95,7 @@ impl<'a> CodeActionFinder<'a> { self.code_actions.push(code_action); } - fn push_qualify_code_action(&mut self, ident: &Ident, prefix: &str, full_path: &str) { + fn push_qualify_code_action(&mut self, ident: &Ident, full_path: &str) { let Some(range) = byte_span_to_range( self.files, self.file, @@ -122,6 +104,10 @@ impl<'a> CodeActionFinder<'a> { return; }; + let mut prefix = full_path.split("::").collect::>(); + prefix.pop(); + let prefix = prefix.join("::"); + let title = format!("Qualify as {}", full_path); let text_edit = TextEdit { range, new_text: format!("{}::", prefix) }; @@ -299,4 +285,76 @@ fn foo(x: SomeTypeInBar) {}"#; assert_code_action(title, src, expected).await; } + + #[test] + async fn test_import_via_reexport() { + let title = "Import aztec::protocol_types::SomeStruct"; + + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub struct SomeStruct {} + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStr>|| CodeActionFinder<'a> { let module_def_id = ModuleDefId::TraitId(trait_id); let mut trait_reexport = None; - if !module_def_id_is_visible( - module_def_id, - self.module_id, - visibility, - None, - self.interner, - self.def_maps, - ) { + // If the item is offered via a re-export of it's parent module, this holds the name of the reexport. + let mut intermediate_name = None; + + if !self.module_def_id_is_visible(module_def_id, visibility, None) { // If it's not, try to find a visible reexport of the trait // that is visible from the current module - let Some((visible_module_id, name, _)) = - self.interner.get_trait_reexports(trait_id).iter().find( - |(module_id, _, visibility)| { - module_def_id_is_visible( - module_def_id, - self.module_id, - *visibility, - Some(*module_id), - self.interner, - self.def_maps, - ) - }, - ) - else { + if let Some(reexport) = + self.interner.get_trait_reexports(trait_id).iter().find(|reexport| { + self.module_def_id_is_visible( + module_def_id, + reexport.visibility, + Some(reexport.module_id), + ) + }) + { + trait_reexport = Some(TraitReexport { + module_id: reexport.module_id, + name: reexport.name.clone(), + }); + } else if let Some(reexport) = + self.get_ancestor_module_reexport(module_def_id, visibility) + { + trait_reexport = Some(TraitReexport { + module_id: reexport.module_id, + name: trait_.name.clone(), + }); + intermediate_name = Some(reexport.name.clone()); + } else { return; - }; - trait_reexport = Some(TraitReexport { module_id: visible_module_id, name }); + } } let trait_name = if let Some(trait_reexport) = &trait_reexport { - trait_reexport.name + trait_reexport.name.clone() } else { - &trait_.name + trait_.name.clone() }; // Check if the trait is currently imported. If yes, no need to suggest anything let module_data = &self.def_maps[&self.module_id.krate].modules()[self.module_id.local_id.0]; - if !module_data.scope().find_name(trait_name).is_none() { + if !module_data.scope().find_name(&trait_name).is_none() { return; } let module_def_id = ModuleDefId::TraitId(trait_id); let current_module_parent_id = self.module_id.parent(self.def_maps); - let module_full_path = if let Some(trait_reexport) = &trait_reexport { - relative_module_id_path( - *trait_reexport.module_id, - &self.module_id, - current_module_parent_id, - self.interner, - ) - } else { - let Some(path) = relative_module_full_path( - module_def_id, - self.module_id, - current_module_parent_id, - self.interner, - ) else { - return; - }; - path - }; + let defining_module = trait_reexport.map(|reexport| reexport.module_id); - let full_path = format!("{}::{}", module_full_path, trait_name); + let Some(full_path) = module_def_id_relative_path( + module_def_id, + &trait_name.0.contents, + self.module_id, + current_module_parent_id, + defining_module, + &intermediate_name, + self.interner, + ) else { + return; + }; let title = format!("Import {}", full_path); @@ -401,6 +396,57 @@ mod moo { pub use nested::Foo as Qux; } +fn main() { + let x: Field = 1; + x.foobar(); +}"#; + + assert_code_action(title, src, expected).await; + } + + #[test] + async fn test_import_trait_via_module_reexport() { + let title = "Import moo::another::Foo"; + + let src = r#"mod moo { + mod nested { + pub mod another { + pub trait Foo { + fn foobar(self); + } + + impl Foo for Field { + fn foobar(self) {} + } + } + } + + pub use nested::another; +} + +fn main() { + let x: Field = 1; + x.foo>| NodeFinder<'a> { let modifiers = self.interner.function_modifiers(&func_id); let visibility = modifiers.visibility; let module_def_id = ModuleDefId::TraitId(trait_id); - if !module_def_id_is_visible( + let is_visible = self.module_def_id_is_visible( module_def_id, - self.module_id, visibility, None, // defining module - self.interner, - self.def_maps, - ) { + ); + if !is_visible { // Try to find a visible reexport of the trait // that is visible from the current module - let Some((visible_module_id, name, _)) = - self.interner.get_trait_reexports(trait_id).iter().find( - |(module_id, _, visibility)| { - module_def_id_is_visible( - module_def_id, - self.module_id, - *visibility, - Some(*module_id), - self.interner, - self.def_maps, - ) - }, - ) + let Some(reexport) = + self.interner.get_trait_reexports(trait_id).iter().find(|reexport| { + self.module_def_id_is_visible( + module_def_id, + reexport.visibility, + Some(reexport.module_id), + ) + }) else { continue; }; - trait_reexport = Some(TraitReexport { module_id: visible_module_id, name }); + trait_reexport = Some(TraitReexport { + module_id: reexport.module_id, + name: reexport.name.clone(), + }); } } @@ -1024,7 +1020,7 @@ impl<'a> NodeFinder<'a> { noir_function: &NoirFunction, ) { // First find the trait - let location = Location::new(noir_trait_impl.trait_name.span(), self.file); + let location = Location::new(noir_trait_impl.r#trait.span, self.file); let Some(ReferenceId::Trait(trait_id)) = self.interner.find_referenced(location) else { return; }; @@ -1181,6 +1177,23 @@ impl<'a> NodeFinder<'a> { Some(()) } + fn module_def_id_is_visible( + &self, + module_def_id: ModuleDefId, + visibility: ItemVisibility, + defining_module: Option, + ) -> bool { + module_def_id_is_visible( + module_def_id, + self.module_id, + visibility, + defining_module, + self.interner, + self.def_maps, + self.dependencies, + ) + } + fn includes_span(&self, span: Span) -> bool { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } @@ -1286,7 +1299,11 @@ impl<'a> Visitor for NodeFinder<'a> { } fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl, _: Span) -> bool { - self.find_in_path(&noir_trait_impl.trait_name, RequestedItems::OnlyTypes); + let UnresolvedTypeData::Named(trait_name, _, _) = &noir_trait_impl.r#trait.typ else { + return false; + }; + + self.find_in_path(trait_name, RequestedItems::OnlyTypes); noir_trait_impl.object_type.accept(self); self.type_parameters.clear(); diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/auto_import.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/auto_import.rs index 0e80b284f32e..08d155f333c8 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/auto_import.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/auto_import.rs @@ -1,11 +1,10 @@ -use noirc_frontend::hir::def_map::ModuleDefId; +use noirc_frontend::{ast::ItemVisibility, hir::def_map::ModuleDefId, node_interner::Reexport}; use crate::{ - modules::{relative_module_full_path, relative_module_id_path}, + modules::{get_ancestor_module_reexport, module_def_id_relative_path}, use_segment_positions::{ use_completion_item_additional_text_edits, UseCompletionItemAdditionTextEditsRequest, }, - visibility::module_def_id_is_visible, }; use super::{ @@ -29,24 +28,33 @@ impl<'a> NodeFinder<'a> { continue; } - for (module_def_id, visibility, defining_module) in entries { - if self.suggested_module_def_ids.contains(module_def_id) { + for entry in entries { + let module_def_id = entry.module_def_id; + if self.suggested_module_def_ids.contains(&module_def_id) { continue; } - if !module_def_id_is_visible( - *module_def_id, - self.module_id, - *visibility, - *defining_module, - self.interner, - self.def_maps, - ) { - continue; + let visibility = entry.visibility; + let mut defining_module = entry.defining_module.as_ref().cloned(); + + // If the item is offered via a re-export of it's parent module, this holds the name of the reexport. + let mut intermediate_name = None; + + let is_visible = + self.module_def_id_is_visible(module_def_id, visibility, defining_module); + if !is_visible { + if let Some(reexport) = + self.get_ancestor_module_reexport(module_def_id, visibility) + { + defining_module = Some(reexport.module_id); + intermediate_name = Some(reexport.name); + } else { + continue; + } } let completion_items = self.module_def_id_completion_items( - *module_def_id, + module_def_id, name.clone(), function_completion_kind, FunctionKind::Any, @@ -57,34 +65,19 @@ impl<'a> NodeFinder<'a> { continue; }; - self.suggested_module_def_ids.insert(*module_def_id); + self.suggested_module_def_ids.insert(module_def_id); for mut completion_item in completion_items { - let module_full_path = if let Some(defining_module) = defining_module { - relative_module_id_path( - *defining_module, - &self.module_id, - current_module_parent_id, - self.interner, - ) - } else { - let Some(module_full_path) = relative_module_full_path( - *module_def_id, - self.module_id, - current_module_parent_id, - self.interner, - ) else { - continue; - }; - module_full_path - }; - - let full_path = if defining_module.is_some() - || !matches!(module_def_id, ModuleDefId::ModuleId(..)) - { - format!("{}::{}", module_full_path, name) - } else { - module_full_path + let Some(full_path) = module_def_id_relative_path( + module_def_id, + name, + self.module_id, + current_module_parent_id, + defining_module, + &intermediate_name, + self.interner, + ) else { + continue; }; let mut label_details = completion_item.label_details.unwrap(); @@ -109,4 +102,19 @@ impl<'a> NodeFinder<'a> { } } } + + fn get_ancestor_module_reexport( + &self, + module_def_id: ModuleDefId, + visibility: ItemVisibility, + ) -> Option { + get_ancestor_module_reexport( + module_def_id, + visibility, + self.module_id, + self.interner, + self.def_maps, + self.dependencies, + ) + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs index b3367c287a0a..cfd11bfe1adf 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/completion_items.rs @@ -395,15 +395,15 @@ impl<'a> NodeFinder<'a> { let (trait_id, trait_reexport) = trait_info?; let trait_name = if let Some(trait_reexport) = trait_reexport { - trait_reexport.name + trait_reexport.name.clone() } else { let trait_ = self.interner.get_trait(trait_id); - &trait_.name + trait_.name.clone() }; let module_data = &self.def_maps[&self.module_id.krate].modules()[self.module_id.local_id.0]; - if !module_data.scope().find_name(trait_name).is_none() { + if !module_data.scope().find_name(&trait_name).is_none() { return None; } @@ -411,8 +411,8 @@ impl<'a> NodeFinder<'a> { let current_module_parent_id = self.module_id.parent(self.def_maps); let module_full_path = if let Some(reexport_data) = trait_reexport { relative_module_id_path( - *reexport_data.module_id, - &self.module_id, + reexport_data.module_id, + self.module_id, current_module_parent_id, self.interner, ) diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs index eb513f37daf8..cbe1a93391a7 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs @@ -3157,4 +3157,234 @@ fn main() { let item = &items[0]; assert_eq!(item.kind, Some(CompletionItemKind::ENUM)); } + + #[test] + async fn autocompletes_via_parent_module_reexport() { + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub struct SomeStruct {} + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru>|< +}"#; + + let mut items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = items.remove(0); + assert_eq!( + item.label_details, + Some(CompletionItemLabelDetails { + detail: Some("(use aztec::protocol_types::SomeStruct)".to_string()), + description: Some("SomeStruct".to_string()), + }) + ); + + let expected = r#"use aztec::protocol_types::SomeStruct; + +mod aztec { + mod deps { + pub mod protocol_types { + pub struct SomeStruct {} + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru +}"#; + + let changed = + apply_text_edits(&src.replace(">|<", ""), &item.additional_text_edits.unwrap()); + assert_eq!(changed, expected); + } + + #[test] + async fn autocompletes_via_renamed_parent_module_reexport() { + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub struct SomeStruct {} + } + } + + pub use deps::protocol_types as export; +} + +fn main() { + SomeStru>|< +}"#; + + let mut items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = items.remove(0); + assert_eq!( + item.label_details, + Some(CompletionItemLabelDetails { + detail: Some("(use aztec::export::SomeStruct)".to_string()), + description: Some("SomeStruct".to_string()), + }) + ); + + let expected = r#"use aztec::export::SomeStruct; + +mod aztec { + mod deps { + pub mod protocol_types { + pub struct SomeStruct {} + } + } + + pub use deps::protocol_types as export; +} + +fn main() { + SomeStru +}"#; + + let changed = + apply_text_edits(&src.replace(">|<", ""), &item.additional_text_edits.unwrap()); + assert_eq!(changed, expected); + } + + #[test] + async fn autocompletes_nested_type_via_parent_module_reexport() { + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub mod nested { + pub struct SomeStruct {} + } + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru>|< +}"#; + + let mut items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = items.remove(0); + assert_eq!( + item.label_details, + Some(CompletionItemLabelDetails { + detail: Some("(use aztec::protocol_types::nested::SomeStruct)".to_string()), + description: Some("SomeStruct".to_string()), + }) + ); + + let expected = r#"use aztec::protocol_types::nested::SomeStruct; + +mod aztec { + mod deps { + pub mod protocol_types { + pub mod nested { + pub struct SomeStruct {} + } + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru +}"#; + + let changed = + apply_text_edits(&src.replace(">|<", ""), &item.additional_text_edits.unwrap()); + assert_eq!(changed, expected); + } + + #[test] + async fn does_not_autocomplete_nested_type_via_parent_module_reexport_if_it_is_not_visible() { + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub mod nested { + struct SomeStruct {} + } + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru>|< +}"#; + + let items = get_completions(src).await; + assert_eq!(items.len(), 0); + } + + #[test] + async fn autocompletes_deeply_nested_type_via_parent_module_reexport() { + let src = r#"mod aztec { + mod deps { + pub mod protocol_types { + pub mod deeply { + pub mod nested { + pub struct SomeStruct {} + } + } + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru>|< +}"#; + + let mut items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = items.remove(0); + assert_eq!( + item.label_details, + Some(CompletionItemLabelDetails { + detail: Some("(use aztec::protocol_types::deeply::nested::SomeStruct)".to_string()), + description: Some("SomeStruct".to_string()), + }) + ); + + let expected = r#"use aztec::protocol_types::deeply::nested::SomeStruct; + +mod aztec { + mod deps { + pub mod protocol_types { + pub mod deeply { + pub mod nested { + pub struct SomeStruct {} + } + } + } + } + + pub use deps::protocol_types; +} + +fn main() { + SomeStru +}"#; + + let changed = + apply_text_edits(&src.replace(">|<", ""), &item.additional_text_edits.unwrap()); + assert_eq!(changed, expected); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/document_symbol.rs b/noir/noir-repo/tooling/lsp/src/requests/document_symbol.rs index 0fc2dc4622e0..b32b2fc7ad7f 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/document_symbol.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/document_symbol.rs @@ -10,7 +10,7 @@ use noirc_errors::Span; use noirc_frontend::{ ast::{ Expression, FunctionReturnType, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, - NoirTraitImpl, TypeImpl, UnresolvedType, Visitor, + NoirTraitImpl, TypeImpl, UnresolvedType, UnresolvedTypeData, Visitor, }, parser::ParsedSubModule, ParsedModule, @@ -349,32 +349,18 @@ impl<'a> Visitor for DocumentSymbolCollector<'a> { return false; }; - let Some(name_location) = self.to_lsp_location(noir_trait_impl.trait_name.span) else { + let name_span = + if let UnresolvedTypeData::Named(trait_name, _, _) = &noir_trait_impl.r#trait.typ { + trait_name.span + } else { + noir_trait_impl.r#trait.span + }; + + let Some(name_location) = self.to_lsp_location(name_span) else { return false; }; - let mut trait_name = String::new(); - trait_name.push_str(&noir_trait_impl.trait_name.to_string()); - if !noir_trait_impl.trait_generics.is_empty() { - trait_name.push('<'); - for (index, generic) in noir_trait_impl.trait_generics.ordered_args.iter().enumerate() { - if index > 0 { - trait_name.push_str(", "); - } - trait_name.push_str(&generic.to_string()); - } - for (index, (name, generic)) in - noir_trait_impl.trait_generics.named_args.iter().enumerate() - { - if index > 0 { - trait_name.push_str(", "); - } - trait_name.push_str(&name.0.contents); - trait_name.push_str(" = "); - trait_name.push_str(&generic.to_string()); - } - trait_name.push('>'); - } + let trait_name = noir_trait_impl.r#trait.to_string(); let old_symbols = std::mem::take(&mut self.symbols); self.symbols = Vec::new(); diff --git a/noir/noir-repo/tooling/lsp/src/requests/hover.rs b/noir/noir-repo/tooling/lsp/src/requests/hover.rs index 60c2a686a627..cbb7dafd3c51 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/hover.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/hover.rs @@ -1,31 +1,17 @@ use std::future::{self, Future}; use async_lsp::ResponseError; -use fm::{FileMap, PathString}; -use lsp_types::{Hover, HoverContents, HoverParams, MarkupContent, MarkupKind}; -use noirc_frontend::{ - ast::{ItemVisibility, Visibility}, - hir::def_map::ModuleId, - hir_def::{ - expr::{HirArrayLiteral, HirExpression, HirLiteral}, - function::FuncMeta, - stmt::HirPattern, - traits::Trait, - }, - node_interner::{ - DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, TraitId, - TraitImplKind, TypeAliasId, TypeId, - }, - DataType, EnumVariant, Generics, Shared, StructField, Type, TypeAlias, TypeBinding, - TypeVariable, -}; +use fm::PathString; +use from_reference::hover_from_reference; +use from_visitor::hover_from_visitor; +use lsp_types::{Hover, HoverParams}; -use crate::{ - attribute_reference_finder::AttributeReferenceFinder, modules::module_full_path, utils, - LspState, -}; +use crate::LspState; -use super::{process_request, to_lsp_location, ProcessRequestCallbackArgs}; +use super::process_request; + +mod from_reference; +mod from_visitor; pub(crate) fn on_hover_request( state: &mut LspState, @@ -35,903 +21,22 @@ pub(crate) fn on_hover_request( let position = params.text_document_position_params.position; let result = process_request(state, params.text_document_position_params, |args| { let path = PathString::from_path(uri.to_file_path().unwrap()); - args.files - .get_file_id(&path) - .and_then(|file_id| { - utils::position_to_byte_index(args.files, file_id, &position).and_then( - |byte_index| { - let file = args.files.get_file(file_id).unwrap(); - let source = file.source(); - let (parsed_module, _errors) = noirc_frontend::parse_program(source); - - let mut finder = AttributeReferenceFinder::new( - file_id, - byte_index, - args.crate_id, - args.def_maps, - ); - finder.find(&parsed_module) - }, - ) - }) - .or_else(|| args.interner.reference_at_location(args.location)) - .and_then(|reference| { - let location = args.interner.reference_location(reference); - let lsp_location = to_lsp_location(args.files, location.file, location.span); - format_reference(reference, &args).map(|formatted| Hover { - range: lsp_location.map(|location| location.range), - contents: HoverContents::Markup(MarkupContent { - kind: MarkupKind::Markdown, - value: formatted, - }), - }) - }) + let file_id = args.files.get_file_id(&path); + hover_from_reference(file_id, position, &args) + .or_else(|| hover_from_visitor(file_id, position, &args)) }); future::ready(result) } -fn format_reference(reference: ReferenceId, args: &ProcessRequestCallbackArgs) -> Option { - match reference { - ReferenceId::Module(id) => format_module(id, args), - ReferenceId::Type(id) => Some(format_type(id, args)), - ReferenceId::StructMember(id, field_index) => { - Some(format_struct_member(id, field_index, args)) - } - ReferenceId::EnumVariant(id, variant_index) => { - Some(format_enum_variant(id, variant_index, args)) - } - ReferenceId::Trait(id) => Some(format_trait(id, args)), - ReferenceId::Global(id) => Some(format_global(id, args)), - ReferenceId::Function(id) => Some(format_function(id, args)), - ReferenceId::Alias(id) => Some(format_alias(id, args)), - ReferenceId::Local(id) => Some(format_local(id, args)), - ReferenceId::Reference(location, _) => { - format_reference(args.interner.find_referenced(location).unwrap(), args) - } - } -} -fn format_module(id: ModuleId, args: &ProcessRequestCallbackArgs) -> Option { - let crate_root = args.def_maps[&id.krate].root(); - - let mut string = String::new(); - - if id.local_id == crate_root { - let dep = args.dependencies.iter().find(|dep| dep.crate_id == id.krate)?; - string.push_str(" crate "); - string.push_str(&dep.name.to_string()); - } else { - // Note: it's not clear why `try_module_attributes` might return None here, but it happens. - // This is a workaround to avoid panicking in that case (which brings the LSP server down). - // Cases where this happens are related to generated code, so once that stops happening - // this won't be an issue anymore. - let module_attributes = args.interner.try_module_attributes(&id)?; - - if let Some(parent_local_id) = module_attributes.parent { - if format_parent_module_from_module_id( - &ModuleId { krate: id.krate, local_id: parent_local_id }, - args, - &mut string, - ) { - string.push('\n'); - } - } - string.push_str(" "); - string.push_str("mod "); - string.push_str(&module_attributes.name); - } - - append_doc_comments(args.interner, ReferenceId::Module(id), &mut string); - - Some(string) -} - -fn format_type(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { - let typ = args.interner.get_type(id); - let typ = typ.borrow(); - if let Some(fields) = typ.get_fields_as_written() { - format_struct(&typ, fields, args) - } else if let Some(variants) = typ.get_variants_as_written() { - format_enum(&typ, variants, args) - } else { - unreachable!("Type should either be a struct or an enum") - } -} - -fn format_struct( - typ: &DataType, - fields: Vec, - args: &ProcessRequestCallbackArgs, -) -> String { - let mut string = String::new(); - if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { - string.push('\n'); - } - string.push_str(" "); - string.push_str("struct "); - string.push_str(&typ.name.0.contents); - format_generics(&typ.generics, &mut string); - string.push_str(" {\n"); - for field in fields { - string.push_str(" "); - string.push_str(&field.name.0.contents); - string.push_str(": "); - string.push_str(&format!("{}", field.typ)); - string.push_str(",\n"); - } - string.push_str(" }"); - - append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); - - string -} - -fn format_enum( - typ: &DataType, - variants: Vec, - args: &ProcessRequestCallbackArgs, -) -> String { - let mut string = String::new(); - if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { - string.push('\n'); - } - string.push_str(" "); - string.push_str("enum "); - string.push_str(&typ.name.0.contents); - format_generics(&typ.generics, &mut string); - string.push_str(" {\n"); - for field in variants { - string.push_str(" "); - string.push_str(&field.name.0.contents); - - if !field.params.is_empty() { - let types = field.params.iter().map(ToString::to_string).collect::>(); - string.push('('); - string.push_str(&types.join(", ")); - string.push(')'); - } - - string.push_str(",\n"); - } - string.push_str(" }"); - - append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); - - string -} - -fn format_struct_member( - id: TypeId, - field_index: usize, - args: &ProcessRequestCallbackArgs, -) -> String { - let struct_type = args.interner.get_type(id); - let struct_type = struct_type.borrow(); - let field = struct_type.field_at(field_index); - - let mut string = String::new(); - if format_parent_module(ReferenceId::Type(id), args, &mut string) { - string.push_str("::"); - } - string.push_str(&struct_type.name.0.contents); - string.push('\n'); - string.push_str(" "); - string.push_str(&field.name.0.contents); - string.push_str(": "); - string.push_str(&format!("{}", field.typ)); - string.push_str(&go_to_type_links(&field.typ, args.interner, args.files)); - - append_doc_comments(args.interner, ReferenceId::StructMember(id, field_index), &mut string); - - string -} - -fn format_enum_variant( - id: TypeId, - field_index: usize, - args: &ProcessRequestCallbackArgs, -) -> String { - let enum_type = args.interner.get_type(id); - let enum_type = enum_type.borrow(); - let variant = enum_type.variant_at(field_index); - - let mut string = String::new(); - if format_parent_module(ReferenceId::Type(id), args, &mut string) { - string.push_str("::"); - } - string.push_str(&enum_type.name.0.contents); - string.push('\n'); - string.push_str(" "); - string.push_str(&variant.name.0.contents); - if !variant.params.is_empty() { - let types = variant.params.iter().map(ToString::to_string).collect::>(); - string.push('('); - string.push_str(&types.join(", ")); - string.push(')'); - } - - for typ in variant.params.iter() { - string.push_str(&go_to_type_links(typ, args.interner, args.files)); - } - - append_doc_comments(args.interner, ReferenceId::EnumVariant(id, field_index), &mut string); - - string -} - -fn format_trait(id: TraitId, args: &ProcessRequestCallbackArgs) -> String { - let a_trait = args.interner.get_trait(id); - - let mut string = String::new(); - if format_parent_module(ReferenceId::Trait(id), args, &mut string) { - string.push('\n'); - } - string.push_str(" "); - string.push_str("trait "); - string.push_str(&a_trait.name.0.contents); - format_generics(&a_trait.generics, &mut string); - - append_doc_comments(args.interner, ReferenceId::Trait(id), &mut string); - - string -} - -fn format_global(id: GlobalId, args: &ProcessRequestCallbackArgs) -> String { - let global_info = args.interner.get_global(id); - let definition_id = global_info.definition_id; - let definition = args.interner.definition(definition_id); - let typ = args.interner.definition_type(definition_id); - - let mut string = String::new(); - if format_parent_module(ReferenceId::Global(id), args, &mut string) { - string.push('\n'); - } - - let mut print_comptime = definition.comptime; - let mut opt_value = None; - - // See if we can figure out what's the global's value - if let Some(stmt) = args.interner.get_global_let_statement(id) { - print_comptime = stmt.comptime; - opt_value = get_global_value(args.interner, stmt.expression); - } - - string.push_str(" "); - if print_comptime { - string.push_str("comptime "); - } - if definition.mutable { - string.push_str("mut "); - } - string.push_str("global "); - string.push_str(&global_info.ident.0.contents); - string.push_str(": "); - string.push_str(&format!("{}", typ)); - - if let Some(value) = opt_value { - string.push_str(" = "); - string.push_str(&value); - } - - string.push_str(&go_to_type_links(&typ, args.interner, args.files)); - - append_doc_comments(args.interner, ReferenceId::Global(id), &mut string); - - string -} - -fn get_global_value(interner: &NodeInterner, expr: ExprId) -> Option { - match interner.expression(&expr) { - HirExpression::Literal(literal) => match literal { - HirLiteral::Array(hir_array_literal) => { - get_global_array_value(interner, hir_array_literal, false) - } - HirLiteral::Slice(hir_array_literal) => { - get_global_array_value(interner, hir_array_literal, true) - } - HirLiteral::Bool(value) => Some(value.to_string()), - HirLiteral::Integer(field_element, _) => Some(field_element.to_string()), - HirLiteral::Str(string) => Some(format!("{:?}", string)), - HirLiteral::FmtStr(..) => None, - HirLiteral::Unit => Some("()".to_string()), - }, - HirExpression::Tuple(values) => { - get_exprs_global_value(interner, &values).map(|value| format!("({})", value)) - } - _ => None, - } -} - -fn get_global_array_value( - interner: &NodeInterner, - literal: HirArrayLiteral, - is_slice: bool, -) -> Option { - match literal { - HirArrayLiteral::Standard(values) => { - get_exprs_global_value(interner, &values).map(|value| { - if is_slice { - format!("&[{}]", value) - } else { - format!("[{}]", value) - } - }) - } - HirArrayLiteral::Repeated { repeated_element, length } => { - get_global_value(interner, repeated_element).map(|value| { - if is_slice { - format!("&[{}; {}]", value, length) - } else { - format!("[{}; {}]", value, length) - } - }) - } - } -} - -fn get_exprs_global_value(interner: &NodeInterner, exprs: &[ExprId]) -> Option { - let strings: Vec = - exprs.iter().filter_map(|value| get_global_value(interner, *value)).collect(); - if strings.len() == exprs.len() { - Some(strings.join(", ")) - } else { - None - } -} - -fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { - let func_meta = args.interner.function_meta(&id); - - // If this points to a trait method, see if we can figure out what's the concrete trait impl method - if let Some(func_id) = get_trait_impl_func_id(id, args, func_meta) { - return format_function(func_id, args); - } - - let func_modifiers = args.interner.function_modifiers(&id); - - let func_name_definition_id = args.interner.definition(func_meta.name.id); - - let enum_variant = match (func_meta.type_id, func_meta.enum_variant_index) { - (Some(type_id), Some(index)) => Some((type_id, index)), - _ => None, - }; - - let reference_id = if let Some((type_id, variant_index)) = enum_variant { - ReferenceId::EnumVariant(type_id, variant_index) - } else { - ReferenceId::Function(id) - }; - - let mut string = String::new(); - let formatted_parent_module = format_parent_module(reference_id, args, &mut string); - - let formatted_parent_type = if let Some(trait_impl_id) = func_meta.trait_impl { - let trait_impl = args.interner.get_trait_implementation(trait_impl_id); - let trait_impl = trait_impl.borrow(); - let trait_ = args.interner.get_trait(trait_impl.trait_id); - - let generics: Vec<_> = - trait_impl - .trait_generics - .iter() - .filter_map(|generic| { - if let Type::NamedGeneric(_, name) = generic { - Some(name) - } else { - None - } - }) - .collect(); - - string.push('\n'); - string.push_str(" impl"); - if !generics.is_empty() { - string.push('<'); - for (index, generic) in generics.into_iter().enumerate() { - if index > 0 { - string.push_str(", "); - } - string.push_str(generic); - } - string.push('>'); - } - - string.push(' '); - string.push_str(&trait_.name.0.contents); - if !trait_impl.trait_generics.is_empty() { - string.push('<'); - for (index, generic) in trait_impl.trait_generics.iter().enumerate() { - if index > 0 { - string.push_str(", "); - } - string.push_str(&generic.to_string()); - } - string.push('>'); - } - - string.push_str(" for "); - string.push_str(&trait_impl.typ.to_string()); - - true - } else if let Some(trait_id) = func_meta.trait_id { - let trait_ = args.interner.get_trait(trait_id); - string.push('\n'); - string.push_str(" trait "); - string.push_str(&trait_.name.0.contents); - format_generics(&trait_.generics, &mut string); - - true - } else if let Some(type_id) = func_meta.type_id { - let data_type = args.interner.get_type(type_id); - let data_type = data_type.borrow(); - if formatted_parent_module { - string.push_str("::"); - } - string.push_str(&data_type.name.0.contents); - if enum_variant.is_none() { - string.push('\n'); - string.push_str(" "); - string.push_str("impl"); - - let impl_generics: Vec<_> = func_meta - .all_generics - .iter() - .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) - .cloned() - .collect(); - format_generics(&impl_generics, &mut string); - - string.push(' '); - string.push_str(&data_type.name.0.contents); - format_generic_names(&impl_generics, &mut string); - } - - true - } else { - false - }; - if formatted_parent_module || formatted_parent_type { - string.push('\n'); - } - string.push_str(" "); - - if func_modifiers.visibility != ItemVisibility::Private - && func_meta.trait_id.is_none() - && func_meta.trait_impl.is_none() - { - string.push_str(&func_modifiers.visibility.to_string()); - string.push(' '); - } - if func_modifiers.is_unconstrained { - string.push_str("unconstrained "); - } - if func_modifiers.is_comptime { - string.push_str("comptime "); - } - - let func_name = &func_name_definition_id.name; - - if enum_variant.is_none() { - string.push_str("fn "); - } - string.push_str(func_name); - format_generics(&func_meta.direct_generics, &mut string); - string.push('('); - let parameters = &func_meta.parameters; - for (index, (pattern, typ, visibility)) in parameters.iter().enumerate() { - let is_self = pattern_is_self(pattern, args.interner); - - // `&mut self` is represented as a mutable reference type, not as a mutable pattern - if is_self && matches!(typ, Type::MutableReference(..)) { - string.push_str("&mut "); - } - - if enum_variant.is_some() { - string.push_str(&format!("{}", typ)); - } else { - format_pattern(pattern, args.interner, &mut string); - - // Don't add type for `self` param - if !is_self { - string.push_str(": "); - if matches!(visibility, Visibility::Public) { - string.push_str("pub "); - } - string.push_str(&format!("{}", typ)); - } - } - - if index != parameters.len() - 1 { - string.push_str(", "); - } - } - - string.push(')'); - - if enum_variant.is_none() { - let return_type = func_meta.return_type(); - match return_type { - Type::Unit => (), - _ => { - string.push_str(" -> "); - string.push_str(&format!("{}", return_type)); - } - } - - string.push_str(&go_to_type_links(return_type, args.interner, args.files)); - } - - if enum_variant.is_some() { - append_doc_comments(args.interner, reference_id, &mut string); - } else { - let had_doc_comments = append_doc_comments(args.interner, reference_id, &mut string); - if !had_doc_comments { - // If this function doesn't have doc comments, but it's a trait impl method, - // use the trait method doc comments. - if let Some(trait_impl_id) = func_meta.trait_impl { - let trait_impl = args.interner.get_trait_implementation(trait_impl_id); - let trait_impl = trait_impl.borrow(); - let trait_ = args.interner.get_trait(trait_impl.trait_id); - if let Some(func_id) = trait_.method_ids.get(func_name) { - let reference_id = ReferenceId::Function(*func_id); - append_doc_comments(args.interner, reference_id, &mut string); - } - } - } - } - - string -} - -fn get_trait_impl_func_id( - id: FuncId, - args: &ProcessRequestCallbackArgs, - func_meta: &FuncMeta, -) -> Option { - func_meta.trait_id?; - - let index = args.interner.find_location_index(args.location)?; - let expr_id = args.interner.get_expr_id_from_index(index)?; - let Some(TraitImplKind::Normal(trait_impl_id)) = - args.interner.get_selected_impl_for_expression(expr_id) - else { - return None; - }; - - let trait_impl = args.interner.get_trait_implementation(trait_impl_id); - let trait_impl = trait_impl.borrow(); - - let function_name = args.interner.function_name(&id); - let mut trait_impl_methods = trait_impl.methods.iter(); - let func_id = - trait_impl_methods.find(|func_id| args.interner.function_name(func_id) == function_name)?; - Some(*func_id) -} - -fn format_alias(id: TypeAliasId, args: &ProcessRequestCallbackArgs) -> String { - let type_alias = args.interner.get_type_alias(id); - let type_alias = type_alias.borrow(); - - let mut string = String::new(); - format_parent_module(ReferenceId::Alias(id), args, &mut string); - string.push('\n'); - string.push_str(" "); - string.push_str("type "); - string.push_str(&type_alias.name.0.contents); - string.push_str(" = "); - string.push_str(&format!("{}", &type_alias.typ)); - - append_doc_comments(args.interner, ReferenceId::Alias(id), &mut string); - - string -} - -fn format_local(id: DefinitionId, args: &ProcessRequestCallbackArgs) -> String { - let definition_info = args.interner.definition(id); - if let DefinitionKind::Global(global_id) = &definition_info.kind { - return format_global(*global_id, args); - } - - let DefinitionKind::Local(expr_id) = definition_info.kind else { - panic!("Expected a local reference to reference a local definition") - }; - let typ = args.interner.definition_type(id); - - let mut string = String::new(); - string.push_str(" "); - if definition_info.comptime { - string.push_str("comptime "); - } - if expr_id.is_some() { - string.push_str("let "); - } - if definition_info.mutable { - if expr_id.is_none() { - string.push_str("let "); - } - string.push_str("mut "); - } - string.push_str(&definition_info.name); - if !matches!(typ, Type::Error) { - string.push_str(": "); - string.push_str(&format!("{}", typ)); - } - - string.push_str(&go_to_type_links(&typ, args.interner, args.files)); - - string -} - -fn format_generics(generics: &Generics, string: &mut String) { - format_generics_impl( - generics, false, // only show names - string, - ); -} - -fn format_generic_names(generics: &Generics, string: &mut String) { - format_generics_impl( - generics, true, // only show names - string, - ); -} - -fn format_generics_impl(generics: &Generics, only_show_names: bool, string: &mut String) { - if generics.is_empty() { - return; - } - - string.push('<'); - for (index, generic) in generics.iter().enumerate() { - if index > 0 { - string.push_str(", "); - } - - if only_show_names { - string.push_str(&generic.name); - } else { - match generic.kind() { - noirc_frontend::Kind::Any | noirc_frontend::Kind::Normal => { - string.push_str(&generic.name); - } - noirc_frontend::Kind::IntegerOrField | noirc_frontend::Kind::Integer => { - string.push_str("let "); - string.push_str(&generic.name); - string.push_str(": u32"); - } - noirc_frontend::Kind::Numeric(typ) => { - string.push_str("let "); - string.push_str(&generic.name); - string.push_str(": "); - string.push_str(&typ.to_string()); - } - } - } - } - string.push('>'); -} - -fn format_pattern(pattern: &HirPattern, interner: &NodeInterner, string: &mut String) { - match pattern { - HirPattern::Identifier(ident) => { - let definition = interner.definition(ident.id); - string.push_str(&definition.name); - } - HirPattern::Mutable(pattern, _) => { - string.push_str("mut "); - format_pattern(pattern, interner, string); - } - HirPattern::Tuple(..) | HirPattern::Struct(..) => { - string.push('_'); - } - } -} - -fn pattern_is_self(pattern: &HirPattern, interner: &NodeInterner) -> bool { - match pattern { - HirPattern::Identifier(ident) => { - let definition = interner.definition(ident.id); - definition.name == "self" - } - HirPattern::Mutable(pattern, _) => pattern_is_self(pattern, interner), - HirPattern::Tuple(..) | HirPattern::Struct(..) => false, - } -} - -fn format_parent_module( - referenced: ReferenceId, - args: &ProcessRequestCallbackArgs, - string: &mut String, -) -> bool { - let Some(module) = args.interner.reference_module(referenced) else { - return false; - }; - - format_parent_module_from_module_id(module, args, string) -} - -fn format_parent_module_from_module_id( - module: &ModuleId, - args: &ProcessRequestCallbackArgs, - string: &mut String, -) -> bool { - let full_path = - module_full_path(module, args.interner, args.crate_id, &args.crate_name, args.dependencies); - if full_path.is_empty() { - return false; - } - - string.push_str(" "); - string.push_str(&full_path); - true -} - -fn go_to_type_links(typ: &Type, interner: &NodeInterner, files: &FileMap) -> String { - let mut gatherer = TypeLinksGatherer { interner, files, links: Vec::new() }; - gatherer.gather_type_links(typ); - - let links = gatherer.links; - if links.is_empty() { - "".to_string() - } else { - let mut string = String::new(); - string.push_str("\n\n"); - string.push_str("Go to "); - for (index, link) in links.iter().enumerate() { - if index > 0 { - string.push_str(" | "); - } - string.push_str(link); - } - string - } -} - -struct TypeLinksGatherer<'a> { - interner: &'a NodeInterner, - files: &'a FileMap, - links: Vec, -} - -impl<'a> TypeLinksGatherer<'a> { - fn gather_type_links(&mut self, typ: &Type) { - match typ { - Type::Array(typ, _) => self.gather_type_links(typ), - Type::Slice(typ) => self.gather_type_links(typ), - Type::Tuple(types) => { - for typ in types { - self.gather_type_links(typ); - } - } - Type::DataType(data_type, generics) => { - self.gather_struct_type_links(data_type); - for generic in generics { - self.gather_type_links(generic); - } - } - Type::Alias(type_alias, generics) => { - self.gather_type_alias_links(type_alias); - for generic in generics { - self.gather_type_links(generic); - } - } - Type::TypeVariable(var) => { - self.gather_type_variable_links(var); - } - Type::TraitAsType(trait_id, _, generics) => { - let some_trait = self.interner.get_trait(*trait_id); - self.gather_trait_links(some_trait); - for generic in &generics.ordered { - self.gather_type_links(generic); - } - for named_type in &generics.named { - self.gather_type_links(&named_type.typ); - } - } - Type::NamedGeneric(var, _) => { - self.gather_type_variable_links(var); - } - Type::Function(args, return_type, env, _) => { - for arg in args { - self.gather_type_links(arg); - } - self.gather_type_links(return_type); - self.gather_type_links(env); - } - Type::MutableReference(typ) => self.gather_type_links(typ), - Type::InfixExpr(lhs, _, rhs, _) => { - self.gather_type_links(lhs); - self.gather_type_links(rhs); - } - Type::CheckedCast { to, .. } => self.gather_type_links(to), - Type::FieldElement - | Type::Integer(..) - | Type::Bool - | Type::String(_) - | Type::FmtString(_, _) - | Type::Unit - | Type::Forall(_, _) - | Type::Constant(..) - | Type::Quoted(_) - | Type::Error => (), - } - } - - fn gather_struct_type_links(&mut self, struct_type: &Shared) { - let struct_type = struct_type.borrow(); - if let Some(lsp_location) = - to_lsp_location(self.files, struct_type.location.file, struct_type.name.span()) - { - self.push_link(format_link(struct_type.name.to_string(), lsp_location)); - } - } - - fn gather_type_alias_links(&mut self, type_alias: &Shared) { - let type_alias = type_alias.borrow(); - if let Some(lsp_location) = - to_lsp_location(self.files, type_alias.location.file, type_alias.name.span()) - { - self.push_link(format_link(type_alias.name.to_string(), lsp_location)); - } - } - - fn gather_trait_links(&mut self, some_trait: &Trait) { - if let Some(lsp_location) = - to_lsp_location(self.files, some_trait.location.file, some_trait.name.span()) - { - self.push_link(format_link(some_trait.name.to_string(), lsp_location)); - } - } - - fn gather_type_variable_links(&mut self, var: &TypeVariable) { - let var = &*var.borrow(); - match var { - TypeBinding::Bound(typ) => { - self.gather_type_links(typ); - } - TypeBinding::Unbound(..) => (), - } - } - - fn push_link(&mut self, link: String) { - if !self.links.contains(&link) { - self.links.push(link); - } - } -} - -fn format_link(name: String, location: lsp_types::Location) -> String { - format!( - "[{}]({}#L{},{}-{},{})", - name, - location.uri, - location.range.start.line + 1, - location.range.start.character + 1, - location.range.end.line + 1, - location.range.end.character + 1 - ) -} - -fn append_doc_comments(interner: &NodeInterner, id: ReferenceId, string: &mut String) -> bool { - if let Some(doc_comments) = interner.doc_comments(id) { - string.push_str("\n\n---\n\n"); - for comment in doc_comments { - string.push_str(comment); - string.push('\n'); - } - true - } else { - false - } -} - #[cfg(test)] mod hover_tests { use crate::test_utils; use super::*; use lsp_types::{ - Position, TextDocumentIdentifier, TextDocumentPositionParams, Url, WorkDoneProgressParams, + HoverContents, Position, TextDocumentIdentifier, TextDocumentPositionParams, Url, + WorkDoneProgressParams, }; use tokio::test; @@ -1357,4 +462,20 @@ mod hover_tests { Like a tomato" )); } + + #[test] + async fn hover_on_integer_literal() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 9, character: 69 }) + .await; + assert_eq!(&hover_text, " Field\n---\nvalue of literal: `123 (0x7b)`"); + } + + #[test] + async fn hover_on_negative_integer_literal() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 113, character: 5 }) + .await; + assert_eq!(&hover_text, " i32\n---\nvalue of literal: `-8 (-0x08)`"); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/hover/from_reference.rs b/noir/noir-repo/tooling/lsp/src/requests/hover/from_reference.rs new file mode 100644 index 000000000000..7f589b9df70a --- /dev/null +++ b/noir/noir-repo/tooling/lsp/src/requests/hover/from_reference.rs @@ -0,0 +1,914 @@ +use fm::{FileId, FileMap}; +use lsp_types::{Hover, HoverContents, MarkupContent, MarkupKind, Position}; +use noirc_frontend::{ + ast::{ItemVisibility, Visibility}, + hir::def_map::ModuleId, + hir_def::{ + expr::{HirArrayLiteral, HirExpression, HirLiteral}, + function::FuncMeta, + stmt::HirPattern, + traits::Trait, + }, + node_interner::{ + DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, TraitId, + TraitImplKind, TypeAliasId, TypeId, + }, + DataType, EnumVariant, Generics, Shared, StructField, Type, TypeAlias, TypeBinding, + TypeVariable, +}; + +use crate::{ + attribute_reference_finder::AttributeReferenceFinder, + modules::module_full_path, + requests::{to_lsp_location, ProcessRequestCallbackArgs}, + utils, +}; + +pub(super) fn hover_from_reference( + file_id: Option, + position: Position, + args: &ProcessRequestCallbackArgs, +) -> Option { + file_id + .and_then(|file_id| { + utils::position_to_byte_index(args.files, file_id, &position).and_then(|byte_index| { + let file = args.files.get_file(file_id).unwrap(); + let source = file.source(); + let (parsed_module, _errors) = noirc_frontend::parse_program(source); + + let mut finder = AttributeReferenceFinder::new( + file_id, + byte_index, + args.crate_id, + args.def_maps, + ); + finder.find(&parsed_module) + }) + }) + .or_else(|| args.interner.reference_at_location(args.location)) + .and_then(|reference| { + let location = args.interner.reference_location(reference); + let lsp_location = to_lsp_location(args.files, location.file, location.span); + format_reference(reference, args).map(|formatted| Hover { + range: lsp_location.map(|location| location.range), + contents: HoverContents::Markup(MarkupContent { + kind: MarkupKind::Markdown, + value: formatted, + }), + }) + }) +} + +fn format_reference(reference: ReferenceId, args: &ProcessRequestCallbackArgs) -> Option { + match reference { + ReferenceId::Module(id) => format_module(id, args), + ReferenceId::Type(id) => Some(format_type(id, args)), + ReferenceId::StructMember(id, field_index) => { + Some(format_struct_member(id, field_index, args)) + } + ReferenceId::EnumVariant(id, variant_index) => { + Some(format_enum_variant(id, variant_index, args)) + } + ReferenceId::Trait(id) => Some(format_trait(id, args)), + ReferenceId::Global(id) => Some(format_global(id, args)), + ReferenceId::Function(id) => Some(format_function(id, args)), + ReferenceId::Alias(id) => Some(format_alias(id, args)), + ReferenceId::Local(id) => Some(format_local(id, args)), + ReferenceId::Reference(location, _) => { + format_reference(args.interner.find_referenced(location).unwrap(), args) + } + } +} +fn format_module(id: ModuleId, args: &ProcessRequestCallbackArgs) -> Option { + let crate_root = args.def_maps[&id.krate].root(); + + let mut string = String::new(); + + if id.local_id == crate_root { + let dep = args.dependencies.iter().find(|dep| dep.crate_id == id.krate)?; + string.push_str(" crate "); + string.push_str(&dep.name.to_string()); + } else { + // Note: it's not clear why `try_module_attributes` might return None here, but it happens. + // This is a workaround to avoid panicking in that case (which brings the LSP server down). + // Cases where this happens are related to generated code, so once that stops happening + // this won't be an issue anymore. + let module_attributes = args.interner.try_module_attributes(&id)?; + + if let Some(parent_local_id) = module_attributes.parent { + if format_parent_module_from_module_id( + &ModuleId { krate: id.krate, local_id: parent_local_id }, + args, + &mut string, + ) { + string.push('\n'); + } + } + string.push_str(" "); + string.push_str("mod "); + string.push_str(&module_attributes.name); + } + + append_doc_comments(args.interner, ReferenceId::Module(id), &mut string); + + Some(string) +} + +fn format_type(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { + let typ = args.interner.get_type(id); + let typ = typ.borrow(); + if let Some(fields) = typ.get_fields_as_written() { + format_struct(&typ, fields, args) + } else if let Some(variants) = typ.get_variants_as_written() { + format_enum(&typ, variants, args) + } else { + unreachable!("Type should either be a struct or an enum") + } +} + +fn format_struct( + typ: &DataType, + fields: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { + string.push('\n'); + } + string.push_str(" "); + string.push_str("struct "); + string.push_str(&typ.name.0.contents); + format_generics(&typ.generics, &mut string); + string.push_str(" {\n"); + for field in fields { + string.push_str(" "); + string.push_str(&field.name.0.contents); + string.push_str(": "); + string.push_str(&format!("{}", field.typ)); + string.push_str(",\n"); + } + string.push_str(" }"); + + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); + + string +} + +fn format_enum( + typ: &DataType, + variants: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { + string.push('\n'); + } + string.push_str(" "); + string.push_str("enum "); + string.push_str(&typ.name.0.contents); + format_generics(&typ.generics, &mut string); + string.push_str(" {\n"); + for field in variants { + string.push_str(" "); + string.push_str(&field.name.0.contents); + + if !field.params.is_empty() { + let types = field.params.iter().map(ToString::to_string).collect::>(); + string.push('('); + string.push_str(&types.join(", ")); + string.push(')'); + } + + string.push_str(",\n"); + } + string.push_str(" }"); + + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); + + string +} + +fn format_struct_member( + id: TypeId, + field_index: usize, + args: &ProcessRequestCallbackArgs, +) -> String { + let struct_type = args.interner.get_type(id); + let struct_type = struct_type.borrow(); + let field = struct_type.field_at(field_index); + + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(id), args, &mut string) { + string.push_str("::"); + } + string.push_str(&struct_type.name.0.contents); + string.push('\n'); + string.push_str(" "); + string.push_str(&field.name.0.contents); + string.push_str(": "); + string.push_str(&format!("{}", field.typ)); + string.push_str(&go_to_type_links(&field.typ, args.interner, args.files)); + + append_doc_comments(args.interner, ReferenceId::StructMember(id, field_index), &mut string); + + string +} + +fn format_enum_variant( + id: TypeId, + field_index: usize, + args: &ProcessRequestCallbackArgs, +) -> String { + let enum_type = args.interner.get_type(id); + let enum_type = enum_type.borrow(); + let variant = enum_type.variant_at(field_index); + + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(id), args, &mut string) { + string.push_str("::"); + } + string.push_str(&enum_type.name.0.contents); + string.push('\n'); + string.push_str(" "); + string.push_str(&variant.name.0.contents); + if !variant.params.is_empty() { + let types = variant.params.iter().map(ToString::to_string).collect::>(); + string.push('('); + string.push_str(&types.join(", ")); + string.push(')'); + } + + for typ in variant.params.iter() { + string.push_str(&go_to_type_links(typ, args.interner, args.files)); + } + + append_doc_comments(args.interner, ReferenceId::EnumVariant(id, field_index), &mut string); + + string +} + +fn format_trait(id: TraitId, args: &ProcessRequestCallbackArgs) -> String { + let a_trait = args.interner.get_trait(id); + + let mut string = String::new(); + if format_parent_module(ReferenceId::Trait(id), args, &mut string) { + string.push('\n'); + } + string.push_str(" "); + string.push_str("trait "); + string.push_str(&a_trait.name.0.contents); + format_generics(&a_trait.generics, &mut string); + + append_doc_comments(args.interner, ReferenceId::Trait(id), &mut string); + + string +} + +fn format_global(id: GlobalId, args: &ProcessRequestCallbackArgs) -> String { + let global_info = args.interner.get_global(id); + let definition_id = global_info.definition_id; + let definition = args.interner.definition(definition_id); + let typ = args.interner.definition_type(definition_id); + + let mut string = String::new(); + if format_parent_module(ReferenceId::Global(id), args, &mut string) { + string.push('\n'); + } + + let mut print_comptime = definition.comptime; + let mut opt_value = None; + + // See if we can figure out what's the global's value + if let Some(stmt) = args.interner.get_global_let_statement(id) { + print_comptime = stmt.comptime; + opt_value = get_global_value(args.interner, stmt.expression); + } + + string.push_str(" "); + if print_comptime { + string.push_str("comptime "); + } + if definition.mutable { + string.push_str("mut "); + } + string.push_str("global "); + string.push_str(&global_info.ident.0.contents); + string.push_str(": "); + string.push_str(&format!("{}", typ)); + + if let Some(value) = opt_value { + string.push_str(" = "); + string.push_str(&value); + } + + string.push_str(&go_to_type_links(&typ, args.interner, args.files)); + + append_doc_comments(args.interner, ReferenceId::Global(id), &mut string); + + string +} + +fn get_global_value(interner: &NodeInterner, expr: ExprId) -> Option { + match interner.expression(&expr) { + HirExpression::Literal(literal) => match literal { + HirLiteral::Array(hir_array_literal) => { + get_global_array_value(interner, hir_array_literal, false) + } + HirLiteral::Slice(hir_array_literal) => { + get_global_array_value(interner, hir_array_literal, true) + } + HirLiteral::Bool(value) => Some(value.to_string()), + HirLiteral::Integer(field_element, _) => Some(field_element.to_string()), + HirLiteral::Str(string) => Some(format!("{:?}", string)), + HirLiteral::FmtStr(..) => None, + HirLiteral::Unit => Some("()".to_string()), + }, + HirExpression::Tuple(values) => { + get_exprs_global_value(interner, &values).map(|value| format!("({})", value)) + } + _ => None, + } +} + +fn get_global_array_value( + interner: &NodeInterner, + literal: HirArrayLiteral, + is_slice: bool, +) -> Option { + match literal { + HirArrayLiteral::Standard(values) => { + get_exprs_global_value(interner, &values).map(|value| { + if is_slice { + format!("&[{}]", value) + } else { + format!("[{}]", value) + } + }) + } + HirArrayLiteral::Repeated { repeated_element, length } => { + get_global_value(interner, repeated_element).map(|value| { + if is_slice { + format!("&[{}; {}]", value, length) + } else { + format!("[{}; {}]", value, length) + } + }) + } + } +} + +fn get_exprs_global_value(interner: &NodeInterner, exprs: &[ExprId]) -> Option { + let strings: Vec = + exprs.iter().filter_map(|value| get_global_value(interner, *value)).collect(); + if strings.len() == exprs.len() { + Some(strings.join(", ")) + } else { + None + } +} + +fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { + let func_meta = args.interner.function_meta(&id); + + // If this points to a trait method, see if we can figure out what's the concrete trait impl method + if let Some(func_id) = get_trait_impl_func_id(id, args, func_meta) { + return format_function(func_id, args); + } + + let func_modifiers = args.interner.function_modifiers(&id); + + let func_name_definition_id = args.interner.definition(func_meta.name.id); + + let enum_variant = match (func_meta.type_id, func_meta.enum_variant_index) { + (Some(type_id), Some(index)) => Some((type_id, index)), + _ => None, + }; + + let reference_id = if let Some((type_id, variant_index)) = enum_variant { + ReferenceId::EnumVariant(type_id, variant_index) + } else { + ReferenceId::Function(id) + }; + + let mut string = String::new(); + let formatted_parent_module = format_parent_module(reference_id, args, &mut string); + + let formatted_parent_type = if let Some(trait_impl_id) = func_meta.trait_impl { + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + let trait_ = args.interner.get_trait(trait_impl.trait_id); + + let generics: Vec<_> = + trait_impl + .trait_generics + .iter() + .filter_map(|generic| { + if let Type::NamedGeneric(_, name) = generic { + Some(name) + } else { + None + } + }) + .collect(); + + string.push('\n'); + string.push_str(" impl"); + if !generics.is_empty() { + string.push('<'); + for (index, generic) in generics.into_iter().enumerate() { + if index > 0 { + string.push_str(", "); + } + string.push_str(generic); + } + string.push('>'); + } + + string.push(' '); + string.push_str(&trait_.name.0.contents); + if !trait_impl.trait_generics.is_empty() { + string.push('<'); + for (index, generic) in trait_impl.trait_generics.iter().enumerate() { + if index > 0 { + string.push_str(", "); + } + string.push_str(&generic.to_string()); + } + string.push('>'); + } + + string.push_str(" for "); + string.push_str(&trait_impl.typ.to_string()); + + true + } else if let Some(trait_id) = func_meta.trait_id { + let trait_ = args.interner.get_trait(trait_id); + string.push('\n'); + string.push_str(" trait "); + string.push_str(&trait_.name.0.contents); + format_generics(&trait_.generics, &mut string); + + true + } else if let Some(type_id) = func_meta.type_id { + let data_type = args.interner.get_type(type_id); + let data_type = data_type.borrow(); + if formatted_parent_module { + string.push_str("::"); + } + string.push_str(&data_type.name.0.contents); + if enum_variant.is_none() { + string.push('\n'); + string.push_str(" "); + string.push_str("impl"); + + let impl_generics: Vec<_> = func_meta + .all_generics + .iter() + .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) + .cloned() + .collect(); + format_generics(&impl_generics, &mut string); + + string.push(' '); + string.push_str(&data_type.name.0.contents); + format_generic_names(&impl_generics, &mut string); + } + + true + } else { + false + }; + if formatted_parent_module || formatted_parent_type { + string.push('\n'); + } + string.push_str(" "); + + if func_modifiers.visibility != ItemVisibility::Private + && func_meta.trait_id.is_none() + && func_meta.trait_impl.is_none() + { + string.push_str(&func_modifiers.visibility.to_string()); + string.push(' '); + } + if func_modifiers.is_unconstrained { + string.push_str("unconstrained "); + } + if func_modifiers.is_comptime { + string.push_str("comptime "); + } + + let func_name = &func_name_definition_id.name; + + if enum_variant.is_none() { + string.push_str("fn "); + } + string.push_str(func_name); + format_generics(&func_meta.direct_generics, &mut string); + string.push('('); + let parameters = &func_meta.parameters; + for (index, (pattern, typ, visibility)) in parameters.iter().enumerate() { + let is_self = pattern_is_self(pattern, args.interner); + + // `&mut self` is represented as a mutable reference type, not as a mutable pattern + if is_self && matches!(typ, Type::MutableReference(..)) { + string.push_str("&mut "); + } + + if enum_variant.is_some() { + string.push_str(&format!("{}", typ)); + } else { + format_pattern(pattern, args.interner, &mut string); + + // Don't add type for `self` param + if !is_self { + string.push_str(": "); + if matches!(visibility, Visibility::Public) { + string.push_str("pub "); + } + string.push_str(&format!("{}", typ)); + } + } + + if index != parameters.len() - 1 { + string.push_str(", "); + } + } + + string.push(')'); + + if enum_variant.is_none() { + let return_type = func_meta.return_type(); + match return_type { + Type::Unit => (), + _ => { + string.push_str(" -> "); + string.push_str(&format!("{}", return_type)); + } + } + + string.push_str(&go_to_type_links(return_type, args.interner, args.files)); + } + + if enum_variant.is_some() { + append_doc_comments(args.interner, reference_id, &mut string); + } else { + let had_doc_comments = append_doc_comments(args.interner, reference_id, &mut string); + if !had_doc_comments { + // If this function doesn't have doc comments, but it's a trait impl method, + // use the trait method doc comments. + if let Some(trait_impl_id) = func_meta.trait_impl { + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + let trait_ = args.interner.get_trait(trait_impl.trait_id); + if let Some(func_id) = trait_.method_ids.get(func_name) { + let reference_id = ReferenceId::Function(*func_id); + append_doc_comments(args.interner, reference_id, &mut string); + } + } + } + } + + string +} + +fn get_trait_impl_func_id( + id: FuncId, + args: &ProcessRequestCallbackArgs, + func_meta: &FuncMeta, +) -> Option { + func_meta.trait_id?; + + let index = args.interner.find_location_index(args.location)?; + let expr_id = args.interner.get_expr_id_from_index(index)?; + let Some(TraitImplKind::Normal(trait_impl_id)) = + args.interner.get_selected_impl_for_expression(expr_id) + else { + return None; + }; + + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + + let function_name = args.interner.function_name(&id); + let mut trait_impl_methods = trait_impl.methods.iter(); + let func_id = + trait_impl_methods.find(|func_id| args.interner.function_name(func_id) == function_name)?; + Some(*func_id) +} + +fn format_alias(id: TypeAliasId, args: &ProcessRequestCallbackArgs) -> String { + let type_alias = args.interner.get_type_alias(id); + let type_alias = type_alias.borrow(); + + let mut string = String::new(); + format_parent_module(ReferenceId::Alias(id), args, &mut string); + string.push('\n'); + string.push_str(" "); + string.push_str("type "); + string.push_str(&type_alias.name.0.contents); + string.push_str(" = "); + string.push_str(&format!("{}", &type_alias.typ)); + + append_doc_comments(args.interner, ReferenceId::Alias(id), &mut string); + + string +} + +fn format_local(id: DefinitionId, args: &ProcessRequestCallbackArgs) -> String { + let definition_info = args.interner.definition(id); + if let DefinitionKind::Global(global_id) = &definition_info.kind { + return format_global(*global_id, args); + } + + let DefinitionKind::Local(expr_id) = definition_info.kind else { + panic!("Expected a local reference to reference a local definition") + }; + let typ = args.interner.definition_type(id); + + let mut string = String::new(); + string.push_str(" "); + if definition_info.comptime { + string.push_str("comptime "); + } + if expr_id.is_some() { + string.push_str("let "); + } + if definition_info.mutable { + if expr_id.is_none() { + string.push_str("let "); + } + string.push_str("mut "); + } + string.push_str(&definition_info.name); + if !matches!(typ, Type::Error) { + string.push_str(": "); + string.push_str(&format!("{}", typ)); + } + + string.push_str(&go_to_type_links(&typ, args.interner, args.files)); + + string +} + +fn format_generics(generics: &Generics, string: &mut String) { + format_generics_impl( + generics, false, // only show names + string, + ); +} + +fn format_generic_names(generics: &Generics, string: &mut String) { + format_generics_impl( + generics, true, // only show names + string, + ); +} + +fn format_generics_impl(generics: &Generics, only_show_names: bool, string: &mut String) { + if generics.is_empty() { + return; + } + + string.push('<'); + for (index, generic) in generics.iter().enumerate() { + if index > 0 { + string.push_str(", "); + } + + if only_show_names { + string.push_str(&generic.name); + } else { + match generic.kind() { + noirc_frontend::Kind::Any | noirc_frontend::Kind::Normal => { + string.push_str(&generic.name); + } + noirc_frontend::Kind::IntegerOrField | noirc_frontend::Kind::Integer => { + string.push_str("let "); + string.push_str(&generic.name); + string.push_str(": u32"); + } + noirc_frontend::Kind::Numeric(typ) => { + string.push_str("let "); + string.push_str(&generic.name); + string.push_str(": "); + string.push_str(&typ.to_string()); + } + } + } + } + string.push('>'); +} + +fn format_pattern(pattern: &HirPattern, interner: &NodeInterner, string: &mut String) { + match pattern { + HirPattern::Identifier(ident) => { + let definition = interner.definition(ident.id); + string.push_str(&definition.name); + } + HirPattern::Mutable(pattern, _) => { + string.push_str("mut "); + format_pattern(pattern, interner, string); + } + HirPattern::Tuple(..) | HirPattern::Struct(..) => { + string.push('_'); + } + } +} + +fn pattern_is_self(pattern: &HirPattern, interner: &NodeInterner) -> bool { + match pattern { + HirPattern::Identifier(ident) => { + let definition = interner.definition(ident.id); + definition.name == "self" + } + HirPattern::Mutable(pattern, _) => pattern_is_self(pattern, interner), + HirPattern::Tuple(..) | HirPattern::Struct(..) => false, + } +} + +fn format_parent_module( + referenced: ReferenceId, + args: &ProcessRequestCallbackArgs, + string: &mut String, +) -> bool { + let Some(module) = args.interner.reference_module(referenced) else { + return false; + }; + + format_parent_module_from_module_id(module, args, string) +} + +fn format_parent_module_from_module_id( + module: &ModuleId, + args: &ProcessRequestCallbackArgs, + string: &mut String, +) -> bool { + let full_path = + module_full_path(module, args.interner, args.crate_id, &args.crate_name, args.dependencies); + if full_path.is_empty() { + return false; + } + + string.push_str(" "); + string.push_str(&full_path); + true +} + +fn go_to_type_links(typ: &Type, interner: &NodeInterner, files: &FileMap) -> String { + let mut gatherer = TypeLinksGatherer { interner, files, links: Vec::new() }; + gatherer.gather_type_links(typ); + + let links = gatherer.links; + if links.is_empty() { + "".to_string() + } else { + let mut string = String::new(); + string.push_str("\n\n"); + string.push_str("Go to "); + for (index, link) in links.iter().enumerate() { + if index > 0 { + string.push_str(" | "); + } + string.push_str(link); + } + string + } +} + +struct TypeLinksGatherer<'a> { + interner: &'a NodeInterner, + files: &'a FileMap, + links: Vec, +} + +impl<'a> TypeLinksGatherer<'a> { + fn gather_type_links(&mut self, typ: &Type) { + match typ { + Type::Array(typ, _) => self.gather_type_links(typ), + Type::Slice(typ) => self.gather_type_links(typ), + Type::Tuple(types) => { + for typ in types { + self.gather_type_links(typ); + } + } + Type::DataType(data_type, generics) => { + self.gather_struct_type_links(data_type); + for generic in generics { + self.gather_type_links(generic); + } + } + Type::Alias(type_alias, generics) => { + self.gather_type_alias_links(type_alias); + for generic in generics { + self.gather_type_links(generic); + } + } + Type::TypeVariable(var) => { + self.gather_type_variable_links(var); + } + Type::TraitAsType(trait_id, _, generics) => { + let some_trait = self.interner.get_trait(*trait_id); + self.gather_trait_links(some_trait); + for generic in &generics.ordered { + self.gather_type_links(generic); + } + for named_type in &generics.named { + self.gather_type_links(&named_type.typ); + } + } + Type::NamedGeneric(var, _) => { + self.gather_type_variable_links(var); + } + Type::Function(args, return_type, env, _) => { + for arg in args { + self.gather_type_links(arg); + } + self.gather_type_links(return_type); + self.gather_type_links(env); + } + Type::MutableReference(typ) => self.gather_type_links(typ), + Type::InfixExpr(lhs, _, rhs, _) => { + self.gather_type_links(lhs); + self.gather_type_links(rhs); + } + Type::CheckedCast { to, .. } => self.gather_type_links(to), + Type::FieldElement + | Type::Integer(..) + | Type::Bool + | Type::String(_) + | Type::FmtString(_, _) + | Type::Unit + | Type::Forall(_, _) + | Type::Constant(..) + | Type::Quoted(_) + | Type::Error => (), + } + } + + fn gather_struct_type_links(&mut self, struct_type: &Shared) { + let struct_type = struct_type.borrow(); + if let Some(lsp_location) = + to_lsp_location(self.files, struct_type.location.file, struct_type.name.span()) + { + self.push_link(format_link(struct_type.name.to_string(), lsp_location)); + } + } + + fn gather_type_alias_links(&mut self, type_alias: &Shared) { + let type_alias = type_alias.borrow(); + if let Some(lsp_location) = + to_lsp_location(self.files, type_alias.location.file, type_alias.name.span()) + { + self.push_link(format_link(type_alias.name.to_string(), lsp_location)); + } + } + + fn gather_trait_links(&mut self, some_trait: &Trait) { + if let Some(lsp_location) = + to_lsp_location(self.files, some_trait.location.file, some_trait.name.span()) + { + self.push_link(format_link(some_trait.name.to_string(), lsp_location)); + } + } + + fn gather_type_variable_links(&mut self, var: &TypeVariable) { + let var = &*var.borrow(); + match var { + TypeBinding::Bound(typ) => { + self.gather_type_links(typ); + } + TypeBinding::Unbound(..) => (), + } + } + + fn push_link(&mut self, link: String) { + if !self.links.contains(&link) { + self.links.push(link); + } + } +} + +fn format_link(name: String, location: lsp_types::Location) -> String { + format!( + "[{}]({}#L{},{}-{},{})", + name, + location.uri, + location.range.start.line + 1, + location.range.start.character + 1, + location.range.end.line + 1, + location.range.end.character + 1 + ) +} + +fn append_doc_comments(interner: &NodeInterner, id: ReferenceId, string: &mut String) -> bool { + if let Some(doc_comments) = interner.doc_comments(id) { + string.push_str("\n\n---\n\n"); + for comment in doc_comments { + string.push_str(comment); + string.push('\n'); + } + true + } else { + false + } +} diff --git a/noir/noir-repo/tooling/lsp/src/requests/hover/from_visitor.rs b/noir/noir-repo/tooling/lsp/src/requests/hover/from_visitor.rs new file mode 100644 index 000000000000..2099d98a93f6 --- /dev/null +++ b/noir/noir-repo/tooling/lsp/src/requests/hover/from_visitor.rs @@ -0,0 +1,118 @@ +use std::str::FromStr; + +use acvm::FieldElement; +use fm::{FileId, FileMap}; +use lsp_types::{Hover, HoverContents, MarkupContent, MarkupKind, Position}; +use noirc_errors::{Location, Span}; +use noirc_frontend::{ast::Visitor, node_interner::NodeInterner, parse_program, Type}; +use num_bigint::BigInt; + +use crate::{ + requests::{to_lsp_location, ProcessRequestCallbackArgs}, + utils, +}; + +pub(super) fn hover_from_visitor( + file_id: Option, + position: Position, + args: &ProcessRequestCallbackArgs, +) -> Option { + let file_id = file_id?; + let file = args.files.get_file(file_id)?; + let source = file.source(); + let (parsed_module, _errors) = parse_program(source); + let byte_index = utils::position_to_byte_index(args.files, file_id, &position)?; + + let mut finder = HoverFinder::new(args.files, file_id, args.interner, byte_index); + parsed_module.accept(&mut finder); + finder.hover +} + +struct HoverFinder<'a> { + files: &'a FileMap, + file: FileId, + interner: &'a NodeInterner, + byte_index: usize, + hover: Option, +} +impl<'a> HoverFinder<'a> { + fn new( + files: &'a FileMap, + file: FileId, + interner: &'a NodeInterner, + byte_index: usize, + ) -> Self { + Self { files, file, interner, byte_index, hover: None } + } + + fn intersects_span(&self, span: Span) -> bool { + span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize + } +} + +impl<'a> Visitor for HoverFinder<'a> { + fn visit_literal_integer(&mut self, value: FieldElement, negative: bool, span: Span) { + if !self.intersects_span(span) { + return; + } + + let location = Location::new(span, self.file); + let lsp_location = to_lsp_location(self.files, location.file, location.span); + let range = lsp_location.map(|location| location.range); + let Some(typ) = self.interner.type_at_location(location) else { + return; + }; + + let value = format_integer(typ, value, negative); + let contents = HoverContents::Markup(MarkupContent { kind: MarkupKind::Markdown, value }); + self.hover = Some(Hover { contents, range }); + } +} + +fn format_integer(typ: Type, value: FieldElement, negative: bool) -> String { + let value_base_10 = value.to_string(); + + // For simplicity we parse the value as a BigInt to convert it to hex + // because `FieldElement::to_hex` will include many leading zeros. + let value_big_int = BigInt::from_str(&value_base_10).unwrap(); + let negative = if negative { "-" } else { "" }; + + format!(" {typ}\n---\nvalue of literal: `{negative}{value_base_10} ({negative}0x{value_big_int:02x})`") +} + +#[cfg(test)] +mod tests { + use noirc_frontend::{ + ast::{IntegerBitSize, Signedness}, + Type, + }; + + use super::format_integer; + + #[test] + fn format_integer_zero() { + let typ = Type::FieldElement; + let value = 0_u128.into(); + let negative = false; + let expected = " Field\n---\nvalue of literal: `0 (0x00)`"; + assert_eq!(format_integer(typ, value, negative), expected); + } + + #[test] + fn format_positive_integer() { + let typ = Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo); + let value = 123456_u128.into(); + let negative = false; + let expected = " u32\n---\nvalue of literal: `123456 (0x1e240)`"; + assert_eq!(format_integer(typ, value, negative), expected); + } + + #[test] + fn format_negative_integer() { + let typ = Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour); + let value = 987654_u128.into(); + let negative = true; + let expected = " i64\n---\nvalue of literal: `-987654 (-0xf1206)`"; + assert_eq!(format_integer(typ, value, negative), expected); + } +} diff --git a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs index b9673755da6a..e2f793f06da8 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs @@ -40,14 +40,14 @@ pub(crate) fn on_inlay_hint_request( args.files.get_file_id(&path).map(|file_id| { let file = args.files.get_file(file_id).unwrap(); let source = file.source(); - let (parsed_moduled, _errors) = noirc_frontend::parse_program(source); + let (parsed_module, _errors) = noirc_frontend::parse_program(source); let span = utils::range_to_byte_span(args.files, file_id, ¶ms.range) .map(|range| Span::from(range.start as u32..range.end as u32)); let mut collector = InlayHintCollector::new(args.files, file_id, args.interner, span, options); - parsed_moduled.accept(&mut collector); + parsed_module.accept(&mut collector); collector.inlay_hints }) }); @@ -327,7 +327,7 @@ impl<'a> Visitor for InlayHintCollector<'a> { fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl, span: Span) -> bool { self.show_closing_brace_hint(span, || { - format!(" impl {} for {}", noir_trait_impl.trait_name, noir_trait_impl.object_type) + format!(" impl {} for {}", noir_trait_impl.r#trait, noir_trait_impl.object_type) }); true diff --git a/noir/noir-repo/tooling/lsp/src/requests/mod.rs b/noir/noir-repo/tooling/lsp/src/requests/mod.rs index 1789c3513f63..9bfe47bdaa55 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/mod.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/mod.rs @@ -657,9 +657,9 @@ pub(crate) fn find_all_references( } /// Represents a trait reexported from a given module with a name. -pub(crate) struct TraitReexport<'a> { - pub(super) module_id: &'a ModuleId, - pub(super) name: &'a Ident, +pub(crate) struct TraitReexport { + pub(super) module_id: ModuleId, + pub(super) name: Ident, } #[cfg(test)] diff --git a/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs b/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs index 4e505eb5e129..a7444f819935 100644 --- a/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs +++ b/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs @@ -1,11 +1,10 @@ use std::collections::BTreeMap; use noirc_frontend::{ - ast::NoirTraitImpl, + ast::{NoirTraitImpl, UnresolvedTypeData}, graph::CrateId, - hir::def_map::ModuleDefId, hir::{ - def_map::{CrateDefMap, ModuleId}, + def_map::{CrateDefMap, ModuleDefId, ModuleId}, type_check::generics::TraitGenerics, }, hir_def::{function::FuncMeta, stmt::HirPattern, traits::Trait}, @@ -204,7 +203,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { let relative_path = relative_module_id_path( parent_module_id, - &self.module_id, + self.module_id, current_module_parent_id, self.interner, ); @@ -241,7 +240,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { let relative_path = relative_module_id_path( *parent_module_id, - &self.module_id, + self.module_id, current_module_parent_id, self.interner, ); @@ -278,7 +277,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { let relative_path = relative_module_id_path( *parent_module_id, - &self.module_id, + self.module_id, current_module_parent_id, self.interner, ); @@ -300,7 +299,13 @@ impl<'a> TraitImplMethodStubGenerator<'a> { if let Some(index) = generics.iter().position(|generic| generic.type_var.id() == typevar.id()) { - if let Some(typ) = self.noir_trait_impl.trait_generics.ordered_args.get(index) { + let UnresolvedTypeData::Named(_, trait_generics, _) = + &self.noir_trait_impl.r#trait.typ + else { + return; + }; + + if let Some(typ) = trait_generics.ordered_args.get(index) { self.string.push_str(&typ.to_string()); return; } diff --git a/noir/noir-repo/tooling/lsp/src/visibility.rs b/noir/noir-repo/tooling/lsp/src/visibility.rs index 6724a0ba5052..7366684a859f 100644 --- a/noir/noir-repo/tooling/lsp/src/visibility.rs +++ b/noir/noir-repo/tooling/lsp/src/visibility.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use noirc_frontend::{ ast::ItemVisibility, - graph::CrateId, + graph::{CrateId, Dependency}, hir::{ def_map::{CrateDefMap, ModuleDefId, ModuleId}, resolution::visibility::item_in_module_is_visible, @@ -23,6 +23,7 @@ pub(super) fn module_def_id_is_visible( mut defining_module: Option, interner: &NodeInterner, def_maps: &BTreeMap, + dependencies: &[Dependency], ) -> bool { // First find out which module we need to check. // If a module is trying to be referenced, it's that module. Otherwise it's the module that contains the item. @@ -38,6 +39,14 @@ pub(super) fn module_def_id_is_visible( return false; } + // If the target module isn't in the same crate as `module_id` or isn't in one of its + // dependencies, then it's not visible. + if module_id.krate != current_module_id.krate + && dependencies.iter().all(|dep| dep.crate_id != module_id.krate) + { + return false; + } + target_module_id = std::mem::take(&mut defining_module).or_else(|| { let module_data = &def_maps[&module_id.krate].modules()[module_id.local_id.0]; let parent_local_id = module_data.parent; diff --git a/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr b/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr index 0baeb83d5c14..d067c523130b 100644 --- a/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr +++ b/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr @@ -7,7 +7,7 @@ pub fn function_two() { use one::subone; fn use_struct() { - let _ = subone::SubOneStruct { some_field: 0, some_other_field: 2 }; + let _ = subone::SubOneStruct { some_field: 0, some_other_field: 123 }; } use one::subone::SomeTrait; @@ -109,3 +109,7 @@ enum Color { fn test_enum() -> Color { Color::Red(1) } + +fn negative_integer() -> i32 { + -8 +} diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs index 7ce5801cf360..f31cb80b9fdc 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs @@ -1,7 +1,7 @@ use noirc_frontend::{ ast::{ AssignStatement, Expression, ExpressionKind, ForLoopStatement, ForRange, LetStatement, - Pattern, Statement, StatementKind, UnresolvedType, UnresolvedTypeData, + Pattern, Statement, StatementKind, UnresolvedType, UnresolvedTypeData, WhileStatement, }, token::{Keyword, SecondaryAttribute, Token, TokenKind}, }; @@ -74,6 +74,9 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { StatementKind::Loop(block, _) => { group.group(self.format_loop(block)); } + StatementKind::While(while_) => { + group.group(self.format_while(while_)); + } StatementKind::Break => { group.text(self.chunk(|formatter| { formatter.write_keyword(Keyword::Break); @@ -266,6 +269,36 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { group } + fn format_while(&mut self, while_: WhileStatement) -> ChunkGroup { + let mut group = ChunkGroup::new(); + + group.text(self.chunk(|formatter| { + formatter.write_keyword(Keyword::While); + })); + + group.space(self); + self.format_expression(while_.condition, &mut group); + group.space(self); + + let ExpressionKind::Block(block) = while_.body.kind else { + panic!("Expected a block expression for loop body"); + }; + + group.group(self.format_block_expression( + block, true, // force multiple lines + )); + + // If there's a trailing semicolon, remove it + group.text(self.chunk(|formatter| { + formatter.skip_whitespace_if_it_is_not_a_newline(); + if formatter.is_at(Token::Semicolon) { + formatter.bump(); + } + })); + + group + } + fn format_comptime_statement(&mut self, statement: Statement) -> ChunkGroup { let mut group = ChunkGroup::new(); @@ -375,7 +408,7 @@ mod tests { #[test] fn format_let_statement_with_unsafe_comment() { - let src = " fn foo() { + let src = " fn foo() { // Safety: some comment let x = unsafe { 1 } ; } "; let expected = "fn foo() { @@ -388,7 +421,7 @@ mod tests { #[test] fn format_let_statement_with_unsafe_doc_comment() { - let src = " fn foo() { + let src = " fn foo() { /// Safety: some comment let x = unsafe { 1 } ; } "; let expected = "fn foo() { @@ -514,7 +547,7 @@ mod tests { #[test] fn format_unsafe_statement() { - let src = " fn foo() { unsafe { + let src = " fn foo() { unsafe { 1 } } "; let expected = "fn foo() { unsafe { @@ -756,6 +789,42 @@ mod tests { 2 } } +"; + assert_format(src, expected); + } + + #[test] + fn format_empty_while() { + let src = " fn foo() { while condition { } } "; + let expected = "fn foo() { + while condition {} +} +"; + assert_format(src, expected); + } + + #[test] + fn format_non_empty_while() { + let src = " fn foo() { while condition { 1 ; 2 } } "; + let expected = "fn foo() { + while condition { + 1; + 2 + } +} +"; + assert_format(src, expected); + } + + #[test] + fn format_while_with_semicolon() { + let src = " fn foo() { while condition { 1 ; 2 }; } "; + let expected = "fn foo() { + while condition { + 1; + 2 + } +} "; assert_format(src, expected); } diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs index 5bb9a0d00254..896620c3bf80 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs @@ -18,8 +18,7 @@ impl<'a> Formatter<'a> { self.write_keyword(Keyword::Impl); self.format_generics(trait_impl.impl_generics); self.write_space(); - self.format_path(trait_impl.trait_name); - self.format_generic_type_args(trait_impl.trait_generics); + self.format_type(trait_impl.r#trait); self.write_space(); self.write_keyword(Keyword::For); self.write_space();