Skip to content

Commit

Permalink
Enable support for batched evaluations.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Apr 6, 2022
1 parent 2641657 commit 87a2177
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 111 deletions.
4 changes: 2 additions & 2 deletions bench/oprf.bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ export async function benchOPRF(bs: Benchmark.Suite) {
break
}

const [finData, evalReq] = await client.blind(input)
const [finData, evalReq] = await client.blind([input])
const evaluatedElement = await server.evaluate(evalReq)
const prefix = mode + '/' + suite + '/'

bs.add(
prefix + 'blind ',
asyncFn(() => client.blind(input))
asyncFn(() => client.blind([input]))
)
bs.add(
prefix + 'evaluate',
Expand Down
69 changes: 48 additions & 21 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import { Elt, Scalar } from './group.js'
import { Evaluation, EvaluationRequest, FinalizeData, ModeID, Oprf, SuiteID } from './oprf.js'

import { zip } from './util.js'

class baseClient extends Oprf {
constructor(mode: ModeID, suite: SuiteID) {
super(mode, suite)
Expand All @@ -15,35 +17,49 @@ class baseClient extends Oprf {
return this.gg.randomScalar()
}

async blind(input: Uint8Array): Promise<[FinalizeData, EvaluationRequest]> {
const scalar = await this.randomBlinder()
const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST))
if (P.isIdentity()) {
throw new Error('InvalidInputError')
async blind(inputs: Uint8Array[]): Promise<[FinalizeData, EvaluationRequest]> {
const eltList = []
const blinds = []
for (const input of inputs) {
const scalar = await this.randomBlinder()
const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST))
if (P.isIdentity()) {
throw new Error('InvalidInputError')
}
eltList.push(P.mul(scalar))
blinds.push(scalar)
}
const Q = P.mul(scalar)
const evalReq = new EvaluationRequest(Q)
const finData = new FinalizeData(input, scalar, evalReq)
const evalReq = new EvaluationRequest(eltList)
const finData = new FinalizeData(inputs, blinds, evalReq)
return [finData, evalReq]
}

doFinalize(
async doFinalize(
finData: FinalizeData,
evaluation: Evaluation,
info = new Uint8Array(0)
): Promise<Uint8Array> {
const blindInv = finData.blind.inv()
const N = evaluation.evaluated.mul(blindInv)
const unblinded = N.serialize()
return this.coreFinalize(finData.input, unblinded, info)
): Promise<Uint8Array[]> {
const n = finData.inputs.length
if (finData.blinds.length !== n || evaluation.evaluated.length !== n) {
throw new Error('mismatched lengths')
}

const outputList = []
for (let i = 0; i < n; i++) {
const blindInv = finData.blinds[i as number].inv()
const N = evaluation.evaluated[i as number].mul(blindInv)
const unblinded = N.serialize()
outputList.push(await this.coreFinalize(finData.inputs[i as number], unblinded, info))
}
return outputList
}
}

export class OPRFClient extends baseClient {
constructor(suite: SuiteID) {
super(Oprf.Mode.OPRF, suite)
}
finalize(finData: FinalizeData, evaluation: Evaluation): Promise<Uint8Array> {
finalize(finData: FinalizeData, evaluation: Evaluation): Promise<Array<Uint8Array>> {
return super.doFinalize(finData, evaluation)
}
}
Expand All @@ -53,15 +69,21 @@ export class VOPRFClient extends baseClient {
super(Oprf.Mode.VOPRF, suite)
}

finalize(finData: FinalizeData, evaluation: Evaluation): Promise<Uint8Array> {
finalize(finData: FinalizeData, evaluation: Evaluation): Promise<Array<Uint8Array>> {
if (!evaluation.proof) {
throw new Error('no proof provided')
}
const pkS = Elt.deserialize(this.gg, this.pubKeyServer)

const n = finData.inputs.length
if (evaluation.evaluated.length !== n) {
throw new Error('mismatched lengths')
}

if (
!evaluation.proof.verify(
!evaluation.proof.verify_batch(
[this.gg.generator(), pkS],
[finData.evalReq.blinded, evaluation.evaluated]
zip(finData.evalReq.blinded, evaluation.evaluated)
)
) {
throw new Error('proof failed')
Expand Down Expand Up @@ -91,15 +113,20 @@ export class POPRFClient extends baseClient {
finData: FinalizeData,
evaluation: Evaluation,
info = new Uint8Array(0)
): Promise<Uint8Array> {
): Promise<Array<Uint8Array>> {
if (!evaluation.proof) {
throw new Error('no proof provided')
}
const tw = await this.pointFromInfo(info)
const n = finData.inputs.length
if (evaluation.evaluated.length !== n) {
throw new Error('mismatched lengths')
}

if (
!evaluation.proof.verify(
!evaluation.proof.verify_batch(
[this.gg.generator(), tw],
[evaluation.evaluated, finData.evalReq.blinded]
zip(evaluation.evaluated, finData.evalReq.blinded)
)
) {
throw new Error('proof failed')
Expand Down
70 changes: 31 additions & 39 deletions src/oprf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

import { DLEQParams, DLEQProof } from './dleq.js'
import { Elt, Group, GroupID, Scalar } from './group.js'
import { checkSize, fromU16LenPrefix, joinAll, toU16LenPrefix } from './util.js'
import {
fromU16LenPrefixClass,
fromU16LenPrefixUint8Array,
joinAll,
toU16LenPrefix,
toU16LenPrefixClass,
toU16LenPrefixUint8Array
} from './util.js'

export type ModeID = typeof Oprf.Mode[keyof typeof Oprf.Mode]
export type SuiteID = typeof Oprf.Suite[keyof typeof Oprf.Suite]
Expand Down Expand Up @@ -123,11 +130,11 @@ export abstract class Oprf {
}

export class Evaluation {
constructor(public readonly evaluated: Elt, public readonly proof?: DLEQProof) {}
constructor(public readonly evaluated: Array<Elt>, public readonly proof?: DLEQProof) {}

serialize(): Uint8Array {
return joinAll([
this.evaluated.serialize(true),
...toU16LenPrefixClass(this.evaluated),
Uint8Array.from([this.proof ? 1 : 0]),
this.proof ? this.proof.serialize() : new Uint8Array()
])
Expand All @@ -137,83 +144,68 @@ export class Evaluation {
if ((this.proof && !e.proof) || (!this.proof && e.proof)) {
return false
}
let res = this.evaluated.isEqual(e.evaluated)
let res = this.evaluated.every((x, i) => x.isEqual(e.evaluated[i as number]))
if (this.proof && e.proof) {
res &&= this.proof.isEqual(e.proof)
}
return res
}

static size(params: DLEQParams): number {
return Elt.size(params.gg) + 1
}

static deserialize(params: DLEQParams, bytes: Uint8Array): Evaluation {
checkSize(bytes, Evaluation, params)
const eltSize = Elt.size(params.gg)
const evaluated = Elt.deserialize(params.gg, bytes.subarray(0, eltSize))
const { head: evalList, tail } = fromU16LenPrefixClass(Elt, params.gg, bytes)
let proof: DLEQProof | undefined
if (bytes[eltSize as number] === 1) {
if (tail[0] === 1) {
const prSize = DLEQProof.size(params)
proof = DLEQProof.deserialize(params, bytes.subarray(1 + eltSize, 1 + eltSize + prSize))
proof = DLEQProof.deserialize(params, tail.subarray(1, 1 + prSize))
}
return new Evaluation(evaluated, proof)
return new Evaluation(evalList, proof)
}
}

export class EvaluationRequest {
constructor(public readonly blinded: Elt) {}
constructor(public readonly blinded: Array<Elt>) {}

serialize(): Uint8Array {
return this.blinded.serialize(true)
return joinAll(toU16LenPrefixClass(this.blinded))
}

isEqual(e: EvaluationRequest): boolean {
return this.blinded.isEqual(e.blinded)
}

static size(g: Group): number {
return Elt.size(g)
return this.blinded.every((x, i) => x.isEqual(e.blinded[i as number]))
}

static deserialize(g: Group, bytes: Uint8Array): EvaluationRequest {
checkSize(bytes, EvaluationRequest, g)
return new EvaluationRequest(Elt.deserialize(g, bytes))
const { head: blindedList } = fromU16LenPrefixClass(Elt, g, bytes)
return new EvaluationRequest(blindedList)
}
}

export class FinalizeData {
constructor(
public readonly input: Uint8Array,
public readonly blind: Scalar,
public readonly inputs: Array<Uint8Array>,
public readonly blinds: Array<Scalar>,
public readonly evalReq: EvaluationRequest
) {}

serialize(): Uint8Array {
return joinAll([
...toU16LenPrefix(this.input),
this.blind.serialize(),
...toU16LenPrefixUint8Array(this.inputs),
...toU16LenPrefixClass(this.blinds),
this.evalReq.serialize()
])
}

isEqual(f: FinalizeData): boolean {
return (
this.input.toString() === f.input.toString() &&
this.blind.isEqual(f.blind) &&
this.inputs.every((x, i) => x.toString() === f.inputs[i as number].toString()) &&
this.blinds.every((x, i) => x.isEqual(f.blinds[i as number])) &&
this.evalReq.isEqual(f.evalReq)
)
}
static size(g: Group): number {
return 2 + Scalar.size(g) + EvaluationRequest.size(g)
}

static deserialize(g: Group, bytes: Uint8Array): FinalizeData {
checkSize(bytes, FinalizeData, g)
const { head: input, tail } = fromU16LenPrefix(bytes)
const scSize = Scalar.size(g)
const erSize = EvaluationRequest.size(g)
const blind = Scalar.deserialize(g, tail.subarray(0, scSize))
const evalReq = EvaluationRequest.deserialize(g, tail.subarray(scSize, scSize + erSize))
return new FinalizeData(input, blind, evalReq)
const { head: inputs, tail: t0 } = fromU16LenPrefixUint8Array(bytes)
const { head: blinds, tail: t1 } = fromU16LenPrefixClass(Scalar, g, t0)
const evalReq = EvaluationRequest.deserialize(g, t1)
return new FinalizeData(inputs, blinds, evalReq)
}
}
29 changes: 20 additions & 9 deletions src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import { DLEQParams, DLEQProver } from './dleq.js'
import { Elt, Scalar } from './group.js'
import { Evaluation, EvaluationRequest, ModeID, Oprf, SuiteID } from './oprf.js'

import { ctEqual } from './util.js'
import { ctEqual, zip } from './util.js'

class baseServer extends Oprf {
protected privateKey: Uint8Array
Expand Down Expand Up @@ -86,7 +85,9 @@ export class OPRFServer extends baseServer {
}

async evaluate(req: EvaluationRequest): Promise<Evaluation> {
return new Evaluation(await this.doEvaluation(req.blinded, this.privateKey))
return new Evaluation(
await Promise.all(req.blinded.map((b) => this.doEvaluation(b, this.privateKey)))
)
}
async fullEvaluate(input: Uint8Array): Promise<Uint8Array> {
return this.doFullEvaluate(input)
Expand All @@ -101,12 +102,18 @@ export class VOPRFServer extends baseServer {
super(Oprf.Mode.VOPRF, suite, privateKey)
}
async evaluate(req: EvaluationRequest): Promise<Evaluation> {
const e = await this.doEvaluation(req.blinded, this.privateKey)
const evalList = await Promise.all(
req.blinded.map((b) => this.doEvaluation(b, this.privateKey))
)
const prover = new DLEQProver(this.constructDLEQParams())
const skS = Scalar.deserialize(this.gg, this.privateKey)
const pkS = this.gg.mulGen(skS)
const proof = await prover.prove(skS, [this.gg.generator(), pkS], [req.blinded, e])
return new Evaluation(e, proof)
const proof = await prover.prove_batch(
skS,
[this.gg.generator(), pkS],
zip(req.blinded, evalList)
)
return new Evaluation(evalList, proof)
}
async fullEvaluate(input: Uint8Array): Promise<Uint8Array> {
return this.doFullEvaluate(input)
Expand All @@ -123,11 +130,15 @@ export class POPRFServer extends baseServer {
async evaluate(req: EvaluationRequest, info = new Uint8Array(0)): Promise<Evaluation> {
const [keyProof, evalSecret] = await this.secretFromInfo(info)
const secret = evalSecret.serialize()
const e = await this.doEvaluation(req.blinded, secret)
const evalList = await Promise.all(req.blinded.map((b) => this.doEvaluation(b, secret)))
const prover = new DLEQProver(this.constructDLEQParams())
const kG = this.gg.mulGen(keyProof)
const proof = await prover.prove(keyProof, [this.gg.generator(), kG], [e, req.blinded])
return new Evaluation(e, proof)
const proof = await prover.prove_batch(
keyProof,
[this.gg.generator(), kG],
zip(evalList, req.blinded)
)
return new Evaluation(evalList, proof)
}
async fullEvaluate(input: Uint8Array, info = new Uint8Array(0)): Promise<Uint8Array> {
return this.doFullEvaluate(input, info)
Expand Down
Loading

0 comments on commit 87a2177

Please sign in to comment.