Skip to content

Commit

Permalink
Merge pull request #26 from starkbank/refactor/der
Browse files Browse the repository at this point in the history
Refactor DER and binary handling structures
  • Loading branch information
cdottori-stark authored Oct 7, 2021
2 parents e215ded + 1dabf70 commit 6f6b680
Show file tree
Hide file tree
Showing 20 changed files with 450 additions and 547 deletions.
7 changes: 6 additions & 1 deletion ellipticcurve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ellipticcurve.utils.compatibility import *
from ellipticcurve.utils.compatibility import *
from ellipticcurve.privateKey import PrivateKey
from ellipticcurve.publicKey import PublicKey
from ellipticcurve.signature import Signature
from ellipticcurve.utils.file import File
from ellipticcurve.ecdsa import Ecdsa
20 changes: 16 additions & 4 deletions ellipticcurve/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, A, B, P, N, Gx, Gy, name, oid, nistName=None):
self.G = Point(Gx, Gy)
self.name = name
self.nistName = nistName
self.oid = oid
self.oid = oid # ASN.1 Object Identifier

def contains(self, p):
"""
Expand All @@ -40,7 +40,7 @@ def length(self):
N=0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141,
Gx=0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798,
Gy=0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8,
oid=(1, 3, 132, 0, 10)
oid=[1, 3, 132, 0, 10]
)

prime256v1 = CurveFp(
Expand All @@ -52,13 +52,25 @@ def length(self):
N=0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551,
Gx=0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296,
Gy=0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5,
oid=(1, 2, 840, 10045, 3, 1, 7),
oid=[1, 2, 840, 10045, 3, 1, 7],
)

p256 = prime256v1

supportedCurves = [
secp256k1,
prime256v1,
]

curvesByOid = {curve.oid: curve for curve in supportedCurves}
_curvesByOid = {tuple(curve.oid): curve for curve in supportedCurves}


def getCurveByOid(oid):
if oid not in _curvesByOid:
raise Exception(
"Unknown curve with oid %s; The following are registered: %s" % (
".".join(oid),
", ".join([curve.name for curve in supportedCurves])
)
)
return _curvesByOid[oid]
24 changes: 12 additions & 12 deletions ellipticcurve/ecdsa.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from hashlib import sha256
from .signature import Signature
from .math import Math
from .utils.binary import BinaryAscii
from .utils.integer import RandomInteger
from .utils.binary import numberFromByteString
from .utils.compatibility import *


class Ecdsa:

@classmethod
def sign(cls, message, privateKey, hashfunc=sha256):
hashMessage = hashfunc(toBytes(message)).digest()
numberMessage = BinaryAscii.numberFromString(hashMessage)
byteMessage = hashfunc(toBytes(message)).digest()
numberMessage = numberFromByteString(byteMessage)
curve = privateKey.curve

r, s, randSignPoint = 0, 0, None
Expand All @@ -28,14 +28,14 @@ def sign(cls, message, privateKey, hashfunc=sha256):

@classmethod
def verify(cls, message, signature, publicKey, hashfunc=sha256):
hashMessage = hashfunc(toBytes(message)).digest()
numberMessage = BinaryAscii.numberFromString(hashMessage)
byteMessage = hashfunc(toBytes(message)).digest()
numberMessage = numberFromByteString(byteMessage)
curve = publicKey.curve
sigR = signature.r
sigS = signature.s
inv = Math.inv(sigS, curve.N)
u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, A=curve.A, P=curve.P, N=curve.N)
u2 = Math.multiply(publicKey.point, n=(sigR * inv) % curve.N, A=curve.A, P=curve.P, N=curve.N)
add = Math.add(u1, u2, P=curve.P, A=curve.A)
r = signature.r
s = signature.s
inv = Math.inv(s, curve.N)
u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P)
u2 = Math.multiply(publicKey.point, n=(r * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P)
add = Math.add(u1, u2, A=curve.A, P=curve.P)
modX = add.x % curve.N
return sigR == modX
return r == modX
68 changes: 20 additions & 48 deletions ellipticcurve/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@ def multiply(cls, p, n, N, A, P):
:return: Point that represents the sum of First and Second Point
"""
return cls._fromJacobian(
cls._jacobianMultiply(
cls._toJacobian(p),
n,
N,
A,
P,
),
P,
cls._jacobianMultiply(cls._toJacobian(p), n, N, A, P), P
)

@classmethod
Expand All @@ -38,13 +31,7 @@ def add(cls, p, q, A, P):
:return: Point that represents the sum of First and Second Point
"""
return cls._fromJacobian(
cls._jacobianAdd(
cls._toJacobian(p),
cls._toJacobian(q),
A,
P,
),
P,
cls._jacobianAdd(cls._toJacobian(p), cls._toJacobian(q), A, P), P,
)

@classmethod
Expand All @@ -59,12 +46,19 @@ def inv(cls, x, n):
if x == 0:
return 0

lm, hm = 1, 0
low, high = x % n, n
lm = 1
hm = 0
low = x % n
high = n

while low > 1:
r = high // low
nm, new = hm - lm * r, high - low * r
lm, low, hm, high = nm, new, lm, low
nm = hm - lm * r
nw = high - low * r
high = low
hm = lm
low = nw
lm = nm

return lm % n

Expand All @@ -88,11 +82,10 @@ def _fromJacobian(cls, p, P):
:return: Point in default coordinates
"""
z = cls.inv(p.z, P)
x = (p.x * z ** 2) % P
y = (p.y * z ** 3) % P

return Point(
(p.x * z ** 2) % P,
(p.y * z ** 3) % P,
)
return Point(x, y, 0)

@classmethod
def _jacobianDouble(cls, p, A, P):
Expand All @@ -113,6 +106,7 @@ def _jacobianDouble(cls, p, A, P):
nx = (M**2 - 2 * S) % P
ny = (M * (S - nx) - 8 * ysq ** 2) % P
nz = (2 * p.y * p.z) % P

return Point(nx, ny, nz)

@classmethod
Expand All @@ -126,9 +120,9 @@ def _jacobianAdd(cls, p, q, A, P):
:param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p)
:return: Point that represents the sum of First and Second Point
"""

if not p.y:
return q

if not q.y:
return p

Expand Down Expand Up @@ -176,31 +170,9 @@ def _jacobianMultiply(cls, p, n, N, A, P):

if (n % 2) == 0:
return cls._jacobianDouble(
cls._jacobianMultiply(
p,
n // 2,
N,
A,
P
),
A,
P,
cls._jacobianMultiply(p, n // 2, N, A, P), A, P
)

# (n % 2) == 1:
return cls._jacobianAdd(
cls._jacobianDouble(
cls._jacobianMultiply(
p,
n // 2,
N,
A,
P,
),
A,
P,
),
p,
A,
P,
cls._jacobianDouble(cls._jacobianMultiply(p, n // 2, N, A, P), A, P), p, A, P
)
95 changes: 36 additions & 59 deletions ellipticcurve/privateKey.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .math import Math
from .utils.integer import RandomInteger
from .utils.compatibility import *
from .utils.binary import BinaryAscii
from .utils.der import fromPem, removeSequence, removeInteger, removeObject, removeOctetString, removeConstructed, toPem, encodeSequence, encodeInteger, encodeBitString, encodeOid, encodeOctetString, encodeConstructed
from .utils.pem import getPemContent, createPem
from .utils.binary import hexFromByteString, byteStringFromHex, intFromHex, base64FromByteString, byteStringFromBase64
from .utils.der import hexFromInt, parse, encodeConstructed, DerFieldType, encodePrimitive
from .curve import secp256k1, getCurveByOid
from .publicKey import PublicKey
from .curve import secp256k1, curvesByOid, supportedCurves
from .math import Math

hexAt = "\x00"


class PrivateKey:
Expand All @@ -27,69 +25,48 @@ def publicKey(self):
return PublicKey(point=publicPoint, curve=curve)

def toString(self):
return BinaryAscii.stringFromNumber(number=self.secret, length=self.curve.length())
return hexFromInt(self.secret)

def toDer(self):
encodedPublicKey = self.publicKey().toString(encoded=True)

return encodeSequence(
encodeInteger(1),
encodeOctetString(self.toString()),
encodeConstructed(0, encodeOid(*self.curve.oid)),
encodeConstructed(1, encodeBitString(encodedPublicKey)),
publicKeyString = self.publicKey().toString(encoded=True)
hexadecimal = encodeConstructed(
encodePrimitive(DerFieldType.integer, 1),
encodePrimitive(DerFieldType.octetString, hexFromInt(self.secret)),
encodePrimitive(DerFieldType.oidContainer, encodePrimitive(DerFieldType.object, self.curve.oid)),
encodePrimitive(DerFieldType.publicKeyPointContainer, encodePrimitive(DerFieldType.bitString, publicKeyString))
)
return byteStringFromHex(hexadecimal)

def toPem(self):
return toPem(der=toBytes(self.toDer()), name="EC PRIVATE KEY")
der = self.toDer()
return createPem(content=base64FromByteString(der), template=_pemTemplate)

@classmethod
def fromPem(cls, string):
privateKeyPem = string[string.index("-----BEGIN EC PRIVATE KEY-----"):]
return cls.fromDer(fromPem(privateKeyPem))
privateKeyPem = getPemContent(pem=string, template=_pemTemplate)
return cls.fromDer(byteStringFromBase64(privateKeyPem))

@classmethod
def fromDer(cls, string):
t, empty = removeSequence(string)
if len(empty) != 0:
raise Exception(
"trailing junk after DER private key: " +
BinaryAscii.hexFromBinary(empty)
)

one, t = removeInteger(t)
if one != 1:
raise Exception(
"expected '1' at start of DER private key, got %d" % one
)

privateKeyStr, t = removeOctetString(t)
tag, curveOidStr, t = removeConstructed(t)
if tag != 0:
raise Exception("expected tag 0 in DER private key, got %d" % tag)

oidCurve, empty = removeObject(curveOidStr)

if len(empty) != 0:
raise Exception(
"trailing junk after DER private key curve_oid: %s" %
BinaryAscii.hexFromBinary(empty)
)

if oidCurve not in curvesByOid:
raise Exception(
"unknown curve with oid %s; The following are registered: %s" % (
oidCurve,
", ".join([curve.name for curve in supportedCurves])
)
)

curve = curvesByOid[oidCurve]

if len(privateKeyStr) < curve.length():
privateKeyStr = hexAt * (curve.lenght() - len(privateKeyStr)) + privateKeyStr

return cls.fromString(privateKeyStr, curve)
hexadecimal = hexFromByteString(string)
privateKeyFlag, secretHex, curveData, publicKeyString = parse(hexadecimal)[0]
if privateKeyFlag != 1:
raise Exception("Private keys should start with a '1' flag, but a '{flag}' was found instead".format(
flag=privateKeyFlag
))
curve = getCurveByOid(curveData[0])
privateKey = cls.fromString(string=secretHex, curve=curve)
if privateKey.publicKey().toString(encoded=True) != publicKeyString[0]:
raise Exception("The public key described inside the private key file doesn't match the actual public key of the pair")
return privateKey

@classmethod
def fromString(cls, string, curve=secp256k1):
return PrivateKey(secret=BinaryAscii.numberFromString(string), curve=curve)
return PrivateKey(secret=intFromHex(string), curve=curve)


_pemTemplate = """
-----BEGIN EC PRIVATE KEY-----
{content}
-----END EC PRIVATE KEY-----
"""
Loading

0 comments on commit 6f6b680

Please sign in to comment.