Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion infra/conf/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
var err error
config.SecretKey, err = ParseWireGuardKey(c.SecretKey)
if err != nil {
return nil, err
return nil, errors.New("invalid WireGuard secret key: %w", err)
}

if c.Address == nil {
Expand Down Expand Up @@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
func ParseWireGuardKey(str string) (string, error) {
var err error

if str == "" {
return "", errors.New("key must not be empty")
}

if len(str)%2 == 0 {
_, err = hex.DecodeString(str)
if err == nil {
Expand Down
40 changes: 20 additions & 20 deletions infra/conf/xray.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
}
rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol)
if err != nil {
return nil, errors.New("failed to load inbound detour config.").Base(err)
return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err)
}
if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok {
receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect
}
ts, err := rawConfig.(Buildable).Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err)
}

return &core.InboundHandlerConfig{
Expand Down Expand Up @@ -303,15 +303,15 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
if c.StreamSetting != nil {
ss, err := c.StreamSetting.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build stream settings for outbound detour").Base(err)
}
senderSettings.StreamSettings = ss
}

if c.ProxySettings != nil {
ps, err := c.ProxySettings.Build()
if err != nil {
return nil, errors.New("invalid outbound detour proxy settings.").Base(err)
return nil, errors.New("invalid outbound detour proxy settings").Base(err)
}
if ps.TransportLayerProxy {
if senderSettings.StreamSettings != nil {
Expand All @@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
if c.MuxSettings != nil {
ms, err := c.MuxSettings.Build()
if err != nil {
return nil, errors.New("failed to build Mux config.").Base(err)
return nil, errors.New("failed to build Mux config").Base(err)
}
senderSettings.MultiplexSettings = ms
}
Expand All @@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
}
rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol)
if err != nil {
return nil, errors.New("failed to parse to outbound detour config.").Base(err)
return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err)
}
ts, err := rawConfig.(Buildable).Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err)
}

return &core.OutboundHandlerConfig{
Expand Down Expand Up @@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) {
// Build implements Buildable.
func (c *Config) Build() (*core.Config, error) {
if err := PostProcessConfigureFile(c); err != nil {
return nil, err
return nil, errors.New("failed to post-process configuration file").Base(err)
}

config := &core.Config{
Expand All @@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) {
if c.API != nil {
apiConf, err := c.API.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build API configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(apiConf))
}
if c.Metrics != nil {
metricsConf, err := c.Metrics.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build metrics configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(metricsConf))
}
if c.Stats != nil {
statsConf, err := c.Stats.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build stats configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(statsConf))
}
Expand All @@ -536,55 +536,55 @@ func (c *Config) Build() (*core.Config, error) {
if c.RouterConfig != nil {
routerConfig, err := c.RouterConfig.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build routing configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(routerConfig))
}

if c.DNSConfig != nil {
dnsApp, err := c.DNSConfig.Build()
if err != nil {
return nil, errors.New("failed to parse DNS config").Base(err)
return nil, errors.New("failed to build DNS configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(dnsApp))
}

if c.Policy != nil {
pc, err := c.Policy.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build policy configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(pc))
}

if c.Reverse != nil {
r, err := c.Reverse.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build reverse configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(r))
}

if c.FakeDNS != nil {
r, err := c.FakeDNS.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build fake DNS configuration").Base(err)
}
config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...)
}

if c.Observatory != nil {
r, err := c.Observatory.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build observatory configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(r))
}

if c.BurstObservatory != nil {
r, err := c.BurstObservatory.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build burst observatory configuration").Base(err)
}
config.App = append(config.App, serial.ToTypedMessage(r))
}
Expand All @@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) {
for _, rawInboundConfig := range inbounds {
ic, err := rawInboundConfig.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err)
}
config.Inbound = append(config.Inbound, ic)
}
Expand All @@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) {
for _, rawOutboundConfig := range outbounds {
oc, err := rawOutboundConfig.Build()
if err != nil {
return nil, err
return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err)
}
config.Outbound = append(config.Outbound, oc)
}
Expand Down
13 changes: 6 additions & 7 deletions proxy/wireguard/gvisortun/tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"net/netip"
"os"
"sync"
"syscall"

"golang.zx2c4.com/wireguard/tun"
Expand All @@ -33,6 +34,7 @@ type netTun struct {
incomingPacket chan *buffer.View
mtu int
hasV4, hasV6 bool
closeOnce sync.Once
}

type Net netTun
Expand Down Expand Up @@ -174,18 +176,15 @@ func (tun *netTun) Flush() error {

// Close implements tun.Device
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
tun.closeOnce.Do(func() {
tun.stack.RemoveNIC(1)

if tun.events != nil {
close(tun.events)
}

tun.ep.Close()
tun.ep.Close()

if tun.incomingPacket != nil {
close(tun.incomingPacket)
}

})
return nil
}

Expand Down
52 changes: 52 additions & 0 deletions proxy/wireguard/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package wireguard_test

import (
"context"
"github.com/stretchr/testify/assert"
"runtime/debug"
"testing"

"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/proxy/wireguard"
)

// TestWireGuardServerInitializationError verifies that an error during TUN initialization
// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead.
func TestWireGuardServerInitializationError(t *testing.T) {
// Create a minimal core instance with default features
config := &core.Config{}
instance, err := core.New(config)
if err != nil {
t.Fatalf("Failed to create core instance: %v", err)
}
// Set the Xray instance in the context
ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)

// Define the server configuration with an empty SecretKey to trigger error
conf := &wireguard.DeviceConfig{
IsClient: false,
Endpoint: []string{"10.0.0.1/32"},
Mtu: 1420,
SecretKey: "", // Empty SecretKey to trigger error
Peers: []*wireguard.PeerConfig{
{
PublicKey: "some_public_key",
AllowedIps: []string{"10.0.0.2/32"},
},
},
}

// Use defer to catch any panic and fail the test explicitly
defer func() {
if r := recover(); r != nil {
t.Errorf("TUN initialization panicked: %v", r)
debug.PrintStack()
}
}()

// Attempt to initialize the WireGuard server
_, err = wireguard.NewServer(ctx, conf)

// Check that an error is returned
assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice")
}