From 5e9ae4526a5a669ab7caa98fb87e1e008e6d0d4c Mon Sep 17 00:00:00 2001 From: Balint Molnar Date: Tue, 9 Apr 2024 11:11:40 +0200 Subject: [PATCH] Add read/write buffer lock during send/recvmsg (#207) * Handle MSG_DONTWAIT in case of ktls_recmsg * Lock write and read buffers during send and recvmsg * Add mutex initialization to the client socket as well * Rename lock inside camblet poll --- socket.c | 51 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/socket.c b/socket.c index c1f823c8..196a9c06 100644 --- a/socket.c +++ b/socket.c @@ -101,7 +101,11 @@ struct camblet_socket i64 direction; char *alpn; - struct mutex lock; + // BearSSL is not thread safe so we need to lock every interaction with it + struct mutex bearssl_lock; + + struct mutex readbuffer_lock; + struct mutex writebuffer_lock; buffer_t *read_buffer; buffer_t *write_buffer; @@ -158,16 +162,24 @@ static int ktls_recvmsg(camblet_socket *s, void *buf, size_t size, int flags) iov_iter_kvec(&hdr.msg_iter, READ, &iov, 1, buf_len); +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 19, 0) + int nonblock = 0; + if (flags & MSG_DONTWAIT) + { + nonblock = MSG_DONTWAIT; + } +#endif + return s->ktls_recvmsg(s->sock, &hdr, size, #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 19, 0) - 0, + nonblock, #endif flags, &addr_len); } static int bearssl_sendmsg(camblet_socket *s, void *src, size_t len) { - mutex_lock(&s->lock); + mutex_lock(&s->bearssl_lock); int err = br_sslio_write_all(&s->ioc, src, len); if (err < 0) @@ -191,7 +203,7 @@ static int bearssl_sendmsg(camblet_socket *s, void *src, size_t len) } } - mutex_unlock(&s->lock); + mutex_unlock(&s->bearssl_lock); return len; } @@ -231,9 +243,9 @@ br_sslio_read_with_flags(br_sslio_context *ctx, void *dst, size_t len, int flags static int bearssl_recvmsg(camblet_socket *s, void *dst, size_t len, int flags) { - mutex_lock(&s->lock); + mutex_lock(&s->bearssl_lock); int ret = br_sslio_read_with_flags(&s->ioc, dst, len, flags); - mutex_unlock(&s->lock); + mutex_unlock(&s->bearssl_lock); return ret; } @@ -448,7 +460,10 @@ static camblet_socket *camblet_new_server_socket(struct sock *sock, opa_socket_c s->direction = INPUT; - mutex_init(&s->lock); + mutex_init(&s->bearssl_lock); + + mutex_init(&s->readbuffer_lock); + mutex_init(&s->writebuffer_lock); proxywasm *p = this_cpu_proxywasm(); @@ -507,7 +522,10 @@ static camblet_socket *camblet_new_client_socket(struct sock *sock, opa_socket_c s->direction = OUTPUT; - mutex_init(&s->lock); + mutex_init(&s->bearssl_lock); + + mutex_init(&s->readbuffer_lock); + mutex_init(&s->writebuffer_lock); proxywasm *p = this_cpu_proxywasm(); @@ -580,7 +598,7 @@ static int ensure_tls_handshake(camblet_socket *s, struct msghdr *msg) pr_debug("TLS handshake # command[%s] sock[%p]", current->comm, s->sock); - mutex_lock(&s->lock); + mutex_lock(&s->bearssl_lock); // check if bearssl is already closed unsigned int state = br_ssl_engine_current_state(&s->cc->eng); @@ -676,7 +694,7 @@ static int ensure_tls_handshake(camblet_socket *s, struct msghdr *msg) } bail: - mutex_unlock(&s->lock); + mutex_unlock(&s->bearssl_lock); return ret; } @@ -741,6 +759,7 @@ int camblet_recvmsg(struct sock *sock, bool end_of_stream = false; int action = Pause; + mutex_lock(&s->readbuffer_lock); int prevbuflen = get_read_buffer_size(s); while (action != Continue) @@ -840,6 +859,7 @@ int camblet_recvmsg(struct sock *sock, ret = len; bail: + mutex_unlock(&s->readbuffer_lock); return ret; } @@ -860,6 +880,8 @@ int camblet_sendmsg(struct sock *sock, struct msghdr *msg, size_t size) goto bail; } + mutex_lock(&s->writebuffer_lock); + size_t prevbuflen = get_write_buffer_size(s); char *buf = get_write_buffer_for_write(s, size); @@ -925,6 +947,7 @@ int camblet_sendmsg(struct sock *sock, struct msghdr *msg, size_t size) ret = size; bail: + mutex_unlock(&s->writebuffer_lock); return ret; } @@ -952,7 +975,7 @@ void camblet_close(struct sock *sk, long timeout) if (s) { - mutex_lock(&s->lock); + mutex_lock(&s->bearssl_lock); if (is_ktls(s)) { close = s->ktls_close; @@ -969,7 +992,7 @@ void camblet_close(struct sock *sk, long timeout) camblet_socket_free(s); - mutex_unlock(&s->lock); + mutex_unlock(&s->bearssl_lock); } close(sk, timeout); @@ -1979,12 +2002,12 @@ __poll_t camblet_poll(struct file *file, struct socket *sock, { size_t left; - mutex_lock(&s->lock); + mutex_lock(&s->bearssl_lock); { br_ssl_engine_context *ctx = get_ssl_engine_context(s); br_ssl_engine_recvapp_buf(ctx, &left); } - mutex_unlock(&s->lock); + mutex_unlock(&s->bearssl_lock); if (left > 0) {