Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use concurrency for grep and replace commands #80

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ compile-releases: clean ## Compile vsh binaries for multiple platforms and archi
compile: ## Compile vsh for platform based on uname
go build -ldflags "-X main.vshVersion=$(VERSION)" -o build/${APP_NAME}_$(shell uname | tr '[:upper:]' '[:lower:]')_$(ARCH)

compile-debug: clean
go build -ldflags "-X main.vshVersion=$(VERSION)" -o build/${APP_NAME}_$(shell uname | tr '[:upper:]' '[:lower:]')_amd64 -gcflags="all=-N -l"

get-bats: ## Download bats dependencies to test directory
rm -rf test/bin/
mkdir -p test/bin/core
Expand Down
4 changes: 2 additions & 2 deletions cli/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (cmd *AppendCommand) createDummySecret(target string) error {

dummy := make(map[string]interface{})
dummy["placeholder"] = struct{}{}
dummySecret := client.NewSecret(&api.Secret{Data: dummy})
dummySecret := client.NewSecret(&api.Secret{Data: dummy}, target)
if targetSecret == nil {
if err = cmd.client.Write(target, dummySecret); err != nil {
return err
Expand Down Expand Up @@ -150,7 +150,7 @@ func (cmd *AppendCommand) mergeSecrets(source string, target string) error {
}

// write
resultSecret := client.NewSecret(&api.Secret{Data: merged})
resultSecret := client.NewSecret(&api.Secret{Data: merged}, target)
if err := cmd.client.Write(target, resultSecret); err != nil {
fmt.Println(err)
return err
Expand Down
31 changes: 31 additions & 0 deletions cli/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package cli

import (
"path/filepath"
"sort"
"strings"
"sync"

"github.com/fatih/structs"
"github.com/fishi0x01/vsh/client"
Expand Down Expand Up @@ -103,3 +105,32 @@ func transportSecrets(c *client.Client, source string, target string, transport

return 0
}

func funcOnPaths(c *client.Client, paths []string, f func(s *client.Secret) (matches []*Match)) (matches []*Match, err error) {
secrets, err := c.BatchRead(c.FilterPaths(paths, client.LEAF))
if err != nil {
return nil, err
}

var wg sync.WaitGroup
queue := make(chan *client.Secret, len(paths))
recv := make(chan []*Match, len(paths))
for _, secret := range secrets {
queue <- secret
}
for range secrets {
wg.Add(1)
go func() {
recv <- f(<-queue)
wg.Done()
}()
}
wg.Wait()
close(recv)

for m := range recv {
matches = append(matches, m...)
}
sort.Slice(matches, func(i, j int) bool { return matches[i].path < matches[j].path })
return matches, nil
}
31 changes: 6 additions & 25 deletions cli/grep.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,12 @@ func (cmd *GrepCommand) Run() int {
return 1
}

for _, curPath := range filePaths {
matches, err := cmd.grepFile(cmd.args.Search, curPath)
if err != nil {
return 1
}
for _, match := range matches {
match.print(os.Stdout, MatchOutputHighlight)
}
matches, err := cmd.searcher.grepPaths(cmd.client, cmd.args.Search, filePaths)
if err != nil {
return 1
}
for _, match := range matches {
match.print(os.Stdout, MatchOutputHighlight)
}
return 0
}
Expand All @@ -119,20 +117,3 @@ func (cmd *GrepCommand) GetSearchParams() SearchParameters {
IsRegexp: cmd.args.Regexp,
}
}

func (cmd *GrepCommand) grepFile(search string, path string) (matches []*Match, err error) {
matches = []*Match{}

if cmd.client.GetType(path) == client.LEAF {
secret, err := cmd.client.Read(path)
if err != nil {
return matches, err
}

for k, v := range secret.GetData() {
matches = append(matches, cmd.searcher.DoSearch(path, k, fmt.Sprintf("%v", v))...)
}
}

return matches, nil
}
74 changes: 33 additions & 41 deletions cli/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,36 @@ func (cmd *ReplaceCommand) Run() int {
return 1
}

allMatches, err := cmd.findMatches(filePaths)
allMatches, err := cmd.FindMatches(filePaths)
if err != nil {
log.UserError(fmt.Sprintf("%s", err))
return 1
}
return cmd.commitMatches(allMatches)
}

func (cmd *ReplaceCommand) findMatches(filePaths []string) (matchesByPath map[string][]*Match, err error) {
// FindMatches will return a map of files sorted by path in which the search occurs
func (cmd *ReplaceCommand) FindMatches(filePaths []string) (matchesByPath map[string][]*Match, err error) {
matches, err := cmd.searcher.grepPaths(cmd.client, cmd.args.Search, filePaths)
if err != nil {
return matchesByPath, err
}
for _, match := range matches {
match.print(os.Stdout, cmd.args.Output.Value)
}
return cmd.groupMatchesByPath(matches), nil
}

func (cmd *ReplaceCommand) groupMatchesByPath(matches []*Match) (matchesByPath map[string][]*Match) {
matchesByPath = make(map[string][]*Match, 0)
for _, curPath := range filePaths {
matches, err := cmd.FindReplacements(cmd.args.Search, cmd.args.Replacement, curPath)
if err != nil {
return matchesByPath, err
}
for _, match := range matches {
match.print(os.Stdout, cmd.args.Output.Value)
}
if len(matches) > 0 {
_, ok := matchesByPath[curPath]
if ok == false {
matchesByPath[curPath] = make([]*Match, 0)
}
matchesByPath[curPath] = append(matchesByPath[curPath], matches...)
for _, m := range matches {
_, ok := matchesByPath[m.path]
if ok == false {
matchesByPath[m.path] = make([]*Match, 0)
}
matchesByPath[m.path] = append(matchesByPath[m.path], matches...)
}
return matchesByPath, nil
return matchesByPath
}

func (cmd *ReplaceCommand) commitMatches(matchesByPath map[string][]*Match) int {
Expand All @@ -171,34 +174,23 @@ func (cmd *ReplaceCommand) commitMatches(matchesByPath map[string][]*Match) int
return 0
}

// FindReplacements will find the matches for a given search string to be replaced
func (cmd *ReplaceCommand) FindReplacements(search string, replacement string, path string) (matches []*Match, err error) {
if cmd.client.GetType(path) == client.LEAF {
secret, err := cmd.client.Read(path)
if err != nil {
return matches, err
}

for k, v := range secret.GetData() {
match := cmd.searcher.DoSearch(path, k, fmt.Sprintf("%v", v))
matches = append(matches, match...)
}
}
return matches, nil
}

// WriteReplacements will write replacement data back to Vault
func (cmd *ReplaceCommand) WriteReplacements(groupedMatches map[string][]*Match) error {
// process matches by vault path
for path, matches := range groupedMatches {
secret, err := cmd.client.Read(path)
if err != nil {
return err
}
data := secret.GetData()
// Re-read paths because they could've gone stale
paths := make([]string, 0)
for path := range groupedMatches {
paths = append(paths, path)
}
secrets, err := cmd.client.BatchRead(paths)
if err != nil {
return err
}

for _, secret := range secrets {
data, path := secret.GetData(), secret.Path

// update secret with changes. remove key w/ prior names, add renamed keys, update values.
for _, match := range matches {
for _, match := range groupedMatches[path] {
if path != match.path {
return fmt.Errorf("match path does not equal group path")
}
Expand Down
10 changes: 10 additions & 0 deletions cli/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/andreyvit/diff"
"github.com/fatih/color"
"github.com/fishi0x01/vsh/client"
)

// SearchingCommand interface to describe a command that performs a search operation
Expand Down Expand Up @@ -258,3 +259,12 @@ func (s *Searcher) matchData(subject string) (matchPairs [][]int, replaced strin

return matchPairs, replaced
}

func (s *Searcher) grepPaths(c *client.Client, search string, paths []string) (matches []*Match, err error) {
return funcOnPaths(c, paths, func(secret *client.Secret) []*Match {
for k, v := range secret.GetData() {
matches = append(matches, s.DoSearch(secret.Path, k, fmt.Sprintf("%v", v))...)
}
return matches
})
}
95 changes: 95 additions & 0 deletions client/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package client

import (
"fmt"
)

// BatchOperation is a kind of operation to perform
type BatchOperation int

// types of operations
const (
OpRead BatchOperation = 0
OpWrite BatchOperation = 1
)

// how many worker threads to use for batch operations
const (
VaultConcurency = 5
)

// BatchOperation can perform reads or writes with concurrency
func (client *Client) BatchOperation(absolutePaths []string, op BatchOperation, secretsIn []*Secret) (secrets []*Secret, err error) {
readQueue := make(chan string, len(absolutePaths))
writeQueue := make(chan *Secret, len(absolutePaths))
results := make(chan *secretOperation, len(absolutePaths))

// load up queue for operation
switch op {
case OpRead:
for _, path := range absolutePaths {
readQueue <- path
}
case OpWrite:
for _, secret := range secretsIn {
writeQueue <- secret
}
default:
return nil, fmt.Errorf("invalid batch operation")
}

// fire off goroutines for operation
for i := 0; i < VaultConcurency; i++ {
client.waitGroup.Add(1)
switch op {
case OpRead:
go client.readWorker(readQueue, results)
case OpWrite:
go client.writeWorker(writeQueue, results)
}
}
client.waitGroup.Wait()
close(results)

// read results from the queue and return as array
for result := range results {
err = result.Error
if err != nil {
return secrets, err
}
if result.Result != nil {
secrets = append(secrets, result.Result)
}
}
return secrets, nil
}

// readWorker fetches paths to be read from the queue until empty
func (client *Client) readWorker(queue chan string, out chan *secretOperation) {
defer client.waitGroup.Done()
readFromQueue:
for {
select {
case path := <-queue:
s, err := client.Read(path)
out <- &secretOperation{Result: s, Path: path, Error: err}
default:
break readFromQueue
}
}
}

// writeWorker writes secrets to Vault in parallel
func (client *Client) writeWorker(queue chan *Secret, out chan *secretOperation) {
defer client.waitGroup.Done()
readFromQueue:
for {
select {
case secret := <-queue:
err := client.Write(secret.Path, secret)
out <- &secretOperation{Result: nil, Path: secret.Path, Error: err}
default:
break readFromQueue
}
}
}
21 changes: 20 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"

"github.com/fishi0x01/vsh/log"
"github.com/hashicorp/vault/api"
Expand All @@ -18,6 +19,7 @@ type Client struct {
Pwd string
KVBackends map[string]int
cache *Cache
waitGroup sync.WaitGroup
}

// VaultConfig container to keep parameters for Client configuration
Expand All @@ -27,6 +29,12 @@ type VaultConfig struct {
StartPath string
}

type secretOperation struct {
Result *Secret
Path string
Error error
}

func verifyClientPwd(client *Client) (*Client, error) {
if client.Pwd == "" {
client.Pwd = "/"
Expand Down Expand Up @@ -105,11 +113,16 @@ func (client *Client) Read(absolutePath string) (secret *Secret, err error) {
apiSecret, err = client.lowLevelRead(normalizedVaultPath(absolutePath))
}
if apiSecret != nil {
secret = NewSecret(apiSecret)
secret = NewSecret(apiSecret, absolutePath)
}
return secret, err
}

// BatchRead returns secrets for given paths
func (client *Client) BatchRead(absolutePaths []string) (secrets []*Secret, err error) {
return client.BatchOperation(absolutePaths, OpRead, make([]*Secret, 0))
}

// Write writes secret to given path, using given Client
func (client *Client) Write(absolutePath string, secret *Secret) (err error) {
if client.isTopLevelPath(absolutePath) {
Expand All @@ -121,6 +134,12 @@ func (client *Client) Write(absolutePath string, secret *Secret) (err error) {
return err
}

// BatchWrite writes provided secrets to Vault
func (client *Client) BatchWrite(absolutePaths []string, secrets []*Secret) (err error) {
_, err = client.BatchOperation(absolutePaths, OpWrite, secrets)
return err
}

// Delete deletes secret at given absolutePath, using given client
func (client *Client) Delete(absolutePath string) (err error) {
if client.isTopLevelPath(absolutePath) {
Expand Down
Loading