diff --git a/src/ssl.jl b/src/ssl.jl index 0f0b78a..2d87448 100644 --- a/src/ssl.jl +++ b/src/ssl.jl @@ -140,6 +140,7 @@ end True unless: - `close(::SSLContext)` is called, or + - `closewrite(::SSLContext)` is called, or - the peer closed the connection. """ Base.iswritable(ctx::SSLContext) = !ctx.close_notify_sent && isopen(ctx.bio) @@ -187,7 +188,12 @@ function Base.close(ctx::SSLContext) if iswritable(ctx) closewrite(ctx) end - @assert !iswritable(ctx) + @static if Sys.iswindows() && VERSION < v"1.9.0" + # work-around for a libuv regression where we check the wrong flags during closing + # introduced by https://github.com/libuv/libuv/pull/3036 in v1.42.0 + # fixed by https://github.com/libuv/libuv/pull/3584 in v1.44.2 + ctx.bio isa TCPSocket && isreadable(ctx.bio) && Base.start_reading(ctx.bio) + end close(ctx.bio) nothing end @@ -344,9 +350,16 @@ function ssl_unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt) n == MBEDTLS_ERR_NET_CONN_RESET ? "(CONN_RESET)" : n == MBEDTLS_ERR_SSL_WANT_READ ? "(WANT_READ)" : "")" if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || - n == MBEDTLS_ERR_SSL_CONN_EOF - ctx.isreadable = false - if ctx.close_notify_sent + n == MBEDTLS_ERR_SSL_CONN_EOF || + n == 0 + if n == nbytes == 0 + # caller just wanted us to update bytesavilable + ctx.bytesavailable = ssl_get_bytes_avail(ctx) ;@🤖 "ssl_read ⬅️ $nread, 📥 $(ctx.bytesavailable)" + else + ctx.bytesavailable = 0 + end + ctx.isreadable = ctx.bytesavailable > 0 + if !ctx.isreadable && ctx.close_notify_sent # already called closewrite, so we can go ahead and destroy this fully immediately close(ctx.bio) end @@ -387,15 +400,24 @@ end Copy at most `nbytes` of encrypted data to `buf` from the `bio` connection. If no encrypted bytes are available return: - `MBEDTLS_ERR_SSL_WANT_READ` if the connection is still open, or - - `MBEDTLS_ERR_NET_RECV_FAILED` if it is closed. + - `MBEDTLS_ERR_SSL_CONN_EOF` if it is closed. + - `MBEDTLS_ERR_NET_RECV_FAILED` if it is errored. """ function f_recv(c_bio, buf, nbytes) # (Ptr{Cvoid}, Ptr{UInt8}, Csize_t) @assert nbytes > 0 bio = unsafe_pointer_to_objref(c_bio) n = bytesavailable(bio) - if n == 0 ;@🤖 "f_recv $(isopen(bio) ? "WANT_READ" : "RECV_FAILED")" - return isreadable(bio) ? Cint(MBEDTLS_ERR_SSL_WANT_READ) : - Cint(MBEDTLS_ERR_NET_RECV_FAILED) + if n == 0 + # TODO: we should be able to forward this value directly from wait_for_encrypted_data + isreadable(bio) && ( @🤖 "f_recv WANT_READ"; + return Cint(MBEDTLS_ERR_SSL_WANT_READ)) + try + eof(bio) && ( @🤖 "f_recv CONN_EOF"; + return Cint(MBEDTLS_ERR_SSL_CONN_EOF)) + catch ex ;@🤖 "f_recv RECV_FAILED" + ex isa IOError && return Cint(MBEDTLS_ERR_NET_RECV_FAILED) + rethrow() + end end n = min(nbytes, n) ;@🤖 "f_recv ⬅️ $n" unsafe_read(bio, buf, n) diff --git a/test/clntsrvr/clntsrvr.jl b/test/clntsrvr/clntsrvr.jl index 63a2921..bed7406 100644 --- a/test/clntsrvr/clntsrvr.jl +++ b/test/clntsrvr/clntsrvr.jl @@ -26,27 +26,47 @@ function testclntsrvr(certfile, keyfile) outbuff = ones(UInt8, 100) * UInt8(65) trigger = Channel{Bool}(1) port = UInt16(0) - - @async begin - (port, server) = listenany(8000) - @info("listening on port $port") - put!(trigger, true) + local clntconn, srvrconn + + # setup a watchdog kill-switch + t = Timer(10) do t + @isdefined(clntconn) && close(clntconn) + @isdefined(srvrconn) && close(srvrconn) + close(trigger) + @test "test failed to complete within timeout" + end + + (port, server) = listenany(8000) + @info("listening on port $port") + + r = @async begin srvrconn = sslaccept(server, certfile, keyfile) + close(server) inbuff = read(srvrconn, 100) @test inbuff == outbuff put!(trigger, true) + inbuff2 = read(srvrconn, 1000) + @test inbuff2 == outbuff + put!(trigger, true) + close(srvrconn) end + bind(trigger, r) - take!(trigger) @info("connecting to port $port") clntconn = sslconnect("127.0.0.1", port) @test write(clntconn, outbuff) == 100 - @async begin - sleep(10) - put!(trigger, false) - end - @test take!(trigger) + outbuff .*= 2 + @test write(clntconn, outbuff) == 100 + close(clntconn) + @test take!(trigger) + wait(r) + + close(t) end -testclntsrvr(joinpath(dirname(@__FILE__), "test.cert"), joinpath(dirname(@__FILE__), "test.key")) +@testset "testclntsrvr" begin + testclntsrvr( + joinpath(@__DIR__, "test.cert"), + joinpath(@__DIR__, "test.key")) +end