diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 838781970b9..eab0566f582 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -125,14 +125,14 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in if !d.config.FollowRedirect { writer = &buf.SequentialWriter{Writer: conn} } else { - tCtx := internet.ContextWithBindAddress(context.Background(), dest) - tCtx = internet.ContextWithStreamSettings(tCtx, &internet.MemoryStreamConfig{ - ProtocolName: "udp", - SocketSettings: &internet.SocketConfig{ - Tproxy: internet.SocketConfig_TProxy, - }, - }) - tConn, err := internet.DialSystem(tCtx, net.DestinationFromAddr(conn.RemoteAddr())) + sockopt := &internet.SocketConfig{ + Tproxy: internet.SocketConfig_TProxy, + } + if dest.Address.Family().IsIP() { + sockopt.BindAddress = dest.Address.IP() + sockopt.BindPort = uint32(dest.Port) + } + tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt) if err != nil { return err } diff --git a/testing/servers/tcp/tcp.go b/testing/servers/tcp/tcp.go index c5630253e23..7d21fb567a7 100644 --- a/testing/servers/tcp/tcp.go +++ b/testing/servers/tcp/tcp.go @@ -22,10 +22,10 @@ type Server struct { } func (server *Server) Start() (net.Destination, error) { - return server.StartContext(context.Background()) + return server.StartContext(context.Background(), nil) } -func (server *Server) StartContext(ctx context.Context) (net.Destination, error) { +func (server *Server) StartContext(ctx context.Context, sockopt *internet.SocketConfig) (net.Destination, error) { listenerAddr := server.Listen if listenerAddr == nil { listenerAddr = net.LocalHostIP @@ -33,7 +33,7 @@ func (server *Server) StartContext(ctx context.Context) (net.Destination, error) listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: listenerAddr.IP(), Port: int(server.Port), - }) + }, sockopt) if err != nil { return net.Destination{}, err } diff --git a/transport/internet/config.pb.go b/transport/internet/config.pb.go index 105df0397cf..4ddef48996d 100644 --- a/transport/internet/config.pb.go +++ b/transport/internet/config.pb.go @@ -310,6 +310,8 @@ type SocketConfig struct { // ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option. // This option is for UDP only. ReceiveOriginalDestAddress bool `protobuf:"varint,4,opt,name=receive_original_dest_address,json=receiveOriginalDestAddress,proto3" json:"receive_original_dest_address,omitempty"` + BindAddress []byte `protobuf:"bytes,5,opt,name=bind_address,json=bindAddress,proto3" json:"bind_address,omitempty"` + BindPort uint32 `protobuf:"varint,6,opt,name=bind_port,json=bindPort,proto3" json:"bind_port,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -368,6 +370,20 @@ func (m *SocketConfig) GetReceiveOriginalDestAddress() bool { return false } +func (m *SocketConfig) GetBindAddress() []byte { + if m != nil { + return m.BindAddress + } + return nil +} + +func (m *SocketConfig) GetBindPort() uint32 { + if m != nil { + return m.BindPort + } + return 0 +} + func init() { proto.RegisterEnum("v2ray.core.transport.internet.TransportProtocol", TransportProtocol_name, TransportProtocol_value) proto.RegisterEnum("v2ray.core.transport.internet.SocketConfig_TCPFastOpenState", SocketConfig_TCPFastOpenState_name, SocketConfig_TCPFastOpenState_value) @@ -383,43 +399,45 @@ func init() { } var fileDescriptor_91dbc815c3d97a05 = []byte{ - // 607 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0xdd, 0x6e, 0xd3, 0x4c, - 0x10, 0xad, 0xed, 0x34, 0x4d, 0x27, 0x69, 0xea, 0xee, 0x55, 0x54, 0xa9, 0xfa, 0xfa, 0x05, 0x09, - 0x45, 0x20, 0xad, 0x2b, 0x23, 0xb8, 0xe2, 0xa6, 0x4d, 0x40, 0x54, 0xd0, 0xc6, 0x72, 0x0c, 0x48, - 0x95, 0x90, 0xb5, 0x75, 0x26, 0x91, 0xd5, 0xd8, 0x1b, 0xed, 0x2e, 0x15, 0x79, 0x25, 0xae, 0x79, - 0x08, 0x5e, 0x86, 0x77, 0x40, 0xbb, 0xfe, 0x21, 0x2a, 0x28, 0xb4, 0xe2, 0x6e, 0x3c, 0x73, 0xe6, - 0xcc, 0x39, 0x33, 0x5e, 0xa0, 0xb7, 0xbe, 0x60, 0x2b, 0x9a, 0xf0, 0xcc, 0x4b, 0xb8, 0x40, 0x4f, - 0x09, 0x96, 0xcb, 0x25, 0x17, 0xca, 0x4b, 0x73, 0x85, 0x22, 0x47, 0xe5, 0x25, 0x3c, 0x9f, 0xa5, - 0x73, 0xba, 0x14, 0x5c, 0x71, 0x72, 0x54, 0xe1, 0x05, 0xd2, 0x1a, 0x4b, 0x2b, 0xec, 0xe1, 0xc9, - 0x1d, 0xba, 0x84, 0x67, 0x19, 0xcf, 0x3d, 0x89, 0x22, 0x65, 0x0b, 0x4f, 0xad, 0x96, 0x38, 0x8d, - 0x33, 0x94, 0x92, 0xcd, 0xb1, 0x20, 0xec, 0x7f, 0xb7, 0x60, 0x3f, 0xaa, 0x88, 0x86, 0x66, 0x14, - 0x79, 0x07, 0x2d, 0x53, 0x4c, 0xf8, 0xa2, 0x67, 0x1d, 0x5b, 0x83, 0xae, 0x7f, 0x42, 0x37, 0xce, - 0xa5, 0x35, 0x43, 0x50, 0xf6, 0x85, 0x35, 0x03, 0x79, 0x04, 0x7b, 0x55, 0x1c, 0xe7, 0x2c, 0xc3, - 0x9e, 0x73, 0x6c, 0x0d, 0x76, 0xc3, 0x4e, 0x95, 0xbc, 0x64, 0x19, 0x92, 0x33, 0x68, 0x49, 0x54, - 0x2a, 0xcd, 0xe7, 0xb2, 0x67, 0x1f, 0x5b, 0x83, 0xb6, 0xff, 0x78, 0x7d, 0x64, 0xe1, 0x83, 0x16, - 0x3e, 0x68, 0xa4, 0x7d, 0x5c, 0x14, 0x36, 0xc2, 0xba, 0xaf, 0xff, 0xcd, 0x81, 0xce, 0x44, 0x09, - 0x64, 0x59, 0xe9, 0x23, 0xf8, 0x77, 0x1f, 0x67, 0x76, 0xcf, 0xda, 0xe4, 0x65, 0xfb, 0x0f, 0x5e, - 0x3e, 0x01, 0xa9, 0xa9, 0xe3, 0x35, 0x57, 0xce, 0xa0, 0xed, 0xd3, 0xfb, 0x0a, 0x28, 0x2c, 0x84, - 0x07, 0x35, 0x66, 0x52, 0x12, 0x69, 0x0d, 0x12, 0x93, 0xcf, 0x22, 0x55, 0xab, 0x58, 0x5f, 0xb4, - 0xda, 0x67, 0x95, 0xd4, 0xdb, 0x21, 0x13, 0x38, 0xa8, 0x41, 0xb5, 0x84, 0x86, 0x91, 0x70, 0xdf, - 0xc5, 0xba, 0x15, 0x41, 0x3d, 0x39, 0x82, 0x7d, 0xc9, 0x93, 0x1b, 0x5c, 0x73, 0xd5, 0x34, 0xb7, - 0x7a, 0xfa, 0x17, 0x57, 0x13, 0xd3, 0x55, 0x5a, 0xea, 0x16, 0x1c, 0x15, 0x6b, 0xff, 0x3f, 0x68, - 0x07, 0x82, 0x7f, 0x59, 0x95, 0x47, 0x73, 0xc1, 0x51, 0x6c, 0x6e, 0xee, 0xb5, 0x1b, 0xea, 0xb0, - 0xff, 0xc3, 0x86, 0xce, 0x3a, 0x03, 0x21, 0xd0, 0xc8, 0x98, 0xb8, 0x31, 0x98, 0xed, 0xd0, 0xc4, - 0xe4, 0x12, 0x1c, 0x35, 0xe3, 0xe6, 0xdf, 0xe9, 0xfa, 0x2f, 0x1f, 0xa0, 0x87, 0x46, 0xc3, 0xe0, - 0x35, 0x93, 0x6a, 0xbc, 0xc4, 0x7c, 0xa2, 0x98, 0xc2, 0x50, 0x13, 0x91, 0x4b, 0x68, 0xaa, 0xa5, - 0x96, 0x65, 0xd6, 0xdb, 0xf5, 0x5f, 0x3c, 0x88, 0xd2, 0x18, 0xba, 0xe0, 0x53, 0x0c, 0x4b, 0x16, - 0x72, 0x0a, 0x47, 0x02, 0x13, 0x4c, 0x6f, 0x31, 0xe6, 0x22, 0x9d, 0xa7, 0x39, 0x5b, 0xc4, 0x53, - 0x94, 0x2a, 0x66, 0xd3, 0xa9, 0x40, 0xa9, 0x8f, 0x63, 0x0d, 0x5a, 0xe1, 0x61, 0x09, 0x1a, 0x97, - 0x98, 0x11, 0x4a, 0x75, 0x5a, 0x20, 0xfa, 0xcf, 0xc1, 0xbd, 0xab, 0x95, 0xb4, 0xa0, 0x71, 0x2a, - 0xcf, 0xa5, 0xbb, 0x45, 0x00, 0x9a, 0xaf, 0x72, 0x76, 0xbd, 0x40, 0xd7, 0x22, 0x6d, 0xd8, 0x19, - 0xa5, 0xd2, 0x7c, 0xd8, 0x7d, 0x0f, 0xe0, 0x97, 0x1e, 0xb2, 0x03, 0xce, 0x78, 0x36, 0x2b, 0xf0, - 0x45, 0xda, 0xb5, 0x48, 0x07, 0x5a, 0x21, 0x4e, 0x53, 0x81, 0x89, 0x72, 0xed, 0x27, 0x57, 0x70, - 0xf0, 0xdb, 0x3b, 0xd0, 0x7d, 0xd1, 0x30, 0x70, 0xb7, 0x74, 0xf0, 0x7e, 0x14, 0xb8, 0x96, 0x1e, - 0x7d, 0xf1, 0x76, 0x18, 0xb8, 0x36, 0xd9, 0x83, 0xdd, 0x8f, 0x78, 0x5d, 0x6c, 0xc0, 0x75, 0x74, - 0xe1, 0x4d, 0x14, 0x05, 0x6e, 0x83, 0xb8, 0xd0, 0x19, 0xf1, 0x8c, 0xa5, 0x79, 0x59, 0xdb, 0x3e, - 0x1b, 0xc3, 0xff, 0x09, 0xcf, 0x36, 0xef, 0x32, 0xb0, 0xae, 0x5a, 0x55, 0xfc, 0xd5, 0x3e, 0xfa, - 0xe0, 0x87, 0x6c, 0x45, 0x87, 0x1a, 0x5b, 0xcb, 0xa2, 0xe7, 0x65, 0xfd, 0xba, 0x69, 0x9e, 0xde, - 0xb3, 0x9f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xf2, 0x99, 0xd5, 0x37, 0x49, 0x05, 0x00, 0x00, + // 636 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0xdd, 0x6a, 0xdb, 0x4c, + 0x10, 0x8d, 0x2c, 0xc7, 0x91, 0xc7, 0x8e, 0xa3, 0xec, 0x95, 0xc9, 0x47, 0xf8, 0x12, 0x17, 0x8a, + 0x69, 0x41, 0x0a, 0x2a, 0xed, 0x55, 0x6f, 0x12, 0xbb, 0xa5, 0xa1, 0x4d, 0x2c, 0x64, 0xb5, 0x85, + 0x40, 0x11, 0x6b, 0x69, 0x6c, 0x44, 0x2c, 0xad, 0xd9, 0xdd, 0x86, 0xfa, 0x95, 0x0a, 0xbd, 0xeb, + 0x43, 0xf4, 0xb1, 0xca, 0xae, 0x7e, 0x6a, 0xd2, 0x92, 0x26, 0xf4, 0x6e, 0x34, 0x73, 0xe6, 0xcc, + 0x9c, 0x39, 0x5a, 0x70, 0x6e, 0x3c, 0x4e, 0xd7, 0x4e, 0xcc, 0x32, 0x37, 0x66, 0x1c, 0x5d, 0xc9, + 0x69, 0x2e, 0x56, 0x8c, 0x4b, 0x37, 0xcd, 0x25, 0xf2, 0x1c, 0xa5, 0x1b, 0xb3, 0x7c, 0x9e, 0x2e, + 0x9c, 0x15, 0x67, 0x92, 0x91, 0xc3, 0x0a, 0xcf, 0xd1, 0xa9, 0xb1, 0x4e, 0x85, 0x3d, 0x38, 0xb9, + 0x45, 0x17, 0xb3, 0x2c, 0x63, 0xb9, 0x2b, 0x90, 0xa7, 0x74, 0xe9, 0xca, 0xf5, 0x0a, 0x93, 0x28, + 0x43, 0x21, 0xe8, 0x02, 0x0b, 0xc2, 0xc1, 0x0f, 0x03, 0xf6, 0xc2, 0x8a, 0x68, 0xa4, 0x47, 0x91, + 0x77, 0x60, 0xe9, 0x62, 0xcc, 0x96, 0x7d, 0xe3, 0xc8, 0x18, 0xf6, 0xbc, 0x13, 0xe7, 0xce, 0xb9, + 0x4e, 0xcd, 0xe0, 0x97, 0x7d, 0x41, 0xcd, 0x40, 0x1e, 0xc1, 0x6e, 0x15, 0x47, 0x39, 0xcd, 0xb0, + 0x6f, 0x1e, 0x19, 0xc3, 0x76, 0xd0, 0xad, 0x92, 0x97, 0x34, 0x43, 0x72, 0x06, 0x96, 0x40, 0x29, + 0xd3, 0x7c, 0x21, 0xfa, 0x8d, 0x23, 0x63, 0xd8, 0xf1, 0x1e, 0x6f, 0x8e, 0x2c, 0x74, 0x38, 0x85, + 0x0e, 0x27, 0x54, 0x3a, 0x2e, 0x0a, 0x19, 0x41, 0xdd, 0x37, 0xf8, 0x6e, 0x42, 0x77, 0x2a, 0x39, + 0xd2, 0xac, 0xd4, 0xe1, 0xff, 0xbb, 0x8e, 0xb3, 0x46, 0xdf, 0xb8, 0x4b, 0xcb, 0xf6, 0x1f, 0xb4, + 0x7c, 0x02, 0x52, 0x53, 0x47, 0x1b, 0xaa, 0xcc, 0x61, 0xc7, 0x73, 0xee, 0xbb, 0x40, 0x21, 0x21, + 0xd8, 0xaf, 0x31, 0xd3, 0x92, 0x48, 0xed, 0x20, 0x30, 0xfe, 0xcc, 0x53, 0xb9, 0x8e, 0x94, 0xa3, + 0xd5, 0x3d, 0xab, 0xa4, 0xba, 0x0e, 0x99, 0xc2, 0x7e, 0x0d, 0xaa, 0x57, 0x68, 0xea, 0x15, 0xee, + 0x7b, 0x58, 0xbb, 0x22, 0xa8, 0x27, 0x87, 0xb0, 0x27, 0x58, 0x7c, 0x8d, 0x1b, 0xaa, 0x5a, 0xda, + 0xab, 0xa7, 0x7f, 0x51, 0x35, 0xd5, 0x5d, 0xa5, 0xa4, 0x5e, 0xc1, 0x51, 0xb1, 0x0e, 0xfe, 0x87, + 0x8e, 0xcf, 0xd9, 0x97, 0x75, 0x69, 0x9a, 0x0d, 0xa6, 0xa4, 0x0b, 0xed, 0x57, 0x3b, 0x50, 0xe1, + 0xe0, 0x9b, 0xf2, 0x75, 0x83, 0x81, 0x10, 0x68, 0x66, 0x94, 0x5f, 0x6b, 0xcc, 0x76, 0xa0, 0x63, + 0x72, 0x09, 0xa6, 0x9c, 0x33, 0xfd, 0xef, 0xf4, 0xbc, 0x97, 0x0f, 0xd8, 0xc7, 0x09, 0x47, 0xfe, + 0x6b, 0x2a, 0xe4, 0x64, 0x85, 0xf9, 0x54, 0x52, 0x89, 0x81, 0x22, 0x22, 0x97, 0xd0, 0x92, 0x2b, + 0xb5, 0x96, 0x3e, 0x6f, 0xcf, 0x7b, 0xf1, 0x20, 0x4a, 0x2d, 0xe8, 0x82, 0x25, 0x18, 0x94, 0x2c, + 0xe4, 0x14, 0x0e, 0x39, 0xc6, 0x98, 0xde, 0x60, 0xc4, 0x78, 0xba, 0x48, 0x73, 0xba, 0x8c, 0x12, + 0x14, 0x32, 0xa2, 0x49, 0xc2, 0x51, 0x28, 0x73, 0x8c, 0xa1, 0x15, 0x1c, 0x94, 0xa0, 0x49, 0x89, + 0x19, 0xa3, 0x90, 0xa7, 0x05, 0x82, 0x1c, 0x43, 0x77, 0x96, 0xe6, 0x49, 0xdd, 0xa1, 0xfe, 0xbd, + 0x6e, 0xd0, 0x51, 0xb9, 0x0a, 0xf2, 0x1f, 0xb4, 0x35, 0x44, 0xed, 0xa6, 0xbd, 0xd9, 0x0d, 0x2c, + 0x95, 0xf0, 0x19, 0x97, 0x83, 0xe7, 0x60, 0xdf, 0xd6, 0x4a, 0x2c, 0x68, 0x9e, 0x8a, 0x73, 0x61, + 0x6f, 0x11, 0x80, 0xd6, 0xab, 0x9c, 0xce, 0x96, 0x68, 0x1b, 0xa4, 0x03, 0x3b, 0xe3, 0x54, 0xe8, + 0x8f, 0xc6, 0xc0, 0x05, 0xf8, 0xa5, 0x87, 0xec, 0x80, 0x39, 0x99, 0xcf, 0x0b, 0x7c, 0x91, 0xb6, + 0x0d, 0xd2, 0x05, 0x2b, 0xc0, 0x24, 0xe5, 0x18, 0x4b, 0xbb, 0xf1, 0xe4, 0x0a, 0xf6, 0x7f, 0x7b, + 0x47, 0xaa, 0x2f, 0x1c, 0xf9, 0xf6, 0x96, 0x0a, 0xde, 0x8f, 0x7d, 0xdb, 0x50, 0xa3, 0x2f, 0xde, + 0x8e, 0x7c, 0xbb, 0x41, 0x76, 0xa1, 0xfd, 0x11, 0x67, 0xc5, 0x05, 0x6d, 0x53, 0x15, 0xde, 0x84, + 0xa1, 0x6f, 0x37, 0x89, 0x0d, 0xdd, 0x31, 0xcb, 0x68, 0x9a, 0x97, 0xb5, 0xed, 0xb3, 0x09, 0x1c, + 0xc7, 0x2c, 0xbb, 0xdb, 0x0b, 0xdf, 0xb8, 0xb2, 0xaa, 0xf8, 0x6b, 0xe3, 0xf0, 0x83, 0x17, 0xd0, + 0xb5, 0x33, 0x52, 0xd8, 0x7a, 0x2d, 0xe7, 0xbc, 0xac, 0xcf, 0x5a, 0xfa, 0xe9, 0x3e, 0xfb, 0x19, + 0x00, 0x00, 0xff, 0xff, 0xce, 0xe9, 0xc8, 0x20, 0x89, 0x05, 0x00, 0x00, } diff --git a/transport/internet/config.proto b/transport/internet/config.proto index 27d5a772af1..d650c35a021 100644 --- a/transport/internet/config.proto +++ b/transport/internet/config.proto @@ -83,4 +83,8 @@ message SocketConfig { // ReceiveOriginalDestAddress is for enabling IP_RECVORIGDSTADDR socket option. // This option is for UDP only. bool receive_original_dest_address = 4; + + bytes bind_address = 5; + + uint32 bind_port = 6; } diff --git a/transport/internet/context.go b/transport/internet/context.go deleted file mode 100644 index afd5ccd5cbd..00000000000 --- a/transport/internet/context.go +++ /dev/null @@ -1,37 +0,0 @@ -package internet - -import ( - "context" - - "v2ray.com/core/common/net" -) - -type key int - -const ( - streamSettingsKey key = iota - bindAddrKey -) - -func ContextWithStreamSettings(ctx context.Context, streamSettings *MemoryStreamConfig) context.Context { - return context.WithValue(ctx, streamSettingsKey, streamSettings) -} - -func StreamSettingsFromContext(ctx context.Context) *MemoryStreamConfig { - ss := ctx.Value(streamSettingsKey) - if ss == nil { - return nil - } - return ss.(*MemoryStreamConfig) -} - -func ContextWithBindAddress(ctx context.Context, dest net.Destination) context.Context { - return context.WithValue(ctx, bindAddrKey, dest) -} - -func BindAddressFromContext(ctx context.Context) net.Destination { - if addr, ok := ctx.Value(bindAddrKey).(net.Destination); ok { - return addr - } - return net.Destination{} -} diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 4e3002ff08d..9ac0b3a6a19 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -17,7 +17,7 @@ type Dialer interface { } // dialFunc is an interface to dial network connection to a specific destination. -type dialFunc func(ctx context.Context, dest net.Destination) (Connection, error) +type dialFunc func(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error) var ( transportDialerCache = make(map[string]dialFunc) @@ -33,16 +33,14 @@ func RegisterTransportDialer(protocol string, dialer dialFunc) error { } // Dial dials a internet connection towards the given destination. -func Dial(ctx context.Context, dest net.Destination) (Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStreamConfig) (Connection, error) { if dest.Network == net.Network_TCP { - streamSettings := StreamSettingsFromContext(ctx) if streamSettings == nil { s, err := ToMemoryStreamConfig(nil) if err != nil { return nil, newError("failed to create default stream settings").Base(err) } streamSettings = s - ctx = ContextWithStreamSettings(ctx, streamSettings) } protocol := streamSettings.ProtocolName @@ -50,7 +48,7 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) { if dialer == nil { return nil, newError(protocol, " dialer not registered").AtError() } - return dialer(ctx, dest) + return dialer(ctx, dest, streamSettings) } if dest.Network == net.Network_UDP { @@ -58,17 +56,17 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) { if udpDialer == nil { return nil, newError("UDP dialer not registered").AtError() } - return udpDialer(ctx, dest) + return udpDialer(ctx, dest, streamSettings) } return nil, newError("unknown network ", dest.Network) } // DialSystem calls system dialer to create a network connection. -func DialSystem(ctx context.Context, dest net.Destination) (net.Conn, error) { +func DialSystem(ctx context.Context, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { var src net.Address if outbound := session.OutboundFromContext(ctx); outbound != nil { src = outbound.Gateway } - return effectiveSystemDialer.Dial(ctx, src, dest) + return effectiveSystemDialer.Dial(ctx, src, dest, sockopt) } diff --git a/transport/internet/dialer_test.go b/transport/internet/dialer_test.go index 42a02c9bc5f..f0cc52049c6 100644 --- a/transport/internet/dialer_test.go +++ b/transport/internet/dialer_test.go @@ -18,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) { assert(err, IsNil) defer server.Close() - conn, err := DialSystem(context.Background(), net.TCPDestination(net.LocalHostIP, dest.Port)) + conn, err := DialSystem(context.Background(), net.TCPDestination(net.LocalHostIP, dest.Port), nil) assert(err, IsNil) assert(conn.RemoteAddr().String(), Equals, "127.0.0.1:"+dest.Port.String()) conn.Close() diff --git a/transport/internet/domainsocket/dial.go b/transport/internet/domainsocket/dial.go index c4e7a301438..e069c7e5137 100644 --- a/transport/internet/domainsocket/dial.go +++ b/transport/internet/domainsocket/dial.go @@ -12,20 +12,8 @@ import ( "v2ray.com/core/transport/internet/tls" ) -func getSettingsFromContext(ctx context.Context) *Config { - rawSettings := internet.StreamSettingsFromContext(ctx) - if rawSettings == nil { - return nil - } - return rawSettings.ProtocolSettings.(*Config) -} - -func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { - settings := getSettingsFromContext(ctx) - if settings == nil { - return nil, newError("domain socket settings is not specified.").AtError() - } - +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { + settings := streamSettings.ProtocolSettings.(*Config) addr, err := settings.GetUnixAddr() if err != nil { return nil, err @@ -36,7 +24,7 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error return nil, newError("failed to dial unix: ", settings.Path).Base(err).AtWarning() } - if config := tls.ConfigFromContext(ctx); config != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { return tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest))), nil } diff --git a/transport/internet/domainsocket/listener.go b/transport/internet/domainsocket/listener.go index 0bfd60203f2..f93402c46cf 100644 --- a/transport/internet/domainsocket/listener.go +++ b/transport/internet/domainsocket/listener.go @@ -25,12 +25,8 @@ type Listener struct { locker *fileLocker } -func Listen(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { - settings := getSettingsFromContext(ctx) - if settings == nil { - return nil, newError("domain socket settings not specified.") - } - +func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { + settings := streamSettings.ProtocolSettings.(*Config) addr, err := settings.GetUnixAddr() if err != nil { return nil, err @@ -58,7 +54,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int } } - if config := tls.ConfigFromContext(ctx); config != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { ln.tlsConfig = config.GetTLSConfig() } diff --git a/transport/internet/domainsocket/listener_test.go b/transport/internet/domainsocket/listener_test.go index b6c64c06643..0095844944c 100644 --- a/transport/internet/domainsocket/listener_test.go +++ b/transport/internet/domainsocket/listener_test.go @@ -18,13 +18,14 @@ import ( func TestListen(t *testing.T) { assert := With(t) - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "domainsocket", ProtocolSettings: &Config{ Path: "/tmp/ts3", }, - }) - listener, err := Listen(ctx, nil, net.Port(0), func(conn internet.Connection) { + } + listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) { defer conn.Close() b := buf.New() @@ -36,7 +37,7 @@ func TestListen(t *testing.T) { assert(err, IsNil) defer listener.Close() - conn, err := Dial(ctx, net.Destination{}) + conn, err := Dial(ctx, net.Destination{}, streamSettings) assert(err, IsNil) defer conn.Close() @@ -56,14 +57,15 @@ func TestListenAbstract(t *testing.T) { assert := With(t) - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "domainsocket", ProtocolSettings: &Config{ Path: "/tmp/ts3", Abstract: true, }, - }) - listener, err := Listen(ctx, nil, net.Port(0), func(conn internet.Connection) { + } + listener, err := Listen(ctx, nil, net.Port(0), streamSettings, func(conn internet.Connection) { defer conn.Close() b := buf.New() @@ -75,7 +77,7 @@ func TestListenAbstract(t *testing.T) { assert(err, IsNil) defer listener.Close() - conn, err := Dial(ctx, net.Destination{}) + conn, err := Dial(ctx, net.Destination{}, streamSettings) assert(err, IsNil) defer conn.Close() diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index 7cdd5bf14f5..81147dae160 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -21,7 +21,7 @@ var ( globalDailerAccess sync.Mutex ) -func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, error) { +func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.Config) (*http.Client, error) { globalDailerAccess.Lock() defer globalDailerAccess.Unlock() @@ -33,11 +33,6 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err return client, nil } - config := tls.ConfigFromContext(ctx) - if config == nil { - return nil, newError("TLS must be enabled for http transport.").AtWarning() - } - transport := &http2.Transport{ DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { rawHost, rawPort, err := net.SplitHostPort(addr) @@ -53,13 +48,13 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err } address := net.ParseAddress(rawHost) - pconn, err := internet.DialSystem(context.Background(), net.TCPDestination(address, port)) + pconn, err := internet.DialSystem(context.Background(), net.TCPDestination(address, port), nil) if err != nil { return nil, err } return gotls.Client(pconn, tlsConfig), nil }, - TLSClientConfig: config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")), + TLSClientConfig: tlsSettings.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")), } client := &http.Client{ @@ -71,14 +66,13 @@ func getHTTPClient(ctx context.Context, dest net.Destination) (*http.Client, err } // Dial dials a new TCP connection to the given destination. -func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { - rawSettings := internet.StreamSettingsFromContext(ctx) - httpSettings, ok := rawSettings.ProtocolSettings.(*Config) - if !ok { - return nil, newError("HTTP config is not set.").AtError() +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { + httpSettings := streamSettings.ProtocolSettings.(*Config) + tlsConfig := tls.ConfigFromStreamSettings(streamSettings) + if tlsConfig == nil { + return nil, newError("TLS must be enabled for http transport.").AtWarning() } - - client, err := getHTTPClient(ctx, dest) + client, err := getHTTPClient(ctx, dest, tlsConfig) if err != nil { return nil, err } diff --git a/transport/internet/http/http_test.go b/transport/internet/http/http_test.go index 9d54b375f3f..d6a4972c451 100644 --- a/transport/internet/http/http_test.go +++ b/transport/internet/http/http_test.go @@ -22,16 +22,14 @@ func TestHTTPConnection(t *testing.T) { port := tcp.PickPort() - lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ ProtocolName: "http", ProtocolSettings: &Config{}, SecurityType: "tls", SecuritySettings: &tls.Config{ Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.v2ray.com")))}, }, - }) - - listener, err := Listen(lctx, net.LocalHostIP, port, func(conn internet.Connection) { + }, func(conn internet.Connection) { go func() { defer conn.Close() @@ -54,7 +52,8 @@ func TestHTTPConnection(t *testing.T) { time.Sleep(time.Second) - dctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + dctx := context.Background() + conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{ ProtocolName: "http", ProtocolSettings: &Config{}, SecurityType: "tls", @@ -63,7 +62,6 @@ func TestHTTPConnection(t *testing.T) { AllowInsecure: true, }, }) - conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port)) assert(err, IsNil) defer conn.Close() diff --git a/transport/internet/http/hub.go b/transport/internet/http/hub.go index c70458693f6..665c5f2ca2f 100644 --- a/transport/internet/http/hub.go +++ b/transport/internet/http/hub.go @@ -88,13 +88,8 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) <-done.Wait() } -func Listen(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { - rawSettings := internet.StreamSettingsFromContext(ctx) - httpSettings, ok := rawSettings.ProtocolSettings.(*Config) - if !ok { - return nil, newError("HTTP config is not set.").AtError() - } - +func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { + httpSettings := streamSettings.ProtocolSettings.(*Config) listener := &Listener{ handler: handler, local: &net.TCPAddr{ @@ -104,7 +99,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int config: *httpSettings, } - config := tls.ConfigFromContext(ctx) + config := tls.ConfigFromStreamSettings(streamSettings) if config == nil { return nil, newError("TLS must be enabled for http transport.").AtWarning() } @@ -120,7 +115,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int tcpListener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), - }) + }, streamSettings.SocketSettings) if err != nil { newError("failed to listen on", address, ":", port).Base(err).WriteToLog(session.ExportIDToError(ctx)) return diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 5d1547996c0..03df1fbe952 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -45,16 +45,16 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn } } -func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, error) { +func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { dest.Network = net.Network_UDP newError("dialing mKCP to ", dest).WriteToLog() - rawConn, err := internet.DialSystem(ctx, dest) + rawConn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) } - kcpSettings := internet.StreamSettingsFromContext(ctx).ProtocolSettings.(*Config) + kcpSettings := streamSettings.ProtocolSettings.(*Config) header, err := kcpSettings.GetPackerHeader() if err != nil { @@ -85,7 +85,7 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er var iConn internet.Connection = session - if config := v2tls.ConfigFromContext(ctx); config != nil { + if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest))) iConn = tlsConn } diff --git a/transport/internet/kcp/kcp_test.go b/transport/internet/kcp/kcp_test.go index 4e512b471ac..e655299d156 100644 --- a/transport/internet/kcp/kcp_test.go +++ b/transport/internet/kcp/kcp_test.go @@ -17,11 +17,10 @@ import ( func TestDialAndListen(t *testing.T) { assert := With(t) - lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + listerner, err := NewListener(context.Background(), net.LocalHostIP, net.Port(0), &internet.MemoryStreamConfig{ ProtocolName: "mkcp", ProtocolSettings: &Config{}, - }) - listerner, err := NewListener(lctx, net.LocalHostIP, net.Port(0), func(conn internet.Connection) { + }, func(conn internet.Connection) { go func(c internet.Connection) { payload := make([]byte, 4096) for { @@ -40,13 +39,12 @@ func TestDialAndListen(t *testing.T) { assert(err, IsNil) port := net.Port(listerner.Addr().(*net.UDPAddr).Port) - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ - ProtocolName: "mkcp", - ProtocolSettings: &Config{}, - }) wg := new(sync.WaitGroup) for i := 0; i < 10; i++ { - clientConn, err := DialKCP(ctx, net.UDPDestination(net.LocalHostIP, port)) + clientConn, err := DialKCP(context.Background(), net.UDPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{ + ProtocolName: "mkcp", + ProtocolSettings: &Config{}, + }) assert(err, IsNil) wg.Add(1) diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index e9c22c8f646..2416979fcff 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -33,10 +33,8 @@ type Listener struct { addConn internet.ConnHandler } -func NewListener(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (*Listener, error) { - networkSettings := internet.StreamSettingsFromContext(ctx) - kcpSettings := networkSettings.ProtocolSettings.(*Config) - +func NewListener(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (*Listener, error) { + kcpSettings := streamSettings.ProtocolSettings.(*Config) header, err := kcpSettings.GetPackerHeader() if err != nil { return nil, newError("failed to create packet header").Base(err).AtError() @@ -57,11 +55,11 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon addConn: addConn, } - if config := v2tls.ConfigFromContext(ctx); config != nil { + if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { l.tlsConfig = config.GetTLSConfig() } - hub, err := udp.ListenUDP(ctx, address, port, udp.HubCapacity(1024)) + hub, err := udp.ListenUDP(ctx, address, port, streamSettings, udp.HubCapacity(1024)) if err != nil { return nil, err } @@ -189,8 +187,8 @@ func (w *Writer) Close() error { return nil } -func ListenKCP(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { - return NewListener(ctx, address, port, addConn) +func ListenKCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { + return NewListener(ctx, address, port, streamSettings, addConn) } func init() { diff --git a/transport/internet/sockopt_darwin.go b/transport/internet/sockopt_darwin.go index 150692e016d..ccac238d978 100644 --- a/transport/internet/sockopt_darwin.go +++ b/transport/internet/sockopt_darwin.go @@ -2,8 +2,6 @@ package internet import ( "syscall" - - "v2ray.com/core/common/net" ) const ( @@ -49,6 +47,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) return nil } -func bindAddr(fd uintptr, address net.Address, port net.Port) error { +func bindAddr(fd uintptr, address []byte, port uint32) error { return nil } diff --git a/transport/internet/sockopt_linux.go b/transport/internet/sockopt_linux.go index b3f81b80138..6c5afddd030 100644 --- a/transport/internet/sockopt_linux.go +++ b/transport/internet/sockopt_linux.go @@ -1,9 +1,8 @@ package internet import ( + "net" "syscall" - - "v2ray.com/core/common/net" ) const ( @@ -13,7 +12,7 @@ const ( TCP_FASTOPEN_CONNECT = 30 ) -func bindAddr(fd uintptr, address net.Address, port net.Port) error { +func bindAddr(fd uintptr, ip []byte, port uint32) error { err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) if err != nil { return newError("failed to set resuse_addr").Base(err).AtWarning() @@ -21,21 +20,21 @@ func bindAddr(fd uintptr, address net.Address, port net.Port) error { var sockaddr syscall.Sockaddr - switch address.Family() { - case net.AddressFamilyIPv4: + switch len(ip) { + case net.IPv4len: a4 := &syscall.SockaddrInet4{ Port: int(port), } - copy(a4.Addr[:], address.IP()) + copy(a4.Addr[:], ip) sockaddr = a4 - case net.AddressFamilyIPv6: + case net.IPv6len: a6 := &syscall.SockaddrInet6{ Port: int(port), } - copy(a6.Addr[:], address.IP()) + copy(a6.Addr[:], ip) sockaddr = a6 default: - return newError("unsupported address family: ", address.Family()) + return newError("unexpected length of ip") } return syscall.Bind(int(fd), sockaddr) diff --git a/transport/internet/sockopt_other.go b/transport/internet/sockopt_other.go index acdd6b4220d..eca55cc3a82 100644 --- a/transport/internet/sockopt_other.go +++ b/transport/internet/sockopt_other.go @@ -2,8 +2,6 @@ package internet -import "v2ray.com/core/common/net" - func applyOutboundSocketOptions(network string, address string, fd uintptr, config *SocketConfig) error { return nil } @@ -12,6 +10,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) return nil } -func bindAddr(fd uintptr, address net.Address, port net.Port) error { +func bindAddr(fd uintptr, ip []byte, port uint32) error { return nil } diff --git a/transport/internet/sockopt_test.go b/transport/internet/sockopt_test.go index 2943277548f..11af684d558 100644 --- a/transport/internet/sockopt_test.go +++ b/transport/internet/sockopt_test.go @@ -17,22 +17,15 @@ func TestTCPFastOpen(t *testing.T) { return b }, } - dest, err := tcpServer.StartContext(ContextWithStreamSettings(context.Background(), &MemoryStreamConfig{ - SocketSettings: &SocketConfig{ - Tfo: SocketConfig_Enable, - }, - })) + dest, err := tcpServer.StartContext(context.Background(), &SocketConfig{Tfo: SocketConfig_Enable}) common.Must(err) defer tcpServer.Close() ctx := context.Background() - ctx = ContextWithStreamSettings(ctx, &MemoryStreamConfig{ - SocketSettings: &SocketConfig{ - Tfo: SocketConfig_Enable, - }, - }) dialer := DefaultSystemDialer{} - conn, err := dialer.Dial(ctx, nil, dest) + conn, err := dialer.Dial(ctx, nil, dest, &SocketConfig{ + Tfo: SocketConfig_Enable, + }) common.Must(err) defer conn.Close() diff --git a/transport/internet/sockopt_windows.go b/transport/internet/sockopt_windows.go index 28ba91f4eb5..f47aa07a7b6 100644 --- a/transport/internet/sockopt_windows.go +++ b/transport/internet/sockopt_windows.go @@ -2,8 +2,6 @@ package internet import ( "syscall" - - "v2ray.com/core/common/net" ) const ( @@ -45,6 +43,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) return nil } -func bindAddr(fd uintptr, address net.Address, port net.Port) error { +func bindAddr(fd uintptr, ip []byte, port uint32) error { return nil } diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index bff0e7e168c..17af58e4ff6 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -14,38 +14,27 @@ var ( ) type SystemDialer interface { - Dial(ctx context.Context, source net.Address, destination net.Destination) (net.Conn, error) + Dial(ctx context.Context, source net.Address, destination net.Destination, sockopt *SocketConfig) (net.Conn, error) } type DefaultSystemDialer struct { } -func getSocketSettings(ctx context.Context) *SocketConfig { - streamSettings := StreamSettingsFromContext(ctx) - if streamSettings != nil && streamSettings.SocketSettings != nil { - return streamSettings.SocketSettings - } - - return nil -} - -func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) { +func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { dialer := &net.Dialer{ Timeout: time.Second * 60, DualStack: true, } - sockopts := getSocketSettings(ctx) - if sockopts != nil { - bindAddress := BindAddressFromContext(ctx) + if sockopt != nil { dialer.Control = func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { - if err := applyOutboundSocketOptions(network, address, fd, sockopts); err != nil { + if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil { newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx)) } - if dest.Network == net.Network_UDP && bindAddress.IsValid() { - if err := bindAddr(fd, bindAddress.Address, bindAddress.Port); err != nil { - newError("failed to bind source address to ", bindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx)) + if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 { + if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil { + newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx)) } } }) @@ -84,7 +73,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer { } } -func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) { +func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr()) } diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index 9c894072a32..364d2d401fd 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -14,10 +14,9 @@ var ( type DefaultListener struct{} -func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener, error) { +func (*DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) { var lc net.ListenConfig - sockopt := getSocketSettings(ctx) if sockopt != nil { lc.Control = func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { @@ -31,10 +30,9 @@ func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener return lc.Listen(ctx, addr.Network(), addr.String()) } -func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { +func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) { var lc net.ListenConfig - sockopt := getSocketSettings(ctx) if sockopt != nil { lc.Control = func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index c445dbf5cb4..f29fc1e4cb7 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -10,28 +10,20 @@ import ( "v2ray.com/core/transport/internet/tls" ) -func getTCPSettingsFromContext(ctx context.Context) *Config { - rawTCPSettings := internet.StreamSettingsFromContext(ctx) - if rawTCPSettings == nil { - return nil - } - return rawTCPSettings.ProtocolSettings.(*Config) -} - // Dial dials a new TCP connection to the given destination. -func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { newError("dialing TCP to ", dest).WriteToLog(session.ExportIDToError(ctx)) - conn, err := internet.DialSystem(ctx, dest) + conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { return nil, err } - if config := tls.ConfigFromContext(ctx); config != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2"))) } - tcpSettings := getTCPSettingsFromContext(ctx) - if tcpSettings != nil && tcpSettings.HeaderSettings != nil { + tcpSettings := streamSettings.ProtocolSettings.(*Config) + if tcpSettings.HeaderSettings != nil { headerConfig, err := tcpSettings.HeaderSettings.GetInstance() if err != nil { return nil, newError("failed to get header settings").Base(err).AtError() diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 0221803f24b..1df608414ad 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -22,25 +22,24 @@ type Listener struct { } // ListenTCP creates a new Listener based on configurations. -func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { +func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), - }) + }, streamSettings.SocketSettings) if err != nil { return nil, err } newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) - tcpSettings := getTCPSettingsFromContext(ctx) - + tcpSettings := streamSettings.ProtocolSettings.(*Config) l := &Listener{ listener: listener, config: tcpSettings, addConn: handler, } - if config := tls.ConfigFromContext(ctx); config != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2")) } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 598a8034f2d..888dec82490 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -20,22 +20,20 @@ func RegisterTransportListener(protocol string, listener ListenFunc) error { type ConnHandler func(Connection) -type ListenFunc func(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) +type ListenFunc func(ctx context.Context, address net.Address, port net.Port, settings *MemoryStreamConfig, handler ConnHandler) (Listener, error) type Listener interface { Close() error Addr() net.Addr } -func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) { - settings := StreamSettingsFromContext(ctx) +func ListenTCP(ctx context.Context, address net.Address, port net.Port, settings *MemoryStreamConfig, handler ConnHandler) (Listener, error) { if settings == nil { s, err := ToMemoryStreamConfig(nil) if err != nil { return nil, newError("failed to create default stream settings").Base(err) } settings = s - ctx = ContextWithStreamSettings(ctx, settings) } if address.Family().IsDomain() && address.Domain() == "localhost" { @@ -47,17 +45,17 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler if listenFunc == nil { return nil, newError(protocol, " listener not registered.").AtError() } - listener, err := listenFunc(ctx, address, port, handler) + listener, err := listenFunc(ctx, address, port, settings, handler) if err != nil { return nil, newError("failed to listen on address: ", address, ":", port).Base(err) } return listener, nil } -func ListenSystem(ctx context.Context, addr net.Addr) (net.Listener, error) { - return effectiveListener.Listen(ctx, addr) +func ListenSystem(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) { + return effectiveListener.Listen(ctx, addr, sockopt) } -func ListenSystemPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { - return effectiveListener.ListenPacket(ctx, addr) +func ListenSystemPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) { + return effectiveListener.ListenPacket(ctx, addr, sockopt) } diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 9cf08a77615..8125476f499 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -1,7 +1,6 @@ package tls import ( - "context" "crypto/tls" "crypto/x509" "sync" @@ -215,13 +214,12 @@ func WithNextProto(protocol ...string) Option { } } -// ConfigFromContext fetches Config from context. Nil if not found. -func ConfigFromContext(ctx context.Context) *Config { - streamSettings := internet.StreamSettingsFromContext(ctx) - if streamSettings == nil { +// ConfigFromStreamSettings fetches Config from stream settings. Nil if not found. +func ConfigFromStreamSettings(settings *internet.MemoryStreamConfig) *Config { + if settings == nil { return nil } - config, ok := streamSettings.SecuritySettings.(*Config) + config, ok := settings.SecuritySettings.(*Config) if !ok { return nil } diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index 0fdb3bdc71c..5ef5a4c0ba5 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -10,8 +10,12 @@ import ( func init() { common.Must(internet.RegisterTransportDialer(protocolName, - func(ctx context.Context, dest net.Destination) (internet.Connection, error) { - conn, err := internet.DialSystem(ctx, dest) + func(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { + var sockopt *internet.SocketConfig + if streamSettings != nil { + sockopt = streamSettings.SocketSettings + } + conn, err := internet.DialSystem(ctx, dest, sockopt) if err != nil { return nil, err } diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index f62f598e792..676d99d5aea 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -36,7 +36,7 @@ type Hub struct { recvOrigDest bool } -func ListenUDP(ctx context.Context, address net.Address, port net.Port, options ...HubOption) (*Hub, error) { +func ListenUDP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, options ...HubOption) (*Hub, error) { hub := &Hub{ capacity: 256, recvOrigDest: false, @@ -45,15 +45,18 @@ func ListenUDP(ctx context.Context, address net.Address, port net.Port, options opt(hub) } - streamSettings := internet.StreamSettingsFromContext(ctx) - if streamSettings != nil && streamSettings.SocketSettings != nil && streamSettings.SocketSettings.ReceiveOriginalDestAddress { + var sockopt *internet.SocketConfig + if streamSettings != nil { + sockopt = streamSettings.SocketSettings + } + if sockopt != nil && sockopt.ReceiveOriginalDestAddress { hub.recvOrigDest = true } udpConn, err := internet.ListenSystemPacket(ctx, &net.UDPAddr{ IP: address.IP(), Port: int(port), - }) + }, sockopt) if err != nil { return nil, err } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 0cb2d06908f..fc2f7e6caed 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -14,10 +14,10 @@ import ( ) // Dial dials a WebSocket connection to the given destination. -func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) { +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) - conn, err := dialWebsocket(ctx, dest) + conn, err := dialWebsocket(ctx, dest, streamSettings) if err != nil { return nil, newError("failed to dial WebSocket").Base(err) } @@ -28,12 +28,12 @@ func init() { common.Must(internet.RegisterTransportDialer(protocolName, Dial)) } -func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) { - wsSettings := internet.StreamSettingsFromContext(ctx).ProtocolSettings.(*Config) +func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { + wsSettings := streamSettings.ProtocolSettings.(*Config) dialer := &websocket.Dialer{ NetDial: func(network, addr string) (net.Conn, error) { - return internet.DialSystem(ctx, dest) + return internet.DialSystem(ctx, dest, streamSettings.SocketSettings) }, ReadBufferSize: 4 * 1024, WriteBufferSize: 4 * 1024, @@ -42,7 +42,7 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) protocol := "ws" - if config := tls.ConfigFromContext(ctx); config != nil { + if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { protocol = "wss" dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest)) } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 5dcb12364df..762c960bcfb 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -55,16 +55,15 @@ type Listener struct { addConn internet.ConnHandler } -func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { - networkSettings := internet.StreamSettingsFromContext(ctx) - wsSettings := networkSettings.ProtocolSettings.(*Config) +func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { + wsSettings := streamSettings.ProtocolSettings.(*Config) var tlsConfig *tls.Config - if config := v2tls.ConfigFromContext(ctx); config != nil { + if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { tlsConfig = config.GetTLSConfig() } - listener, err := listenTCP(ctx, address, port, tlsConfig) + listener, err := listenTCP(ctx, address, port, tlsConfig, streamSettings.SocketSettings) if err != nil { return nil, err } @@ -84,11 +83,11 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i return l, err } -func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config) (net.Listener, error) { +func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (net.Listener, error) { listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), - }) + }, sockopt) if err != nil { return nil, newError("failed to listen TCP on", address, ":", port).Base(err) } diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index 1a4b406edb0..f0793f80f07 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -18,13 +18,12 @@ import ( func Test_listenWSAndDial(t *testing.T) { assert := With(t) - lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{ ProtocolName: "websocket", ProtocolSettings: &Config{ Path: "ws", }, - }) - listen, err := ListenWS(lctx, net.LocalHostIP, 13146, func(conn internet.Connection) { + }, func(conn internet.Connection) { go func(c internet.Connection) { defer c.Close() @@ -42,11 +41,12 @@ func Test_listenWSAndDial(t *testing.T) { }) assert(err, IsNil) - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + ctx := context.Background() + streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "websocket", ProtocolSettings: &Config{Path: "ws"}, - }) - conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) + } + conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings) assert(err, IsNil) _, err = conn.Write([]byte("Test connection 1")) @@ -59,7 +59,7 @@ func Test_listenWSAndDial(t *testing.T) { assert(conn.Close(), IsNil) <-time.After(time.Second * 5) - conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) + conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings) assert(err, IsNil) _, err = conn.Write([]byte("Test connection 2")) assert(err, IsNil) @@ -73,13 +73,12 @@ func Test_listenWSAndDial(t *testing.T) { func TestDialWithRemoteAddr(t *testing.T) { assert := With(t) - lctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{ ProtocolName: "websocket", ProtocolSettings: &Config{ Path: "ws", }, - }) - listen, err := ListenWS(lctx, net.LocalHostIP, 13148, func(conn internet.Connection) { + }, func(conn internet.Connection) { go func(c internet.Connection) { defer c.Close() @@ -99,11 +98,10 @@ func TestDialWithRemoteAddr(t *testing.T) { }) assert(err, IsNil) - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13148), &internet.MemoryStreamConfig{ ProtocolName: "websocket", ProtocolSettings: &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}}, }) - conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13148)) assert(err, IsNil) _, err = conn.Write([]byte("Test connection 1")) @@ -126,7 +124,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) { start := time.Now() - ctx := internet.ContextWithStreamSettings(context.Background(), &internet.MemoryStreamConfig{ + streamSettings := &internet.MemoryStreamConfig{ ProtocolName: "websocket", ProtocolSettings: &Config{ Path: "wss", @@ -136,9 +134,8 @@ func Test_listenWSAndDial_TLS(t *testing.T) { AllowInsecure: true, Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, }, - }) - - listen, err := ListenWS(ctx, net.LocalHostIP, 13143, func(conn internet.Connection) { + } + listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn internet.Connection) { go func() { _ = conn.Close() }() @@ -146,7 +143,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) { assert(err, IsNil) defer listen.Close() - conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13143)) + conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13143), streamSettings) assert(err, IsNil) _ = conn.Close()