From ba020b214f8fa41f4278621b4c5f1c3360dd4d1f Mon Sep 17 00:00:00 2001 From: Ryan Petrich Date: Sat, 29 May 2021 12:23:36 -0400 Subject: [PATCH] Cleanup volume when mounting fails --- main.go | 105 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 54 insertions(+), 51 deletions(-) diff --git a/main.go b/main.go index 746962f..761f9fa 100644 --- a/main.go +++ b/main.go @@ -341,64 +341,67 @@ wait_for_volume_completion: // mount the device _ = os.Mkdir("./snapshot", 0644) err = syscall.Mount(newestDevice, "./snapshot", "ext4", syscall.MS_RDONLY|syscall.MS_NOEXEC, "") + var errorToReturn error if err != nil { - return fmt.Errorf("failed to mount device: %v", err) - } - log.Printf("mounted successfully, scanning") - // set up the background scanners - var wg sync.WaitGroup - pathsToScan := make(chan string, 1024) - for i := 0; i < 64; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for recursivePath := range pathsToScan { - // Actually scan the file - var m yara.MatchRules - if err := rules.ScanFile(recursivePath, 0, 0, &m); err != nil { - errorText := err.Error() - // Ignore bad symlinks and unlink race conditions. Have to - // compare by string since go-yara doesn't use structured error - // types :( - if errorText != "could not open file" && errorText != "could not map file" { - log.Printf("could not scan file in volume %v at path %q: %v", volumeInfo, strings.TrimPrefix(recursivePath, "snapshot"), err) - } - } else { - // If we have matches, dispatch an alert - if len(m) != 0 { - for _, match := range m { - log.Printf("file in volume %v at path %q violated rule %q from %q", volumeInfo, strings.TrimPrefix(recursivePath, "snapshot"), match.Rule, match.Namespace) + errorToReturn = fmt.Errorf("failed to mount device: %v", err) + log.Printf("failed to mount, detaching") + } else { + log.Printf("mounted successfully, scanning") + // set up the background scanners + var wg sync.WaitGroup + pathsToScan := make(chan string, 1024) + for i := 0; i < 64; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for recursivePath := range pathsToScan { + // Actually scan the file + var m yara.MatchRules + if err := rules.ScanFile(recursivePath, 0, 0, &m); err != nil { + errorText := err.Error() + // Ignore bad symlinks and unlink race conditions. Have to + // compare by string since go-yara doesn't use structured error + // types :( + if errorText != "could not open file" && errorText != "could not map file" { + log.Printf("could not scan file in volume %v at path %q: %v", volumeInfo, strings.TrimPrefix(recursivePath, "snapshot"), err) + } + } else { + // If we have matches, dispatch an alert + if len(m) != 0 { + for _, match := range m { + log.Printf("file in volume %v at path %q violated rule %q from %q", volumeInfo, strings.TrimPrefix(recursivePath, "snapshot"), match.Rule, match.Namespace) + } } } } - } - }() - } - // search for files - err = filepath.Walk("./snapshot", func(recursivePath string, info os.FileInfo, err error) error { - if err != nil { - return nil + }() } - // Precheck to make sure we're dealing with an actual file - if info.IsDir() || info.Size() == 0 { - return nil - } - if info.Mode()&(os.ModeDevice|os.ModeNamedPipe|os.ModeSocket|os.ModeCharDevice|os.ModeSymlink) != 0 { + // search for files + err = filepath.Walk("./snapshot", func(recursivePath string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + // Precheck to make sure we're dealing with an actual file + if info.IsDir() || info.Size() == 0 { + return nil + } + if info.Mode()&(os.ModeDevice|os.ModeNamedPipe|os.ModeSocket|os.ModeCharDevice|os.ModeSymlink) != 0 { + return nil + } + pathsToScan <- recursivePath return nil + }) + if err != nil { + log.Printf("error scanning: %v", err) } - pathsToScan <- recursivePath - return nil - }) - if err != nil { - log.Printf("error scanning: %v", err) + close(pathsToScan) + wg.Wait() + log.Printf("finished scanning, unmounting") + // unmount the volume + syscall.Unmount("./snapshot", 0) + time.Sleep(1 * time.Second) + log.Printf("finished unmounting, detaching") } - close(pathsToScan) - wg.Wait() - log.Printf("finished scanning, unmounting") - // unmount the volume - syscall.Unmount("./snapshot", 0) - time.Sleep(1 * time.Second) - log.Printf("finished unmounting, detaching") // detach the snapshot volume _, err = client.DetachVolume(ctx, &ec2.DetachVolumeInput{ InstanceId: &instanceId, @@ -452,7 +455,7 @@ wait_for_volume_detachment: if err != nil { return fmt.Errorf("delete volume request failed: %v", err) } - return nil + return errorToReturn } func main() {