Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spec/std/openssl/ssl/context_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ describe OpenSSL::SSL::Context do
(context.options & OpenSSL::SSL::Options::ALL).should eq(OpenSSL::SSL::Options::ALL)
(context.options & OpenSSL::SSL::Options::NO_SESSION_RESUMPTION_ON_RENEGOTIATION).should eq(OpenSSL::SSL::Options::NO_SESSION_RESUMPTION_ON_RENEGOTIATION)

context.modes.should eq(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS))
context.modes.should eq(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS, ENABLE_PARTIAL_WRITE))
context.verify_mode.should eq(OpenSSL::SSL::VerifyMode::PEER)

OpenSSL::SSL::Context::Client.new(LibSSL.tlsv1_method)
Expand All @@ -26,7 +26,7 @@ describe OpenSSL::SSL::Context do
(context.options & OpenSSL::SSL::Options::NO_SESSION_RESUMPTION_ON_RENEGOTIATION).should eq(OpenSSL::SSL::Options::NO_SESSION_RESUMPTION_ON_RENEGOTIATION)
(context.options & OpenSSL::SSL::Options::NO_RENEGOTIATION).should eq(OpenSSL::SSL::Options::NO_RENEGOTIATION)

context.modes.should eq(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS))
context.modes.should eq(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS, ENABLE_PARTIAL_WRITE))
context.verify_mode.should eq(OpenSSL::SSL::VerifyMode::NONE)

OpenSSL::SSL::Context::Server.new(LibSSL.tlsv1_method)
Expand Down
2 changes: 2 additions & 0 deletions src/openssl/lib_crypto.cr
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ lib LibCrypto

type BioMethod = Void

fun BIO_ctrl(bio : Bio*, cmd : Int, larg : Long, parg : Void*) : Long

fun BIO_new(BioMethod*) : Bio*
fun BIO_free(Bio*) : Int

Expand Down
3 changes: 3 additions & 0 deletions src/openssl/lib_ssl.cr
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ lib LibSSL
fun ssl_get_error = SSL_get_error(handle : SSL, ret : Int) : SSLError
fun ssl_get_servername = SSL_get_servername(ssl : SSL, host_type : TLSExt) : UInt8*
fun ssl_set_bio = SSL_set_bio(handle : SSL, rbio : LibCrypto::Bio*, wbio : LibCrypto::Bio*)
fun ssl_set_fd = SSL_set_fd(handle : SSL, fd : Int) : Int
fun ssl_get_rbio = SSL_get_rbio(handle : SSL) : LibCrypto::Bio*
fun ssl_get_wbio = SSL_get_wbio(handle : SSL) : LibCrypto::Bio*
fun ssl_select_next_proto = SSL_select_next_proto(output : Char**, output_len : Char*, input : Char*, input_len : Int, client : Char*, client_len : Int) : Int
fun ssl_ctrl = SSL_ctrl(handle : SSL, cmd : Int, larg : Long, parg : Void*) : Long
fun ssl_free = SSL_free(handle : SSL)
Expand Down
2 changes: 1 addition & 1 deletion src/openssl/ssl/context.cr
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ abstract class OpenSSL::SSL::Context
NO_SESSION_RESUMPTION_ON_RENEGOTIATION,
NO_RENEGOTIATION,
))
add_modes(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS))
add_modes(OpenSSL::SSL::Modes.flags(AUTO_RETRY, RELEASE_BUFFERS, ENABLE_PARTIAL_WRITE))

# OpenSSL does not support reading from the system root certificate store on
# Windows, so we have to import them ourselves
Expand Down
6 changes: 3 additions & 3 deletions src/openssl/ssl/server.cr
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ class OpenSSL::SSL::Server
#
# This method calls `@wrapped.accept` and wraps the resulting IO in a SSL socket (`OpenSSL::SSL::Socket::Server`) with `context` configuration.
def accept : OpenSSL::SSL::Socket::Server
new_ssl_socket(@wrapped.accept)
new_ssl_socket(@wrapped.accept.as(::Socket))
end

# Implements `::Socket::Server#accept?`.
#
# This method calls `@wrapped.accept?` and wraps the resulting IO in a SSL socket (`OpenSSL::SSL::Socket::Server`) with `context` configuration.
def accept? : OpenSSL::SSL::Socket::Server?
if socket = @wrapped.accept?
new_ssl_socket(socket)
new_ssl_socket(socket.as(::Socket))
end
end

private def new_ssl_socket(io)
private def new_ssl_socket(io : ::Socket)
OpenSSL::SSL::Socket::Server.new(io, @context, sync_close: @sync_close, accept: @start_immediately)
end

Expand Down
162 changes: 98 additions & 64 deletions src/openssl/ssl/socket.cr
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract class OpenSSL::SSL::Socket < IO
class Client < Socket
def initialize(io, context : Context::Client = Context::Client.new, sync_close : Bool = false, hostname : String? = nil)
def initialize(io : ::Socket, context : Context::Client = Context::Client.new, sync_close : Bool = false, hostname : String? = nil)
super(io, context, sync_close)
begin
if hostname
Expand All @@ -25,9 +25,15 @@ abstract class OpenSSL::SSL::Socket < IO
end
end

ret = LibSSL.ssl_connect(@ssl)
unless ret == 1
raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_connect")
loop do
ret = LibSSL.ssl_connect(@ssl)
break if ret == 1
error = LibSSL.ssl_get_error(@ssl, ret)
case error
when .want_read? then wait_readable
when .want_write? then wait_writable

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Does it make sense to handle want_write? for a read operation? And similarly want_read? for write operations?

else raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_connect")
end
end
rescue ex
LibSSL.ssl_free(@ssl) # GC never calls finalize, avoid mem leak
Expand All @@ -52,7 +58,7 @@ abstract class OpenSSL::SSL::Socket < IO
end

class Server < Socket
def initialize(io, context : Context::Server = Context::Server.new,
def initialize(io : ::Socket, context : Context::Server = Context::Server.new,
sync_close : Bool = false, accept : Bool = true)
super(io, context, sync_close)

Expand All @@ -67,10 +73,17 @@ abstract class OpenSSL::SSL::Socket < IO
end

def accept : Nil
ret = LibSSL.ssl_accept(@ssl)
unless ret == 1
@bio.io.close if @sync_close
raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_accept")
loop do
ret = LibSSL.ssl_accept(@ssl)
break if ret == 1
error = LibSSL.ssl_get_error(@ssl, ret)
case error
when .want_read? then wait_readable
when .want_write? then wait_writable
else
@io.close if @sync_close
raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_accept")
end
end
end

Expand All @@ -93,7 +106,10 @@ abstract class OpenSSL::SSL::Socket < IO

getter? closed : Bool

protected def initialize(io, context : Context, @sync_close : Bool = false)
# Returns the underlying `::Socket`.
getter io : ::Socket

protected def initialize(@io : ::Socket, context : Context, @sync_close : Bool = false)
@closed = false

@ssl = LibSSL.ssl_new(context)
Expand All @@ -103,13 +119,14 @@ abstract class OpenSSL::SSL::Socket < IO

# Since OpenSSL::SSL::Socket is buffered it makes no
# sense to wrap a IO::Buffered with buffering activated.
if io.is_a?(IO::Buffered)
io.sync = true
io.read_buffering = false
if @io.is_a?(IO::Buffered)
@io.sync = true
@io.read_buffering = false
Comment on lines +122 to +124

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: If we restrict @io to ::Socket, the restriction to IO::Buffered is unnecessary.

end

@bio = BIO.new(io)
LibSSL.ssl_set_bio(@ssl, @bio, @bio)
unless LibSSL.ssl_set_fd(@ssl, @io.fd) == 1
raise OpenSSL::Error.new("SSL_set_fd")
end
end

def finalize
Expand All @@ -122,11 +139,21 @@ abstract class OpenSSL::SSL::Socket < IO
count = slice.size
return 0 if count == 0

LibSSL.ssl_read(@ssl, slice.to_unsafe, count).tap do |bytes|
if bytes <= 0 && !LibSSL.ssl_get_error(@ssl, bytes).zero_return?
ex = OpenSSL::SSL::Error.new(@ssl, bytes, "SSL_read")
loop do
ret = LibSSL.ssl_read(@ssl, slice.to_unsafe, count)
if ret > 0
return ret
end

error = LibSSL.ssl_get_error(@ssl, ret)
case error
when .want_read? then wait_readable
when .want_write? then wait_writable
when .zero_return? then return 0
else
ex = OpenSSL::SSL::Error.new(@ssl, ret, "SSL_read")
if ex.underlying_eof?
# underlying BIO terminated gracefully, without terminating SSL aspect gracefully first
# underlying socket terminated gracefully, without terminating SSL aspect gracefully first
# some misbehaving servers "do this" so treat as EOF even though it's a protocol error
return 0
end
Expand All @@ -140,15 +167,23 @@ abstract class OpenSSL::SSL::Socket < IO

return if slice.empty?

count = slice.size
bytes = LibSSL.ssl_write(@ssl, slice.to_unsafe, count)
unless bytes > 0
raise OpenSSL::SSL::Error.new(@ssl, bytes, "SSL_write")
while slice.size > 0
ret = LibSSL.ssl_write(@ssl, slice.to_unsafe, slice.size)
if ret > 0
slice += ret
else
error = LibSSL.ssl_get_error(@ssl, ret)
case error
when .want_read? then wait_readable
when .want_write? then wait_writable
else raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_write")
end
end
end
end

def unbuffered_flush : Nil
@bio.io.flush
@io.flush
end

# Returns the negotiated ALPN protocol (eg: `"h2"`) of `nil` if no protocol was
Expand All @@ -167,24 +202,21 @@ abstract class OpenSSL::SSL::Socket < IO
ret = LibSSL.ssl_shutdown(@ssl)
break if ret == 1 # done bidirectional
break if ret == 0 && sync_close? # done unidirectional, "this first successful call to SSL_shutdown() is sufficient"
raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_shutdown") if ret < 0
rescue e : OpenSSL::SSL::Error
case e.error
when .want_read?, .want_write?
# Ignore, shutdown did not complete yet
when .syscall?
# OpenSSL claimed an underlying syscall failed, but that didn't set any error state,
# assume we're done
break
else
raise e
if ret < 0
error = LibSSL.ssl_get_error(@ssl, ret)
case error
when .want_read? then wait_readable
when .want_write? then wait_writable
when .syscall? then break # underlying syscall failed without error state, assume done
else raise OpenSSL::SSL::Error.new(@ssl, ret, "SSL_shutdown")
end
end

# ret == 0, retry, shutdown is not complete yet
end
rescue IO::Error
ensure
@bio.io.close if @sync_close
@io.close if @sync_close
end
end

Expand All @@ -210,49 +242,43 @@ abstract class OpenSSL::SSL::Socket < IO
end

def local_address
io = @bio.io
io.responds_to?(:local_address) ? io.local_address : nil
io = @io
if io.responds_to?(:local_address)
io.local_address
end
end

def remote_address
io = @bio.io
io.responds_to?(:remote_address) ? io.remote_address : nil
io = @io
if io.responds_to?(:remote_address)
io.remote_address
end
end

def read_timeout
io = @bio.io
if io.responds_to? :read_timeout
io.read_timeout
else
raise NotImplementedError.new("#{io.class}#read_timeout")
end
@io.read_timeout
end

def read_timeout=(value)
io = @bio.io
if io.responds_to? :read_timeout=
io.read_timeout = value
else
raise NotImplementedError.new("#{io.class}#read_timeout=")
end
@io.read_timeout = value
end

def write_timeout
io = @bio.io
if io.responds_to? :write_timeout
io.write_timeout
else
raise NotImplementedError.new("#{io.class}#write_timeout")
end
@io.write_timeout
end

def write_timeout=(value)
io = @bio.io
if io.responds_to? :write_timeout=
io.write_timeout = value
else
raise NotImplementedError.new("#{io.class}#write_timeout=")
end
@io.write_timeout = value
end

# Returns `true` if kTLS is being used for sending data.
def ktls_send? : Bool
LibCrypto.BIO_ctrl(LibSSL.ssl_get_wbio(@ssl), LibCrypto::CTRL_GET_KTLS_SEND, 0, Pointer(Void).null) != 0
end

# Returns `true` if kTLS is being used for receiving data.
def ktls_recv? : Bool
LibCrypto.BIO_ctrl(LibSSL.ssl_get_rbio(@ssl), LibCrypto::CTRL_GET_KTLS_RECV, 0, Pointer(Void).null) != 0
end

# Returns the `OpenSSL::X509::Certificate` the peer presented, if a
Expand All @@ -273,4 +299,12 @@ abstract class OpenSSL::SSL::Socket < IO
end
end
end

private def wait_readable : Nil
Crystal::EventLoop.current.wait_readable(@io)
end

private def wait_writable : Nil
Crystal::EventLoop.current.wait_writable(@io)
end
end
Loading