Skip to content

Commit 7cfa33d

Browse files
Allow users to request TLS client-side enforcement (#525)
1 parent 3f480db commit 7cfa33d

10 files changed

+136
-1
lines changed

include/cassandra.h

+20
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,12 @@ typedef enum CassSslVerifyFlags_ {
629629
CASS_SSL_VERIFY_PEER_IDENTITY_DNS = 0x04
630630
} CassSslVerifyFlags;
631631

632+
typedef enum CassSslTlsVersion_ {
633+
CASS_SSL_VERSION_TLS1 = 0x00,
634+
CASS_SSL_VERSION_TLS1_1 = 0x01,
635+
CASS_SSL_VERSION_TLS1_2 = 0x02
636+
} CassSslTlsVersion;
637+
632638
typedef enum CassProtocolVersion_ {
633639
CASS_PROTOCOL_VERSION_V1 = 0x01, /**< Deprecated */
634640
CASS_PROTOCOL_VERSION_V2 = 0x02, /**< Deprecated */
@@ -4686,6 +4692,20 @@ cass_ssl_set_private_key_n(CassSsl* ssl,
46864692
const char* password,
46874693
size_t password_length);
46884694

4695+
/**
4696+
* Set minimum supported client-side protocol version. This will prevent the
4697+
* connection using protocol versions earlier than the specified one. Useful
4698+
* for preventing TLS downgrade attacks.
4699+
*
4700+
* @public @memberof CassSsl
4701+
*
4702+
* @param[in] ssl
4703+
* @param[in] min_version
4704+
* @return CASS_OK if successful, otherwise an error occurred.
4705+
*/
4706+
CASS_EXPORT CassError
4707+
cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version);
4708+
46894709
/***********************************************************************************
46904710
*
46914711
* Authenticator

src/ssl.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ CassError cass_ssl_set_private_key_n(CassSsl* ssl, const char* key, size_t key_l
6565
return ssl->set_private_key(key, key_length, password, password_length);
6666
}
6767

68+
CassError cass_ssl_set_min_protocol_version(CassSsl* ssl, CassSslTlsVersion min_version) {
69+
return ssl->set_min_protocol_version(min_version);
70+
}
71+
6872
} // extern "C"
6973

7074
template <class T>

src/ssl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class SslContext : public RefCounted<SslContext> {
8787
virtual CassError set_cert(const char* cert, size_t cert_length) = 0;
8888
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
8989
size_t password_length) = 0;
90+
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version) = 0;
9091

9192
protected:
9293
int verify_flags_;

src/ssl/ssl_no_impl.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,8 @@ CassError NoSslContext::set_private_key(const char* key, size_t key_length, cons
4444
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
4545
}
4646

47+
CassError NoSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
48+
return CASS_ERROR_LIB_NOT_IMPLEMENTED;
49+
}
50+
4751
SslContext::Ptr NoSslContextFactory::create() { return SslContext::Ptr(new NoSslContext()); }

src/ssl/ssl_no_impl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class NoSslContext : public SslContext {
4040
virtual CassError set_cert(const char* cert, size_t cert_length);
4141
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
4242
size_t password_length);
43+
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);
4344
};
4445

4546
class NoSslContextFactory : public SslContextFactoryBase<NoSslContextFactory> {

src/ssl/ssl_openssl_impl.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@
4141
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
4242
// as 2.0.0
4343
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
44+
#define SSL_CAN_SET_MIN_VERSION
4445
#define SSL_CLIENT_METHOD TLS_client_method
4546
#else
4647
#define SSL_CLIENT_METHOD SSLv23_client_method
4748
#endif
4849
#else
4950
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
51+
#define SSL_CAN_SET_MIN_VERSION
5052
#define SSL_CLIENT_METHOD TLS_client_method
5153
#else
5254
#define SSL_CLIENT_METHOD SSLv23_client_method
@@ -615,6 +617,48 @@ CassError OpenSslContext::set_private_key(const char* key, size_t key_length, co
615617
return CASS_OK;
616618
}
617619

620+
CassError OpenSslContext::set_min_protocol_version(CassSslTlsVersion min_version) {
621+
#ifdef SSL_CAN_SET_MIN_VERSION
622+
int method;
623+
switch (min_version) {
624+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
625+
method = TLS1_VERSION;
626+
break;
627+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
628+
method = TLS1_1_VERSION;
629+
break;
630+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
631+
method = TLS1_2_VERSION;
632+
break;
633+
default:
634+
// unsupported version
635+
return CASS_ERROR_LIB_BAD_PARAMS;
636+
}
637+
SSL_CTX_set_min_proto_version(ssl_ctx_, method);
638+
return CASS_OK;
639+
#else
640+
// If we don't have the `set_min_proto_version` function then we do this via
641+
// the (deprecated in later versions) options function.
642+
int options = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
643+
switch (min_version) {
644+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1:
645+
break;
646+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_1:
647+
options |= SSL_OP_NO_TLSv1;
648+
break;
649+
case CassSslTlsVersion::CASS_SSL_VERSION_TLS1_2:
650+
options |= SSL_OP_NO_TLSv1;
651+
options |= SSL_OP_NO_TLSv1_1;
652+
break;
653+
default:
654+
// unsupported version
655+
return CASS_ERROR_LIB_BAD_PARAMS;
656+
}
657+
SSL_CTX_set_options(ssl_ctx_, options);
658+
return CASS_OK;
659+
#endif
660+
}
661+
618662
SslContext::Ptr OpenSslContextFactory::create() { return SslContext::Ptr(new OpenSslContext()); }
619663

620664
namespace openssl {

src/ssl/ssl_openssl_impl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class OpenSslContext : public SslContext {
6161
virtual CassError set_cert(const char* cert, size_t cert_length);
6262
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,
6363
size_t password_length);
64+
virtual CassError set_min_protocol_version(CassSslTlsVersion min_version);
6465

6566
private:
6667
SSL_CTX* ssl_ctx_;

tests/src/unit/mockssandra.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ using datastax::internal::core::UuidGen;
5252
!defined(LIBRESSL_VERSION_NUMBER) // Required as OPENSSL_VERSION_NUMBER for LibreSSL is defined
5353
// as 2.0.0
5454
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
55+
#define SSL_CAN_SET_MAX_VERSION
5556
#define SSL_SERVER_METHOD TLS_server_method
5657
#else
5758
#define SSL_SERVER_METHOD SSLv23_server_method
5859
#endif
5960
#else
6061
#if (LIBRESSL_VERSION_NUMBER >= 0x20302000L)
62+
#define SSL_CAN_SET_MAX_VERSION
6163
#define SSL_SERVER_METHOD TLS_server_method
6264
#else
6365
#define SSL_SERVER_METHOD SSLv23_server_method
@@ -555,6 +557,21 @@ bool ServerConnection::use_ssl(const String& key, const String& cert,
555557
return true;
556558
}
557559

560+
// Weaken the SSL connection, enforcing that it can only use TLS1.0 at max.
561+
// This is used for testing client-side enforcement of more secure TLS
562+
// protocols.
563+
void ServerConnection::weaken_ssl() {
564+
if (!ssl_context_) {
565+
return;
566+
}
567+
568+
#ifdef SSL_CAN_SET_MAX_VERSION
569+
SSL_CTX_set_max_proto_version(ssl_context_, TLS1_VERSION);
570+
#else
571+
SSL_CTX_set_options(ssl_context_, SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1_2);
572+
#endif
573+
}
574+
558575
using datastax::internal::core::Task;
559576

560577
class RunListen : public Task {

tests/src/unit/mockssandra.hpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class ServerConnection : public RefCounted<ServerConnection> {
175175

176176
bool use_ssl(const String& key, const String& cert, const String& ca_cert = "",
177177
bool require_client_cert = false);
178+
void weaken_ssl();
178179

179180
void listen(EventLoopGroup* event_loop_group);
180181
int wait_listen();
@@ -1161,6 +1162,7 @@ class Cluster {
11611162
~Cluster();
11621163

11631164
String use_ssl(const String& cn = "");
1165+
void weaken_ssl();
11641166

11651167
int start_all(EventLoopGroup* event_loop_group);
11661168
void start_all_async(EventLoopGroup* event_loop_group);
@@ -1264,7 +1266,8 @@ class SimpleEchoServer {
12641266
public:
12651267
SimpleEchoServer()
12661268
: factory_(new EchoClientConnectionFactory())
1267-
, event_loop_group_(1) {}
1269+
, event_loop_group_(1)
1270+
, ssl_weaken_(false) {}
12681271

12691272
~SimpleEchoServer() { close(); }
12701273

@@ -1281,6 +1284,8 @@ class SimpleEchoServer {
12811284
return ssl_cert_;
12821285
}
12831286

1287+
void weaken_ssl() { ssl_weaken_ = true; }
1288+
12841289
void use_connection_factory(internal::ClientConnectionFactory* factory) {
12851290
factory_.reset(factory);
12861291
}
@@ -1290,6 +1295,11 @@ class SimpleEchoServer {
12901295
if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) {
12911296
return -1;
12921297
}
1298+
1299+
if (ssl_weaken_) {
1300+
server_->weaken_ssl();
1301+
}
1302+
12931303
server_->listen(&event_loop_group_);
12941304
return server_->wait_listen();
12951305
}
@@ -1316,6 +1326,7 @@ class SimpleEchoServer {
13161326
internal::ServerConnection::Ptr server_;
13171327
String ssl_key_;
13181328
String ssl_cert_;
1329+
bool ssl_weaken_;
13191330
};
13201331

13211332
} // namespace mockssandra

tests/src/unit/tests/test_socket.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class SocketUnitTest : public LoopTest {
135135
return settings;
136136
}
137137

138+
void weaken_ssl() { server_.weaken_ssl(); }
139+
138140
void listen(const Address& address = Address("127.0.0.1", 8888)) {
139141
ASSERT_EQ(server_.listen(address), 0);
140142
}
@@ -185,6 +187,17 @@ class SocketUnitTest : public LoopTest {
185187
}
186188
}
187189

190+
/* SSL handshake failures have different error codes on different versions of
191+
* OpenSSL - this accounts for both of them
192+
*/
193+
static void on_socket_ssl_error(SocketConnector* connector, bool* is_error) {
194+
SocketConnector::SocketError err = connector->error_code();
195+
if ((err == SocketConnector::SOCKET_ERROR_CLOSE) ||
196+
(err == SocketConnector::SOCKET_ERROR_SSL_HANDSHAKE)) {
197+
*is_error = true;
198+
}
199+
}
200+
188201
static void on_socket_canceled(SocketConnector* connector, bool* is_canceled) {
189202
if (connector->is_canceled()) {
190203
*is_canceled = true;
@@ -409,3 +422,22 @@ TEST_F(SocketUnitTest, SslVerifyIdentityDns) {
409422

410423
EXPECT_EQ(result, "The socket is successfully connected and wrote data - Closed");
411424
}
425+
426+
TEST_F(SocketUnitTest, SslEnforceTlsVersion) {
427+
SocketSettings settings(use_ssl("127.0.0.1"));
428+
weaken_ssl();
429+
430+
listen();
431+
432+
settings.ssl_context->set_min_protocol_version(CASS_SSL_VERSION_TLS1_2);
433+
434+
bool is_error;
435+
SocketConnector::Ptr connector(new SocketConnector(
436+
Address("127.0.0.1", 8888), bind_callback(on_socket_ssl_error, &is_error)));
437+
438+
connector->with_settings(settings)->connect(loop());
439+
440+
uv_run(loop(), UV_RUN_DEFAULT);
441+
442+
EXPECT_TRUE(is_error);
443+
}

0 commit comments

Comments
 (0)