Skip to content
This repository has been archived by the owner on Jan 20, 2022. It is now read-only.

[Pal/lib] Add spinlocks to mbedTLS-specific SSL recv/send #2059

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LibOS/shim/test/regression/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
/multi_pthread_exitless
/openmp
/pipe
/pipe_multithread
/pipe_nonblocking
/pipe_ocloexec
/poll
Expand Down
20 changes: 11 additions & 9 deletions LibOS/shim/test/regression/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ c_executables = \
multi_pthread \
openmp \
pipe \
pipe_multithread \
pipe_nonblocking \
pipe_ocloexec \
poll \
Expand Down Expand Up @@ -151,26 +152,27 @@ extra_rules = \
include ../../../../Scripts/Makefile.manifest
include ../../../../Scripts/Makefile.Test

CFLAGS-bootstrap_static = -static
CFLAGS-abort_multithread = -pthread
CFLAGS-bootstrap_pie = -fPIC -pie
CFLAGS-bootstrap_static = -static
CFLAGS-debug = -g3
CFLAGS-debug_regs-x86_64 = -g3
CFLAGS-eventfd = -pthread
CFLAGS-exec_same = -pthread
CFLAGS-shared_object = -fPIC -pie
CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux
CFLAGS-openmp = -fopenmp
CFLAGS-multi_pthread = -pthread
CFLAGS-exit_group = -pthread
CFLAGS-abort_multithread = -pthread
CFLAGS-eventfd = -pthread
CFLAGS-futex_bitset = -pthread
CFLAGS-futex_requeue = -pthread
CFLAGS-futex_wake_op = -pthread
CFLAGS-multi_pthread = -pthread
CFLAGS-openmp = -fopenmp
CFLAGS-pipe_multithread = -pthread
CFLAGS-proc_common = -pthread
CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread
CFLAGS-pthread_set_get_affinity += -pthread
CFLAGS-shared_object = -fPIC -pie
CFLAGS-sigaction_per_process += -pthread
CFLAGS-signal_multithread += -pthread
CFLAGS-pthread_set_get_affinity += -pthread
CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread
CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux

CFLAGS-attestation += -I$(PALDIR)/../lib/crypto/mbedtls/crypto/include \
-I$(PALDIR)/host/Linux-SGX \
Expand Down
81 changes: 81 additions & 0 deletions LibOS/shim/test/regression/pipe_multithread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* test creates two threads simulteneously writing on the same pipe */

#include <err.h>
#include <errno.h>
#include <pthread.h>
#include <stdint.h>
#include <stdio.h>
#include <sys/socket.h>
#include <sys/types.h>

#define ITERATIONS 100000

int fds[2];

static void* thread_run(void* arg) {
char c = (char)(uintptr_t)arg;
for (int i = 0; i < ITERATIONS; i++) {
ssize_t bytes = 0;
while (bytes < sizeof(c)) {
bytes = send(fds[1], &c, sizeof(c), /*flags=*/0);
if (bytes < 0) {
if (errno == EAGAIN || errno == EINTR)
continue;
err(1, "send");
}
}
}
return NULL;
}

int main(int argc, char** argv) {
int ret;
pthread_t threads[2];
char thread_ids[2] = {42, 24};
int thread_bytes[2] = {0, 0};

ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
if (ret) {
err(1, "socketpair");
}

ret = pthread_create(&threads[0], NULL, &thread_run, (void*)(uintptr_t)thread_ids[0]);
if (ret) {
errno = ret;
err(1, "pthread_create");
}

ret = pthread_create(&threads[1], NULL, &thread_run, (void*)(uintptr_t)thread_ids[1]);
if (ret) {
errno = ret;
err(1, "pthread_create");
}

for (int i = 0; i < 2 * ITERATIONS; i++) {
char c = 0;
ssize_t bytes = 0;
while (bytes < sizeof(c)) {
bytes = recv(fds[0], &c, sizeof(c), /*flags=*/0);
if (bytes < 0) {
if (errno == EAGAIN || errno == EINTR)
continue;
err(1, "recv");
}
}

if (c == thread_ids[0])
thread_bytes[0] += bytes;
else if (c == thread_ids[1])
thread_bytes[1] += bytes;
else
errx(1, "received unrecognized thread ID");
}

printf("received total bytes from threads: %d and %d\n", thread_bytes[0], thread_bytes[1]);

if (thread_bytes[0] != ITERATIONS || thread_bytes[1] != ITERATIONS)
errx(1, "received wrong number of bytes from threads");

puts("TEST OK");
return 0;
}
3 changes: 3 additions & 0 deletions Pal/include/lib/pal_crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <stdint.h>
#include <unistd.h>

#include "spinlock.h"

#define SHA256_DIGEST_LEN 32

#ifdef CRYPTO_USE_MBEDTLS
Expand Down Expand Up @@ -51,6 +53,7 @@ typedef struct {
ssize_t (*pal_recv_cb)(int fd, void* buf, size_t buf_size);
ssize_t (*pal_send_cb)(int fd, const void* buf, size_t buf_size);
int stream_fd;
spinlock_t lock;
} LIB_SSL_CONTEXT;

#endif /* CRYPTO_USE_MBEDTLS */
Expand Down
11 changes: 7 additions & 4 deletions Pal/include/lib/spinlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#ifndef _SPINLOCK_H
#define _SPINLOCK_H

#include "api.h"
#include <stdbool.h>

#include "assert.h"
#include "cpu.h"

#ifdef DEBUG
Expand Down Expand Up @@ -60,11 +62,11 @@ static inline void debug_spinlock_giveup_ownership(spinlock_t* lock) {
}
#else
static inline void debug_spinlock_take_ownership(spinlock_t* lock) {
__UNUSED(lock);
(void)lock;
}

static inline void debug_spinlock_giveup_ownership(spinlock_t* lock) {
__UNUSED(lock);
(void)lock;
}
#endif // DEBUG_SPINLOCKS_SHIM

Expand Down Expand Up @@ -158,7 +160,8 @@ static inline int spinlock_lock_timeout(spinlock_t* lock, unsigned long iteratio
* returned.
*/
static inline int spinlock_cmpxchg(spinlock_t* lock, int* expected, int desired) {
static_assert(SAME_TYPE(&lock->lock, expected), "spinlock is not implemented as int*");
static_assert(__builtin_types_compatible_p(__typeof__(&lock->lock), __typeof__(expected)),
"spinlock is not implemented as int*");
return __atomic_compare_exchange_n(&lock->lock, expected, desired, /*weak=*/false,
__ATOMIC_ACQUIRE, __ATOMIC_RELAXED);
}
Expand Down
23 changes: 23 additions & 0 deletions Pal/lib/crypto/adapters/mbedtls_adapter.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "pal_debug.h"
#include "pal_error.h"
#include "rng-arch.h"
#include "spinlock.h"

int mbedtls_to_pal_error(int error) {
switch (error) {
Expand Down Expand Up @@ -380,6 +381,11 @@ static int recv_cb(void* ctx, uint8_t* buf, size_t buf_size) {
/* pal_recv_cb cannot receive more than 32-bit limit, trim buf_size to fit in 32-bit */
buf_size = INT_MAX;
}

/* NOTE: If two threads recv on the same SSL context simultaneously, one of them may block on
* recv() and the other will spin and burn CPU cycles. We consider "shared SSL context"
* a rare case and use simple spinlocks instead of mutexes. */
assert(spinlock_is_locked(&ssl_ctx->lock));
ssize_t ret = ssl_ctx->pal_recv_cb(fd, buf, buf_size);

if (ret < 0) {
Expand All @@ -403,7 +409,13 @@ static int send_cb(void* ctx, uint8_t const* buf, size_t buf_size) {
/* pal_send_cb cannot send more than 32-bit limit, trim buf_size to fit in 32-bit */
buf_size = INT_MAX;
}

/* NOTE: If two threads send on the same SSL context simultaneously, one of them may block on
* send() and the other will spin and burn CPU cycles. We consider "shared SSL context"
* a rare case and use simple spinlocks instead of mutexes. */
assert(spinlock_is_locked(&ssl_ctx->lock));
ssize_t ret = ssl_ctx->pal_send_cb(fd, buf, buf_size);

if (ret < 0) {
if (ret == -EINTR || ret == -EAGAIN || ret == -EWOULDBLOCK)
return MBEDTLS_ERR_SSL_WANT_WRITE;
Expand All @@ -430,6 +442,7 @@ int lib_SSLInit(LIB_SSL_CONTEXT* ssl_ctx, int stream_fd, bool is_server, const u
ssl_ctx->pal_recv_cb = pal_recv_cb;
ssl_ctx->pal_send_cb = pal_send_cb;
ssl_ctx->stream_fd = stream_fd;
spinlock_init(&ssl_ctx->lock);

mbedtls_entropy_init(&ssl_ctx->entropy);
mbedtls_ctr_drbg_init(&ssl_ctx->ctr_drbg);
Expand Down Expand Up @@ -482,18 +495,24 @@ int lib_SSLFree(LIB_SSL_CONTEXT* ssl_ctx) {

int lib_SSLHandshake(LIB_SSL_CONTEXT* ssl_ctx) {
int ret;

spinlock_lock(&ssl_ctx->lock);
while ((ret = mbedtls_ssl_handshake(&ssl_ctx->ssl)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
break;
}
spinlock_unlock(&ssl_ctx->lock);

if (ret != 0)
return mbedtls_to_pal_error(ret);

return 0;
}

int lib_SSLRead(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_read(&ssl_ctx->ssl, buf, buf_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret == 0)
return -PAL_ERROR_ENDOFSTREAM;
if (ret < 0)
Expand All @@ -502,14 +521,18 @@ int lib_SSLRead(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size) {
}

int lib_SSLWrite(LIB_SSL_CONTEXT* ssl_ctx, const uint8_t* buf, size_t buf_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_write(&ssl_ctx->ssl, buf, buf_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret <= 0)
return mbedtls_to_pal_error(ret);
return ret;
}

int lib_SSLSave(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size, size_t* out_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_context_save(&ssl_ctx->ssl, buf, buf_size, out_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret == MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL) {
return -PAL_ERROR_NOMEM;
} else if (ret < 0) {
Expand Down
2 changes: 2 additions & 0 deletions Pal/src/host/Linux-SGX/tools/common/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ include ../../../../../../Scripts/Makefile.rules
CFLAGS := $(filter-out -DIN_PAL, $(CFLAGS))
CFLAGS += -I../.. \
-I../../../../../include/lib \
-I../../../../../include/arch/$(ARCH) \
-I../../../../../include/arch/$(ARCH)/Linux \
-I../../../../../lib/crypto/mbedtls/install/include \
-I../../../../../lib/crypto/mbedtls/crypto/include \
-I../../protected-files \
Expand Down
1 change: 1 addition & 0 deletions Pal/src/pal_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define PAL_INTERNAL_H

#include <stdarg.h>
#include <sys/types.h>

#include "pal.h"
#include "pal_defs.h"
Expand Down