diff --git a/LibOS/shim/test/regression/.gitignore b/LibOS/shim/test/regression/.gitignore index c6f5bd96b8..6efc901295 100644 --- a/LibOS/shim/test/regression/.gitignore +++ b/LibOS/shim/test/regression/.gitignore @@ -65,6 +65,7 @@ /multi_pthread_exitless /openmp /pipe +/pipe_multithread /pipe_nonblocking /pipe_ocloexec /poll diff --git a/LibOS/shim/test/regression/Makefile b/LibOS/shim/test/regression/Makefile index d74204b67e..aea73792dd 100644 --- a/LibOS/shim/test/regression/Makefile +++ b/LibOS/shim/test/regression/Makefile @@ -53,6 +53,7 @@ c_executables = \ multi_pthread \ openmp \ pipe \ + pipe_multithread \ pipe_nonblocking \ pipe_ocloexec \ poll \ @@ -151,26 +152,27 @@ extra_rules = \ include ../../../../Scripts/Makefile.manifest include ../../../../Scripts/Makefile.Test -CFLAGS-bootstrap_static = -static +CFLAGS-abort_multithread = -pthread CFLAGS-bootstrap_pie = -fPIC -pie +CFLAGS-bootstrap_static = -static CFLAGS-debug = -g3 CFLAGS-debug_regs-x86_64 = -g3 +CFLAGS-eventfd = -pthread CFLAGS-exec_same = -pthread -CFLAGS-shared_object = -fPIC -pie -CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux -CFLAGS-openmp = -fopenmp -CFLAGS-multi_pthread = -pthread CFLAGS-exit_group = -pthread -CFLAGS-abort_multithread = -pthread -CFLAGS-eventfd = -pthread CFLAGS-futex_bitset = -pthread CFLAGS-futex_requeue = -pthread CFLAGS-futex_wake_op = -pthread +CFLAGS-multi_pthread = -pthread +CFLAGS-openmp = -fopenmp +CFLAGS-pipe_multithread = -pthread CFLAGS-proc_common = -pthread -CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread +CFLAGS-pthread_set_get_affinity += -pthread +CFLAGS-shared_object = -fPIC -pie CFLAGS-sigaction_per_process += -pthread CFLAGS-signal_multithread += -pthread -CFLAGS-pthread_set_get_affinity += -pthread +CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread +CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux CFLAGS-attestation += -I$(PALDIR)/../lib/crypto/mbedtls/crypto/include \ -I$(PALDIR)/host/Linux-SGX \ diff --git a/LibOS/shim/test/regression/pipe_multithread.c b/LibOS/shim/test/regression/pipe_multithread.c new file mode 100644 index 0000000000..3084227748 --- /dev/null +++ b/LibOS/shim/test/regression/pipe_multithread.c @@ -0,0 +1,81 @@ +/* test creates two threads simulteneously writing on the same pipe */ + +#include +#include +#include +#include +#include +#include +#include + +#define ITERATIONS 100000 + +int fds[2]; + +static void* thread_run(void* arg) { + char c = (char)(uintptr_t)arg; + for (int i = 0; i < ITERATIONS; i++) { + ssize_t bytes = 0; + while (bytes < sizeof(c)) { + bytes = send(fds[1], &c, sizeof(c), /*flags=*/0); + if (bytes < 0) { + if (errno == EAGAIN || errno == EINTR) + continue; + err(1, "send"); + } + } + } + return NULL; +} + +int main(int argc, char** argv) { + int ret; + pthread_t threads[2]; + char thread_ids[2] = {42, 24}; + int thread_bytes[2] = {0, 0}; + + ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fds); + if (ret) { + err(1, "socketpair"); + } + + ret = pthread_create(&threads[0], NULL, &thread_run, (void*)(uintptr_t)thread_ids[0]); + if (ret) { + errno = ret; + err(1, "pthread_create"); + } + + ret = pthread_create(&threads[1], NULL, &thread_run, (void*)(uintptr_t)thread_ids[1]); + if (ret) { + errno = ret; + err(1, "pthread_create"); + } + + for (int i = 0; i < 2 * ITERATIONS; i++) { + char c = 0; + ssize_t bytes = 0; + while (bytes < sizeof(c)) { + bytes = recv(fds[0], &c, sizeof(c), /*flags=*/0); + if (bytes < 0) { + if (errno == EAGAIN || errno == EINTR) + continue; + err(1, "recv"); + } + } + + if (c == thread_ids[0]) + thread_bytes[0] += bytes; + else if (c == thread_ids[1]) + thread_bytes[1] += bytes; + else + errx(1, "received unrecognized thread ID"); + } + + printf("received total bytes from threads: %d and %d\n", thread_bytes[0], thread_bytes[1]); + + if (thread_bytes[0] != ITERATIONS || thread_bytes[1] != ITERATIONS) + errx(1, "received wrong number of bytes from threads"); + + puts("TEST OK"); + return 0; +} diff --git a/Pal/include/lib/pal_crypto.h b/Pal/include/lib/pal_crypto.h index 2bd9fa3958..7faf32b4e1 100644 --- a/Pal/include/lib/pal_crypto.h +++ b/Pal/include/lib/pal_crypto.h @@ -15,6 +15,8 @@ #include #include +#include "spinlock.h" + #define SHA256_DIGEST_LEN 32 #ifdef CRYPTO_USE_MBEDTLS @@ -51,6 +53,7 @@ typedef struct { ssize_t (*pal_recv_cb)(int fd, void* buf, size_t buf_size); ssize_t (*pal_send_cb)(int fd, const void* buf, size_t buf_size); int stream_fd; + spinlock_t lock; } LIB_SSL_CONTEXT; #endif /* CRYPTO_USE_MBEDTLS */ diff --git a/Pal/include/lib/spinlock.h b/Pal/include/lib/spinlock.h index 52e0e6156a..766c986523 100644 --- a/Pal/include/lib/spinlock.h +++ b/Pal/include/lib/spinlock.h @@ -7,7 +7,9 @@ #ifndef _SPINLOCK_H #define _SPINLOCK_H -#include "api.h" +#include + +#include "assert.h" #include "cpu.h" #ifdef DEBUG @@ -60,11 +62,11 @@ static inline void debug_spinlock_giveup_ownership(spinlock_t* lock) { } #else static inline void debug_spinlock_take_ownership(spinlock_t* lock) { - __UNUSED(lock); + (void)lock; } static inline void debug_spinlock_giveup_ownership(spinlock_t* lock) { - __UNUSED(lock); + (void)lock; } #endif // DEBUG_SPINLOCKS_SHIM @@ -158,7 +160,8 @@ static inline int spinlock_lock_timeout(spinlock_t* lock, unsigned long iteratio * returned. */ static inline int spinlock_cmpxchg(spinlock_t* lock, int* expected, int desired) { - static_assert(SAME_TYPE(&lock->lock, expected), "spinlock is not implemented as int*"); + static_assert(__builtin_types_compatible_p(__typeof__(&lock->lock), __typeof__(expected)), + "spinlock is not implemented as int*"); return __atomic_compare_exchange_n(&lock->lock, expected, desired, /*weak=*/false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED); } diff --git a/Pal/lib/crypto/adapters/mbedtls_adapter.c b/Pal/lib/crypto/adapters/mbedtls_adapter.c index 8cc291a101..e4e0bd4b5b 100644 --- a/Pal/lib/crypto/adapters/mbedtls_adapter.c +++ b/Pal/lib/crypto/adapters/mbedtls_adapter.c @@ -22,6 +22,7 @@ #include "pal_debug.h" #include "pal_error.h" #include "rng-arch.h" +#include "spinlock.h" int mbedtls_to_pal_error(int error) { switch (error) { @@ -380,6 +381,11 @@ static int recv_cb(void* ctx, uint8_t* buf, size_t buf_size) { /* pal_recv_cb cannot receive more than 32-bit limit, trim buf_size to fit in 32-bit */ buf_size = INT_MAX; } + + /* NOTE: If two threads recv on the same SSL context simultaneously, one of them may block on + * recv() and the other will spin and burn CPU cycles. We consider "shared SSL context" + * a rare case and use simple spinlocks instead of mutexes. */ + assert(spinlock_is_locked(&ssl_ctx->lock)); ssize_t ret = ssl_ctx->pal_recv_cb(fd, buf, buf_size); if (ret < 0) { @@ -403,7 +409,13 @@ static int send_cb(void* ctx, uint8_t const* buf, size_t buf_size) { /* pal_send_cb cannot send more than 32-bit limit, trim buf_size to fit in 32-bit */ buf_size = INT_MAX; } + + /* NOTE: If two threads send on the same SSL context simultaneously, one of them may block on + * send() and the other will spin and burn CPU cycles. We consider "shared SSL context" + * a rare case and use simple spinlocks instead of mutexes. */ + assert(spinlock_is_locked(&ssl_ctx->lock)); ssize_t ret = ssl_ctx->pal_send_cb(fd, buf, buf_size); + if (ret < 0) { if (ret == -EINTR || ret == -EAGAIN || ret == -EWOULDBLOCK) return MBEDTLS_ERR_SSL_WANT_WRITE; @@ -430,6 +442,7 @@ int lib_SSLInit(LIB_SSL_CONTEXT* ssl_ctx, int stream_fd, bool is_server, const u ssl_ctx->pal_recv_cb = pal_recv_cb; ssl_ctx->pal_send_cb = pal_send_cb; ssl_ctx->stream_fd = stream_fd; + spinlock_init(&ssl_ctx->lock); mbedtls_entropy_init(&ssl_ctx->entropy); mbedtls_ctr_drbg_init(&ssl_ctx->ctr_drbg); @@ -482,10 +495,14 @@ int lib_SSLFree(LIB_SSL_CONTEXT* ssl_ctx) { int lib_SSLHandshake(LIB_SSL_CONTEXT* ssl_ctx) { int ret; + + spinlock_lock(&ssl_ctx->lock); while ((ret = mbedtls_ssl_handshake(&ssl_ctx->ssl)) != 0) { if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) break; } + spinlock_unlock(&ssl_ctx->lock); + if (ret != 0) return mbedtls_to_pal_error(ret); @@ -493,7 +510,9 @@ int lib_SSLHandshake(LIB_SSL_CONTEXT* ssl_ctx) { } int lib_SSLRead(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size) { + spinlock_lock(&ssl_ctx->lock); int ret = mbedtls_ssl_read(&ssl_ctx->ssl, buf, buf_size); + spinlock_unlock(&ssl_ctx->lock); if (ret == 0) return -PAL_ERROR_ENDOFSTREAM; if (ret < 0) @@ -502,14 +521,18 @@ int lib_SSLRead(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size) { } int lib_SSLWrite(LIB_SSL_CONTEXT* ssl_ctx, const uint8_t* buf, size_t buf_size) { + spinlock_lock(&ssl_ctx->lock); int ret = mbedtls_ssl_write(&ssl_ctx->ssl, buf, buf_size); + spinlock_unlock(&ssl_ctx->lock); if (ret <= 0) return mbedtls_to_pal_error(ret); return ret; } int lib_SSLSave(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size, size_t* out_size) { + spinlock_lock(&ssl_ctx->lock); int ret = mbedtls_ssl_context_save(&ssl_ctx->ssl, buf, buf_size, out_size); + spinlock_unlock(&ssl_ctx->lock); if (ret == MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL) { return -PAL_ERROR_NOMEM; } else if (ret < 0) { diff --git a/Pal/src/host/Linux-SGX/tools/common/Makefile b/Pal/src/host/Linux-SGX/tools/common/Makefile index 02944a762d..c95b69aa0a 100644 --- a/Pal/src/host/Linux-SGX/tools/common/Makefile +++ b/Pal/src/host/Linux-SGX/tools/common/Makefile @@ -5,6 +5,8 @@ include ../../../../../../Scripts/Makefile.rules CFLAGS := $(filter-out -DIN_PAL, $(CFLAGS)) CFLAGS += -I../.. \ -I../../../../../include/lib \ + -I../../../../../include/arch/$(ARCH) \ + -I../../../../../include/arch/$(ARCH)/Linux \ -I../../../../../lib/crypto/mbedtls/install/include \ -I../../../../../lib/crypto/mbedtls/crypto/include \ -I../../protected-files \ diff --git a/Pal/src/pal_internal.h b/Pal/src/pal_internal.h index 1b7e6cd9e6..d504678ae6 100644 --- a/Pal/src/pal_internal.h +++ b/Pal/src/pal_internal.h @@ -9,6 +9,7 @@ #define PAL_INTERNAL_H #include +#include #include "pal.h" #include "pal_defs.h"