Skip to content

Commit

Permalink
Merge pull request #836 from afbjorklund/reverse
Browse files Browse the repository at this point in the history
Enable reverse forwarding of portForwards
  • Loading branch information
AkihiroSuda authored May 10, 2022
2 parents 051fc82 + 3ab90da commit 9391621
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 14 deletions.
2 changes: 2 additions & 0 deletions examples/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ networks:
#
# - guestSocket: "/run/user/{{.UID}}/my.sock"
# hostSocket: mysocket
# # default: reverse: false
# # "guestSocket" can include these template variables: {{.Home}}, {{.UID}}, and {{.User}}.
# # "hostSocket" can include {{.Home}}, {{.Dir}}, {{.Name}}, {{.UID}}, and {{.User}}.
# # "reverse" can only be used for unix sockets right now, not for tcp sockets.
# # Put sockets into "{{.Dir}}/sock" to avoid collision with Lima internal sockets!
# # Sockets can also be forwarded to ports and vice versa, but not to/from a range of ports.
# # Forwarding requires the lima user to have rw access to the "guestsocket",
Expand Down
34 changes: 26 additions & 8 deletions pkg/hostagent/hostagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
for _, rule := range a.y.PortForwards {
if rule.GuestSocket != "" {
local := hostAddress(rule, guestagentapi.IPPort{})
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbForward)
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbForward, rule.Reverse)
}
}

Expand All @@ -437,20 +437,20 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
if rule.GuestSocket != "" {
local := hostAddress(rule, guestagentapi.IPPort{})
// using ctx.Background() because ctx has already been cancelled
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbCancel); err != nil {
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbCancel, rule.Reverse); err != nil {
mErr = multierror.Append(mErr, err)
}
}
}
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbCancel); err != nil {
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbCancel, false); err != nil {
mErr = multierror.Append(mErr, err)
}
return mErr
})

for {
if !isGuestAgentSocketAccessible(ctx, localUnix) {
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward)
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward, false)
}
if err := a.processGuestAgentEvents(ctx, localUnix); err != nil {
if !errors.Is(err, context.Canceled) {
Expand Down Expand Up @@ -506,12 +506,22 @@ const (
verbCancel = "cancel"
)

func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string, verb string) error {
func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string, verb string, reverse bool) error {
args := sshConfig.Args()
args = append(args,
"-T",
"-O", verb,
"-L", local+":"+remote,
)
if reverse {
args = append(args,
"-R", remote+":"+local,
)
} else {
args = append(args,
"-L", local+":"+remote,
)
}
args = append(args,
"-N",
"-f",
"-p", strconv.Itoa(port),
Expand All @@ -521,15 +531,23 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
if strings.HasPrefix(local, "/") {
switch verb {
case verbForward:
logrus.Infof("Forwarding %q (guest) to %q (host)", remote, local)
if reverse {
logrus.Infof("Forwarding %q (host) to %q (guest)", local, remote)
} else {
logrus.Infof("Forwarding %q (guest) to %q (host)", remote, local)
}
if err := os.RemoveAll(local); err != nil {
logrus.WithError(err).Warnf("Failed to clean up %q (host) before setting up forwarding", local)
}
if err := os.MkdirAll(filepath.Dir(local), 0750); err != nil {
return fmt.Errorf("can't create directory for local socket %q: %w", local, err)
}
case verbCancel:
logrus.Infof("Stopping forwarding %q (guest) to %q (host)", remote, local)
if reverse {
logrus.Infof("Stopping forwarding %q (host) to %q (guest)", local, remote)
} else {
logrus.Infof("Stopping forwarding %q (guest) to %q (host)", remote, local)
}
defer func() {
if err := os.RemoveAll(local); err != nil {
logrus.WithError(err).Warnf("Failed to clean up %q (host) after stopping forwarding", local)
Expand Down
10 changes: 5 additions & 5 deletions pkg/hostagent/port_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
// forwardTCP is not thread-safe
func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string, verb string) error {
if strings.HasPrefix(local, "/") {
return forwardSSH(ctx, sshConfig, port, local, remote, verb)
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
}
localIPStr, localPortStr, err := net.SplitHostPort(local)
if err != nil {
Expand All @@ -31,7 +31,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
}

if !localIP.Equal(api.IPv4loopback1) || localPort >= 1024 {
return forwardSSH(ctx, sshConfig, port, local, remote, verb)
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
}

// on macOS, listening on 127.0.0.1:80 requires root while 0.0.0.0:80 does not require root.
Expand All @@ -46,7 +46,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
localUnix := plf.unixAddr.Name
_ = plf.Close()
delete(pseudoLoopbackForwarders, local)
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb); err != nil {
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil {
return err
}
} else {
Expand All @@ -61,12 +61,12 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
}
localUnix := filepath.Join(localUnixDir, "sock")
logrus.Debugf("forwarding %q to %q", localUnix, remote)
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb); err != nil {
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil {
return err
}
plf, err := newPseudoLoopbackForwarder(localPort, localUnix)
if err != nil {
if cancelErr := forwardSSH(ctx, sshConfig, port, localUnix, remote, verbCancel); cancelErr != nil {
if cancelErr := forwardSSH(ctx, sshConfig, port, localUnix, remote, verbCancel, false); cancelErr != nil {
logrus.WithError(cancelErr).Warnf("failed to cancel forwarding %q to %q", localUnix, remote)
}
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/hostagent/port_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ import (
)

func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string, verb string) error {
return forwardSSH(ctx, sshConfig, port, local, remote, verb)
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
}
1 change: 1 addition & 0 deletions pkg/limayaml/defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func TestFillDefault(t *testing.T) {
HostIP: api.IPv4loopback1,
HostPortRange: [2]int{1, 65535},
Proto: TCP,
Reverse: false,
}

// ------------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions pkg/limayaml/limayaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type PortForward struct {
HostPortRange [2]int `yaml:"hostPortRange,omitempty" json:"hostPortRange,omitempty"`
HostSocket string `yaml:"hostSocket,omitempty" json:"hostSocket,omitempty"`
Proto Proto `yaml:"proto,omitempty" json:"proto,omitempty"`
Reverse bool `yaml:"reverse,omitempty" json:"reverse,omitempty"`
Ignore bool `yaml:"ignore,omitempty" json:"ignore,omitempty"`
}

Expand Down
6 changes: 6 additions & 0 deletions pkg/limayaml/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ func Validate(y LimaYAML, warn bool) error {
if rule.Proto != TCP {
return fmt.Errorf("field `%s.proto` must be %q", field, TCP)
}
if rule.Reverse && rule.GuestSocket == "" {
return fmt.Errorf("field `%s.reverse` must be %t", field, false)
}
if rule.Reverse && rule.HostSocket == "" {
return fmt.Errorf("field `%s.reverse` must be %t", field, false)
}
// Not validating that the various GuestPortRanges and HostPortRanges are not overlapping. Rules will be
// processed sequentially and the first matching rule for a guest port determines forwarding behavior.
}
Expand Down

0 comments on commit 9391621

Please sign in to comment.