diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index fb442d5eb..d56e5390a 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -10,6 +10,7 @@ type Service interface { Start(ctx context.Context) (runError <-chan error, err error) Stop() (err error) GetPortsForwarded() (ports []uint16) + SetPortsForwarded(ctx context.Context, ports []uint16) (err error) } type Routing interface { diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 8f9431456..2e19a4bea 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -2,6 +2,7 @@ package portforward import ( "context" + "errors" "fmt" "net/http" "sync" @@ -166,6 +167,16 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) { return l.service.GetPortsForwarded() } +var ErrServiceNotStarted = errors.New("port forwarding service not started") + +func (l *Loop) SetPortsForwarded(ports []uint16) (err error) { + if l.service == nil { + return fmt.Errorf("%w", ErrServiceNotStarted) + } + + return l.service.SetPortsForwarded(l.runCtx, ports) +} + func ptrTo[T any](value T) *T { return &value } diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index 579b37397..10b0b2187 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -2,7 +2,9 @@ package service import ( "context" + "fmt" "net/http" + "slices" "sync" ) @@ -50,3 +52,29 @@ func (s *Service) GetPortsForwarded() (ports []uint16) { copy(ports, s.ports) return ports } + +func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) { + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() + + s.portMutex.Lock() + defer s.portMutex.Unlock() + slices.Sort(ports) + if slices.Equal(s.ports, ports) { + return nil + } + + err = s.cleanup() + if err != nil { + return fmt.Errorf("cleaning up: %w", err) + } + + err = s.onNewPorts(ctx, ports) + if err != nil { + return fmt.Errorf("handling new ports: %w", err) + } + + s.logger.Info("updated: " + portsToString(s.ports)) + + return nil +} diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index a13ac0e41..2eefee622 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "slices" "github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/provider/utils" @@ -47,18 +48,54 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) return nil, fmt.Errorf("port forwarding for the first time: %w", err) } + err = s.onNewPorts(ctx, ports) + if err != nil { + return nil, err + } + + keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) + s.keepPortCancel = keepPortCancel + runErrorCh := make(chan error) + keepPortDoneCh := make(chan struct{}) + s.keepPortDoneCh = keepPortDoneCh + + readyCh := make(chan struct{}) + go func(ctx context.Context, portForwarder PortForwarder, + obj utils.PortForwardObjects, readyCh chan<- struct{}, + runError chan<- error, doneCh chan<- struct{}, + ) { + defer close(doneCh) + close(readyCh) + err = portForwarder.KeepPortForward(ctx, obj) + crashed := ctx.Err() == nil + if !crashed { // stopped by Stop call + return + } + s.startStopMutex.Lock() + defer s.startStopMutex.Unlock() + _ = s.cleanup() + runError <- err + }(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh) + <-readyCh + + return runErrorCh, nil +} + +func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) { + slices.Sort(ports) + s.logger.Info(portsToString(ports)) for _, port := range ports { err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface) if err != nil { - return nil, fmt.Errorf("allowing port in firewall: %w", err) + return fmt.Errorf("allowing port in firewall: %w", err) } if s.settings.ListeningPort != 0 { err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort) if err != nil { - return nil, fmt.Errorf("redirecting port in firewall: %w", err) + return fmt.Errorf("redirecting port in firewall: %w", err) } } } @@ -66,11 +103,12 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) err = s.writePortForwardedFile(ports) if err != nil { _ = s.cleanup() - return nil, fmt.Errorf("writing port file: %w", err) + return fmt.Errorf("writing port file: %w", err) } s.portMutex.Lock() - s.ports = ports + s.ports = make([]uint16, len(ports)) + copy(s.ports, ports) s.portMutex.Unlock() if s.settings.UpCommand != "" { @@ -81,30 +119,5 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) } } - keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) - s.keepPortCancel = keepPortCancel - runErrorCh := make(chan error) - keepPortDoneCh := make(chan struct{}) - s.keepPortDoneCh = keepPortDoneCh - - readyCh := make(chan struct{}) - go func(ctx context.Context, portForwarder PortForwarder, - obj utils.PortForwardObjects, readyCh chan<- struct{}, - runError chan<- error, doneCh chan<- struct{}, - ) { - defer close(doneCh) - close(readyCh) - err = portForwarder.KeepPortForward(ctx, obj) - crashed := ctx.Err() == nil - if !crashed { // stopped by Stop call - return - } - s.startStopMutex.Lock() - defer s.startStopMutex.Unlock() - _ = s.cleanup() - runError <- err - }(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh) - <-readyCh - - return runErrorCh, nil + return nil } diff --git a/internal/server/handler.go b/internal/server/handler.go index 92cdf7b90..9508b9ca5 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -15,7 +15,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool, authSettings auth.Settings, buildInfo models.BuildInformation, vpnLooper VPNLooper, - pfGetter PortForwardedGetter, + pf PortForwarding, dnsLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, @@ -25,7 +25,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool, handler := &handler{} vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger) - openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger) + openvpn := newOpenvpnHandler(ctx, vpnLooper, pf, logger) dns := newDNSHandler(ctx, dnsLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger) publicip := newPublicIPHandler(publicIPLooper, logger) diff --git a/internal/server/interfaces.go b/internal/server/interfaces.go index 5b4702510..b49d28373 100644 --- a/internal/server/interfaces.go +++ b/internal/server/interfaces.go @@ -21,8 +21,9 @@ type DNSLoop interface { GetStatus() (status models.LoopStatus) } -type PortForwardedGetter interface { +type PortForwarding interface { GetPortsForwarded() (ports []uint16) + SetPortsForwarded(ports []uint16) (err error) } type PublicIPLoop interface { diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index 3e25c5feb..746c41694 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "fmt" "net/http" "strings" @@ -11,12 +12,12 @@ import ( ) func newOpenvpnHandler(ctx context.Context, looper VPNLooper, - pfGetter PortForwardedGetter, w warner, + portForwarding PortForwarding, w warner, ) http.Handler { return &openvpnHandler{ ctx: ctx, looper: looper, - pf: pfGetter, + pf: portForwarding, warner: w, } } @@ -24,7 +25,7 @@ func newOpenvpnHandler(ctx context.Context, looper VPNLooper, type openvpnHandler struct { ctx context.Context //nolint:containedctx looper VPNLooper - pf PortForwardedGetter + pf PortForwarding warner warner } @@ -51,6 +52,8 @@ func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: h.getPortForwarded(w) + case http.MethodPut: + h.setPortForwarded(w, r) default: errMethodNotSupported(w, r.Method) } @@ -142,3 +145,24 @@ func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { w.WriteHeader(http.StatusInternalServerError) } } + +func (h *openvpnHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) { + var data portsWrapper + + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&data) + if err != nil { + h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err)) + http.Error(w, "failed setting forwarded ports", http.StatusBadRequest) + return + } + + err = h.pf.SetPortsForwarded(data.Ports) + if err != nil { + h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err)) + http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/internal/server/server.go b/internal/server/server.go index 3f50717b2..e296cf313 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -11,7 +11,7 @@ import ( func New(ctx context.Context, address string, logEnabled bool, logger Logger, authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper, - pfGetter PortForwardedGetter, dnsLooper DNSLoop, + pf PortForwarding, dnsLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage, ipv6Supported bool) ( server *httpserver.Server, err error, @@ -27,7 +27,7 @@ func New(ctx context.Context, address string, logEnabled bool, logger Logger, } handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo, - openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper, + openvpnLooper, pf, dnsLooper, updaterLooper, publicIPLooper, storage, ipv6Supported) if err != nil { return nil, fmt.Errorf("creating handler: %w", err)