-
-
Notifications
You must be signed in to change notification settings - Fork 37
/
system.go
113 lines (98 loc) · 3.22 KB
/
system.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package nameserver
import (
"errors"
"fmt"
"net/netip"
"os"
"path/filepath"
"strings"
"github.com/qdm12/gosettings"
)
type SettingsSystemDNS struct {
// IP is the IP address to use for the DNS.
// It defaults to 127.0.0.1 if nil.
IP netip.Addr
// ResolvPath is the path to the resolv configuration file.
// It defaults to /etc/resolv.conf.
ResolvPath string
// KeepNameserver can be set to preserve existing nameserver lines
// in the resolv configuration file.
KeepNameserver *bool
}
func (s *SettingsSystemDNS) SetDefaults() {
s.IP = gosettings.DefaultValidator(s.IP, netip.AddrFrom4([4]byte{127, 0, 0, 1}))
s.ResolvPath = gosettings.DefaultString(s.ResolvPath, "/etc/resolv.conf")
s.KeepNameserver = gosettings.DefaultPointer(s.KeepNameserver, false)
}
var (
ErrResolvPathIsDirectory = errors.New("resolv path is a directory")
)
func (s *SettingsSystemDNS) Validate() (err error) {
stat, err := os.Stat(s.ResolvPath)
switch {
case errors.Is(err, os.ErrNotExist): // it will be created
case err != nil:
return fmt.Errorf("stating resolv path: %w", err)
case stat.IsDir():
return fmt.Errorf("%w: %s", ErrResolvPathIsDirectory, s.ResolvPath)
}
return nil
}
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
// If resolvConfPath is empty, it defaults to /etc/resolv.conf.
func UseDNSSystemWide(settings SettingsSystemDNS) (err error) {
settings.SetDefaults()
stat, err := os.Stat(settings.ResolvPath)
switch {
case errors.Is(err, os.ErrNotExist):
return createResolvFile(settings.ResolvPath, settings.IP)
case err != nil:
return fmt.Errorf("stating resolv path: %w", err)
case stat.IsDir():
return fmt.Errorf("%w: %s", ErrResolvPathIsDirectory, settings.ResolvPath)
}
return patchResolvFile(settings.ResolvPath, settings.IP, *settings.KeepNameserver)
}
func createResolvFile(resolvPath string, ip netip.Addr) (err error) {
parentDirectory := filepath.Dir(resolvPath)
const defaultPerms os.FileMode = 0o755
err = os.MkdirAll(parentDirectory, defaultPerms)
if err != nil {
return fmt.Errorf("creating resolv path parent directory: %w", err)
}
const filePermissions os.FileMode = 0600
data := []byte("nameserver " + ip.String() + "\n")
err = os.WriteFile(resolvPath, data, filePermissions)
if err != nil {
return fmt.Errorf("creating resolv file: %w", err)
}
return nil
}
func patchResolvFile(resolvPath string, ip netip.Addr,
keepNameserver bool) (err error) {
data, err := os.ReadFile(resolvPath)
if err != nil {
return fmt.Errorf("reading file: %w", err)
}
lines := strings.Split(string(data), "\n")
patchedLines := make([]string, 0, len(lines)+1)
patchedLines = append(patchedLines, "nameserver "+ip.String())
for _, line := range lines {
if keepNameserver || !strings.HasPrefix(line, "nameserver ") {
patchedLines = append(patchedLines, line)
}
}
patchedString := strings.Join(patchedLines, "\n")
patchedString = strings.TrimRight(patchedString, "\n")
hadTrailNewLine := patchedLines[len(patchedLines)-1] == ""
if hadTrailNewLine {
patchedString += "\n"
}
patchedData := []byte(patchedString)
const permissions os.FileMode = 0600
err = os.WriteFile(resolvPath, patchedData, permissions)
if err != nil {
return fmt.Errorf("writing resolv file: %w", err)
}
return nil
}