Skip to content

Commit

Permalink
Fix side-channel leakage in RSA decryption
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Dec 25, 2023
1 parent ee91c67 commit afb5e27
Show file tree
Hide file tree
Showing 17 changed files with 350 additions and 35 deletions.
6 changes: 2 additions & 4 deletions lib/Crypto/Cipher/PKCS1_OAEP.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ def decrypt(self, ciphertext):
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
m_int = self._key._decrypt(ct_int)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 2b (RSADP) and step 2c (I2OSP)
em = self._key._decrypt(ct_int)
# Step 3a
lHash = self._hashObj.new(self._label).digest()
# Step 3b
Expand Down
7 changes: 2 additions & 5 deletions lib/Crypto/Cipher/PKCS1_v1_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,8 @@ def decrypt(self, ciphertext, sentinel, expected_pt_len=0):
# Step 2a (O2SIP)
ct_int = bytes_to_long(ciphertext)

# Step 2b (RSADP)
m_int = self._key._decrypt(ct_int)

# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 2b (RSADP) and Step 2c (I2OSP)
em = self._key._decrypt(ct_int)

# Step 3 (not constant time when the sentinel is not a byte string)
output = bytes(bytearray(k))
Expand Down
20 changes: 20 additions & 0 deletions lib/Crypto/Math/_IntegerBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,23 @@ def random_range(cls, **kwargs):
)
return norm_candidate + min_inclusive

@staticmethod
@abc.abstractmethod
def _mult_modulo_bytes(term1, term2, modulus):
"""Multiply two integers, take the modulo, and encode as big endian.
This specialized method is used for RSA decryption.
Args:
term1 : integer
The first term of the multiplication, non-negative.
term2 : integer
The second term of the multiplication, non-negative.
modulus: integer
The modulus, a positive odd number.
:Returns:
A byte string, with the result of the modular multiplication
encoded in big endian mode.
It is as long as the modulus would be, with zero padding
on the left if needed.
"""
pass
4 changes: 4 additions & 0 deletions lib/Crypto/Math/_IntegerBase.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ class IntegerBase:
def random(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
@classmethod
def random_range(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
@staticmethod
def _mult_modulo_bytes(term1: Union[IntegerBase, int],
term2: Union[IntegerBase, int],
modulus: Union[IntegerBase, int]) -> bytes: ...

56 changes: 50 additions & 6 deletions lib/Crypto/Math/_IntegerCustom.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,18 @@
from Crypto.Random.random import getrandbits

c_defs = """
int monty_pow(const uint8_t *base,
const uint8_t *exp,
const uint8_t *modulus,
uint8_t *out,
size_t len,
uint64_t seed);
int monty_pow(uint8_t *out,
const uint8_t *base,
const uint8_t *exp,
const uint8_t *modulus,
size_t len,
uint64_t seed);
int monty_multiply(uint8_t *out,
const uint8_t *term1,
const uint8_t *term2,
const uint8_t *modulus,
size_t len);
"""


Expand Down Expand Up @@ -116,3 +122,41 @@ def inplace_pow(self, exponent, modulus=None):
result = bytes_to_long(get_raw_buffer(out))
self._value = result
return self

@staticmethod
def _mult_modulo_bytes(term1, term2, modulus):

# With modular reduction
mod_value = int(modulus)
if mod_value < 0:
raise ValueError("Modulus must be positive")
if mod_value == 0:
raise ZeroDivisionError("Modulus cannot be zero")

# C extension only works with odd moduli
if (mod_value & 1) == 0:
raise ValueError("Odd modulus is required")

# C extension only works with non-negative terms smaller than modulus
if term1 >= mod_value or term1 < 0:
term1 %= mod_value
if term2 >= mod_value or term2 < 0:
term2 %= mod_value

modulus_b = long_to_bytes(mod_value)
numbers_len = len(modulus_b)
term1_b = long_to_bytes(term1, numbers_len)
term2_b = long_to_bytes(term2, numbers_len)
out = create_string_buffer(numbers_len)

error = _raw_montgomery.monty_multiply(
out,
term1_b,
term2_b,
modulus_b,
c_size_t(numbers_len)
)
if error:
raise ValueError("monty_multiply failed with error: %d" % error)

return get_raw_buffer(out)
20 changes: 20 additions & 0 deletions lib/Crypto/Math/_IntegerGMP.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,26 @@ def jacobi_symbol(a, n):
raise ValueError("n must be positive odd for the Jacobi symbol")
return _gmp.mpz_jacobi(a._mpz_p, n._mpz_p)

@staticmethod
def _mult_modulo_bytes(term1, term2, modulus):
if not isinstance(term1, IntegerGMP):
term1 = IntegerGMP(term1)
if not isinstance(term2, IntegerGMP):
term2 = IntegerGMP(term2)
if not isinstance(modulus, IntegerGMP):
modulus = IntegerGMP(modulus)

if modulus < 0:
raise ValueError("Modulus must be positive")
if modulus == 0:
raise ZeroDivisionError("Modulus cannot be zero")
if (modulus & 1) == 0:
raise ValueError("Odd modulus is required")

numbers_len = len(modulus.to_bytes())
result = ((term1 * term2) % modulus).to_bytes(numbers_len)
return result

# Clean-up
def __del__(self):

Expand Down
12 changes: 12 additions & 0 deletions lib/Crypto/Math/_IntegerNative.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,15 @@ def jacobi_symbol(a, n):
n1 = n % a1
# Step 8
return s * IntegerNative.jacobi_symbol(n1, a1)

@staticmethod
def _mult_modulo_bytes(term1, term2, modulus):
if modulus < 0:
raise ValueError("Modulus must be positive")
if modulus == 0:
raise ZeroDivisionError("Modulus cannot be zero")
if (modulus & 1) == 0:
raise ValueError("Odd modulus is required")

number_len = len(long_to_bytes(modulus))
return long_to_bytes((term1 * term2) % modulus, number_len)
10 changes: 6 additions & 4 deletions lib/Crypto/PublicKey/RSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from Crypto import Random
from Crypto.Util.py3compat import tobytes, bord, tostr
from Crypto.Util.asn1 import DerSequence, DerNull
from Crypto.Util.number import bytes_to_long

from Crypto.Math.Numbers import Integer
from Crypto.Math.Primality import (test_probable_prime,
Expand Down Expand Up @@ -198,10 +199,11 @@ def _decrypt(self, ciphertext):
h = ((m2 - m1) * self._u) % self._q
mp = h * self._p + m1
# Step 4: Compute m = m' * (r**(-1)) mod n
result = (r.inverse(self._n) * mp) % self._n
# Verify no faults occurred
if ciphertext != pow(result, self._e, self._n):
raise ValueError("Fault detected in RSA decryption")
# then encode into a big endian byte string
result = Integer._mult_modulo_bytes(
r.inverse(self._n),
mp,
self._n)
return result

def has_private(self):
Expand Down
2 changes: 2 additions & 0 deletions lib/Crypto/SelfTest/Math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def get_tests(config={}):
from Crypto.SelfTest.Math import test_Numbers
from Crypto.SelfTest.Math import test_Primality
from Crypto.SelfTest.Math import test_modexp
from Crypto.SelfTest.Math import test_modmult
tests += test_Numbers.get_tests(config=config)
tests += test_Primality.get_tests(config=config)
tests += test_modexp.get_tests(config=config)
tests += test_modmult.get_tests(config=config)
return tests

if __name__ == '__main__':
Expand Down
28 changes: 28 additions & 0 deletions lib/Crypto/SelfTest/Math/test_Numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,34 @@ def test_hex(self):
v1, = self.Integers(0x10)
self.assertEqual(hex(v1), "0x10")

def test_mult_modulo_bytes(self):
modmult = self.Integer._mult_modulo_bytes

res = modmult(4, 5, 19)
self.assertEqual(res, b'\x01')

res = modmult(4 - 19, 5, 19)
self.assertEqual(res, b'\x01')

res = modmult(4, 5 - 19, 19)
self.assertEqual(res, b'\x01')

res = modmult(4 + 19, 5, 19)
self.assertEqual(res, b'\x01')

res = modmult(4, 5 + 19, 19)
self.assertEqual(res, b'\x01')

modulus = 2**512 - 1 # 64 bytes
t1 = 13**100
t2 = 17**100
expect = b"\xfa\xb2\x11\x87\xc3(y\x07\xf8\xf1n\xdepq\x0b\xca\xf3\xd3B,\xef\xf2\xfbf\xcc)\x8dZ*\x95\x98r\x96\xa8\xd5\xc3}\xe2q:\xa2'z\xf48\xde%\xef\t\x07\xbc\xc4[C\x8bUE2\x90\xef\x81\xaa:\x08"
self.assertEqual(expect, modmult(t1, t2, modulus))

self.assertRaises(ZeroDivisionError, modmult, 4, 5, 0)
self.assertRaises(ValueError, modmult, 4, 5, -1)
self.assertRaises(ValueError, modmult, 4, 5, 4)


class TestIntegerInt(TestIntegerBase):

Expand Down
120 changes: 120 additions & 0 deletions lib/Crypto/SelfTest/Math/test_modmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# SelfTest/Math/test_modmult.py: Self-test for custom modular multiplication
#
# ===================================================================
#
# Copyright (c) 2023, Helder Eijs <[email protected]>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# ===================================================================

"""Self-test for the custom modular multiplication"""

import unittest

from Crypto.SelfTest.st_common import list_test_cases

from Crypto.Util.number import long_to_bytes, bytes_to_long

from Crypto.Util._raw_api import (create_string_buffer,
get_raw_buffer,
c_size_t)

from Crypto.Math._IntegerCustom import _raw_montgomery


class ExceptionModulus(ValueError):
pass


def monty_mult(term1, term2, modulus):

if term1 >= modulus:
term1 %= modulus
if term2 >= modulus:
term2 %= modulus

modulus_b = long_to_bytes(modulus)
numbers_len = len(modulus_b)
term1_b = long_to_bytes(term1, numbers_len)
term2_b = long_to_bytes(term2, numbers_len)

out = create_string_buffer(numbers_len)
error = _raw_montgomery.monty_multiply(
out,
term1_b,
term2_b,
modulus_b,
c_size_t(numbers_len)
)

if error == 17:
raise ExceptionModulus()
if error:
raise ValueError("monty_multiply() failed with error: %d" % error)

return get_raw_buffer(out)


modulus1 = 0xd66691b20071be4d66d4b71032b37fa007cfabf579fcb91e50bfc2753b3f0ce7be74e216aef7e26d4ae180bc20d7bd3ea88a6cbf6f87380e613c8979b5b043b200a8ff8856a3b12875e36e98a7569f3852d028e967551000b02c19e9fa52e83115b89309aabb1e1cf1e2cb6369d637d46775ce4523ea31f64ad2794cbc365dd8a35e007ed3b57695877fbf102dbeb8b3212491398e494314e93726926e1383f8abb5889bea954eb8c0ca1c62c8e9d83f41888095c5e645ed6d32515fe0c58c1368cad84694e18da43668c6f43e61d7c9bca633ddcda7aef5b79bc396d4a9f48e2a9abe0836cc455e435305357228e93d25aaed46b952defae0f57339bf26f5a9


class TestModMultiply(unittest.TestCase):

def test_small(self):
self.assertEqual(b"\x01", monty_mult(5, 6, 29))

def test_large(self):
numbers_len = (modulus1.bit_length() + 7) // 8

t1 = modulus1 // 2
t2 = modulus1 - 90
expect = b'\x00' * (numbers_len - 1) + b'\x2d'
self.assertEqual(expect, monty_mult(t1, t2, modulus1))

def test_zero_term(self):
numbers_len = (modulus1.bit_length() + 7) // 8
expect = b'\x00' * numbers_len
self.assertEqual(expect, monty_mult(0x100, 0, modulus1))
self.assertEqual(expect, monty_mult(0, 0x100, modulus1))

def test_larger_term(self):
t1 = 2**2047
expect_int = 0x8edf4071f78e3d7ba622cdbbbef74612e301d69186776ae6bf87ff38c320d9aebaa64889c2f67de2324e6bccd2b10ad89e91fd21ba4bb523904d033eff5e70e62f01a84f41fa90a4f248ef249b82e1d2729253fdfc2a3b5b740198123df8bfbf7057d03e15244ad5f26eb9a099763b5c5972121ec076b0bf899f59bd95f7cc129abddccf24217bce52ca0f3a44c9ccc504765dbb89734205f3ae6a8cc560494a60ea84b27d8e00fa24bdd5b4f1d4232edb61e47d3d984c1fa50a3820a2e580fbc3fc8bc11e99df53b9efadf5a40ac75d384e400905aa6f1d88950cd53b1c54dc2222115ad84a27260fa4d978155c1434c551de1ee7361a17a2f79d4388f78a5d
res = bytes_to_long(monty_mult(t1, t1, modulus1))
self.assertEqual(res, expect_int)


def get_tests(config={}):
tests = []
tests += list_test_cases(TestModMultiply)
return tests


if __name__ == '__main__':
def suite():
return unittest.TestSuite(get_tests())
unittest.main(defaultTest='suite')
4 changes: 2 additions & 2 deletions lib/Crypto/SelfTest/PublicKey/test_RSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _exercise_primitive(self, rsaObj):
ciphertext = bytes_to_long(a2b_hex(self.ciphertext))

# Test decryption
plaintext = rsaObj._decrypt(ciphertext)
plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))

# Test encryption (2 arguments)
new_ciphertext2 = rsaObj._encrypt(plaintext)
Expand All @@ -304,7 +304,7 @@ def _check_decryption(self, rsaObj):
ciphertext = bytes_to_long(a2b_hex(self.ciphertext))

# Test plain decryption
new_plaintext = rsaObj._decrypt(ciphertext)
new_plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))
self.assertEqual(plaintext, new_plaintext)


Expand Down
Loading

0 comments on commit afb5e27

Please sign in to comment.