Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tests/Dialect/Field/field_runner.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#mont = #mod_arith.montgomery<7:i32>
!PF = !field.pf<7:i32>
!PFm = !field.pf<7:i32, true>
!PF_exp = !field.pf<7:i64>

#beta = #field.pf.elem<6:i32> : !PF
#beta_mont = #field.pf.elem<3:i32> : !PFm
Expand All @@ -17,6 +18,7 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interf

func.func @test_power() {
%exp = arith.constant 51 : i64
%exp_pf = field.encapsulate %exp : i64 -> !PF_exp
%base = arith.constant 3 : i32
%base_pf = field.encapsulate %base: i32 -> !PF

Expand Down Expand Up @@ -53,10 +55,27 @@ func.func @test_power() {
%U4 = memref.cast %16 : memref<2xi32> to memref<*xi32>
func.call @printMemrefI32(%U4) : (memref<*xi32>) -> ()

%res1_pf = field.powpf %base_pf, %exp_pf : !PF, !PF_exp
%17 = field.extract %res1_pf : !PF -> i32
%18 = tensor.from_elements %17 : tensor<1xi32>
%19 = bufferization.to_buffer %18 : tensor<1xi32> to memref<1xi32>
%U5 = memref.cast %19 : memref<1xi32> to memref<*xi32>
func.call @printMemrefI32(%U5) : (memref<*xi32>) -> ()

%res2_pf_mont = field.powpf %base_f2_mont, %exp_pf : !QFm, !PF_exp
%res2_pf_standard = field.from_mont %res2_pf_mont : !QF
%20, %21 = field.extract %res2_pf_standard : !QF -> i32, i32
%22 = tensor.from_elements %20, %21 : tensor<2xi32>
%23 = bufferization.to_buffer %22 : tensor<2xi32> to memref<2xi32>
%U6 = memref.cast %23 : memref<2xi32> to memref<*xi32>
func.call @printMemrefI32(%U6) : (memref<*xi32>) -> ()

return
}

// CHECK_TEST_POWER: [6]
// CHECK_TEST_POWER: [6]
// CHECK_TEST_POWER: [2, 5]
// CHECK_TEST_POWER: [2, 5]
// CHECK_TEST_POWER: [6]
// CHECK_TEST_POWER: [2, 5]
253 changes: 147 additions & 106 deletions zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,124 @@ struct ConvertSquare : public OpConversionPattern<SquareOp> {
}
};

namespace {

template <typename OpType>
LogicalResult computePower(ImplicitLocOpBuilder &b, OpType &op, Value base,
Value exp, Value &result) {
APInt modulus;
Value init;
auto fieldType = getElementTypeOrSelf(base);

if (auto pfType = dyn_cast<PrimeFieldType>(fieldType)) {
modulus = pfType.getModulus().getValue();
init =
pfType.isMontgomery()
? b.create<field::ToMontOp>(
pfType,
b.create<field::ConstantOp>(
cast<PrimeFieldType>(field::getStandardFormType(pfType)),
1))
.getResult()
: b.create<field::ConstantOp>(pfType, 1);
} else if (auto f2Type = dyn_cast<QuadraticExtFieldType>(fieldType)) {
modulus = f2Type.getBaseField().getModulus().getValue();
init = f2Type.isMontgomery()
? b.create<field::ToMontOp>(
f2Type, b.create<field::ConstantOp>(
cast<QuadraticExtFieldType>(
field::getStandardFormType(f2Type)),
1, 0))
.getResult()
: b.create<field::ConstantOp>(f2Type, 1, 0);
} else {
op.emitOpError("unsupported output type");
return failure();
}

unsigned expBitWidth = cast<IntegerType>(exp.getType()).getWidth();
unsigned modBitWidth = modulus.getBitWidth();
if (modBitWidth > expBitWidth) {
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), modBitWidth), exp);
} else {
modulus = modulus.zext(expBitWidth);
modBitWidth = expBitWidth;
}
IntegerType intType = cast<IntegerType>(exp.getType());

// For prime field, x^(p-1) ≡ 1 mod p, so x^n ≡ x^(n mod (p-1)) mod p
// For quadratic extension field, x^(p²-1) ≡ 1 mod p², so
// x^n ≡ x^(n mod (p²-1)) mod p²
if (isa<PrimeFieldType>(fieldType)) {
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, modulus - 1));
} else if (isa<QuadraticExtFieldType>(fieldType)) {
modulus = modulus.zext(modBitWidth * 2);
modulus = modulus * modulus - 1;
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp);
intType = IntegerType::get(exp.getContext(), modulus.getBitWidth());
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, modulus));
Comment on lines +581 to +587
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modulus variable is reused to store the group order (p^2 - 1), which can be confusing as it originally holds the prime modulus p. Using a separate, more descriptively named variable like groupOrder would make the logic for exponent reduction clearer and improve maintainability.

Suggested change
modulus = modulus.zext(modBitWidth * 2);
modulus = modulus * modulus - 1;
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp);
intType = IntegerType::get(exp.getContext(), modulus.getBitWidth());
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, modulus));
APInt groupOrder = modulus.zext(modBitWidth * 2);
groupOrder = groupOrder * groupOrder - 1;
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), groupOrder.getBitWidth()), exp);
intType = IntegerType::get(exp.getContext(), groupOrder.getBitWidth());
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, groupOrder));

}

Value zero = b.create<arith::ConstantIntOp>(intType, 0);
Value one = b.create<arith::ConstantIntOp>(intType, 1);
Value powerOfP = base;
auto ifOp = b.create<scf::IfOp>(
b.create<arith::CmpIOp>(arith::CmpIPredicate::ne,
b.create<arith::AndIOp>(exp, one), zero),
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
auto newResult = b.create<field::MulOp>(init, powerOfP);
b.create<scf::YieldOp>(ValueRange{newResult});
},
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
b.create<scf::YieldOp>(ValueRange{init});
});
exp = b.create<arith::ShRUIOp>(exp, one);
init = ifOp.getResult(0);
auto whileOp = b.create<scf::WhileOp>(
TypeRange{intType, fieldType, fieldType}, ValueRange{exp, powerOfP, init},
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
auto cond =
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, args[0], zero);
b.create<scf::ConditionOp>(cond, args);
},
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
auto currExp = args[0];
auto currPowerOfP = args[1];
auto currResult = args[2];

auto newPowerOfP = b.create<field::SquareOp>(currPowerOfP);
auto masked = b.create<arith::AndIOp>(currExp, one);
auto isOdd =
b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, masked, zero);
auto ifOp = b.create<scf::IfOp>(
isOdd,
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
auto newResult = b.create<field::MulOp>(currResult, newPowerOfP);
b.create<scf::YieldOp>(ValueRange{newResult});
},
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
b.create<scf::YieldOp>(ValueRange{currResult});
});
auto shifted = b.create<arith::ShRUIOp>(currExp, one);
b.create<scf::YieldOp>(
ValueRange{shifted, newPowerOfP, ifOp.getResult(0)});
});
result = whileOp.getResult(2);
return success();
}

} // namespace

struct ConvertPowUI : public OpConversionPattern<PowUIOp> {
explicit ConvertPowUI(MLIRContext *context)
: OpConversionPattern<PowUIOp>(context) {}
Expand All @@ -535,117 +653,39 @@ struct ConvertPowUI : public OpConversionPattern<PowUIOp> {
matchAndRewrite(PowUIOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto base = op.getBase();
auto exp = op.getExp();
auto fieldType = getElementTypeOrSelf(base);

APInt modulus;
Value init;
if (auto pfType = dyn_cast<PrimeFieldType>(fieldType)) {
modulus = pfType.getModulus().getValue();
init = pfType.isMontgomery()
? b.create<field::ToMontOp>(
pfType, b.create<field::ConstantOp>(
cast<PrimeFieldType>(
field::getStandardFormType(pfType)),
1))
.getResult()
: b.create<field::ConstantOp>(pfType, 1);
} else if (auto f2Type = dyn_cast<QuadraticExtFieldType>(fieldType)) {
modulus = f2Type.getBaseField().getModulus().getValue();
init = f2Type.isMontgomery()
? b.create<field::ToMontOp>(
f2Type, b.create<field::ConstantOp>(
cast<QuadraticExtFieldType>(
field::getStandardFormType(f2Type)),
1, 0))
.getResult()
: b.create<field::ConstantOp>(f2Type, 1, 0);
} else {
op.emitOpError("unsupported output type");

Value result;
if (failed(computePower(b, op, op.getBase(), op.getExp(), result))) {
return failure();
}
rewriter.replaceOp(op, result);
return success();
}
};

unsigned expBitWidth = cast<IntegerType>(exp.getType()).getWidth();
unsigned modBitWidth = modulus.getBitWidth();
if (modBitWidth > expBitWidth) {
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), modBitWidth), exp);
} else {
modulus = modulus.zext(expBitWidth);
modBitWidth = expBitWidth;
}
IntegerType intType = cast<IntegerType>(exp.getType());
struct ConvertPowPF : public OpConversionPattern<PowPFOp> {
explicit ConvertPowPF(MLIRContext *context)
: OpConversionPattern<PowPFOp>(context) {}

// For prime field, x^(p-1) ≡ 1 mod p, so x^n ≡ x^(n mod (p-1)) mod p
// For quadratic extension field, x^(p²-1) ≡ 1 mod p², so
// x^n ≡ x^(n mod (p²-1)) mod p²
if (isa<PrimeFieldType>(fieldType)) {
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, modulus - 1));
} else if (isa<QuadraticExtFieldType>(fieldType)) {
modulus = modulus.zext(modBitWidth * 2);
modulus = modulus * modulus - 1;
exp = b.create<arith::ExtUIOp>(
IntegerType::get(exp.getContext(), modulus.getBitWidth()), exp);
intType = IntegerType::get(exp.getContext(), modulus.getBitWidth());
exp = b.create<arith::RemUIOp>(
exp, b.create<arith::ConstantIntOp>(intType, modulus));
}
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
PowPFOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value zero = b.create<arith::ConstantIntOp>(intType, 0);
Value one = b.create<arith::ConstantIntOp>(intType, 1);
Value powerOfP = base;
auto ifOp = b.create<scf::IfOp>(
b.create<arith::CmpIOp>(arith::CmpIPredicate::ne,
b.create<arith::AndIOp>(exp, one), zero),
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
auto newResult = b.create<field::MulOp>(init, powerOfP);
b.create<scf::YieldOp>(ValueRange{newResult});
},
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
b.create<scf::YieldOp>(ValueRange{init});
});
exp = b.create<arith::ShRUIOp>(exp, one);
init = ifOp.getResult(0);
auto whileOp = b.create<scf::WhileOp>(
TypeRange{intType, fieldType, fieldType},
ValueRange{exp, powerOfP, init},
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
auto cond =
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, args[0], zero);
b.create<scf::ConditionOp>(cond, args);
},
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);
auto currExp = args[0];
auto currPowerOfP = args[1];
auto currResult = args[2];

auto newPowerOfP = b.create<field::SquareOp>(currPowerOfP);
auto masked = b.create<arith::AndIOp>(currExp, one);
auto isOdd =
b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, masked, zero);
auto ifOp = b.create<scf::IfOp>(
isOdd,
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
auto newResult =
b.create<field::MulOp>(currResult, newPowerOfP);
b.create<scf::YieldOp>(ValueRange{newResult});
},
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);
b.create<scf::YieldOp>(ValueRange{currResult});
});
auto shifted = b.create<arith::ShRUIOp>(currExp, one);
b.create<scf::YieldOp>(
ValueRange{shifted, newPowerOfP, ifOp.getResult(0)});
});
rewriter.replaceOp(op, whileOp.getResult(2));
Value exp = op.getExp();
auto expFieldType = cast<field::PrimeFieldType>(exp.getType());
unsigned expBitWidth = expFieldType.getModulus().getValue().getBitWidth();
auto expIntType = IntegerType::get(b.getContext(), expBitWidth);
Value expInt =
b.create<field::ExtractOp>(TypeRange{expIntType}, exp).getResult(0);

Value result;
if (failed(computePower(b, op, op.getBase(), expInt, result))) {
return failure();
}
rewriter.replaceOp(op, result);
return success();
}
};
Expand Down Expand Up @@ -738,6 +778,7 @@ void FieldToModArith::runOnOperation() {
ConvertInverse,
ConvertNegate,
ConvertMul,
ConvertPowPF,
ConvertPowUI,
ConvertSquare,
ConvertSub,
Expand Down
1 change: 1 addition & 0 deletions zkir/Dialect/Field/IR/FieldOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ LogicalResult AddOp::verify() { return disallowShapedTypeOfExtField(*this); }
LogicalResult SubOp::verify() { return disallowShapedTypeOfExtField(*this); }
LogicalResult MulOp::verify() { return disallowShapedTypeOfExtField(*this); }
LogicalResult PowUIOp::verify() { return disallowShapedTypeOfExtField(*this); }
LogicalResult PowPFOp::verify() { return disallowShapedTypeOfExtField(*this); }
LogicalResult InverseOp::verify() {
return disallowShapedTypeOfExtField(*this);
}
Expand Down
22 changes: 20 additions & 2 deletions zkir/Dialect/Field/IR/FieldOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def Field_MulOp : Field_BinaryOp<"mul", [Commutative]> {
let hasCanonicalizer = 1;
}

// Field power.
// Field power with unsigned integer exponent.
def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith<
"base and output must have same type", "base", "output", "$_self">]> {
let summary = "Field power with unsigned integer exponent";
Expand All @@ -317,7 +317,7 @@ def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith<

Example:
```
%power = field.pow %a, %b : field.pf<primeModulus>, i32
%power = field.powui %a, %b : field.pf<primeModulus>, i32
```
}];
let arguments = (ins FieldLike:$base, SignlessIntegerLike:$exp);
Expand All @@ -326,4 +326,22 @@ def Field_PowUIOp : Field_Op<"powui", [TypesMatchWith<
let hasVerifier = 1;
}

// Field power with prime field exponent.
def Field_PowPFOp : Field_Op<"powpf", [TypesMatchWith<
"base and output must have same type", "base", "output", "$_self">]> {
let summary = "Field power with prime field exponent";
let description = [{
Computes the field element a raised to the power of prime field b.

Example:
```
%power = field.powpf %a, %b : field.pf<primeModulus>, field.pf<primeModulus>
```
}];
let arguments = (ins FieldLike:$base, Field_PrimeFieldType:$exp);
let results = (outs FieldLike:$output);
let assemblyFormat = "operands attr-dict `:` type($base) `,` type($exp)";
let hasVerifier = 1;
}

#endif // ZKIR_DIALECT_FIELD_IR_FIELDOPS_TD_
Loading