Skip to content

Commit 6bc6143

Browse files
committed
feat: save vhosts to disk and load on startup
1 parent 698577a commit 6bc6143

File tree

2 files changed

+78
-20
lines changed

2 files changed

+78
-20
lines changed

daemon.go

+53-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package vproxy
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"log"
78
"net"
89
"net/http"
910
"os"
11+
"path"
1012
"strconv"
1113
"strings"
1214
"sync"
@@ -43,7 +45,9 @@ type Daemon struct {
4345

4446
// NewDaemon
4547
func NewDaemon(lh *LoggedHandler, listen string, httpPort int, httpsPort int) *Daemon {
46-
return &Daemon{loggedHandler: lh, listenHost: listen, httpPort: httpPort, httpsPort: httpsPort}
48+
d := &Daemon{loggedHandler: lh, listenHost: listen, httpPort: httpPort, httpsPort: httpsPort}
49+
d.loadVhosts()
50+
return d
4751
}
4852

4953
func rerunWithSudo(addr string) {
@@ -275,6 +279,52 @@ func (d *Daemon) doRemoveVhost(vhost *Vhost, w http.ResponseWriter) {
275279
fmt.Fprintf(w, "removing vhost: %s -> %d\n", vhost.Host, vhost.ServicePort)
276280
vhost.Close()
277281
d.loggedHandler.RemoveVhost(vhost.Host)
282+
d.saveVhosts()
283+
}
284+
285+
// load saved vhosts from disk
286+
func (d *Daemon) loadVhosts() {
287+
c := path.Join(CertPath(), "vhosts.json")
288+
j, err := os.ReadFile(c)
289+
if err != nil {
290+
fmt.Println("[*] warning: failed to load vhosts from disk: ", err)
291+
return
292+
}
293+
servers := map[string]*Vhost{}
294+
err = json.Unmarshal(j, &servers)
295+
if err != nil {
296+
fmt.Println("[*] warning: failed to load vhosts from disk: ", err)
297+
return
298+
}
299+
if len(servers) == 0 {
300+
return
301+
}
302+
for _, vhost := range servers {
303+
vhost.Init()
304+
d.loggedHandler.AddVhost(vhost)
305+
err = addToHosts(vhost.Host)
306+
if err != nil {
307+
msg := fmt.Sprintf("[*] warning: failed to add %s to system hosts file: %s\n", vhost.Host, err)
308+
fmt.Println(msg)
309+
}
310+
// fmt.Printf("[*] loaded vhost: %s\n", vhost.String())
311+
}
312+
fmt.Printf("[*] loaded %d vhost(s)\n", len(servers))
313+
}
314+
315+
// save vhosts to disk
316+
func (d *Daemon) saveVhosts() {
317+
c := path.Join(CertPath(), "vhosts.json")
318+
j, err := json.Marshal(d.loggedHandler.vhostMux.Servers)
319+
if err != nil {
320+
fmt.Println("[*] warning: failed to save vhosts to disk: ", err)
321+
return
322+
}
323+
err = os.WriteFile(c, j, 0644)
324+
if err != nil {
325+
fmt.Println("[*] warning: failed to save vhosts to disk: ", err)
326+
return
327+
}
278328
}
279329

280330
// addVhost for the given binding to the LoggedHandler
@@ -303,6 +353,8 @@ func (d *Daemon) addVhost(binding string, w http.ResponseWriter) *Vhost {
303353
w.Header().Set("Access-Control-Allow-Origin", "*")
304354

305355
d.loggedHandler.AddVhost(vhost)
356+
d.saveVhosts()
357+
306358
if d.enableTLS() {
307359
d.restartTLS()
308360
}

vhost.go

+25-19
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ import (
1515

1616
// Vhost represents a single backend service
1717
type Vhost struct {
18-
Host string // virtual host name
18+
Host string `json:"host"` // virtual host name
1919

2020
ServiceHost string // service host or IP
2121
ServicePort int // service port
2222

23-
Handler http.Handler
24-
Cert string // TLS Certificate
25-
Key string // TLS Private Key
23+
Handler http.Handler `json:"-"`
24+
Cert string // TLS Certificate
25+
Key string // TLS Private Key
2626

27-
logRing *deque.Deque[string]
28-
logChan LogListener
29-
listeners []LogListener
27+
logRing *deque.Deque[string] `json:"-"`
28+
logChan LogListener `json:"-"`
29+
listeners []LogListener `json:"-"`
3030
}
3131

3232
type LogListener chan string
@@ -107,22 +107,13 @@ func CreateVhost(input string, useTLS bool) (*Vhost, error) {
107107
return nil, fmt.Errorf("failed to parse target port: %s", err)
108108
}
109109
targetHost := "127.0.0.1"
110-
targetURL := url.URL{Scheme: "http", Host: fmt.Sprintf("%s:%d", targetHost, targetPort)}
111-
112-
proxy := CreateProxy(targetURL, hostname)
113110

114111
vhost := &Vhost{
115-
Host: hostname, ServiceHost: targetHost, ServicePort: targetPort, Handler: proxy,
116-
logRing: &deque.Deque[string]{},
117-
logChan: make(LogListener, 10),
112+
Host: hostname,
113+
ServiceHost: targetHost,
114+
ServicePort: targetPort,
118115
}
119116

120-
// set fixed capacity at 16
121-
vhost.logRing.Grow(16)
122-
vhost.logRing.SetBaseCap(16)
123-
124-
go vhost.populateLogBuffer()
125-
126117
if useTLS {
127118
vhost.Cert, vhost.Key, err = MakeCert(hostname)
128119
if err != nil {
@@ -133,6 +124,17 @@ func CreateVhost(input string, useTLS bool) (*Vhost, error) {
133124
return vhost, nil
134125
}
135126

127+
func (v *Vhost) Init() {
128+
targetURL := url.URL{Scheme: "http", Host: fmt.Sprintf("%s:%d", v.ServiceHost, v.ServicePort)}
129+
v.Handler = CreateProxy(targetURL, v.Host)
130+
v.logChan = make(LogListener, 10)
131+
// set fixed capacity at 16
132+
v.logRing = &deque.Deque[string]{}
133+
v.logRing.Grow(16)
134+
v.logRing.SetBaseCap(16)
135+
go v.populateLogBuffer()
136+
}
137+
136138
func (v *Vhost) NewLogListener() LogListener {
137139
logChan := make(LogListener, 100)
138140
v.listeners = append(v.listeners, logChan)
@@ -180,6 +182,10 @@ func (v *Vhost) populateLogBuffer() {
180182
}
181183
}
182184

185+
func (v Vhost) String() string {
186+
return fmt.Sprintf("%s -> %s:%d", v.Host, v.ServiceHost, v.ServicePort)
187+
}
188+
183189
// Map given host to 127.0.0.1 in system hosts file (usually /etc/hosts)
184190
func addToHosts(host string) error {
185191
hosts, err := txeh.NewHostsDefault()

0 commit comments

Comments
 (0)