diff --git a/cmd/dev.go b/cmd/dev.go index 37d171cb..4ddc2f1b 100644 --- a/cmd/dev.go +++ b/cmd/dev.go @@ -126,24 +126,35 @@ Examples: log.Fatal("failed to find available port: %s", err) } + connectProxyPort, _ := cmd.Flags().GetInt("proxy-port") + if connectProxyPort < 0 || connectProxyPort > 65535 { + log.Fatal("invalid --proxy-port: %d (must be 0-65535)", connectProxyPort) + } + var connectProxyPortPtr *uint + if connectProxyPort > 0 { + port := uint(connectProxyPort) + connectProxyPortPtr = &port + } + server, err := dev.New(dev.ServerArgs{ APIURL: apiUrl, APIKey: apiKey, Hostname: endpoint.Hostname, Config: &gravity.Config{ - Context: ctx, - Logger: log, - Version: Version, - OrgID: orgId, - Project: theproject, - EndpointID: endpoint.ID, - URL: gravityUrl, - SDKKey: project.Secrets["AGENTUITY_SDK_KEY"], - ProxyPort: uint(proxyPort), - AgentPort: uint(agentPort), - Ephemeral: true, - ClientName: "cli/devmode", - DynamicHostname: true, + Context: ctx, + Logger: log, + Version: Version, + OrgID: orgId, + Project: theproject, + EndpointID: endpoint.ID, + URL: gravityUrl, + SDKKey: project.Secrets["AGENTUITY_SDK_KEY"], + ProxyPort: uint(proxyPort), + AgentPort: uint(agentPort), + ConnectProxyPort: connectProxyPortPtr, + Ephemeral: true, + ClientName: "cli/devmode", + DynamicHostname: true, }, }) if err != nil { @@ -308,6 +319,7 @@ func init() { rootCmd.AddCommand(devCmd) devCmd.Flags().StringP("dir", "d", ".", "The directory to run the development server in") devCmd.Flags().Int("port", 0, "The port to run the development server on (uses project default if not provided)") + devCmd.Flags().Int("proxy-port", 19081, "The port to run the HTTP CONNECT proxy server on (disabled if zero)") devCmd.Flags().Bool("no-build", false, "Do not build the project before running it (useful for debugging)") devCmd.Flags().MarkHidden("no-build") } diff --git a/internal/gravity/gravity.go b/internal/gravity/gravity.go index 8b53d2d5..87b8671a 100644 --- a/internal/gravity/gravity.go +++ b/internal/gravity/gravity.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -36,67 +37,71 @@ import ( const ( nicID = 1 - mtu = 1500 + mtu = 1280 // Reduced to IPv6 minimum to avoid MTU blackhole issues ) type Client struct { - context context.Context - logger logger.Logger - version string - orgID string - projectID string - project project.ProjectContext - endpointID string - url string - sdkKey string - proxyPort uint - agentPort uint - ephemeral bool - clientname string - dynamicHostname bool - dynamicProject bool - server *http.Server - client *gravity.GravityClient - once sync.Once - stack *stack.Stack - endpoint *channel.Endpoint - provider *cliProvider + context context.Context + logger logger.Logger + version string + orgID string + projectID string + project project.ProjectContext + endpointID string + url string + sdkKey string + proxyPort uint + agentPort uint + connectProxyPort *uint + ephemeral bool + clientname string + dynamicHostname bool + dynamicProject bool + server *http.Server + connectProxy *http.Server + client *gravity.GravityClient + once sync.Once + stack *stack.Stack + endpoint *channel.Endpoint + provider *cliProvider } type Config struct { - Context context.Context - Logger logger.Logger - Version string // of the cli - OrgID string - Project project.ProjectContext - EndpointID string - URL string - SDKKey string - ProxyPort uint - AgentPort uint - Ephemeral bool - ClientName string - DynamicHostname bool - DynamicProject bool + Context context.Context + Logger logger.Logger + Version string // of the cli + OrgID string + Project project.ProjectContext + EndpointID string + URL string + SDKKey string + ProxyPort uint + AgentPort uint + ConnectProxyPort *uint + Ephemeral bool + ClientName string + DynamicHostname bool + DynamicProject bool } func New(config Config) *Client { return &Client{ - context: config.Context, - logger: config.Logger, - version: config.Version, - orgID: config.OrgID, - projectID: config.Project.Project.ProjectId, - project: config.Project, - endpointID: config.EndpointID, - url: config.URL, - sdkKey: config.SDKKey, - ephemeral: config.Ephemeral, - proxyPort: config.ProxyPort, - agentPort: config.AgentPort, - clientname: config.ClientName, - dynamicHostname: config.DynamicHostname, - dynamicProject: config.DynamicProject, + context: config.Context, + logger: config.Logger, + version: config.Version, + orgID: config.OrgID, + projectID: config.Project.Project.ProjectId, + project: config.Project, + endpointID: config.EndpointID, + url: config.URL, + sdkKey: config.SDKKey, + ephemeral: config.Ephemeral, + proxyPort: config.ProxyPort, + agentPort: config.AgentPort, + connectProxyPort: config.ConnectProxyPort, + clientname: config.ClientName, + dynamicHostname: config.DynamicHostname, + dynamicProject: config.DynamicProject, } } @@ -162,6 +167,198 @@ func (c *Client) bridgeToLocalTLS(remote *gonet.TCPConn) { logger.Trace("bridgeToLocalTLS: local server -> netstack finished (copied %d bytes, err: %v)", n, err) } +// handleConnect handles HTTP CONNECT requests and bridges them through the netstack. +func (c *Client) handleConnect(w http.ResponseWriter, r *http.Request) { + logger := c.logger + logger.Trace("CONNECT request: %s", r.Host) + + if r.Method != http.MethodConnect { + logger.Debug("non-CONNECT request rejected: %s %s", r.Method, r.URL.Path) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse host and port + host, portStr, err := net.SplitHostPort(r.Host) + if err != nil { + logger.Error("invalid CONNECT target: %s (%v)", r.Host, err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + port, err := net.LookupPort("tcp", portStr) + if err != nil { + logger.Error("invalid port in CONNECT target: %s (%v)", portStr, err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Check if this is an agentuity domain + agentuityDomains := []string{".agentuity.io", ".agentuity.cloud", ".agentuity.run", ".agentuity.com", "agentuity.ai"} + isAgentuityDomain := false + for _, domain := range agentuityDomains { + if strings.HasSuffix(host, domain) || host == domain[1:] { + isAgentuityDomain = true + break + } + } + + var remoteConn net.Conn + + if isAgentuityDomain { + // Route through gravity tunnel + logger.Trace("CONNECT to %s via gravity tunnel", host) + + ip, ok := cnet.Addresses["catalyst"] + if !ok || ip == nil { + logger.Error("catalyst address not found in address map") + http.Error(w, "Service configuration error", http.StatusServiceUnavailable) + return + } + + if strings.HasSuffix(host, ".agentuity.cloud") { + part := strings.Split(host, ".agentuity.cloud")[0] + if customip, ok := cnet.Addresses[part]; ok && customip != nil { + ip = customip + } + } else if strings.HasSuffix(host, ".agentuity.io") { + part := strings.Split(host, ".agentuity.io")[0] + if customip, ok := cnet.Addresses[part]; ok && customip != nil { + ip = customip + } + } + + var protocolNumber tcpip.NetworkProtocolNumber + var addr tcpip.Address + + if ip4 := ip.To4(); ip4 != nil { + protocolNumber = ipv4.ProtocolNumber + var addr4 [4]byte + copy(addr4[:], ip4) + addr = tcpip.AddrFrom4(addr4) + } else { + protocolNumber = ipv6.ProtocolNumber + ip6 := ip.To16() + var addr6 [16]byte + copy(addr6[:], ip6) + addr = tcpip.AddrFrom16(addr6) + } + + // Dial through netstack + fullAddr := tcpip.FullAddress{ + Addr: addr, + Port: uint16(port), + } + + remoteConn, err = gonet.DialTCP(c.stack, fullAddr, protocolNumber) + if err != nil { + logger.Error("failed to dial %s: %v", host, err) + http.Error(w, "Connection failed", http.StatusBadGateway) + return + } + } else { + // Direct connection for non-agentuity domains + logger.Trace("CONNECT to %s directly", host) + remoteConn, err = net.Dial("tcp", net.JoinHostPort(host, portStr)) + if err != nil { + logger.Error("failed to dial %s: %v", host, err) + http.Error(w, "Connection failed", http.StatusBadGateway) + return + } + } + + // Hijack the client connection + hijacker, ok := w.(http.Hijacker) + if !ok { + logger.Error("ResponseWriter does not support hijacking") + remoteConn.Close() + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + clientConn, _, err := hijacker.Hijack() + if err != nil { + logger.Error("failed to hijack connection: %v", err) + remoteConn.Close() + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Send success response with dynamic HTTP version + response := fmt.Sprintf("%s 200 Connection Established\r\n\r\n", r.Proto) + if _, err := clientConn.Write([]byte(response)); err != nil { + logger.Error("failed to send 200 response: %v", err) + clientConn.Close() + remoteConn.Close() + return + } + + logger.Debug("proxying CONNECT for %s", r.Host) + + // Bidirectional copy + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + n, err := io.Copy(remoteConn, clientConn) + if err != nil { + logger.Trace("client -> remote error: %v", err) + } + logger.Trace("client -> remote: %d bytes", n) + // Close write side if possible + if tcpConn, ok := remoteConn.(interface{ CloseWrite() error }); ok { + tcpConn.CloseWrite() + } + }() + + go func() { + defer wg.Done() + n, err := io.Copy(clientConn, remoteConn) + if err != nil { + logger.Trace("remote -> client error: %v", err) + } + logger.Trace("remote -> client: %d bytes", n) + }() + + wg.Wait() + + clientConn.Close() + remoteConn.Close() + logger.Debug("CONNECT session completed for %s", r.Host) +} + +// startConnectProxy starts the HTTP CONNECT proxy server if configured. +func (c *Client) startConnectProxy() error { + if c.connectProxyPort == nil { + return nil + } + + logger := c.logger + port := *c.connectProxyPort + + logger.Debug("starting CONNECT proxy on port %d", port) + + server := &http.Server{ + Addr: fmt.Sprintf("127.0.0.1:%d", port), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.handleConnect(w, r) + }), + ReadTimeout: 0, + WriteTimeout: 0, + } + c.connectProxy = server + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Fatal("CONNECT proxy server failed: %v", err) + } + }() + + logger.Debug("CONNECT proxy listening on http://127.0.0.1:%d", port) + return nil +} + // Close will close the client and all the associated services. func (c *Client) Close() error { var err error @@ -186,6 +383,14 @@ func (c *Client) cleanup() error { err = c.server.Shutdown(ctx) c.server = nil } + if c.connectProxy != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if shutdownErr := c.connectProxy.Shutdown(ctx); shutdownErr != nil && err == nil { + err = shutdownErr + } + c.connectProxy = nil + } if c.endpoint != nil { c.endpoint.Close() c.endpoint = nil @@ -452,7 +657,7 @@ func (c *Client) Start() error { // Create netstack. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) c.stack = s @@ -463,30 +668,57 @@ func (c *Client) Start() error { return fmt.Errorf("failed to create virtual NIC: %s", err) } c.endpoint = linkEP + + // Add IPv6 address ipBytes := net.ParseIP(ipv6Address.String()).To16() - var addr [16]byte - copy(addr[:], ipBytes) + var addr6 [16]byte + copy(addr6[:], ipBytes) if err := s.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFrom16(addr), + Address: tcpip.AddrFrom16(addr6), PrefixLen: 64, }, }, stack.AddressProperties{}, ); err != nil { - return fmt.Errorf("failed to create protocol address: %s", err) + return fmt.Errorf("failed to create IPv6 protocol address: %s", err) } - // Add default route - subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16))) + // Add IPv4 address + ipv4Bytes := net.ParseIP(ipv4addr).To4() + var addr4 [4]byte + copy(addr4[:], ipv4Bytes) + if err := s.AddProtocolAddress(nicID, + tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom4(addr4), + PrefixLen: 24, + }, + }, + stack.AddressProperties{}, + ); err != nil { + return fmt.Errorf("failed to create IPv4 protocol address: %s", err) + } + + // Add default routes for both IPv4 and IPv6 + subnet4, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) + if err != nil { + return fmt.Errorf("failed to create IPv4 subnet: %w", err) + } + subnet6, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16))) if err != nil { - return fmt.Errorf("failed to create subnet: %w", err) + return fmt.Errorf("failed to create IPv6 subnet: %w", err) } s.SetRouteTable([]tcpip.Route{ { - Destination: subnet, + Destination: subnet4, + NIC: nicID, + }, + { + Destination: subnet6, NIC: nicID, }, }) @@ -594,6 +826,11 @@ func (c *Client) Start() error { break } + // Start CONNECT proxy if configured + if err := c.startConnectProxy(); err != nil { + return fmt.Errorf("failed to start CONNECT proxy: %w", err) + } + go func() { log.Debug("waiting on provider disconnect") client.Disconnected(c.context) diff --git a/internal/gravity/provider.go b/internal/gravity/provider.go index b3450ff3..83df128c 100644 --- a/internal/gravity/provider.go +++ b/internal/gravity/provider.go @@ -7,7 +7,10 @@ import ( "github.com/agentuity/go-common/gravity/provider" "github.com/agentuity/go-common/logger" "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -88,9 +91,38 @@ func (p *cliProvider) ProcessInPacket(payload []byte) { if p.ep == nil { return } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{}) - view := buffer.NewView(len(payload)) - view.Write(payload) - pkt.Data().AppendView(view) - p.ep.InjectInbound(ipv6.ProtocolNumber, pkt) + + if len(payload) < 1 { + return + } + + // Detect IP version from the packet header + version := header.IPVersion(payload) + var protocol tcpip.NetworkProtocolNumber + + switch version { + case 4: + if len(payload) < header.IPv4MinimumSize { + p.logger.Trace("dropping IPv4 packet: too short (%d bytes, need at least %d)", len(payload), header.IPv4MinimumSize) + return + } + protocol = ipv4.ProtocolNumber + case 6: + if len(payload) < header.IPv6MinimumSize { + p.logger.Trace("dropping IPv6 packet: too short (%d bytes, need at least %d)", len(payload), header.IPv6MinimumSize) + return + } + protocol = ipv6.ProtocolNumber + default: + p.logger.Trace("dropping packet: unknown IP version %d", version) + return + } + + // Create packet buffer with proper payload and cleanup + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(append([]byte(nil), payload...)), + }) + defer pkt.DecRef() + + p.ep.InjectInbound(protocol, pkt) }