diff --git a/spec/std/socket_spec.cr b/spec/std/socket_spec.cr index 0b0ee66e5269..59276cb706b0 100644 --- a/spec/std/socket_spec.cr +++ b/spec/std/socket_spec.cr @@ -386,6 +386,27 @@ describe TCPSocket do TCPSocket.new("localhostttttt", 12345) end end + + it "binds to a local port" do + port = TCPServer.open("::", 0) do |server| + server.local_address.port + end + port.should be > 0 + + TCPServer.open("::", port) do |server| + server.local_address.family.should eq(Socket::Family::INET6) + server.local_address.port.should eq(port) + server.local_address.address.should eq("::") + + # test sync flag propagation after accept + server.sync = !server.sync? + + TCPSocket.open("::", server.local_address.port, "::", 54321) do |client| + sock = server.accept + sock.sync?.should eq(server.sync?) + end + end + end end describe UDPSocket do diff --git a/src/socket/tcp_socket.cr b/src/socket/tcp_socket.cr index 36d9344a5ef3..293423d12189 100644 --- a/src/socket/tcp_socket.cr +++ b/src/socket/tcp_socket.cr @@ -17,7 +17,36 @@ class TCPSocket < IPSocket # must be in seconds (integers or floats). # # Note that `dns_timeout` is currently ignored. - def initialize(host, port, dns_timeout = nil, connect_timeout = nil) + def initialize(host, port, local_host, local_port, dns_timeout : Time::Span | Number | Nil = nil, connect_timeout : Time::Span | Number | Nil = nil) + getaddrinfo(host, port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |addrinfo| + super create_socket(addrinfo.ai_family, addrinfo.ai_socktype, addrinfo.ai_protocol) + + getaddrinfo(local_host, local_port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |localaddrinfo| + ret = + {% if flag?(:freebsd) || flag?(:openbsd) %} + LibC.bind(fd, localaddrinfo.ai_addr.as(LibC::Sockaddr*), localaddrinfo.ai_addrlen) + {% else %} + LibC.bind(fd, localaddrinfo.ai_addr, localaddrinfo.ai_addrlen) + {% end %} + unless ret == 0 + next false if localaddrinfo.ai_next + raise Errno.new("Error binding TCP socket at #{local_host}:#{local_port}") + end + + if err = nonblocking_connect(host, port, addrinfo, timeout: connect_timeout) + close + next false if addrinfo.ai_next + raise err + end + + true + end + + true + end + end + + def initialize(host, port, dns_timeout : Time::Span | Number | Nil = nil, connect_timeout : Time::Span | Number | Nil = nil) getaddrinfo(host, port, nil, Type::STREAM, Protocol::TCP, timeout: dns_timeout) do |addrinfo| super create_socket(addrinfo.ai_family, addrinfo.ai_socktype, addrinfo.ai_protocol) @@ -38,9 +67,10 @@ class TCPSocket < IPSocket # Opens a TCP socket to a remote TCP server, yields it to the block, then # eventually closes the socket when the block returns. # + # Forwards all params to TCPSocket#new. # Returns the value of the block. - def self.open(host, port) - sock = new(host, port) + def self.open(*args, **kwargs) + sock = new(*args, **kwargs) begin yield sock ensure