diff --git a/samples/qir/oracle-generator/oracle-generator/read_qir.hpp b/samples/qir/oracle-generator/oracle-generator/read_qir.hpp index 58f626144cfd..e24a96bfd549 100644 --- a/samples/qir/oracle-generator/oracle-generator/read_qir.hpp +++ b/samples/qir/oracle-generator/oracle-generator/read_qir.hpp @@ -72,9 +72,9 @@ namespace detail } } - // for now, the sample only supports Boolean functions (i.e., all - // input and output types are either of type `Bool` or a tuple of - // type `Bool`) + // for now, the sample only supports Boolean and Integer functions (i.e., all + // input and output types are either of type `Bool`, `Int` or a tuple of + // type `Bool` `Int`) if (!analyze_function_signature(function)) { fmt::print("[e] function signature not supported: inputs must be Bool and return type must be Bool or " @@ -256,10 +256,138 @@ namespace detail } break; + // Signed remainder operation + case llvm::Instruction::SRem: + { + // Get the previous instruction + const llvm::Instruction* prevInst = inst.getPrevNode(); + if (prevInst) { + // Check the opcode of the previous instruction + unsigned int prevOpcode = prevInst->getOpcode(); + + if (prevOpcode == llvm::Instruction::Add) + { + // Perform modular addition inplace + auto const* op0 = prevInst -> getOperand(0u); + auto const* op1 = prevInst -> getOperand(1u); + auto const* ty0 = op0->getType(); + auto const* ty1 = op1->getType(); + auto const* op2 = inst.getOperand(1u); + + value_signals[&inst] = value_signals[op0]; + + // Check if the operand is a constant integer + if (const llvm::ConstantInt* constantInt = llvm::dyn_cast(op2)) { + // Get the unsigned value of the constant + llvm::APInt intValue = constantInt->getValue(); + uint64_t value = intValue.getZExtValue(); + + // The op2 value is converted to uint64_t in the 'value' variable + modular_adder_inplace(ntk, value_signals[&inst], value_signals[op1], value); + } else { + // Handle the case when op2 is not a constant integer + fmt::print("op2 is not a constant integer\n"); + std::abort(); + } + + + } + + else if (prevOpcode == llvm::Instruction::Mul) { + // Get the operands from the previous instruction + auto const* op0 = prevInst->getOperand(0u); + auto const* op1 = prevInst->getOperand(1u); + + // Get the operands from the current instruction + auto const* op2 = inst.getOperand(1u); + auto const* ty0 = op0->getType(); + auto const* ty1 = op1->getType(); + auto const* ty2 = op2->getType(); + + if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64) && ty2->isIntegerTy(64)) + { + auto signal0 = get_signal(ntk, op0); + auto signal1 = get_signal(ntk, op1); + const auto size = std::max(signal0.size(), signal1.size()); + value_signals[&inst] = value_signals[op0]; + + if (const llvm::ConstantInt* constantInt = llvm::dyn_cast(op2)) + { + // Get the unsigned value of the constant + llvm::APInt intValue = constantInt->getValue(); + uint64_t value = intValue.getZExtValue(); + + // Now you have the op2 value converted to uint64_t in the 'value' variable + // You can use it as needed in your code + modular_multiplication_inplace(ntk, value_signals[&inst], value_signals[op1], value); + } + else + { + // Handle the case when op2 is not a constant integer + fmt::print("op2 is not a constant integer\n"); + std::abort(); + } + } + } + + else + { + fmt::print("Unsupported previous opcode: {}\n", prevOpcode); + std::abort(); + } + } + else { + fmt::print("No previous instruction found\n"); + std::abort(); + } + } + break; + + // Multiplication operation + case llvm::Instruction::Mul: + { + auto const* op0 = inst.getOperand(0u); + auto const* op1 = inst.getOperand(1u); + auto const* ty0 = op0->getType(); + auto const* ty1 = op1->getType(); + + if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64)) + { + auto signal0 = get_signal(ntk, op0); + auto signal1 = get_signal(ntk, op1); + const auto size = std::max(signal0.size(), signal1.size()); + value_signals[&inst] = value_signals[inst.getOperand(0u)]; + modular_multiplication_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)], 11); + } + else + { + fmt::print("Not Implemented"); + std::abort(); + } + + } + break; + + // Addition operation case llvm::Instruction::Add: + { + auto const* op0 = inst.getOperand(0u); + auto const* op1 = inst.getOperand(1u); + auto const* ty0 = op0->getType(); + auto const* ty1 = op1->getType(); value_signals[&inst] = value_signals[inst.getOperand(0u)]; - modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)]); - break; + + if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64)) + { + modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)], 11); + + } + else + { + modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)]); + } + } + break; case llvm::Instruction::Br: { @@ -474,7 +602,7 @@ namespace detail /* input type */ for (const auto& arg : f.args()) { - if (arg.getType()->isIntegerTy(1u)) + if (arg.getType()->isIntegerTy(64) || arg.getType()->isIntegerTy(1u)) { continue; } @@ -484,7 +612,7 @@ namespace detail /* output type */ auto const* retTy = f.getReturnType(); - return retTy->isIntegerTy(1u) || is_valid_tuple_pointer_type(retTy); + return retTy->isIntegerTy(64) || retTy->isIntegerTy(1u) || is_valid_tuple_pointer_type(retTy); } private: