From fa075fbbebbfd747c4d9507fbc5ced0452172239 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 17 Oct 2024 12:59:43 +0200 Subject: [PATCH 1/3] fixing the round, adding the test --- .../compiler/innerproduct/compiler.go | 27 ++++- .../protocol/compiler/innerproduct/context.go | 47 +++++--- .../innerproduct/innerproduct_test.go | 107 ++++++++++++++++++ 3 files changed, 161 insertions(+), 20 deletions(-) create mode 100644 prover/protocol/compiler/innerproduct/innerproduct_test.go diff --git a/prover/protocol/compiler/innerproduct/compiler.go b/prover/protocol/compiler/innerproduct/compiler.go index 2e5291d522c..bfe76159f0c 100644 --- a/prover/protocol/compiler/innerproduct/compiler.go +++ b/prover/protocol/compiler/innerproduct/compiler.go @@ -31,9 +31,11 @@ func Compile(comp *wizard.CompiledIOP) { // same protocol if the same step of compilation are applied in the same // order. sizes = []int{} - // contextsForSize list all the sub-compilation context in the same - // order as `sizes` - proverTask proverTask + // contextsForSize list all the sub-compilation context + // in the same order as `sizes`. + // proverTaskCollaps indicates when we have more than one pair of inner-product with the same size + // and thus collapsing all pairs to a single column is required. + proverTaskNoCollaps, proverTaskCollpas proverTask ) for _, qName := range comp.QueriesParams.AllUnignoredKeys() { @@ -60,9 +62,24 @@ func Compile(comp *wizard.CompiledIOP) { } for _, size := range sizes { - proverTask = append(proverTask, compileForSize(comp, round, queryMap[size])) + ctx := compileForSize(comp, round, queryMap[size]) + switch ctx.round { + case round: + proverTaskNoCollaps = append(proverTaskNoCollaps, ctx) + case round + 1: + proverTaskCollpas = append(proverTaskCollpas, ctx) + default: + utils.Panic("round before compilation was %v and after compilation %v", round, ctx.round) + } + + } + // run the prover of the relevant round + if len(proverTaskNoCollaps) >= 1 { + comp.RegisterProverAction(round, proverTaskNoCollaps) } - comp.RegisterProverAction(round+1, proverTask) + if len(proverTaskCollpas) >= 1 { + comp.RegisterProverAction(round+1, proverTaskCollpas) + } } diff --git a/prover/protocol/compiler/innerproduct/context.go b/prover/protocol/compiler/innerproduct/context.go index 789b1d82fe7..924aeb4382a 100644 --- a/prover/protocol/compiler/innerproduct/context.go +++ b/prover/protocol/compiler/innerproduct/context.go @@ -34,6 +34,9 @@ type contextForSize struct { // entry of [Summation]. It is compared to the alleged inner-product values // by the verifier to finalize the compilation step.s SummationOpening query.LocalOpening + + // round after compilation + round int } // compileForSize applies the compilation step on a range of queries such that @@ -41,6 +44,7 @@ type contextForSize struct { // list of queries. // // It returns the compilation context of the query +// the round indicate the round of the last inner-product query, independent of its size. func compileForSize( comp *wizard.CompiledIOP, round int, @@ -60,33 +64,45 @@ func compileForSize( if hasMoreThan1Pair { round = round + 1 } + //set the round + ctx.round = round ctx.Summation = comp.InsertCommit( - round+1, - deriveName[ifaces.ColID]("SUMMATION", comp.SelfRecursionCount), + round, + deriveName[ifaces.ColID]("SUMMATION", size, comp.SelfRecursionCount), size, ) if hasMoreThan1Pair { var ( - pairProduct = []*symbolic.Expression{} + pairProduct = []ifaces.Column{} ) batchingCoin = comp.InsertCoin( - round+1, - deriveName[coin.Name]("BATCHING_COIN", comp.SelfRecursionCount), + round, + deriveName[coin.Name]("BATCHING_COIN", size, comp.SelfRecursionCount), coin.Field, ) for _, q := range queries { for _, b := range q.Bs { - pairProduct = append(pairProduct, symbolic.Mul(q.A, b)) + pairProduct = append(pairProduct, q.A, b) } } - ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct) - ctx.Collapsed.Board() + // @Azam the following function is commented out due to the issue https://github.com/Consensys/linea-monorepo/issues/192 + // ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct) + res := symbolic.NewConstant(0) + for i := len(pairProduct) - 1; i >= 0; i-- { + res2 := symbolic.Mul(pairProduct[i], pairProduct[i-1]) + res = symbolic.Mul(res, batchingCoin) + res = symbolic.Add(res, res2) + i-- + } + + ctx.Collapsed = res + ctx.CollapsedBoard = ctx.Collapsed.Board() } if !hasMoreThan1Pair { @@ -96,8 +112,8 @@ func compileForSize( // This constraints set the recurrent property of summation comp.InsertGlobal( - round+1, - deriveName[ifaces.QueryID]("SUMMATION_CONSISTENCY", comp.SelfRecursionCount), + round, + deriveName[ifaces.QueryID]("SUMMATION_CONSISTENCY", size, comp.SelfRecursionCount), symbolic.Sub( ctx.Summation, column.Shift(ctx.Summation, -1), @@ -107,20 +123,21 @@ func compileForSize( // This constraint ensures that summation has the correct initial value comp.InsertLocal( - round+1, - deriveName[ifaces.QueryID]("SUMMATION_INIT", comp.SelfRecursionCount), + round, + deriveName[ifaces.QueryID]("SUMMATION_INIT", size, comp.SelfRecursionCount), symbolic.Sub(ctx.Collapsed, ctx.Summation), ) // The opening of the final position of ctx.Summation should be equal to // the linear combinations of the alleged openings of the inner-products. ctx.SummationOpening = comp.InsertLocalOpening( - round+1, - deriveName[ifaces.QueryID]("SUMMATION_END", comp.SelfRecursionCount), + round, + deriveName[ifaces.QueryID]("SUMMATION_END", size, comp.SelfRecursionCount), column.Shift(ctx.Summation, -1), ) - comp.RegisterVerifierAction(round+1, &verifierForSize{ + lastRound := comp.NumRounds() - 1 + comp.RegisterVerifierAction(lastRound, &verifierForSize{ Queries: queries, SummationOpening: ctx.SummationOpening, BatchOpening: batchingCoin, diff --git a/prover/protocol/compiler/innerproduct/innerproduct_test.go b/prover/protocol/compiler/innerproduct/innerproduct_test.go new file mode 100644 index 00000000000..0a9afb8998d --- /dev/null +++ b/prover/protocol/compiler/innerproduct/innerproduct_test.go @@ -0,0 +1,107 @@ +package innerproduct + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/stretchr/testify/assert" +) + +func TestInnerProduct(t *testing.T) { + define := func(b *wizard.Builder) { + for i, c := range testCases { + bs := make([]ifaces.Column, len(c.bName)) + a := b.RegisterCommit(c.aName, c.size) + for i, name := range c.bName { + bs[i] = b.RegisterCommit(name, c.size) + } + b.InnerProduct(c.qName, a, bs...) + // go to the next round + _ = b.RegisterRandomCoin(coin.Namef("Coin_%v", i), coin.Field) + } + } + prover := func(run *wizard.ProverRuntime) { + for j, c := range testCases { + run.AssignColumn(c.aName, c.a) + for i, name := range c.bName { + run.AssignColumn(name, c.b[i]) + } + run.AssignInnerProduct(c.qName, c.expected...) + run.GetRandomCoinField(coin.Namef("Coin_%v", j)) + } + } + + comp := wizard.Compile(define, Compile, dummy.Compile) + proof := wizard.Prove(comp, prover) + assert.NoErrorf(t, wizard.Verify(comp, proof), "invalid proof") +} + +var testCases = []struct { + qName ifaces.QueryID + aName ifaces.ColID + bName []ifaces.ColID + size int + a smartvectors.SmartVector + b []smartvectors.SmartVector + expected []field.Element +}{ + {qName: "Quey1", + aName: "ColA1", + bName: []ifaces.ColID{"ColB1"}, + size: 4, + a: smartvectors.ForTest(1, 1, 1, 1), + b: []smartvectors.SmartVector{ + smartvectors.ForTest(0, 3, 0, 2), + }, + expected: []field.Element{field.NewElement(5)}, + }, + {qName: "Quey2", + aName: "ColA2", + bName: []ifaces.ColID{"ColB2_0", "ColB2_1"}, + size: 4, + a: smartvectors.ForTest(1, 1, 1, 1), + b: []smartvectors.SmartVector{ + smartvectors.ForTest(0, 3, 0, 2), + smartvectors.ForTest(1, 0, 0, 2), + }, + expected: []field.Element{field.NewElement(5), field.NewElement(3)}, + }, + {qName: "Quey3", + aName: "ColA3", + bName: []ifaces.ColID{"ColB3_0", "ColB3_1"}, + size: 8, + a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0), + b: []smartvectors.SmartVector{ + smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0), + smartvectors.ForTest(1, 0, 0, 2, 1, 0, 0, 0), + }, + expected: []field.Element{field.NewElement(7), field.NewElement(5)}, + }, + {qName: "Quey4", + aName: "ColA4", + bName: []ifaces.ColID{"ColB4"}, + size: 16, + a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1), + b: []smartvectors.SmartVector{ + smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1), + }, + expected: []field.Element{field.NewElement(15)}, + }, + + {qName: "Quey", + + aName: "ColA", + bName: []ifaces.ColID{"ColB"}, + size: 32, + a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1), + b: []smartvectors.SmartVector{ + smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1), + }, + expected: []field.Element{field.NewElement(30)}, + }, +} From 7aac9c7a26acbb86509ffcbd61086b71c1436a55 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 17 Oct 2024 14:42:48 +0200 Subject: [PATCH 2/3] minor --- prover/protocol/compiler/innerproduct/context.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/prover/protocol/compiler/innerproduct/context.go b/prover/protocol/compiler/innerproduct/context.go index 924aeb4382a..81922c0bdf3 100644 --- a/prover/protocol/compiler/innerproduct/context.go +++ b/prover/protocol/compiler/innerproduct/context.go @@ -76,7 +76,7 @@ func compileForSize( if hasMoreThan1Pair { var ( - pairProduct = []ifaces.Column{} + pairProduct = []*symbolic.Expression{} ) batchingCoin = comp.InsertCoin( @@ -87,7 +87,7 @@ func compileForSize( for _, q := range queries { for _, b := range q.Bs { - pairProduct = append(pairProduct, q.A, b) + pairProduct = append(pairProduct, symbolic.Mul(q.A, b)) } } @@ -95,10 +95,9 @@ func compileForSize( // ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct) res := symbolic.NewConstant(0) for i := len(pairProduct) - 1; i >= 0; i-- { - res2 := symbolic.Mul(pairProduct[i], pairProduct[i-1]) + res2 := symbolic.Mul(pairProduct[i]) res = symbolic.Mul(res, batchingCoin) res = symbolic.Add(res, res2) - i-- } ctx.Collapsed = res From 89abc33fbeb7430f37dbabfe8af573b3e51d8a78 Mon Sep 17 00:00:00 2001 From: Soleimani193 Date: Thu, 17 Oct 2024 14:46:55 +0200 Subject: [PATCH 3/3] minor --- prover/protocol/compiler/innerproduct/context.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/prover/protocol/compiler/innerproduct/context.go b/prover/protocol/compiler/innerproduct/context.go index 81922c0bdf3..14eb787375d 100644 --- a/prover/protocol/compiler/innerproduct/context.go +++ b/prover/protocol/compiler/innerproduct/context.go @@ -95,9 +95,8 @@ func compileForSize( // ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct) res := symbolic.NewConstant(0) for i := len(pairProduct) - 1; i >= 0; i-- { - res2 := symbolic.Mul(pairProduct[i]) res = symbolic.Mul(res, batchingCoin) - res = symbolic.Add(res, res2) + res = symbolic.Add(res, pairProduct[i]) } ctx.Collapsed = res