Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
Add read/write buffer lock during send/recvmsg (#207)
Browse files Browse the repository at this point in the history
* Handle MSG_DONTWAIT in case of ktls_recmsg

* Lock write and read buffers during send and recvmsg

* Add mutex initialization to the client socket as well

* Rename lock inside camblet poll
  • Loading branch information
baluchicken authored Apr 9, 2024
1 parent 19377ca commit 5e9ae45
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ struct camblet_socket
i64 direction;
char *alpn;

struct mutex lock;
// BearSSL is not thread safe so we need to lock every interaction with it
struct mutex bearssl_lock;

struct mutex readbuffer_lock;
struct mutex writebuffer_lock;

buffer_t *read_buffer;
buffer_t *write_buffer;
Expand Down Expand Up @@ -158,16 +162,24 @@ static int ktls_recvmsg(camblet_socket *s, void *buf, size_t size, int flags)

iov_iter_kvec(&hdr.msg_iter, READ, &iov, 1, buf_len);

#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 19, 0)
int nonblock = 0;
if (flags & MSG_DONTWAIT)
{
nonblock = MSG_DONTWAIT;
}
#endif

return s->ktls_recvmsg(s->sock, &hdr, size,
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 19, 0)
0,
nonblock,
#endif
flags, &addr_len);
}

static int bearssl_sendmsg(camblet_socket *s, void *src, size_t len)
{
mutex_lock(&s->lock);
mutex_lock(&s->bearssl_lock);

int err = br_sslio_write_all(&s->ioc, src, len);
if (err < 0)
Expand All @@ -191,7 +203,7 @@ static int bearssl_sendmsg(camblet_socket *s, void *src, size_t len)
}
}

mutex_unlock(&s->lock);
mutex_unlock(&s->bearssl_lock);

return len;
}
Expand Down Expand Up @@ -231,9 +243,9 @@ br_sslio_read_with_flags(br_sslio_context *ctx, void *dst, size_t len, int flags

static int bearssl_recvmsg(camblet_socket *s, void *dst, size_t len, int flags)
{
mutex_lock(&s->lock);
mutex_lock(&s->bearssl_lock);
int ret = br_sslio_read_with_flags(&s->ioc, dst, len, flags);
mutex_unlock(&s->lock);
mutex_unlock(&s->bearssl_lock);
return ret;
}

Expand Down Expand Up @@ -448,7 +460,10 @@ static camblet_socket *camblet_new_server_socket(struct sock *sock, opa_socket_c

s->direction = INPUT;

mutex_init(&s->lock);
mutex_init(&s->bearssl_lock);

mutex_init(&s->readbuffer_lock);
mutex_init(&s->writebuffer_lock);

proxywasm *p = this_cpu_proxywasm();

Expand Down Expand Up @@ -507,7 +522,10 @@ static camblet_socket *camblet_new_client_socket(struct sock *sock, opa_socket_c

s->direction = OUTPUT;

mutex_init(&s->lock);
mutex_init(&s->bearssl_lock);

mutex_init(&s->readbuffer_lock);
mutex_init(&s->writebuffer_lock);

proxywasm *p = this_cpu_proxywasm();

Expand Down Expand Up @@ -580,7 +598,7 @@ static int ensure_tls_handshake(camblet_socket *s, struct msghdr *msg)

pr_debug("TLS handshake # command[%s] sock[%p]", current->comm, s->sock);

mutex_lock(&s->lock);
mutex_lock(&s->bearssl_lock);

// check if bearssl is already closed
unsigned int state = br_ssl_engine_current_state(&s->cc->eng);
Expand Down Expand Up @@ -676,7 +694,7 @@ static int ensure_tls_handshake(camblet_socket *s, struct msghdr *msg)
}

bail:
mutex_unlock(&s->lock);
mutex_unlock(&s->bearssl_lock);

return ret;
}
Expand Down Expand Up @@ -741,6 +759,7 @@ int camblet_recvmsg(struct sock *sock,
bool end_of_stream = false;
int action = Pause;

mutex_lock(&s->readbuffer_lock);
int prevbuflen = get_read_buffer_size(s);

while (action != Continue)
Expand Down Expand Up @@ -840,6 +859,7 @@ int camblet_recvmsg(struct sock *sock,
ret = len;

bail:
mutex_unlock(&s->readbuffer_lock);
return ret;
}

Expand All @@ -860,6 +880,8 @@ int camblet_sendmsg(struct sock *sock, struct msghdr *msg, size_t size)
goto bail;
}

mutex_lock(&s->writebuffer_lock);

size_t prevbuflen = get_write_buffer_size(s);

char *buf = get_write_buffer_for_write(s, size);
Expand Down Expand Up @@ -925,6 +947,7 @@ int camblet_sendmsg(struct sock *sock, struct msghdr *msg, size_t size)
ret = size;

bail:
mutex_unlock(&s->writebuffer_lock);
return ret;
}

Expand Down Expand Up @@ -952,7 +975,7 @@ void camblet_close(struct sock *sk, long timeout)

if (s)
{
mutex_lock(&s->lock);
mutex_lock(&s->bearssl_lock);
if (is_ktls(s))
{
close = s->ktls_close;
Expand All @@ -969,7 +992,7 @@ void camblet_close(struct sock *sk, long timeout)

camblet_socket_free(s);

mutex_unlock(&s->lock);
mutex_unlock(&s->bearssl_lock);
}

close(sk, timeout);
Expand Down Expand Up @@ -1979,12 +2002,12 @@ __poll_t camblet_poll(struct file *file, struct socket *sock,
{
size_t left;

mutex_lock(&s->lock);
mutex_lock(&s->bearssl_lock);
{
br_ssl_engine_context *ctx = get_ssl_engine_context(s);
br_ssl_engine_recvapp_buf(ctx, &left);
}
mutex_unlock(&s->lock);
mutex_unlock(&s->bearssl_lock);

if (left > 0)
{
Expand Down

0 comments on commit 5e9ae45

Please sign in to comment.