From b1c2457733c11f30a396ef16345ead88e2f4ead5 Mon Sep 17 00:00:00 2001 From: Max Gurela Date: Wed, 2 Nov 2016 15:05:30 -0700 Subject: [PATCH 1/2] Allow TCPSocket to bind to specific local address --- spec/std/socket_spec.cr | 21 +++++++++++++++++++++ src/socket/tcp_socket.cr | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/spec/std/socket_spec.cr b/spec/std/socket_spec.cr index 0b0ee66e5269..59276cb706b0 100644 --- a/spec/std/socket_spec.cr +++ b/spec/std/socket_spec.cr @@ -386,6 +386,27 @@ describe TCPSocket do TCPSocket.new("localhostttttt", 12345) end end + + it "binds to a local port" do + port = TCPServer.open("::", 0) do |server| + server.local_address.port + end + port.should be > 0 + + TCPServer.open("::", port) do |server| + server.local_address.family.should eq(Socket::Family::INET6) + server.local_address.port.should eq(port) + server.local_address.address.should eq("::") + + # test sync flag propagation after accept + server.sync = !server.sync? + + TCPSocket.open("::", server.local_address.port, "::", 54321) do |client| + sock = server.accept + sock.sync?.should eq(server.sync?) + end + end + end end describe UDPSocket do diff --git a/src/socket/tcp_socket.cr b/src/socket/tcp_socket.cr index 36d9344a5ef3..71e12d4d3cd6 100644 --- a/src/socket/tcp_socket.cr +++ b/src/socket/tcp_socket.cr @@ -17,7 +17,36 @@ class TCPSocket < IPSocket # must be in seconds (integers or floats). # # Note that `dns_timeout` is currently ignored. - def initialize(host, port, dns_timeout = nil, connect_timeout = nil) + def initialize(host, port, local_host, local_port, dns_timeout : Time::Span | Number | Nil = nil, connect_timeout : Time::Span | Number | Nil = nil) + getaddrinfo(host, port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |addrinfo| + super create_socket(addrinfo.ai_family, addrinfo.ai_socktype, addrinfo.ai_protocol) + + getaddrinfo(local_host, local_port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |localaddrinfo| + ret = + {% if flag?(:freebsd) || flag?(:openbsd) %} + LibC.bind(fd, localaddrinfo.ai_addr.as(LibC::Sockaddr*), localaddrinfo.ai_addrlen) + {% else %} + LibC.bind(fd, localaddrinfo.ai_addr, localaddrinfo.ai_addrlen) + {% end %} + unless ret == 0 + next false if localaddrinfo.ai_next + raise Errno.new("Error binding TCP socket at #{local_host}:#{local_port}") + end + + if err = nonblocking_connect(host, port, addrinfo, timeout: connect_timeout) + close + next false if addrinfo.ai_next + raise err + end + + true + end + + true + end + end + + def initialize(host, port, dns_timeout : Time::Span | Number | Nil = nil, connect_timeout : Time::Span | Number | Nil = nil) getaddrinfo(host, port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |addrinfo| super create_socket(addrinfo.ai_family, addrinfo.ai_socktype, addrinfo.ai_protocol) @@ -39,6 +68,15 @@ class TCPSocket < IPSocket # eventually closes the socket when the block returns. # # Returns the value of the block. + def self.open(host, port, local_host, local_port) + sock = new(host, port, local_host: local_host, local_port: local_port) + begin + yield sock + ensure + sock.close + end + end + def self.open(host, port) sock = new(host, port) begin From 53c85a30f8c143e28376482fcb98bdb13f89ebce Mon Sep 17 00:00:00 2001 From: Max Gurela Date: Wed, 2 Nov 2016 17:32:08 -0700 Subject: [PATCH 2/2] Forward TCPSocket#open params to TCPSocket#new --- src/socket/tcp_socket.cr | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/socket/tcp_socket.cr b/src/socket/tcp_socket.cr index 71e12d4d3cd6..293423d12189 100644 --- a/src/socket/tcp_socket.cr +++ b/src/socket/tcp_socket.cr @@ -67,18 +67,10 @@ class TCPSocket < IPSocket # Opens a TCP socket to a remote TCP server, yields it to the block, then # eventually closes the socket when the block returns. # + # Forwards all params to TCPSocket#new. # Returns the value of the block. - def self.open(host, port, local_host, local_port) - sock = new(host, port, local_host: local_host, local_port: local_port) - begin - yield sock - ensure - sock.close - end - end - - def self.open(host, port) - sock = new(host, port) + def self.open(*args, **kwargs) + sock = new(*args, **kwargs) begin yield sock ensure