Skip to content

Commit

Permalink
Merge pull request #29 from starkbank/fix/x9.62-review
Browse files Browse the repository at this point in the history
Fix point at infinity verification in signature and public key
  • Loading branch information
cdottori-stark authored Nov 9, 2021
2 parents cb6d807 + 2ea31b0 commit bc59e5d
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 24 deletions.
19 changes: 11 additions & 8 deletions ellipticcurve/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# y^2 = x^3 + A*x + B (mod P)
#

from .point import Point


Expand All @@ -26,7 +25,13 @@ def contains(self, p):
:param p: Point p = Point(x, y)
:return: boolean
"""
return (p.y**2 - (p.x**3 + self.A * p.x + self.B)) % self.P == 0
if not 0 <= p.x <= self.P - 1:
return False
if not 0 <= p.y <= self.P - 1:
return False
if (p.y**2 - (p.x**3 + self.A * p.x + self.B)) % self.P != 0:
return False
return True

def length(self):
return (1 + len("%x" % self.N)) // 2
Expand Down Expand Up @@ -67,10 +72,8 @@ def length(self):

def getCurveByOid(oid):
if oid not in _curvesByOid:
raise Exception(
"Unknown curve with oid %s; The following are registered: %s" % (
".".join(oid),
", ".join([curve.name for curve in supportedCurves])
)
)
raise Exception("Unknown curve with oid {oid}; The following are registered: {names}".format(
oid=".".join(oid),
names=", ".join([curve.name for curve in supportedCurves]),
))
return _curvesByOid[oid]
7 changes: 4 additions & 3 deletions ellipticcurve/ecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def verify(cls, message, signature, publicKey, hashfunc=sha256):
inv = Math.inv(s, curve.N)
u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P)
u2 = Math.multiply(publicKey.point, n=(r * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P)
add = Math.add(u1, u2, A=curve.A, P=curve.P)
modX = add.x % curve.N
return r == modX
v = Math.add(u1, u2, A=curve.A, P=curve.P)
if v.isAtInfinity():
return False
return v.x % curve.N == r
7 changes: 3 additions & 4 deletions ellipticcurve/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _jacobianDouble(cls, p, A, P):
:param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
:return: Point that represents the sum of First and Second Point
"""
if not p.y:
if p.y == 0:
return Point(0, 0, 0)

ysq = (p.y ** 2) % P
Expand All @@ -120,10 +120,9 @@ def _jacobianAdd(cls, p, q, A, P):
:param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
:return: Point that represents the sum of First and Second Point
"""
if not p.y:
if p.y == 0:
return q

if not q.y:
if q.y == 0:
return p

U1 = (p.x * q.z ** 2) % P
Expand Down
6 changes: 6 additions & 0 deletions ellipticcurve/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ def __init__(self, x=0, y=0, z=0):
self.x = x
self.y = y
self.z = z

def __str__(self):
return "({x}, {y}, {z})".format(x=self.x, y=self.y, z=self.z)

def isAtInfinity(self):
return self.y == 0
21 changes: 13 additions & 8 deletions ellipticcurve/publicKey.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .math import Math
from .point import Point
from .curve import secp256k1, getCurveByOid
from .utils.pem import getPemContent, createPem
from .utils.binary import hexFromByteString, byteStringFromHex, intFromHex, base64FromByteString, byteStringFromBase64
from .utils.der import hexFromInt, parse, DerFieldType, encodeConstructed, encodePrimitive
from .curve import secp256k1, getCurveByOid
from .utils.binary import hexFromByteString, byteStringFromHex, intFromHex, base64FromByteString, byteStringFromBase64


class PublicKey:
Expand Down Expand Up @@ -65,12 +66,16 @@ def fromString(cls, string, curve=secp256k1, validatePoint=True):
x=intFromHex(xs),
y=intFromHex(ys),
)
if validatePoint and not curve.contains(p):
raise Exception(
"Point ({x},{y}) is not valid for curve {name}".format(x=p.x, y=p.y, name=curve.name)
)

return PublicKey(point=p, curve=curve)
publicKey = PublicKey(point=p, curve=curve)
if not validatePoint:
return publicKey
if p.isAtInfinity():
raise Exception("Public Key point is at infinity")
if not curve.contains(p):
raise Exception("Point ({x},{y}) is not valid for curve {name}".format(x=p.x, y=p.y, name=curve.name))
if not Math.multiply(p=p, n=curve.N, N=curve.N, A=curve.A, P=curve.P).isAtInfinity():
raise Exception("Point ({x},{y}) * {name}.N is not at infinity".format(x=p.x, y=p.y, name=curve.name))
return publicKey


_ecdsaPublicKeyOid = (1, 2, 840, 10045, 2, 1)
Expand Down
10 changes: 9 additions & 1 deletion tests/testEcdsa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest.case import TestCase
from ellipticcurve import Ecdsa, PrivateKey
from ellipticcurve import Ecdsa, PrivateKey, Signature


class EcdsaTest(TestCase):
Expand All @@ -24,3 +24,11 @@ def testVerifyWrongMessage(self):
signature = Ecdsa.sign(message1, privateKey)

self.assertFalse(Ecdsa.verify(message2, signature, publicKey))

def testZeroSignature(self):
privateKey = PrivateKey()
publicKey = privateKey.publicKey()

message2 = "This is the wrong message"

self.assertFalse(Ecdsa.verify(message2, Signature(0, 0), publicKey))

0 comments on commit bc59e5d

Please sign in to comment.