Skip to content

Commit 4096fa0

Browse files
authored
Merge pull request #344 from tlsfuzzer/non-prime-order-curve
handle non-prime order curves more gracefully
2 parents eebe016 + de9141c commit 4096fa0

File tree

4 files changed

+287
-30
lines changed

4 files changed

+287
-30
lines changed

src/ecdsa/ellipticcurve.py

+34-23
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def __eq__(self, other):
633633
"""
634634
x1, y1, z1 = self.__coords
635635
if other is INFINITY:
636-
return not y1 or not z1
636+
return not z1
637637
if isinstance(other, Point):
638638
x2, y2, z2 = other.x(), other.y(), 1
639639
elif isinstance(other, PointJacobi):
@@ -723,11 +723,13 @@ def scale(self):
723723

724724
def to_affine(self):
725725
"""Return point in affine form."""
726-
_, y, z = self.__coords
727-
if not y or not z:
726+
_, _, z = self.__coords
727+
p = self.__curve.p()
728+
if not (z % p):
728729
return INFINITY
729730
self.scale()
730731
x, y, z = self.__coords
732+
assert z == 1
731733
return Point(self.__curve, x, y, self.__order)
732734

733735
@staticmethod
@@ -759,7 +761,7 @@ def _double_with_z_1(self, X1, Y1, p, a):
759761
# http://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-mdbl-2007-bl
760762
XX, YY = X1 * X1 % p, Y1 * Y1 % p
761763
if not YY:
762-
return 0, 0, 1
764+
return 0, 0, 0
763765
YYYY = YY * YY % p
764766
S = 2 * ((X1 + YY) ** 2 - XX - YYYY) % p
765767
M = 3 * XX + a
@@ -773,13 +775,13 @@ def _double(self, X1, Y1, Z1, p, a):
773775
"""Add a point to itself, arbitrary z."""
774776
if Z1 == 1:
775777
return self._double_with_z_1(X1, Y1, p, a)
776-
if not Y1 or not Z1:
777-
return 0, 0, 1
778+
if not Z1:
779+
return 0, 0, 0
778780
# after:
779781
# http://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-2007-bl
780782
XX, YY = X1 * X1 % p, Y1 * Y1 % p
781783
if not YY:
782-
return 0, 0, 1
784+
return 0, 0, 0
783785
YYYY = YY * YY % p
784786
ZZ = Z1 * Z1 % p
785787
S = 2 * ((X1 + YY) ** 2 - XX - YYYY) % p
@@ -795,14 +797,14 @@ def double(self):
795797
"""Add a point to itself."""
796798
X1, Y1, Z1 = self.__coords
797799

798-
if not Y1:
800+
if not Z1:
799801
return INFINITY
800802

801803
p, a = self.__curve.p(), self.__curve.a()
802804

803805
X3, Y3, Z3 = self._double(X1, Y1, Z1, p, a)
804806

805-
if not Y3 or not Z3:
807+
if not Z3:
806808
return INFINITY
807809
return PointJacobi(self.__curve, X3, Y3, Z3, self.__order)
808810

@@ -886,10 +888,10 @@ def __radd__(self, other):
886888

887889
def _add(self, X1, Y1, Z1, X2, Y2, Z2, p):
888890
"""add two points, select fastest method."""
889-
if not Y1 or not Z1:
890-
return X2, Y2, Z2
891-
if not Y2 or not Z2:
892-
return X1, Y1, Z1
891+
if not Z1:
892+
return X2 % p, Y2 % p, Z2 % p
893+
if not Z2:
894+
return X1 % p, Y1 % p, Z1 % p
893895
if Z1 == Z2:
894896
if Z1 == 1:
895897
return self._add_with_z_1(X1, Y1, X2, Y2, p)
@@ -917,7 +919,7 @@ def __add__(self, other):
917919

918920
X3, Y3, Z3 = self._add(X1, Y1, Z1, X2, Y2, Z2, p)
919921

920-
if not Y3 or not Z3:
922+
if not Z3:
921923
return INFINITY
922924
return PointJacobi(self.__curve, X3, Y3, Z3, self.__order)
923925

@@ -927,7 +929,7 @@ def __rmul__(self, other):
927929

928930
def _mul_precompute(self, other):
929931
"""Multiply point by integer with precomputation table."""
930-
X3, Y3, Z3, p = 0, 0, 1, self.__curve.p()
932+
X3, Y3, Z3, p = 0, 0, 0, self.__curve.p()
931933
_add = self._add
932934
for X2, Y2 in self.__precompute:
933935
if other % 2:
@@ -940,7 +942,7 @@ def _mul_precompute(self, other):
940942
else:
941943
other //= 2
942944

943-
if not Y3 or not Z3:
945+
if not Z3:
944946
return INFINITY
945947
return PointJacobi(self.__curve, X3, Y3, Z3, self.__order)
946948

@@ -959,7 +961,7 @@ def __mul__(self, other):
959961

960962
self = self.scale()
961963
X2, Y2, _ = self.__coords
962-
X3, Y3, Z3 = 0, 0, 1
964+
X3, Y3, Z3 = 0, 0, 0
963965
p, a = self.__curve.p(), self.__curve.a()
964966
_double = self._double
965967
_add = self._add
@@ -972,7 +974,7 @@ def __mul__(self, other):
972974
elif i > 0:
973975
X3, Y3, Z3 = _add(X3, Y3, Z3, X2, Y2, 1, p)
974976

975-
if not Y3 or not Z3:
977+
if not Z3:
976978
return INFINITY
977979

978980
return PointJacobi(self.__curve, X3, Y3, Z3, self.__order)
@@ -1001,7 +1003,7 @@ def mul_add(self, self_mul, other, other_mul):
10011003
other_mul = other_mul % self.__order
10021004

10031005
# (X3, Y3, Z3) is the accumulator
1004-
X3, Y3, Z3 = 0, 0, 1
1006+
X3, Y3, Z3 = 0, 0, 0
10051007
p, a = self.__curve.p(), self.__curve.a()
10061008

10071009
# as we have 6 unique points to work with, we can't scale all of them,
@@ -1025,7 +1027,7 @@ def mul_add(self, self_mul, other, other_mul):
10251027
# when the self and other sum to infinity, we need to add them
10261028
# one by one to get correct result but as that's very unlikely to
10271029
# happen in regular operation, we don't need to optimise this case
1028-
if not pApB_Y or not pApB_Z:
1030+
if not pApB_Z:
10291031
return self * self_mul + other * other_mul
10301032

10311033
# gmp object creation has cumulatively higher overhead than the
@@ -1070,7 +1072,7 @@ def mul_add(self, self_mul, other, other_mul):
10701072
assert B > 0
10711073
X3, Y3, Z3 = _add(X3, Y3, Z3, pApB_X, pApB_Y, pApB_Z, p)
10721074

1073-
if not Y3 or not Z3:
1075+
if not Z3:
10741076
return INFINITY
10751077

10761078
return PointJacobi(self.__curve, X3, Y3, Z3, self.__order)
@@ -1154,6 +1156,8 @@ def __eq__(self, other):
11541156
11551157
Note: only points that lay on the same curve can be equal.
11561158
"""
1159+
if other is INFINITY:
1160+
return self.__x is None or self.__y is None
11571161
if isinstance(other, Point):
11581162
return (
11591163
self.__curve == other.__curve
@@ -1220,7 +1224,12 @@ def leftmost_bit(x):
12201224
# From X9.62 D.3.2:
12211225

12221226
e3 = 3 * e
1223-
negative_self = Point(self.__curve, self.__x, -self.__y, self.__order)
1227+
negative_self = Point(
1228+
self.__curve,
1229+
self.__x,
1230+
(-self.__y) % self.__curve.p(),
1231+
self.__order,
1232+
)
12241233
i = leftmost_bit(e3) // 2
12251234
result = self
12261235
# print("Multiplying %s by %d (e3 = %d):" % (self, other, e3))
@@ -1247,7 +1256,6 @@ def __str__(self):
12471256

12481257
def double(self):
12491258
"""Return a new point that is twice the old."""
1250-
12511259
if self == INFINITY:
12521260
return INFINITY
12531261

@@ -1261,6 +1269,9 @@ def double(self):
12611269
* numbertheory.inverse_mod(2 * self.__y, p)
12621270
) % p
12631271

1272+
if not l:
1273+
return INFINITY
1274+
12641275
x3 = (l * l - 2 * self.__x) % p
12651276
y3 = (l * (self.__x - x3) - self.__y) % p
12661277

src/ecdsa/test_ellipticcurve.py

+38
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def test_inequality_curves(self):
8383
c192 = CurveFp(p, -3, b)
8484
self.assertNotEqual(self.c_23, c192)
8585

86+
def test_inequality_curves_by_b_only(self):
87+
a = CurveFp(23, 1, 0)
88+
b = CurveFp(23, 1, 1)
89+
self.assertNotEqual(a, b)
90+
8691
def test_usability_in_a_hashed_collection_curves(self):
8792
{self.c_23: None}
8893

@@ -184,6 +189,33 @@ def test_double(self):
184189
self.assertEqual(p3.x(), x3)
185190
self.assertEqual(p3.y(), y3)
186191

192+
def test_double_to_infinity(self):
193+
p1 = Point(self.c_23, 11, 20)
194+
p2 = p1.double()
195+
self.assertEqual((p2.x(), p2.y()), (4, 0))
196+
self.assertNotEqual(p2, INFINITY)
197+
p3 = p2.double()
198+
self.assertEqual(p3, INFINITY)
199+
self.assertIs(p3, INFINITY)
200+
201+
def test_add_self_to_infinity(self):
202+
p1 = Point(self.c_23, 11, 20)
203+
p2 = p1 + p1
204+
self.assertEqual((p2.x(), p2.y()), (4, 0))
205+
self.assertNotEqual(p2, INFINITY)
206+
p3 = p2 + p2
207+
self.assertEqual(p3, INFINITY)
208+
self.assertIs(p3, INFINITY)
209+
210+
def test_mul_to_infinity(self):
211+
p1 = Point(self.c_23, 11, 20)
212+
p2 = p1 * 2
213+
self.assertEqual((p2.x(), p2.y()), (4, 0))
214+
self.assertNotEqual(p2, INFINITY)
215+
p3 = p2 * 2
216+
self.assertEqual(p3, INFINITY)
217+
self.assertIs(p3, INFINITY)
218+
187219
def test_multiply(self):
188220
x1, y1, m, x3, y3 = (3, 10, 2, 7, 12)
189221
p1 = Point(self.c_23, x1, y1)
@@ -224,6 +256,12 @@ def test_inequality_points_diff_types(self):
224256
c = CurveFp(100, -3, 100)
225257
self.assertNotEqual(self.g_23, c)
226258

259+
def test_inequality_diff_y(self):
260+
p1 = Point(self.c_23, 6, 4)
261+
p2 = Point(self.c_23, 6, 19)
262+
263+
self.assertNotEqual(p1, p2)
264+
227265
def test_to_bytes_from_bytes(self):
228266
p = Point(self.c_23, 3, 10)
229267

0 commit comments

Comments
 (0)