diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 2ae649029625..4e3f9487cd33 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -449,6 +449,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw } return contentResult, contentErr } + func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index fc83a7405dd7..56c1303fbf23 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -4,7 +4,6 @@ import ( "context" "time" - "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/net" @@ -231,8 +230,14 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l return w.Dispatcher.DispatchLink(ctx, dest, link) } - link = w.Dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) + link = w.Dispatcher.WrapLink(ctx, link) w.handleInternalConn(link) return nil } + +// WrapLink this method will never be called; it's only used to implement the routing.Dispatcher interface. +func (w *BridgeWorker) WrapLink(ctx context.Context, link *transport.Link) *transport.Link { + // if this line of code is called, there may be duplicate calls. please remove it. + return w.Dispatcher.WrapLink(ctx, link) // noop, duplicate calls +} diff --git a/common/mux/server.go b/common/mux/server.go index f01c325d08dc..30909e43ae70 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -5,7 +5,6 @@ import ( "io" "time" - "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -64,7 +63,7 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t if dest.Address != muxCoolAddress { return s.dispatcher.DispatchLink(ctx, dest, link) } - link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) + link = s.dispatcher.WrapLink(ctx, link) worker, err := NewServerWorker(ctx, s.dispatcher, link) if err != nil { return err @@ -76,6 +75,12 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t return nil } +// WrapLink this method will never be called; it's only used to implement the routing.Dispatcher interface. +func (s *Server) WrapLink(ctx context.Context, link *transport.Link) *transport.Link { + // if this line of code is called, there may be duplicate calls. please remove it. + return s.dispatcher.WrapLink(ctx, link) // noop, duplicate calls +} + // Start implements common.Runnable. func (s *Server) Start() error { return nil diff --git a/features/routing/dispatcher.go b/features/routing/dispatcher.go index 53d3bf900f15..77ea27c975fd 100644 --- a/features/routing/dispatcher.go +++ b/features/routing/dispatcher.go @@ -18,6 +18,8 @@ type Dispatcher interface { // Dispatch returns a Ray for transporting data for the given request. Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error + + WrapLink(ctx context.Context, link *transport.Link) *transport.Link } // DispatcherType returns the type of Dispatcher interface. Can be used to implement common.HasType. diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 223aade08b68..a7bf47170f5e 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -12,7 +12,6 @@ import ( "time" "unsafe" - "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/reverse" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -76,7 +75,7 @@ type Handler struct { validator vless.Validator decryption *encryption.ServerInstance outboundHandlerManager outbound.Manager - defaultDispatcher *dispatcher.DefaultDispatcher + dispatcher routing.Dispatcher ctx context.Context fallbacks map[string]map[string]map[string]*Fallback // or nil // regexps map[string]*regexp.Regexp // or nil @@ -90,7 +89,7 @@ func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Val policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), validator: validator, outboundHandlerManager: v.GetFeature(outbound.ManagerType()).(outbound.Manager), - defaultDispatcher: v.GetFeature(routing.DispatcherType()).(*dispatcher.DefaultDispatcher), + dispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher), ctx: ctx, } @@ -619,7 +618,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if err != nil { return err } - return r.NewMux(ctx, h.defaultDispatcher.WrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter})) + return r.NewMux(ctx, h.dispatcher.WrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter})) } if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{