diff --git a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov.h b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov.h index 85c45324cd..5b79d241dd 100644 --- a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov.h +++ b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov.h @@ -20,8 +20,11 @@ #define DEFAULT_SERVERS "localhost:4433" +/* must be treated as an opaque object by users */ struct ra_tls_ctx { void* ssl; + void* net; + void* conf; }; typedef int (*verify_measurements_cb_t)(const char* mrenclave, const char* mrsigner, @@ -29,10 +32,6 @@ typedef int (*verify_measurements_cb_t)(const char* mrenclave, const char* mrsig typedef int (*secret_provision_cb_t)(struct ra_tls_ctx* ctx); -/* internally used functions, not exported */ -__attribute__ ((visibility("hidden"))) -void secret_provision_free_resources(void); - /*! * \brief Write arbitrary data in an established RA-TLS session. * diff --git a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_attest.c b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_attest.c index cb4008bf17..9040c18cf5 100644 --- a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_attest.c +++ b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_attest.c @@ -33,31 +33,9 @@ #include "secret_prov.h" #include "util.h" -/* these are globals because the user may continue using the SSL session even after invoking - * secret_provision_start() (in the user-supplied callback) */ -static mbedtls_ctr_drbg_context g_ctr_drbg; -static mbedtls_entropy_context g_entropy; -static mbedtls_ssl_config g_conf; -static mbedtls_x509_crt g_verifier_ca_chain; -static mbedtls_net_context g_verifier_fd; -static mbedtls_ssl_context g_ssl; -static mbedtls_pk_context g_my_ratls_key; -static mbedtls_x509_crt g_my_ratls_cert; - static uint8_t* g_provisioned_secret = NULL; static size_t g_provisioned_secret_size = 0; -void secret_provision_free_resources(void) { - mbedtls_x509_crt_free(&g_my_ratls_cert); - mbedtls_pk_free(&g_my_ratls_key); - mbedtls_net_free(&g_verifier_fd); - mbedtls_ssl_free(&g_ssl); - mbedtls_ssl_config_free(&g_conf); - mbedtls_x509_crt_free(&g_verifier_ca_chain); - mbedtls_ctr_drbg_free(&g_ctr_drbg); - mbedtls_entropy_free(&g_entropy); -} - int secret_provision_get(uint8_t** out_secret, size_t* out_secret_size) { if (!out_secret || !out_secret_size) return -EINVAL; @@ -79,6 +57,36 @@ void secret_provision_destroy(void) { g_provisioned_secret_size = 0; } +extern int common_write(mbedtls_ssl_context* ssl, const uint8_t* buf, size_t size); +extern int common_read(mbedtls_ssl_context* ssl, uint8_t* buf, size_t size); +extern int common_close(mbedtls_ssl_context* ssl); + +int secret_provision_write(struct ra_tls_ctx* ctx, const uint8_t* buf, size_t size) { + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + return common_write(ssl, buf, size); +} + +int secret_provision_read(struct ra_tls_ctx* ctx, uint8_t* buf, size_t size) { + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + return common_read(ssl, buf, size); +} + +int secret_provision_close(struct ra_tls_ctx* ctx) { + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + mbedtls_ssl_config* conf = (mbedtls_ssl_config*)ctx->conf; + mbedtls_net_context* net = (mbedtls_net_context*)ctx->net; + + int ret = common_close(ssl); + + mbedtls_ssl_free(ssl); + mbedtls_ssl_config_free(conf); + mbedtls_net_free(net); + free(ssl); + free(conf); + free(net); + return ret; +} + int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, struct ra_tls_ctx* out_ctx) { int ret; @@ -89,19 +97,31 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, char* connected_addr = NULL; char* connected_port = NULL; - mbedtls_ctr_drbg_init(&g_ctr_drbg); - mbedtls_entropy_init(&g_entropy); - mbedtls_x509_crt_init(&g_verifier_ca_chain); + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_context entropy; + mbedtls_x509_crt verifier_ca_chain; + mbedtls_pk_context my_ratls_key; + mbedtls_x509_crt my_ratls_cert; + + mbedtls_net_context* verifier_fd = malloc(sizeof(*verifier_fd)); + mbedtls_ssl_config* conf = malloc(sizeof(*conf)); + mbedtls_ssl_context* ssl = malloc(sizeof(*ssl)); + if (!verifier_fd || !conf || !ssl) { + return -ENOMEM; + } - mbedtls_pk_init(&g_my_ratls_key); - mbedtls_x509_crt_init(&g_my_ratls_cert); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_entropy_init(&entropy); + mbedtls_x509_crt_init(&verifier_ca_chain); + mbedtls_pk_init(&my_ratls_key); + mbedtls_x509_crt_init(&my_ratls_cert); + mbedtls_ssl_config_init(conf); - mbedtls_net_init(&g_verifier_fd); - mbedtls_ssl_config_init(&g_conf); - mbedtls_ssl_init(&g_ssl); + mbedtls_net_init(verifier_fd); + mbedtls_ssl_init(ssl); const char* pers = "secret-provisioning"; - ret = mbedtls_ctr_drbg_seed(&g_ctr_drbg, mbedtls_entropy_func, &g_entropy, + ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const uint8_t*)pers, strlen(pers)); if (ret < 0) { goto out; @@ -148,7 +168,7 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, if (!connected_port) continue; - ret = mbedtls_net_connect(&g_verifier_fd, connected_addr, connected_port, + ret = mbedtls_net_connect(verifier_fd, connected_addr, connected_port, MBEDTLS_NET_PROTO_TCP); if (!ret) break; @@ -158,53 +178,53 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, goto out; } - ret = mbedtls_ssl_config_defaults(&g_conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, + ret = mbedtls_ssl_config_defaults(conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret < 0) { goto out; } - ret = mbedtls_x509_crt_parse_file(&g_verifier_ca_chain, ca_chain_path); + ret = mbedtls_x509_crt_parse_file(&verifier_ca_chain, ca_chain_path); if (ret != 0) { goto out; } char crt_issuer[256]; - ret = mbedtls_x509_dn_gets(crt_issuer, sizeof(crt_issuer), &g_verifier_ca_chain.issuer); + ret = mbedtls_x509_dn_gets(crt_issuer, sizeof(crt_issuer), &verifier_ca_chain.issuer); if (ret < 0) { goto out; } - mbedtls_ssl_conf_authmode(&g_conf, MBEDTLS_SSL_VERIFY_REQUIRED); - mbedtls_ssl_conf_ca_chain(&g_conf, &g_verifier_ca_chain, NULL); + mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_ca_chain(conf, &verifier_ca_chain, NULL); - ret = ra_tls_create_key_and_crt(&g_my_ratls_key, &g_my_ratls_cert); + ret = ra_tls_create_key_and_crt(&my_ratls_key, &my_ratls_cert); if (ret < 0) { goto out; } - mbedtls_ssl_conf_rng(&g_conf, mbedtls_ctr_drbg_random, &g_ctr_drbg); + mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg); - ret = mbedtls_ssl_conf_own_cert(&g_conf, &g_my_ratls_cert, &g_my_ratls_key); + ret = mbedtls_ssl_conf_own_cert(conf, &my_ratls_cert, &my_ratls_key); if (ret < 0) { goto out; } - ret = mbedtls_ssl_setup(&g_ssl, &g_conf); + ret = mbedtls_ssl_setup(ssl, conf); if (ret < 0) { goto out; } - ret = mbedtls_ssl_set_hostname(&g_ssl, connected_addr); + ret = mbedtls_ssl_set_hostname(ssl, connected_addr); if (ret < 0) { goto out; } - mbedtls_ssl_set_bio(&g_ssl, &g_verifier_fd, mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl, verifier_fd, mbedtls_net_send, mbedtls_net_recv, NULL); ret = -1; while (ret < 0) { - ret = mbedtls_ssl_handshake(&g_ssl); + ret = mbedtls_ssl_handshake(ssl); if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { continue; } @@ -213,13 +233,12 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, } } - uint32_t flags = mbedtls_ssl_get_verify_result(&g_ssl); + uint32_t flags = mbedtls_ssl_get_verify_result(ssl); if (flags != 0) { ret = MBEDTLS_ERR_X509_CERT_VERIFY_FAILED; goto out; } - struct ra_tls_ctx ctx = {.ssl = &g_ssl}; uint8_t buf[128] = {0}; size_t size; @@ -228,7 +247,7 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, size = sprintf((char*)buf, SECRET_PROVISION_REQUEST); size += 1; /* include null byte */ - ret = secret_provision_write(&ctx, buf, size); + ret = common_write(ssl, buf, size); if (ret < 0) { goto out; } @@ -239,8 +258,7 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, "buffer must be sufficiently large to hold SECRET_PROVISION_RESPONSE + int32"); memset(buf, 0, sizeof(buf)); - ret = secret_provision_read(&ctx, buf, - sizeof(SECRET_PROVISION_RESPONSE) + sizeof(received_secret_size)); + ret = common_read(ssl, buf, sizeof(SECRET_PROVISION_RESPONSE) + sizeof(received_secret_size)); if (ret < 0) { goto out; } @@ -269,24 +287,42 @@ int secret_provision_start(const char* in_servers, const char* in_ca_chain_path, } g_provisioned_secret_size = received_secret_size; - ret = secret_provision_read(&ctx, g_provisioned_secret, g_provisioned_secret_size); + ret = common_read(ssl, g_provisioned_secret, g_provisioned_secret_size); if (ret < 0) { goto out; } if (out_ctx) { - out_ctx->ssl = ctx.ssl; + /* pass ownership of SSL session to the caller; it is caller's responsibility to gracefuly + * terminate the session using secret_provision_close() */ + out_ctx->ssl = ssl; + out_ctx->conf = conf; + out_ctx->net = verifier_fd; } else { - secret_provision_close(&ctx); + common_close(ssl); } ret = 0; out: + if (!out_ctx) { + mbedtls_ssl_free(ssl); + mbedtls_ssl_config_free(conf); + mbedtls_net_free(verifier_fd); + free(ssl); + free(conf); + free(verifier_fd); + } + if (ret < 0) { secret_provision_destroy(); - secret_provision_free_resources(); } + mbedtls_x509_crt_free(&my_ratls_cert); + mbedtls_pk_free(&my_ratls_key); + mbedtls_x509_crt_free(&verifier_ca_chain); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + free(servers); free(ca_chain_path); diff --git a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_common.c b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_common.c index 8b254c47f6..d3f20b658d 100644 --- a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_common.c +++ b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_common.c @@ -16,17 +16,19 @@ #include "secret_prov.h" -int secret_provision_write(struct ra_tls_ctx* ctx, const uint8_t* buf, size_t size) { +int common_write(mbedtls_ssl_context* ssl, const uint8_t* buf, size_t size); +int common_read(mbedtls_ssl_context* ssl, uint8_t* buf, size_t size); +int common_close(mbedtls_ssl_context* ssl); + +int common_write(mbedtls_ssl_context* ssl, const uint8_t* buf, size_t size) { int ret; - if (!ctx || !ctx->ssl || size > INT_MAX) + if (!ssl || size > INT_MAX) return -EINVAL; - mbedtls_ssl_context* _ssl = (mbedtls_ssl_context*)ctx->ssl; - size_t written = 0; while (written < size) { - ret = mbedtls_ssl_write(_ssl, buf + written, size - written); + ret = mbedtls_ssl_write(ssl, buf + written, size - written); if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) continue; if (ret < 0) { @@ -39,17 +41,15 @@ int secret_provision_write(struct ra_tls_ctx* ctx, const uint8_t* buf, size_t si return (int)written; } -int secret_provision_read(struct ra_tls_ctx* ctx, uint8_t* buf, size_t size) { +int common_read(mbedtls_ssl_context* ssl, uint8_t* buf, size_t size) { int ret; - if (!ctx || !ctx->ssl || size > INT_MAX) + if (!ssl || size > INT_MAX) return -EINVAL; - mbedtls_ssl_context* _ssl = (mbedtls_ssl_context*)ctx->ssl; - size_t read = 0; while (read < size) { - ret = mbedtls_ssl_read(_ssl, buf + read, size - read); + ret = mbedtls_ssl_read(ssl, buf + read, size - read); if (!ret) return -ECONNRESET; if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) @@ -65,29 +65,20 @@ int secret_provision_read(struct ra_tls_ctx* ctx, uint8_t* buf, size_t size) { return (int)read; } -int secret_provision_close(struct ra_tls_ctx* ctx) { - int ret; - if (!ctx || !ctx->ssl) { - ret = 0; - goto out; - } - - mbedtls_ssl_context* _ssl = (mbedtls_ssl_context*)ctx->ssl; +int common_close(mbedtls_ssl_context* ssl) { + if (!ssl) + return 0; - ret = -1; + int ret = -1; while (ret < 0) { - ret = mbedtls_ssl_close_notify(_ssl); + ret = mbedtls_ssl_close_notify(ssl); if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { continue; } if (ret < 0) { /* use well-known error code for a typical case when remote party closes connection */ - ret = ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ? -ECONNRESET : -EPERM; - goto out; + return ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ? -ECONNRESET : -EPERM; } } - ret = 0; -out: - secret_provision_free_resources(); - return ret; + return 0; } diff --git a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_verify.c b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_verify.c index 64f49d90fd..a7c46b58e2 100644 --- a/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_verify.c +++ b/Pal/src/host/Linux-SGX/tools/ra-tls/secret_prov_verify.c @@ -43,8 +43,24 @@ struct thread_info { /* SSL/TLS + RA-TLS handshake is not thread-safe, use coarse-grained lock */ static pthread_mutex_t g_handshake_lock; -void secret_provision_free_resources(void) { - /* nothing to be freed */ +extern int common_write(mbedtls_ssl_context* ssl, const uint8_t* buf, size_t size); +extern int common_read(mbedtls_ssl_context* ssl, uint8_t* buf, size_t size); +extern int common_close(mbedtls_ssl_context* ssl); + +int secret_provision_write(struct ra_tls_ctx* ctx, const uint8_t* buf, size_t size) { + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + return common_write(ssl, buf, size); +} + +int secret_provision_read(struct ra_tls_ctx* ctx, uint8_t* buf, size_t size) { + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + return common_read(ssl, buf, size); +} + +int secret_provision_close(struct ra_tls_ctx* ctx) { + /* no need to free the ctx resources, this will be done in client_connection() */ + mbedtls_ssl_context* ssl = (mbedtls_ssl_context*)ctx->ssl; + return common_close(ssl); } static void* client_connection(void* data) { @@ -82,12 +98,11 @@ static void* client_connection(void* data) { goto out; } - struct ra_tls_ctx ctx = {.ssl = &ssl}; uint8_t buf[128] = {0}; static_assert(sizeof(buf) >= sizeof(SECRET_PROVISION_REQUEST), "buffer must be sufficiently large to hold SECRET_PROVISION_REQUEST"); - ret = secret_provision_read(&ctx, buf, sizeof(SECRET_PROVISION_REQUEST)); + ret = common_read(&ssl, buf, sizeof(SECRET_PROVISION_REQUEST)); if (ret < 0) { goto out; } @@ -110,13 +125,12 @@ static void* client_connection(void* data) { memcpy(buf, SECRET_PROVISION_RESPONSE, sizeof(SECRET_PROVISION_RESPONSE)); memcpy(buf + sizeof(SECRET_PROVISION_RESPONSE), &send_secret_size, sizeof(send_secret_size)); - ret = secret_provision_write(&ctx, buf, - sizeof(SECRET_PROVISION_RESPONSE) + sizeof(send_secret_size)); + ret = common_write(&ssl, buf, sizeof(SECRET_PROVISION_RESPONSE) + sizeof(send_secret_size)); if (ret < 0) { goto out; } - ret = secret_provision_write(&ctx, ti->secret, ti->secret_size); + ret = common_write(&ssl, ti->secret, ti->secret_size); if (ret < 0) { goto out; } @@ -124,9 +138,10 @@ static void* client_connection(void* data) { if (ti->f_cb) { /* pass ownership of SSL session with client to the caller; it is caller's responsibility * to gracefuly terminate the session using secret_provision_close() */ + struct ra_tls_ctx ctx = {.ssl = &ssl, .net = /*unused*/NULL, .conf = /*unused*/NULL}; ti->f_cb(&ctx); } else { - secret_provision_close(&ctx); + common_close(&ssl); } out: