Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions prover/protocol/compiler/innerproduct/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ func Compile(comp *wizard.CompiledIOP) {
// same protocol if the same step of compilation are applied in the same
// order.
sizes = []int{}
// contextsForSize list all the sub-compilation context in the same
// order as `sizes`
proverTask proverTask
// contextsForSize list all the sub-compilation context
// in the same order as `sizes`.
// proverTaskCollaps indicates when we have more than one pair of inner-product with the same size
// and thus collapsing all pairs to a single column is required.
proverTaskNoCollaps, proverTaskCollpas proverTask
)

for _, qName := range comp.QueriesParams.AllUnignoredKeys() {
Expand All @@ -60,9 +62,24 @@ func Compile(comp *wizard.CompiledIOP) {
}

for _, size := range sizes {
proverTask = append(proverTask, compileForSize(comp, round, queryMap[size]))
ctx := compileForSize(comp, round, queryMap[size])
switch ctx.round {
case round:
proverTaskNoCollaps = append(proverTaskNoCollaps, ctx)
case round + 1:
proverTaskCollpas = append(proverTaskCollpas, ctx)
default:
utils.Panic("round before compilation was %v and after compilation %v", round, ctx.round)
}

}
// run the prover of the relevant round
if len(proverTaskNoCollaps) >= 1 {
comp.RegisterProverAction(round, proverTaskNoCollaps)
}

comp.RegisterProverAction(round+1, proverTask)
if len(proverTaskCollpas) >= 1 {
comp.RegisterProverAction(round+1, proverTaskCollpas)
}

}
41 changes: 28 additions & 13 deletions prover/protocol/compiler/innerproduct/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ type contextForSize struct {
// entry of [Summation]. It is compared to the alleged inner-product values
// by the verifier to finalize the compilation step.s
SummationOpening query.LocalOpening

// round after compilation
round int
}

// compileForSize applies the compilation step on a range of queries such that
// they all relate to column of the same size. The function expects a non-empty
// list of queries.
//
// It returns the compilation context of the query
// the round indicate the round of the last inner-product query, independent of its size.
func compileForSize(
comp *wizard.CompiledIOP,
round int,
Expand All @@ -60,10 +64,12 @@ func compileForSize(
if hasMoreThan1Pair {
round = round + 1
}
//set the round
ctx.round = round

ctx.Summation = comp.InsertCommit(
round+1,
deriveName[ifaces.ColID]("SUMMATION", comp.SelfRecursionCount),
round,
deriveName[ifaces.ColID]("SUMMATION", size, comp.SelfRecursionCount),
size,
)

Expand All @@ -74,8 +80,8 @@ func compileForSize(
)

batchingCoin = comp.InsertCoin(
round+1,
deriveName[coin.Name]("BATCHING_COIN", comp.SelfRecursionCount),
round,
deriveName[coin.Name]("BATCHING_COIN", size, comp.SelfRecursionCount),
coin.Field,
)

Expand All @@ -85,8 +91,16 @@ func compileForSize(
}
}

ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct)
ctx.Collapsed.Board()
// @Azam the following function is commented out due to the issue https://github.com/Consensys/linea-monorepo/issues/192
// ctx.Collapsed = symbolic.NewPolyEval(batchingCoin.AsVariable(), pairProduct)
res := symbolic.NewConstant(0)
for i := len(pairProduct) - 1; i >= 0; i-- {
res = symbolic.Mul(res, batchingCoin)
res = symbolic.Add(res, pairProduct[i])
}

ctx.Collapsed = res
ctx.CollapsedBoard = ctx.Collapsed.Board()
}

if !hasMoreThan1Pair {
Expand All @@ -96,8 +110,8 @@ func compileForSize(

// This constraints set the recurrent property of summation
comp.InsertGlobal(
round+1,
deriveName[ifaces.QueryID]("SUMMATION_CONSISTENCY", comp.SelfRecursionCount),
round,
deriveName[ifaces.QueryID]("SUMMATION_CONSISTENCY", size, comp.SelfRecursionCount),
symbolic.Sub(
ctx.Summation,
column.Shift(ctx.Summation, -1),
Expand All @@ -107,20 +121,21 @@ func compileForSize(

// This constraint ensures that summation has the correct initial value
comp.InsertLocal(
round+1,
deriveName[ifaces.QueryID]("SUMMATION_INIT", comp.SelfRecursionCount),
round,
deriveName[ifaces.QueryID]("SUMMATION_INIT", size, comp.SelfRecursionCount),
symbolic.Sub(ctx.Collapsed, ctx.Summation),
)

// The opening of the final position of ctx.Summation should be equal to
// the linear combinations of the alleged openings of the inner-products.
ctx.SummationOpening = comp.InsertLocalOpening(
round+1,
deriveName[ifaces.QueryID]("SUMMATION_END", comp.SelfRecursionCount),
round,
deriveName[ifaces.QueryID]("SUMMATION_END", size, comp.SelfRecursionCount),
column.Shift(ctx.Summation, -1),
)

comp.RegisterVerifierAction(round+1, &verifierForSize{
lastRound := comp.NumRounds() - 1
comp.RegisterVerifierAction(lastRound, &verifierForSize{
Queries: queries,
SummationOpening: ctx.SummationOpening,
BatchOpening: batchingCoin,
Expand Down
107 changes: 107 additions & 0 deletions prover/protocol/compiler/innerproduct/innerproduct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package innerproduct

import (
"testing"

"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/stretchr/testify/assert"
)

func TestInnerProduct(t *testing.T) {
define := func(b *wizard.Builder) {
for i, c := range testCases {
bs := make([]ifaces.Column, len(c.bName))
a := b.RegisterCommit(c.aName, c.size)
for i, name := range c.bName {
bs[i] = b.RegisterCommit(name, c.size)
}
b.InnerProduct(c.qName, a, bs...)
// go to the next round
_ = b.RegisterRandomCoin(coin.Namef("Coin_%v", i), coin.Field)
}
}
prover := func(run *wizard.ProverRuntime) {
for j, c := range testCases {
run.AssignColumn(c.aName, c.a)
for i, name := range c.bName {
run.AssignColumn(name, c.b[i])
}
run.AssignInnerProduct(c.qName, c.expected...)
run.GetRandomCoinField(coin.Namef("Coin_%v", j))
}
}

comp := wizard.Compile(define, Compile, dummy.Compile)
proof := wizard.Prove(comp, prover)
assert.NoErrorf(t, wizard.Verify(comp, proof), "invalid proof")
}

var testCases = []struct {
qName ifaces.QueryID
aName ifaces.ColID
bName []ifaces.ColID
size int
a smartvectors.SmartVector
b []smartvectors.SmartVector
expected []field.Element
}{
{qName: "Quey1",
aName: "ColA1",
bName: []ifaces.ColID{"ColB1"},
size: 4,
a: smartvectors.ForTest(1, 1, 1, 1),
b: []smartvectors.SmartVector{
smartvectors.ForTest(0, 3, 0, 2),
},
expected: []field.Element{field.NewElement(5)},
},
{qName: "Quey2",
aName: "ColA2",
bName: []ifaces.ColID{"ColB2_0", "ColB2_1"},
size: 4,
a: smartvectors.ForTest(1, 1, 1, 1),
b: []smartvectors.SmartVector{
smartvectors.ForTest(0, 3, 0, 2),
smartvectors.ForTest(1, 0, 0, 2),
},
expected: []field.Element{field.NewElement(5), field.NewElement(3)},
},
{qName: "Quey3",
aName: "ColA3",
bName: []ifaces.ColID{"ColB3_0", "ColB3_1"},
size: 8,
a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0),
b: []smartvectors.SmartVector{
smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0),
smartvectors.ForTest(1, 0, 0, 2, 1, 0, 0, 0),
},
expected: []field.Element{field.NewElement(7), field.NewElement(5)},
},
{qName: "Quey4",
aName: "ColA4",
bName: []ifaces.ColID{"ColB4"},
size: 16,
a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1),
b: []smartvectors.SmartVector{
smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1),
},
expected: []field.Element{field.NewElement(15)},
},

{qName: "Quey",

aName: "ColA",
bName: []ifaces.ColID{"ColB"},
size: 32,
a: smartvectors.ForTest(1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1),
b: []smartvectors.SmartVector{
smartvectors.ForTest(0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1),
},
expected: []field.Element{field.NewElement(30)},
},
}