Skip to content

Commit

Permalink
Merge pull request #821 from rhenium/ky/ssl-read-write-check-cb-state
Browse files Browse the repository at this point in the history
ssl: handle callback exceptions in SSLSocket#sysread and #syswrite
  • Loading branch information
rhenium authored Dec 7, 2024
2 parents 81409ee + aac9ce1 commit f8937a6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
20 changes: 18 additions & 2 deletions ext/openssl/ossl_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -1916,7 +1916,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
{
SSL *ssl;
int ilen;
VALUE len, str;
VALUE len, str, cb_state;
VALUE opts = Qnil;

if (nonblock) {
Expand Down Expand Up @@ -1949,6 +1949,14 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
rb_str_locktmp(str);
for (;;) {
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
rb_ivar_set(self, ID_callback_state, Qnil);
ossl_clear_error();
rb_jump_tag(NUM2INT(cb_state));
}

switch (ssl_get_error(ssl, nread)) {
case SSL_ERROR_NONE:
rb_str_unlocktmp(str);
Expand Down Expand Up @@ -2038,7 +2046,7 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)
SSL *ssl;
rb_io_t *fptr;
int num, nonblock = opts != Qfalse;
VALUE tmp;
VALUE tmp, cb_state;

GetSSL(self, ssl);
if (!ssl_started(ssl))
Expand All @@ -2055,6 +2063,14 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)

for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num);

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
rb_ivar_set(self, ID_callback_state, Qnil);
ossl_clear_error();
rb_jump_tag(NUM2INT(cb_state));
}

switch (ssl_get_error(ssl, nwritten)) {
case SSL_ERROR_NONE:
return INT2NUM(nwritten);
Expand Down
55 changes: 49 additions & 6 deletions test/openssl/test_ssl_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def test_server_session_cache
# deadlock.
TEST_SESSION_REMOVE_CB = ENV["OSSL_TEST_ALL"] == "1"

def test_ctx_client_session_cb
ctx_proc = proc { |ctx| ctx.ssl_version = :TLSv1_2 }
start_server(ctx_proc: ctx_proc) do |port|
def test_ctx_client_session_cb_tls12
start_server do |port|
called = {}
ctx = OpenSSL::SSL::SSLContext.new
ctx.min_version = ctx.max_version = :TLS1_2
ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT
ctx.session_new_cb = lambda { |ary|
sock, sess = ary
Expand All @@ -233,23 +233,66 @@ def test_ctx_client_session_cb
ctx.session_remove_cb = lambda { |ary|
ctx, sess = ary
called[:remove] = [ctx, sess]
# any resulting value is OK (ignored)
}
end

server_connect_with_session(port, ctx, nil) { |ssl|
assert_equal(1, ctx.session_cache_stats[:cache_num])
assert_equal(1, ctx.session_cache_stats[:connect_good])
assert_equal([ssl, ssl.session], called[:new])
assert(ctx.session_remove(ssl.session))
assert(!ctx.session_remove(ssl.session))
assert_equal(true, ctx.session_remove(ssl.session))
assert_equal(false, ctx.session_remove(ssl.session))
if TEST_SESSION_REMOVE_CB
assert_equal([ctx, ssl.session], called[:remove])
end
}
end
end

def test_ctx_client_session_cb_tls13
omit "TLS 1.3 not supported" unless tls13_supported?
omit "LibreSSL does not call session_new_cb in TLS 1.3" if libressl?

start_server do |port|
called = {}
ctx = OpenSSL::SSL::SSLContext.new
ctx.min_version = :TLS1_3
ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT
ctx.session_new_cb = lambda { |ary|
sock, sess = ary
called[:new] = [sock, sess]
}

server_connect_with_session(port, ctx, nil) { |ssl|
ssl.puts("abc"); assert_equal("abc\n", ssl.gets)

assert_operator(1, :<=, ctx.session_cache_stats[:cache_num])
assert_operator(1, :<=, ctx.session_cache_stats[:connect_good])
assert_equal([ssl, ssl.session], called[:new])
}
end
end

def test_ctx_client_session_cb_tls13_exception
omit "TLS 1.3 not supported" unless tls13_supported?
omit "LibreSSL does not call session_new_cb in TLS 1.3" if libressl?

start_server do |port|
ctx = OpenSSL::SSL::SSLContext.new
ctx.min_version = :TLS1_3
ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT
ctx.session_new_cb = lambda { |ary|
raise "in session_new_cb"
}

server_connect_with_session(port, ctx, nil) { |ssl|
assert_raise_with_message(RuntimeError, /in session_new_cb/) {
ssl.puts("abc"); assert_equal("abc\n", ssl.gets)
}
}
end
end

def test_ctx_server_session_cb
connections = nil
called = {}
Expand Down

0 comments on commit f8937a6

Please sign in to comment.