diff --git a/tests/Dialect/Field/field_runner.mlir b/tests/Dialect/Field/field_runner.mlir index 76bcd23ba..df701b3a3 100644 --- a/tests/Dialect/Field/field_runner.mlir +++ b/tests/Dialect/Field/field_runner.mlir @@ -6,6 +6,7 @@ #mont = #mod_arith.montgomery<7:i32> !PF = !field.pf<7:i32> !PFm = !field.pf<7:i32, true> +!PF_exp = !field.pf<7:i64> #beta = #field.pf.elem<6:i32> : !PF #beta_mont = #field.pf.elem<3:i32> : !PFm @@ -17,6 +18,7 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interf func.func @test_power() { %exp = arith.constant 51 : i64 + %exp_pf = field.encapsulate %exp : i64 -> !PF_exp %base = arith.constant 3 : i32 %base_pf = field.encapsulate %base: i32 -> !PF @@ -53,6 +55,21 @@ func.func @test_power() { %U4 = memref.cast %16 : memref<2xi32> to memref<*xi32> func.call @printMemrefI32(%U4) : (memref<*xi32>) -> () + %res1_pf = field.powpf %base_pf, %exp_pf : !PF, !PF_exp + %17 = field.extract %res1_pf : !PF -> i32 + %18 = tensor.from_elements %17 : tensor<1xi32> + %19 = bufferization.to_buffer %18 : tensor<1xi32> to memref<1xi32> + %U5 = memref.cast %19 : memref<1xi32> to memref<*xi32> + func.call @printMemrefI32(%U5) : (memref<*xi32>) -> () + + %res2_pf_mont = field.powpf %base_f2_mont, %exp_pf : !QFm, !PF_exp + %res2_pf_standard = field.from_mont %res2_pf_mont : !QF + %20, %21 = field.extract %res2_pf_standard : !QF -> i32, i32 + %22 = tensor.from_elements %20, %21 : tensor<2xi32> + %23 = bufferization.to_buffer %22 : tensor<2xi32> to memref<2xi32> + %U6 = memref.cast %23 : memref<2xi32> to memref<*xi32> + func.call @printMemrefI32(%U6) : (memref<*xi32>) -> () + return } @@ -60,3 +77,5 @@ func.func @test_power() { // CHECK_TEST_POWER: [6] // CHECK_TEST_POWER: [2, 5] // CHECK_TEST_POWER: [2, 5] +// CHECK_TEST_POWER: [6] +// CHECK_TEST_POWER: [2, 5] diff --git a/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp b/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp index 16883d8df..742d5a2af 100644 --- a/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp +++ b/zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp @@ -525,6 +525,124 @@ struct ConvertSquare : public OpConversionPattern { } }; +namespace { + +template +LogicalResult computePower(ImplicitLocOpBuilder &b, OpType &op, Value base, + Value exp, Value &result) { + APInt modulus; + Value init; + auto fieldType = getElementTypeOrSelf(base); + + if (auto pfType = dyn_cast(fieldType)) { + modulus = pfType.getModulus().getValue(); + init = + pfType.isMontgomery() + ? b.create( + pfType, + b.create( + cast(field::getStandardFormType(pfType)), + 1)) + .getResult() + : b.create(pfType, 1); + } else if (auto f2Type = dyn_cast(fieldType)) { + modulus = f2Type.getBaseField().getModulus().getValue(); + init = f2Type.isMontgomery() + ? b.create( + f2Type, b.create( + cast( + field::getStandardFormType(f2Type)), + 1, 0)) + .getResult() + : b.create(f2Type, 1, 0); + } else { + op.emitOpError("unsupported output type"); + return failure(); + } + + unsigned expBitWidth = cast(exp.getType()).getWidth(); + unsigned modBitWidth = modulus.getBitWidth(); + if (modBitWidth > expBitWidth) { + exp = b.create( + IntegerType::get(exp.getContext(), modBitWidth), exp); + } else { + modulus = modulus.zext(expBitWidth); + modBitWidth = expBitWidth; + } + IntegerType intType = cast(exp.getType()); + + // For prime field, x^(p-1) ≡ 1 mod p, so x^n ≡ x^(n mod (p-1)) mod p + // For quadratic extension field, x^(p²-1) ≡ 1 mod p², so + // x^n ≡ x^(n mod (p²-1)) mod p² + if (isa(fieldType)) { + exp = b.create( + exp, b.create(intType, modulus - 1)); + } else if (isa(fieldType)) { + modulus = modulus.zext(modBitWidth * 2); + modulus = modulus * modulus - 1; + exp = b.create( + IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp); + intType = IntegerType::get(exp.getContext(), modulus.getBitWidth()); + exp = b.create( + exp, b.create(intType, modulus)); + } + + Value zero = b.create(intType, 0); + Value one = b.create(intType, 1); + Value powerOfP = base; + auto ifOp = b.create( + b.create(arith::CmpIPredicate::ne, + b.create(exp, one), zero), + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + auto newResult = b.create(init, powerOfP); + b.create(ValueRange{newResult}); + }, + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + b.create(ValueRange{init}); + }); + exp = b.create(exp, one); + init = ifOp.getResult(0); + auto whileOp = b.create( + TypeRange{intType, fieldType, fieldType}, ValueRange{exp, powerOfP, init}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + auto cond = + b.create(arith::CmpIPredicate::ugt, args[0], zero); + b.create(cond, args); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + ImplicitLocOpBuilder b(loc, builder); + auto currExp = args[0]; + auto currPowerOfP = args[1]; + auto currResult = args[2]; + + auto newPowerOfP = b.create(currPowerOfP); + auto masked = b.create(currExp, one); + auto isOdd = + b.create(arith::CmpIPredicate::ne, masked, zero); + auto ifOp = b.create( + isOdd, + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + auto newResult = b.create(currResult, newPowerOfP); + b.create(ValueRange{newResult}); + }, + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + b.create(ValueRange{currResult}); + }); + auto shifted = b.create(currExp, one); + b.create( + ValueRange{shifted, newPowerOfP, ifOp.getResult(0)}); + }); + result = whileOp.getResult(2); + return success(); +} + +} // namespace + struct ConvertPowUI : public OpConversionPattern { explicit ConvertPowUI(MLIRContext *context) : OpConversionPattern(context) {} @@ -535,117 +653,39 @@ struct ConvertPowUI : public OpConversionPattern { matchAndRewrite(PowUIOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto base = op.getBase(); - auto exp = op.getExp(); - auto fieldType = getElementTypeOrSelf(base); - - APInt modulus; - Value init; - if (auto pfType = dyn_cast(fieldType)) { - modulus = pfType.getModulus().getValue(); - init = pfType.isMontgomery() - ? b.create( - pfType, b.create( - cast( - field::getStandardFormType(pfType)), - 1)) - .getResult() - : b.create(pfType, 1); - } else if (auto f2Type = dyn_cast(fieldType)) { - modulus = f2Type.getBaseField().getModulus().getValue(); - init = f2Type.isMontgomery() - ? b.create( - f2Type, b.create( - cast( - field::getStandardFormType(f2Type)), - 1, 0)) - .getResult() - : b.create(f2Type, 1, 0); - } else { - op.emitOpError("unsupported output type"); + + Value result; + if (failed(computePower(b, op, op.getBase(), op.getExp(), result))) { return failure(); } + rewriter.replaceOp(op, result); + return success(); + } +}; - unsigned expBitWidth = cast(exp.getType()).getWidth(); - unsigned modBitWidth = modulus.getBitWidth(); - if (modBitWidth > expBitWidth) { - exp = b.create( - IntegerType::get(exp.getContext(), modBitWidth), exp); - } else { - modulus = modulus.zext(expBitWidth); - modBitWidth = expBitWidth; - } - IntegerType intType = cast(exp.getType()); +struct ConvertPowPF : public OpConversionPattern { + explicit ConvertPowPF(MLIRContext *context) + : OpConversionPattern(context) {} - // For prime field, x^(p-1) ≡ 1 mod p, so x^n ≡ x^(n mod (p-1)) mod p - // For quadratic extension field, x^(p²-1) ≡ 1 mod p², so - // x^n ≡ x^(n mod (p²-1)) mod p² - if (isa(fieldType)) { - exp = b.create( - exp, b.create(intType, modulus - 1)); - } else if (isa(fieldType)) { - modulus = modulus.zext(modBitWidth * 2); - modulus = modulus * modulus - 1; - exp = b.create( - IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp); - intType = IntegerType::get(exp.getContext(), modulus.getBitWidth()); - exp = b.create( - exp, b.create(intType, modulus)); - } + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + PowPFOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create(intType, 0); - Value one = b.create(intType, 1); - Value powerOfP = base; - auto ifOp = b.create( - b.create(arith::CmpIPredicate::ne, - b.create(exp, one), zero), - [&](OpBuilder &builder, Location loc) { - ImplicitLocOpBuilder b(loc, builder); - auto newResult = b.create(init, powerOfP); - b.create(ValueRange{newResult}); - }, - [&](OpBuilder &builder, Location loc) { - ImplicitLocOpBuilder b(loc, builder); - b.create(ValueRange{init}); - }); - exp = b.create(exp, one); - init = ifOp.getResult(0); - auto whileOp = b.create( - TypeRange{intType, fieldType, fieldType}, - ValueRange{exp, powerOfP, init}, - [&](OpBuilder &builder, Location loc, ValueRange args) { - ImplicitLocOpBuilder b(loc, builder); - auto cond = - b.create(arith::CmpIPredicate::ugt, args[0], zero); - b.create(cond, args); - }, - [&](OpBuilder &builder, Location loc, ValueRange args) { - ImplicitLocOpBuilder b(loc, builder); - auto currExp = args[0]; - auto currPowerOfP = args[1]; - auto currResult = args[2]; - - auto newPowerOfP = b.create(currPowerOfP); - auto masked = b.create(currExp, one); - auto isOdd = - b.create(arith::CmpIPredicate::ne, masked, zero); - auto ifOp = b.create( - isOdd, - [&](OpBuilder &builder, Location loc) { - ImplicitLocOpBuilder b(loc, builder); - auto newResult = - b.create(currResult, newPowerOfP); - b.create(ValueRange{newResult}); - }, - [&](OpBuilder &builder, Location loc) { - ImplicitLocOpBuilder b(loc, builder); - b.create(ValueRange{currResult}); - }); - auto shifted = b.create(currExp, one); - b.create( - ValueRange{shifted, newPowerOfP, ifOp.getResult(0)}); - }); - rewriter.replaceOp(op, whileOp.getResult(2)); + Value exp = op.getExp(); + auto expFieldType = cast(exp.getType()); + unsigned expBitWidth = expFieldType.getModulus().getValue().getBitWidth(); + auto expIntType = IntegerType::get(b.getContext(), expBitWidth); + Value expInt = + b.create(TypeRange{expIntType}, exp).getResult(0); + + Value result; + if (failed(computePower(b, op, op.getBase(), expInt, result))) { + return failure(); + } + rewriter.replaceOp(op, result); return success(); } }; @@ -738,6 +778,7 @@ void FieldToModArith::runOnOperation() { ConvertInverse, ConvertNegate, ConvertMul, + ConvertPowPF, ConvertPowUI, ConvertSquare, ConvertSub, diff --git a/zkir/Dialect/Field/IR/FieldOps.cpp b/zkir/Dialect/Field/IR/FieldOps.cpp index 02ea73c49..dee2f02d2 100644 --- a/zkir/Dialect/Field/IR/FieldOps.cpp +++ b/zkir/Dialect/Field/IR/FieldOps.cpp @@ -202,6 +202,7 @@ LogicalResult AddOp::verify() { return disallowShapedTypeOfExtField(*this); } LogicalResult SubOp::verify() { return disallowShapedTypeOfExtField(*this); } LogicalResult MulOp::verify() { return disallowShapedTypeOfExtField(*this); } LogicalResult PowUIOp::verify() { return disallowShapedTypeOfExtField(*this); } +LogicalResult PowPFOp::verify() { return disallowShapedTypeOfExtField(*this); } LogicalResult InverseOp::verify() { return disallowShapedTypeOfExtField(*this); } diff --git a/zkir/Dialect/Field/IR/FieldOps.td b/zkir/Dialect/Field/IR/FieldOps.td index 2f423e9da..ba32cd1df 100644 --- a/zkir/Dialect/Field/IR/FieldOps.td +++ b/zkir/Dialect/Field/IR/FieldOps.td @@ -308,7 +308,7 @@ def Field_MulOp : Field_BinaryOp<"mul", [Commutative]> { let hasCanonicalizer = 1; } -// Field power. +// Field power with unsigned integer exponent. def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith< "base and output must have same type", "base", "output", "$_self">]> { let summary = "Field power with unsigned integer exponent"; @@ -317,7 +317,7 @@ def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith< Example: ``` - %power = field.pow %a, %b : field.pf, i32 + %power = field.powui %a, %b : field.pf, i32 ``` }]; let arguments = (ins FieldLike:$base, SignlessIntegerLike:$exp); @@ -326,4 +326,22 @@ def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith< let hasVerifier = 1; } +// Field power with prime field exponent. +def Field_PowPFOp : Field_Op<"powpf", [TypesMatchWith< + "base and output must have same type", "base", "output", "$_self">]> { + let summary = "Field power with prime field exponent"; + let description = [{ + Computes the field element a raised to the power of prime field b. + + Example: + ``` + %power = field.powpf %a, %b : field.pf, field.pf + ``` + }]; + let arguments = (ins FieldLike:$base, Field_PrimeFieldType:$exp); + let results = (outs FieldLike:$output); + let assemblyFormat = "operands attr-dict `:` type($base) `,` type($exp)"; + let hasVerifier = 1; +} + #endif // ZKIR_DIALECT_FIELD_IR_FIELDOPS_TD_