Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(web): add file transfer support #62

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 276 additions & 10 deletions cmd/wasm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"bytes"
"context"
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
"strings"
"syscall/js"
"time"

Expand Down Expand Up @@ -52,13 +55,8 @@ func main() {
<-make(chan struct{}, 0)
}

func newWush(jsConfig js.Value) map[string]any {
func newWush(cfg js.Value) map[string]any {
ctx := context.Background()
var authKey string
if jsAuthKey := jsConfig.Get("authKey"); jsAuthKey.Type() == js.TypeString {
authKey = jsAuthKey.String()
}

logger := slog.New(slog.NewTextHandler(jsConsoleWriter{}, nil))
hlog := func(format string, args ...any) {
fmt.Printf(format+"\n", args...)
Expand All @@ -68,18 +66,19 @@ func newWush(jsConfig js.Value) map[string]any {
panic(err)
}

send := overlay.NewSendOverlay(logger, dm)
err = send.Auth.Parse(authKey)
ov := overlay.NewWasmOverlay(log.Printf, dm, cfg.Get("onNewPeer"))

err = ov.PickDERPHome(ctx)
if err != nil {
panic(err)
}

s, err := tsserver.NewServer(ctx, logger, send, dm)
s, err := tsserver.NewServer(ctx, logger, ov, dm)
if err != nil {
panic(err)
}

go send.ListenOverlayDERP(ctx)
go ov.ListenOverlayDERP(ctx)
go s.ListenAndServe(ctx)
netns.SetDialerOverride(s.Dialer())

Expand All @@ -94,12 +93,40 @@ func newWush(jsConfig js.Value) map[string]any {
}
hlog("WireGuard is ready")

cpListener, err := ts.Listen("tcp", ":4444")
if err != nil {
panic(err)
}

go func() {
err := http.Serve(cpListener, http.HandlerFunc(cpH(
cfg.Get("onIncomingFile"),
cfg.Get("downloadFile"),
)))
if err != nil {
hlog("File transfer server exited: " + err.Error())
}
}()

return map[string]any{
"auth_info": js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) != 0 {
log.Printf("Usage: auth_info()")
return nil
}

return map[string]any{
"derp_id": ov.DerpRegionID,
"derp_name": ov.DerpMap.Regions[int(ov.DerpRegionID)].RegionName,
"auth_key": ov.ClientAuth().AuthKey(),
}
}),
"stop": js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) != 0 {
log.Printf("Usage: stop()")
return nil
}
cpListener.Close()
ts.Close()
return nil
}),
Expand Down Expand Up @@ -127,6 +154,157 @@ func newWush(jsConfig js.Value) map[string]any {
}),
}
}),
"connect": js.FuncOf(func(this js.Value, args []js.Value) any {
handler := js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]

go func() {
if len(args) != 1 {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New("Usage: connect(authKey)")
reject.Invoke(errorObject)
return
}

var authKey string
if args[0].Type() == js.TypeString {
authKey = args[0].String()
} else {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New("Usage: connect(authKey)")
reject.Invoke(errorObject)
return
}

var ca overlay.ClientAuth
err := ca.Parse(authKey)
if err != nil {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New(fmt.Errorf("parse authkey: %w", err).Error())
reject.Invoke(errorObject)
return
}

ctx, cancel := context.WithCancel(context.Background())
peer, err := ov.Connect(ctx, ca)
if err != nil {
cancel()
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New(fmt.Errorf("parse authkey: %w", err).Error())
reject.Invoke(errorObject)
return
}

resolve.Invoke(map[string]any{
"id": js.ValueOf(peer.ID),
"name": js.ValueOf(peer.Name),
"ip": js.ValueOf(peer.IP.String()),
"cancel": js.FuncOf(func(this js.Value, args []js.Value) any {
cancel()
return nil
}),
})
}()

return nil
})

promiseConstructor := js.Global().Get("Promise")
return promiseConstructor.New(handler)
}),
"transfer": js.FuncOf(func(this js.Value, args []js.Value) any {
handler := js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]

if len(args) != 5 {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New("Usage: transfer(peer, file)")
reject.Invoke(errorObject)
return nil
}

peer := args[0]
ip := peer.Get("ip").String()
fileName := args[1].String()
sizeBytes := args[2].Int()
stream := args[3]
streamHelper := args[4]

pr, pw := io.Pipe()

goCallback := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
promiseConstructor := js.Global().Get("Promise")
return promiseConstructor.New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
_ = promiseArgs[1]
go func() {
if len(args) == 0 || args[0].IsNull() || args[0].IsUndefined() {
pw.Close()
resolve.Invoke()
return
}

fmt.Println("in go callback")
// Convert the JavaScript Uint8Array to a Go byte slice
uint8Array := args[0]
fmt.Println("type is", uint8Array.Type().String())
length := uint8Array.Get("length").Int()
buf := make([]byte, length)
js.CopyBytesToGo(buf, uint8Array)

fmt.Println("sending data to channel")
// Send the data to the channel
if _, err := pw.Write(buf); err != nil {
pw.CloseWithError(err)
}
fmt.Println("callback finished")

// Resolve the promise
resolve.Invoke()
}()
return nil
}))
})

go func() {
defer goCallback.Release()

streamHelper.Invoke(stream, goCallback)

hc := ts.HTTPClient()
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s:4444/%s", ip, fileName), pr)
if err != nil {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New(err.Error())
reject.Invoke(errorObject)
return
}
req.ContentLength = int64(sizeBytes)

res, err := hc.Do(req)
if err != nil {
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New(err.Error())
reject.Invoke(errorObject)
return
}
defer res.Body.Close()

bod := bytes.NewBuffer(nil)
_, _ = io.Copy(bod, res.Body)

fmt.Println(bod.String())
resolve.Invoke()
}()

return nil
})

promiseConstructor := js.Global().Get("Promise")
return promiseConstructor.New(handler)
}),
}
}

Expand Down Expand Up @@ -306,3 +484,91 @@ func newTSNet(direction string) (*tsnet.Server, error) {

return srv, nil
}

func cpH(onIncomingFile js.Value, downloadFile js.Value) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
return
}

fiName := strings.TrimPrefix(r.URL.Path, "/")

// TODO: impl
peer := map[string]any{
"id": js.ValueOf(0),
"name": js.ValueOf(""),
"ip": js.ValueOf(""),
"cancel": js.FuncOf(func(this js.Value, args []js.Value) any {
return nil
}),
}

allow := onIncomingFile.Invoke(peer, fiName, r.ContentLength).Bool()
if !allow {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("File transfer was denied"))
r.Body.Close()
return
}

underlyingSource := map[string]interface{}{
// start method
"start": js.FuncOf(func(this js.Value, args []js.Value) interface{} {
// The first and only arg is the controller object
controller := args[0]

// Process the stream in yet another background goroutine,
// because we can't block on a goroutine invoked by JS in Wasm
// that is dealing with HTTP requests
go func() {
// Close the response body at the end of this method
defer r.Body.Close()

// Read the entire stream and pass it to JavaScript
for {
// Read up to 16KB at a time
buf := make([]byte, 16384)
n, err := r.Body.Read(buf)
if err != nil && err != io.EOF {
// Tell the controller we have an error
// We're ignoring "EOF" however, which means the stream was done
errorConstructor := js.Global().Get("Error")
errorObject := errorConstructor.New(err.Error())
controller.Call("error", errorObject)
return
}
if n > 0 {
// If we read anything, send it to JavaScript using the "enqueue" method on the controller
// We need to convert it to a Uint8Array first
arrayConstructor := js.Global().Get("Uint8Array")
dataJS := arrayConstructor.New(n)
js.CopyBytesToJS(dataJS, buf[0:n])
controller.Call("enqueue", dataJS)
}
if err == io.EOF {
// Stream is done, so call the "close" method on the controller
controller.Call("close")
return
}
}
}()

return nil
}),
// cancel method
"cancel": js.FuncOf(func(this js.Value, args []js.Value) interface{} {
// If the request is canceled, just close the body
r.Body.Close()

return nil
}),
}

readableStreamConstructor := js.Global().Get("ReadableStream")
readableStream := readableStreamConstructor.New(underlyingSource)

downloadFile.Invoke(peer, fiName, r.ContentLength, readableStream)
}
}
Loading