Skip to content

Commit

Permalink
Add copy
Browse files Browse the repository at this point in the history
  • Loading branch information
d3witt committed Aug 24, 2024
1 parent 8ab1c78 commit 4889a20
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 4 deletions.
179 changes: 179 additions & 0 deletions archive/tar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package archive

import (
"archive/tar"
"fmt"
"io"
"os"
"path/filepath"

"github.com/d3witt/viking/sshexec"
)

func Tar(source string) (io.Reader, error) {
pr, pw := io.Pipe()
tw := tar.NewWriter(pw)

go func() {
defer func() {
if err := tw.Close(); err != nil {
fmt.Println("Error closing tar writer:", err)
}
pw.Close() // Close the pipe writer when done
}()

fi, err := os.Stat(source)
if err != nil {
pw.CloseWithError(err)
return
}

if fi.IsDir() {
err = filepath.Walk(source, func(filePath string, fi os.FileInfo, err error) error {
if err != nil {
return err
}

// Construct the header
relPath, err := filepath.Rel(source, filePath)
if err != nil {
return err
}
header, err := tar.FileInfoHeader(fi, "")
if err != nil {
return err
}

// Use relative path to avoid including the entire source directory structure
header.Name = relPath

// Write the header
if err := tw.WriteHeader(header); err != nil {
return err
}

// If it's a regular file, write its content to the tar writer
if fi.Mode().IsRegular() {
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()

if _, err := io.Copy(tw, file); err != nil {
return err
}
}

return nil
})
} else {
// Handle the case where source is a single file
header, err := tar.FileInfoHeader(fi, "")
if err != nil {
pw.CloseWithError(err)
return
}

// Only set the base name for a single file
header.Name = filepath.Base(source)

// Write the header
if err := tw.WriteHeader(header); err != nil {
pw.CloseWithError(err)
return
}

// Write the file content
file, err := os.Open(source)
if err != nil {
pw.CloseWithError(err)
return
}
defer file.Close()

if _, err := io.Copy(tw, file); err != nil {
pw.CloseWithError(err)
return
}
}

if err != nil {
pw.CloseWithError(err) // Close the pipe with an error if it occurs
}
}()

return pr, nil
}

func Untar(r io.Reader, dest string) error {
tr := tar.NewReader(r)

for {
header, err := tr.Next()
if err == io.EOF {
break // End of tar archive
}
if err != nil {
return err
}

// Create the file or directory
target := filepath.Join(dest, header.Name)
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
return err
}
case tar.TypeReg:
dir := filepath.Dir(target)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
file, err := os.OpenFile(target, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return err
}
defer file.Close()
if _, err := io.Copy(file, tr); err != nil {
return err
}
default:
continue
}
}

return nil
}

// TarRemote creates a tar archive for the given file/directory on the remote server.
func TarRemote(exec sshexec.Executor, source string) (io.Reader, error) {
outPipe, inPipe := io.Pipe()

go func() {
defer inPipe.Close()
cmd := sshexec.Command(exec, "tar", "-cf", "-", "-C", source, ".")
cmd.Stdout = inPipe
if err := cmd.Run(); err != nil {
inPipe.CloseWithError(err)
}
}()

return outPipe, nil
}

func UntarRemote(exec sshexec.Executor, dest string, in io.Reader) error {
folderPath := filepath.Dir(dest)

// Ensure the destination directory exists
cmd := sshexec.Command(exec, "mkdir", "-p", folderPath)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}

// Untar the contents to the destination directory, replacing existing files
cmd = sshexec.Command(exec, "tar", "--overwrite", "-xf", "-", "-C", folderPath)
cmd.Stdin = in

return cmd.Run()
}
181 changes: 181 additions & 0 deletions cli/command/machine/copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package machine

import (
"bytes"
"fmt"
"io"
"os"
"path"
"strings"
"sync"

"github.com/d3witt/viking/archive"
"github.com/d3witt/viking/cli/command"
"github.com/d3witt/viking/sshexec"
"github.com/urfave/cli/v2"
)

func NewCopyCmd(vikingCli *command.Cli) *cli.Command {
return &cli.Command{
Name: "copy",
Usage: "Copy files/folders between local and remote machine",
Args: true,
ArgsUsage: "MACHINE:SRC_PATH DEST_PATH | SRC_PATH MACHINE:DEST_PATH",
Action: func(ctx *cli.Context) error {
if ctx.NArg() != 2 {
return fmt.Errorf("expected 2 arguments, got %d", ctx.NArg())
}

return runCopy(vikingCli, ctx.Args().Get(0), ctx.Args().Get(1))
},
}
}

func parseMachinePath(fullPath string) (machine, path string) {
if strings.Contains(fullPath, ":") {
parts := strings.SplitN(fullPath, ":", 2)
return parts[0], parts[1]
}

return "", fullPath
}

func runCopy(vikingCli *command.Cli, from, to string) error {
fromMachine, fromPath := parseMachinePath(from)
toMachine, toPath := parseMachinePath(to)

if fromMachine == "" && toMachine == "" {
return fmt.Errorf("at least one path must contain machine name")
}

if fromMachine != "" && toMachine != "" {
return fmt.Errorf("cannot copy between two remote machines")
}

machine := fromMachine + toMachine

execs, err := vikingCli.MachineExecuters(machine)
defer func() {
for _, exec := range execs {
exec.Close()
}
}()

if err != nil {
return err
}

if fromMachine != "" {
return copyFromRemote(vikingCli, execs, fromPath, toPath)
}

return copyToRemote(vikingCli, execs, fromPath, toPath)
}

func copyToRemote(vikingCli *command.Cli, execs []sshexec.Executor, from, to string) error {
fmt.Fprintln(vikingCli.Out, "Archiving files...")

data, err := archive.Tar(from)
if err != nil {
return err
}

// Create a temporary file to store the tar archive
tmpFile, err := os.CreateTemp("", "archive-*.tar")
if err != nil {
return err
}
defer os.Remove(tmpFile.Name())

// Write the tar archive to the temporary file
if _, err := io.Copy(tmpFile, data); err != nil {
return err
}

// Close the temporary file to flush the data
if err := tmpFile.Close(); err != nil {
return err
}

var wg sync.WaitGroup
wg.Add(len(execs))

for _, exec := range execs {
go func(exec sshexec.Executor) {
defer wg.Done()

out := vikingCli.Out
errOut := vikingCli.Err
if len(execs) > 1 {
prefix := fmt.Sprintf("%s: ", exec.Addr())
out = out.WithPrefix(prefix)
errOut = errOut.WithPrefix(prefix + "error: ")
}

fmt.Fprintf(out, "Copying to %s...\n", exec.Addr()+":"+to)

// Open the temporary file for reading
tmpFile, err := os.Open(tmpFile.Name())
if err != nil {
fmt.Fprintln(errOut, err)
return
}
defer tmpFile.Close()

if err := archive.UntarRemote(exec, to, tmpFile); err != nil {
fmt.Fprintln(errOut, err)
return
}

fmt.Fprintf(out, "Successfully copied to %s\n", exec.Addr()+":"+to)
}(exec)
}

wg.Wait()

return nil
}

func copyFromRemote(vikingCli *command.Cli, execs []sshexec.Executor, from, to string) error {
var wg sync.WaitGroup
wg.Add(len(execs))

for _, exec := range execs {
go func(exec sshexec.Executor) {
defer wg.Done()

dest := to

out := vikingCli.Out
errOut := vikingCli.Err
if len(execs) > 1 {
dest = path.Join(to, exec.Addr())

prefix := fmt.Sprintf("%s: ", exec.Addr())
out = out.WithPrefix(prefix)
errOut = errOut.WithPrefix(prefix + "error: ")
}

fmt.Fprintf(out, "Copying from %s...\n", from)
data, err := archive.TarRemote(exec, from)
if err != nil {
fmt.Fprintln(errOut, err)
return
}

buf := new(bytes.Buffer)
buf.ReadFrom(data)

if err := archive.Untar(buf, dest); err != nil {
fmt.Fprintln(errOut, err)

return
}

fmt.Fprintf(out, "Successfully copied to %s\n", dest)
}(exec)
}

wg.Wait()
return nil
}
1 change: 1 addition & 0 deletions cli/command/machine/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func NewCmd(vikingCli *command.Cli) *cli.Command {
NewListCmd(vikingCli),
NewRmCmd(vikingCli),
NewExecuteCmd(vikingCli),
NewCopyCmd(vikingCli),
},
}
}
4 changes: 2 additions & 2 deletions sshexec/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *Cmd) Run() error {
}

if err := c.Wait(); err != nil {
return err
return fmt.Errorf("%w.\n%s", err, b.String())
}

return nil
Expand Down Expand Up @@ -90,7 +90,7 @@ func (c *Cmd) CombinedOutput() (string, error) {
if c.Stdout != nil {
return "", errors.New("stdout already set")
}
if c.Stdin != nil {
if c.Stderr != nil {
return "", errors.New("stderr already set")
}

Expand Down
Loading

0 comments on commit 4889a20

Please sign in to comment.