Skip to content

Commit

Permalink
Adds polynomials and Lagrange polynomials.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Aug 1, 2022
1 parent dbf8547 commit dc30325
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 3 deletions.
2 changes: 2 additions & 0 deletions group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type Scalar interface {
Set(x Scalar) Scalar
// Copy returns a new scalar equal to the receiver.
Copy() Scalar
// IsZero returns true if the receiver is equal to zero.
IsZero() bool
// IsEqual returns true if the receiver is equal to x.
IsEqual(x Scalar) bool
// SetUint64 set the receiver to x, and returns the receiver.
Expand Down
3 changes: 2 additions & 1 deletion group/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ func testScalar(t *testing.T, testTimes int, g group.Group) {

c.Inv(a)
c.Mul(c, a)
if !one.IsEqual(c) {
c.Sub(c, one)
if !c.IsZero() {
test.ReportError(t, c, one, a)
}
}
Expand Down
5 changes: 3 additions & 2 deletions group/ristretto255.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

r255 "github.com/bwesterb/go-ristretto"
"github.com/cloudflare/circl/expander"
"github.com/cloudflare/circl/internal/conv"
)

// Ristretto255 is a quotient group generated from the edwards25519 curve.
Expand Down Expand Up @@ -203,9 +204,9 @@ func (e *ristrettoElement) UnmarshalBinary(data []byte) error {
}

func (s *ristrettoScalar) Group() Group { return Ristretto255 }
func (s *ristrettoScalar) String() string { return fmt.Sprintf("0x%x", s.s.Bytes()) }
func (s *ristrettoScalar) String() string { return conv.BytesLe2Hex(s.s.Bytes()) }
func (s *ristrettoScalar) SetUint64(n uint64) Scalar { s.s.SetUint64(n); return s }

func (s *ristrettoScalar) IsZero() bool { return s.s.IsNonZeroI() == 0 }
func (s *ristrettoScalar) IsEqual(x Scalar) bool {
return s.s.Equals(&x.(*ristrettoScalar).s)
}
Expand Down
4 changes: 4 additions & 0 deletions group/short.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ type wScl struct {
func (s *wScl) Group() Group { return s.wG }
func (s *wScl) String() string { return fmt.Sprintf("0x%x", s.k) }
func (s *wScl) SetUint64(n uint64) Scalar { s.fromBig(new(big.Int).SetUint64(n)); return s }
func (s *wScl) IsZero() bool {
return subtle.ConstantTimeCompare(s.k, make([]byte, (s.wG.c.Params().BitSize+7)/8)) == 1
}

func (s *wScl) IsEqual(a Scalar) bool {
aa := s.cvtScl(a)
return subtle.ConstantTimeCompare(s.k, aa.k) == 1
Expand Down
147 changes: 147 additions & 0 deletions math/polynomial/polynomial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Package polynomial provides representations of polynomials over the scalars
// of a group.
package polynomial

import "github.com/cloudflare/circl/group"

// Polynomial stores a polynomial over the set of scalars of a group.
type Polynomial struct {
// Internal representation is in polynomial basis:
// Thus,
// p(x) = \sum_i^k c[i] x^i,
// where k = len(c)-1 is the degree of the polynomial.
c []group.Scalar
}

// New creates a new polynomial given its coefficients in ascending order.
// Thus,
// p(x) = \sum_i^k c[i] x^i,
// where k = len(c)-1 is the degree of the polynomial.
//
// The zero polynomial has degree equal to -1 and can be instantiated passing
// nil to New.
func New(coeffs []group.Scalar) (p Polynomial) {
if l := len(coeffs); l != 0 {
p.c = make([]group.Scalar, l)
for i := range coeffs {
p.c[i] = coeffs[i].Copy()
}
}

return
}

func (p Polynomial) Degree() int {
i := len(p.c) - 1
for i > 0 && p.c[i].IsZero() {
i--
}
return i
}

func (p Polynomial) Evaluate(x group.Scalar) group.Scalar {
px := x.Group().NewScalar()
if l := len(p.c); l != 0 {
px.Set(p.c[l-1])
for i := l - 2; i >= 0; i-- {
px.Mul(px, x)
px.Add(px, p.c[i])
}
}
return px
}

// LagrangePolynomial stores a Lagrange polynomial over the set of scalars of a group.
type LagrangePolynomial struct {
// Internal representation is in Lagrange basis:
// Thus,
// p(x) = \sum_i^k y[i] L_j(x), where k is the degree of the polynomial,
// L_j(x) = \prod_i^k (x-x[i])/(x[j]-x[i]),
// y[i] = p(x[i]), and
// all x[i] are different.
x, y []group.Scalar
}

// NewLagrangePolynomial creates a polynomial in Lagrange basis given a list
// of nodes (x) and values (y), such that:
// p(x) = \sum_i^k y[i] L_j(x), where k is the degree of the polynomial,
// L_j(x) = \prod_i^k (x-x[i])/(x[j]-x[i]),
// y[i] = p(x[i]), and
// all x[i] are different.
// It panics if one of these conditions does not hold.
//
// The zero polynomial has degree equal to -1 and can be instantiated passing
// (nil,nil) to NewLagrangePolynomial.
func NewLagrangePolynomial(x, y []group.Scalar) (l LagrangePolynomial) {
if len(x) != len(y) {
panic("lagrange: invalid length")
}

if !areAllDifferent(x) {
panic("lagrange: x[i] must be different")
}

if n := len(x); n != 0 {
l.x, l.y = make([]group.Scalar, n), make([]group.Scalar, n)
for i := range x {
l.x[i], l.y[i] = x[i].Copy(), y[i].Copy()
}
}

return
}

func (l LagrangePolynomial) Degree() int { return len(l.x) - 1 }

func (l LagrangePolynomial) Evaluate(x group.Scalar) group.Scalar {
px := x.Group().NewScalar()
tmp := x.Group().NewScalar()
for i := range l.x {
LjX := baseRatio(uint(i), l.x, x)
tmp.Mul(l.y[i], LjX)
px.Add(px, tmp)
}

return px
}

// LagrangeBase returns the j-th Lagrange polynomial base evaluated at x.
// Thus, L_j(x) = \prod (x - x[i]) / (x[j] - x[i]) for 0 <= i < k, and i != j.
func LagrangeBase(jth uint, xi []group.Scalar, x group.Scalar) group.Scalar {
if jth >= uint(len(xi)) {
panic("lagrange: invalid index")
}
return baseRatio(jth, xi, x)
}

func baseRatio(jth uint, xi []group.Scalar, x group.Scalar) group.Scalar {
num := x.Copy()
num.SetUint64(1)
den := x.Copy()
den.SetUint64(1)

tmp := x.Copy()
for i := range xi {
if uint(i) != jth {
num.Mul(num, tmp.Sub(x, xi[i]))
den.Mul(den, tmp.Sub(xi[jth], xi[i]))
}
}

return num.Mul(num, den.Inv(den))
}

func areAllDifferent(x []group.Scalar) bool {
m := make(map[string]struct{})
for i := range x {
k, err := x[i].MarshalBinary()
if err != nil {
panic(err)
}
if _, exists := m[string(k)]; exists {
return false
}
m[string(k)] = struct{}{}
}
return true
}
131 changes: 131 additions & 0 deletions math/polynomial/polynomial_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package polynomial_test

import (
"testing"

"github.com/cloudflare/circl/group"
"github.com/cloudflare/circl/internal/test"
"github.com/cloudflare/circl/math/polynomial"
)

func TestPolyDegree(t *testing.T) {
g := group.P256

t.Run("zeroPoly", func(t *testing.T) {
p := polynomial.New(nil)
test.CheckOk(p.Degree() == -1, "it should be -1", t)
p = polynomial.New([]group.Scalar{})
test.CheckOk(p.Degree() == -1, "it should be -1", t)
})

t.Run("constantPoly", func(t *testing.T) {
c := []group.Scalar{
g.NewScalar().SetUint64(0),
g.NewScalar().SetUint64(0),
}
p := polynomial.New(c)
test.CheckOk(p.Degree() == 0, "it should be 0", t)
})

t.Run("linearPoly", func(t *testing.T) {
c := []group.Scalar{
g.NewScalar().SetUint64(0),
g.NewScalar().SetUint64(1),
g.NewScalar().SetUint64(0),
}
p := polynomial.New(c)
test.CheckOk(p.Degree() == 1, "it should be 1", t)
})
}

func TestPolyEval(t *testing.T) {
g := group.P256
c := []group.Scalar{
g.NewScalar(),
g.NewScalar(),
g.NewScalar(),
}
c[0].SetUint64(5)
c[1].SetUint64(5)
c[2].SetUint64(2)
p := polynomial.New(c)

x := g.NewScalar()
x.SetUint64(10)

got := p.Evaluate(x)
want := g.NewScalar()
want.SetUint64(255)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}
}

func TestLagrange(t *testing.T) {
g := group.P256
c := []group.Scalar{
g.NewScalar(),
g.NewScalar(),
g.NewScalar(),
}
c[0].SetUint64(1234)
c[1].SetUint64(166)
c[2].SetUint64(94)
p := polynomial.New(c)

x := []group.Scalar{g.NewScalar(), g.NewScalar(), g.NewScalar()}
x[0].SetUint64(2)
x[1].SetUint64(4)
x[2].SetUint64(5)

y := []group.Scalar{}
for i := range x {
y = append(y, p.Evaluate(x[i]))
}

zero := g.NewScalar()
l := polynomial.NewLagrangePolynomial(x, y)
test.CheckOk(l.Degree() == p.Degree(), "bad degree", t)

got := l.Evaluate(zero)
want := p.Evaluate(zero)

if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

// Test Kronecker's delta of LagrangeBase.
// Thus:
// L_j(x[i]) = { 1, if i == j;
// { 0, otherwise.
one := g.NewScalar()
one.SetUint64(1)
for j := range x {
for i := range x {
got := polynomial.LagrangeBase(uint(j), x, x[i])

if i == j {
want = one
} else {
want = zero
}

if !got.IsEqual(want) {
test.ReportError(t, got, want)
}
}
}

// Test that inputs are different length
err := test.CheckPanic(func() { polynomial.NewLagrangePolynomial(x, y[:1]) })
test.CheckNoErr(t, err, "should panic")

// Test that nodes must be different.
x[0].Set(x[1])
err = test.CheckPanic(func() { polynomial.NewLagrangePolynomial(x, y) })
test.CheckNoErr(t, err, "should panic")

// Test LagrangeBase wrong index
err = test.CheckPanic(func() { polynomial.LagrangeBase(10, x, zero) })
test.CheckNoErr(t, err, "should panic")
}

0 comments on commit dc30325

Please sign in to comment.