From dc69af9f3887141b15ea893b0cb97746450f94b7 Mon Sep 17 00:00:00 2001 From: Damian Bednarczyk Date: Wed, 22 Nov 2023 21:26:44 -0600 Subject: [PATCH 1/2] idiomatize (?) ruleset package and lint --- Makefile | 2 +- cmd/main.go | 19 +++++++++-- handlers/cli/cli.go | 64 +++++++++++++++++++++---------------- handlers/proxy.go | 11 +++---- handlers/proxy.test.go | 3 +- pkg/ruleset/ruleset.go | 52 ++++++++++++++++++++++-------- pkg/ruleset/ruleset_test.go | 30 ++++++++++++++--- 7 files changed, 124 insertions(+), 57 deletions(-) diff --git a/Makefile b/Makefile index fdce449..98f3097 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ lint: gofumpt -l -w . - golangci-lint run -c .golangci-lint.yaml + golangci-lint run -c .golangci-lint.yaml --fix go mod tidy go clean diff --git a/cmd/main.go b/cmd/main.go index e9ead89..8ed3b1a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,6 +29,7 @@ func main() { if os.Getenv("PORT") == "" { portEnv = "8080" } + port := parser.String("p", "port", &argparse.Options{ Required: false, Default: portEnv, @@ -49,10 +50,12 @@ func main() { Required: false, Help: "Compiles a directory of yaml files into a single ruleset.yaml. Requires --ruleset arg.", }) + mergeRulesetsGzip := parser.Flag("", "merge-rulesets-gzip", &argparse.Options{ Required: false, Help: "Compiles a directory of yaml files into a single ruleset.gz Requires --ruleset arg.", }) + mergeRulesetsOutput := parser.String("", "merge-rulesets-output", &argparse.Options{ Required: false, Help: "Specify output file for --merge-rulesets and --merge-rulesets-gzip. Requires --ruleset and --merge-rulesets args.", @@ -65,7 +68,13 @@ func main() { // utility cli flag to compile ruleset directory into single ruleset.yaml if *mergeRulesets || *mergeRulesetsGzip { - err = cli.HandleRulesetMerge(ruleset, mergeRulesets, mergeRulesetsGzip, mergeRulesetsOutput) + output, err := os.Create(*mergeRulesetsOutput) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = cli.HandleRulesetMerge(*ruleset, *mergeRulesets, *mergeRulesetsGzip, output) if err != nil { fmt.Println(err) os.Exit(1) @@ -87,6 +96,7 @@ func main() { userpass := os.Getenv("USERPASS") if userpass != "" { userpass := strings.Split(userpass, ":") + app.Use(basicauth.New(basicauth.Config{ Users: map[string]string{ userpass[0]: userpass[1], @@ -102,23 +112,28 @@ func main() { if os.Getenv("NOLOGS") != "true" { app.Use(func(c *fiber.Ctx) error { log.Println(c.Method(), c.Path()) + return c.Next() }) } app.Get("/", handlers.Form) + app.Get("/styles.css", func(c *fiber.Ctx) error { cssData, err := cssData.ReadFile("styles.css") if err != nil { return c.Status(fiber.StatusInternalServerError).SendString("Internal Server Error") } + c.Set("Content-Type", "text/css") + return c.Send(cssData) }) - app.Get("ruleset", handlers.Ruleset) + app.Get("ruleset", handlers.Ruleset) app.Get("raw/*", handlers.Raw) app.Get("api/*", handlers.Api) app.Get("/*", handlers.ProxySite(*ruleset)) + log.Fatal(app.Listen(":" + *port)) } diff --git a/handlers/cli/cli.go b/handlers/cli/cli.go index ab355d7..ca7b7e0 100644 --- a/handlers/cli/cli.go +++ b/handlers/cli/cli.go @@ -3,10 +3,10 @@ package cli import ( "fmt" "io" - "io/fs" - "ladder/pkg/ruleset" "os" + "ladder/pkg/ruleset" + "golang.org/x/term" ) @@ -14,32 +14,38 @@ import ( // Exits the program with an error message if the ruleset path is not provided or if loading the ruleset fails. // // Parameters: -// - rulesetPath: A pointer to a string specifying the path to the ruleset file. -// - mergeRulesets: A pointer to a boolean indicating if a merge operation should be performed. -// - mergeRulesetsGzip: A pointer to a boolean indicating if the merge should be in Gzip format. -// - mergeRulesetsOutput: A pointer to a string specifying the output file path. If empty, the output is printed to stdout. +// - rulesetPath: Specifies the path to the ruleset file. +// - mergeRulesets: Indicates if a merge operation should be performed. +// - useGzip: Indicates if the merged rulesets should be gzip-ped. +// - output: Specifies the output file. If nil, stdout will be used. // // Returns: // - An error if the ruleset loading or merging process fails, otherwise nil. -func HandleRulesetMerge(rulesetPath *string, mergeRulesets *bool, mergeRulesetsGzip *bool, mergeRulesetsOutput *string) error { - if *rulesetPath == "" { - *rulesetPath = os.Getenv("RULESET") +func HandleRulesetMerge(rulesetPath string, mergeRulesets bool, useGzip bool, output *os.File) error { + if !mergeRulesets { + return nil } - if *rulesetPath == "" { - fmt.Println("ERROR: no ruleset provided. Try again with --ruleset ") + + if rulesetPath == "" { + rulesetPath = os.Getenv("RULESET") + } + + if rulesetPath == "" { + fmt.Println("error: no ruleset provided. Try again with --ruleset ") os.Exit(1) } - rs, err := ruleset.NewRuleset(*rulesetPath) + rs, err := ruleset.NewRuleset(rulesetPath) if err != nil { fmt.Println(err) os.Exit(1) } - if *mergeRulesetsGzip { - return gzipMerge(rs, mergeRulesetsOutput) + if useGzip { + return gzipMerge(rs, output) } - return yamlMerge(rs, mergeRulesetsOutput) + + return yamlMerge(rs, output) } // gzipMerge takes a RuleSet and an optional output file path pointer. It compresses the RuleSet into Gzip format. @@ -48,33 +54,33 @@ func HandleRulesetMerge(rulesetPath *string, mergeRulesets *bool, mergeRulesetsG // // Parameters: // - rs: The ruleset.RuleSet to be compressed. -// - mergeRulesetsOutput: A pointer to a string specifying the output file path. If empty, the output is directed to stdout. +// - output: The output for the gzip data. If nil, stdout will be used. // // Returns: // - An error if compression or file writing fails, otherwise nil. -func gzipMerge(rs ruleset.RuleSet, mergeRulesetsOutput *string) error { +func gzipMerge(rs ruleset.RuleSet, output io.Writer) error { gzip, err := rs.GzipYaml() if err != nil { return err } - if *mergeRulesetsOutput != "" { - out, err := os.Create(*mergeRulesetsOutput) - defer out.Close() - _, err = io.Copy(out, gzip) + if output != nil { + _, err = io.Copy(output, gzip) if err != nil { return err } } if term.IsTerminal(int(os.Stdout.Fd())) { - println("WARNING: binary output can mess up your terminal. Use '--merge-rulesets-output ' or pipe it to a file.") + println("warning: binary output can mess up your terminal. Use '--merge-rulesets-output ' or pipe it to a file.") os.Exit(1) } + _, err = io.Copy(os.Stdout, gzip) if err != nil { return err } + return nil } @@ -83,23 +89,25 @@ func gzipMerge(rs ruleset.RuleSet, mergeRulesetsOutput *string) error { // // Parameters: // - rs: The ruleset.RuleSet to be converted to YAML. -// - mergeRulesetsOutput: A pointer to a string specifying the output file path. If empty, the output is printed to stdout. +// - output: The output for the merged data. If nil, stdout will be used. // // Returns: // - An error if YAML conversion or file writing fails, otherwise nil. -func yamlMerge(rs ruleset.RuleSet, mergeRulesetsOutput *string) error { +func yamlMerge(rs ruleset.RuleSet, output io.Writer) error { yaml, err := rs.Yaml() if err != nil { return err } - if *mergeRulesetsOutput == "" { - fmt.Printf(yaml) + + if output == nil { + fmt.Println(yaml) os.Exit(0) } - err = os.WriteFile(*mergeRulesetsOutput, []byte(yaml), fs.FileMode(os.O_RDWR)) + _, err = io.WriteString(output, yaml) if err != nil { - return fmt.Errorf("ERROR: failed to write merged YAML ruleset to '%s'\n", *mergeRulesetsOutput) + return fmt.Errorf("failed to write merged YAML ruleset: %v", err) } + return nil } diff --git a/handlers/proxy.go b/handlers/proxy.go index 128314e..bedfbde 100644 --- a/handlers/proxy.go +++ b/handlers/proxy.go @@ -80,7 +80,6 @@ func extractUrl(c *fiber.Ctx) (string, error) { // default behavior: // eg: https://localhost:8080/https://realsite.com/images/foobar.jpg -> https://realsite.com/images/foobar.jpg return urlQuery.String(), nil - } func ProxySite(rulesetPath string) fiber.Handler { @@ -121,18 +120,18 @@ func modifyURL(uri string, rule ruleset.Rule) (string, error) { return "", err } - for _, urlMod := range rule.UrlMods.Domain { + for _, urlMod := range rule.URLMods.Domain { re := regexp.MustCompile(urlMod.Match) newUrl.Host = re.ReplaceAllString(newUrl.Host, urlMod.Replace) } - for _, urlMod := range rule.UrlMods.Path { + for _, urlMod := range rule.URLMods.Path { re := regexp.MustCompile(urlMod.Match) newUrl.Path = re.ReplaceAllString(newUrl.Path, urlMod.Replace) } v := newUrl.Query() - for _, query := range rule.UrlMods.Query { + for _, query := range rule.URLMods.Query { if query.Value == "" { v.Del(query.Key) continue @@ -223,11 +222,11 @@ func fetchSite(urlpath string, queries map[string]string) (string, *http.Request } if rule.Headers.CSP != "" { - //log.Println(rule.Headers.CSP) + // log.Println(rule.Headers.CSP) resp.Header.Set("Content-Security-Policy", rule.Headers.CSP) } - //log.Print("rule", rule) TODO: Add a debug mode to print the rule + // log.Print("rule", rule) TODO: Add a debug mode to print the rule body := rewriteHtml(bodyB, u, rule) return body, req, resp, nil } diff --git a/handlers/proxy.test.go b/handlers/proxy.test.go index 07f72bd..0ed2c4b 100644 --- a/handlers/proxy.test.go +++ b/handlers/proxy.test.go @@ -2,12 +2,13 @@ package handlers import ( - "ladder/pkg/ruleset" "net/http" "net/http/httptest" "net/url" "testing" + "ladder/pkg/ruleset" + "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" ) diff --git a/pkg/ruleset/ruleset.go b/pkg/ruleset/ruleset.go index 1c289e0..7f3079f 100644 --- a/pkg/ruleset/ruleset.go +++ b/pkg/ruleset/ruleset.go @@ -1,6 +1,7 @@ package ruleset import ( + "compress/gzip" "errors" "fmt" "io" @@ -11,8 +12,6 @@ import ( "regexp" "strings" - "compress/gzip" - "gopkg.in/yaml.v3" ) @@ -41,7 +40,7 @@ type Rule struct { GoogleCache bool `yaml:"googleCache,omitempty"` RegexRules []Regex `yaml:"regexRules,omitempty"` - UrlMods struct { + URLMods struct { Domain []Regex `yaml:"domain,omitempty"` Path []Regex `yaml:"path,omitempty"` Query []KV `yaml:"query,omitempty"` @@ -55,6 +54,8 @@ type Rule struct { } `yaml:"injections,omitempty"` } +var remoteRegex = regexp.MustCompile(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`) + // NewRulesetFromEnv creates a new RuleSet based on the RULESET environment variable. // It logs a warning and returns an empty RuleSet if the RULESET environment variable is not set. // If the RULESET is set but the rules cannot be loaded, it panics. @@ -64,10 +65,12 @@ func NewRulesetFromEnv() RuleSet { log.Printf("WARN: No ruleset specified. Set the `RULESET` environment variable to load one for a better success rate.") return RuleSet{} } + ruleSet, err := NewRuleset(rulesPath) if err != nil { log.Println(err) } + return ruleSet } @@ -75,16 +78,17 @@ func NewRulesetFromEnv() RuleSet { // It supports loading rules from both local file paths and remote URLs. // Returns a RuleSet and an error if any issues occur during loading. func NewRuleset(rulePaths string) (RuleSet, error) { - ruleSet := RuleSet{} - errs := []error{} + var ruleSet RuleSet + + var errs []error rp := strings.Split(rulePaths, ";") - var remoteRegex = regexp.MustCompile(`^https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()!@:%_\+.~#?&\/\/=]*)`) for _, rule := range rp { - rulePath := strings.Trim(rule, " ") var err error + rulePath := strings.Trim(rule, " ") isRemote := remoteRegex.MatchString(rulePath) + if isRemote { err = ruleSet.loadRulesFromRemoteFile(rulePath) } else { @@ -94,6 +98,7 @@ func NewRuleset(rulePaths string) (RuleSet, error) { if err != nil { e := fmt.Errorf("WARN: failed to load ruleset from '%s'", rulePath) errs = append(errs, errors.Join(e, err)) + continue } } @@ -101,6 +106,7 @@ func NewRuleset(rulePaths string) (RuleSet, error) { if len(errs) != 0 { e := fmt.Errorf("WARN: failed to load %d rulesets", len(rp)) errs = append(errs, e) + // panic if the user specified a local ruleset, but it wasn't found on disk // don't fail silently for _, err := range errs { @@ -109,10 +115,13 @@ func NewRuleset(rulePaths string) (RuleSet, error) { panic(errors.Join(e, err)) } } + // else, bubble up any errors, such as syntax or remote host issues return ruleSet, errors.Join(errs...) } + ruleSet.PrintStats() + return ruleSet, nil } @@ -146,13 +155,16 @@ func (rs *RuleSet) loadRulesFromLocalDir(path string) error { log.Printf("WARN: failed to load directory ruleset '%s': %s, skipping", path, err) return nil } + log.Printf("INFO: loaded ruleset %s\n", path) + return nil }) if err != nil { return err } + return nil } @@ -167,42 +179,51 @@ func (rs *RuleSet) loadRulesFromLocalFile(path string) error { var r RuleSet err = yaml.Unmarshal(yamlFile, &r) + if err != nil { e := fmt.Errorf("failed to load rules from local file, possible syntax error in '%s'", path) ee := errors.Join(e, err) + if _, ok := os.LookupEnv("DEBUG"); ok { debugPrintRule(string(yamlFile), ee) } + return ee } + *rs = append(*rs, r...) + return nil } // loadRulesFromRemoteFile loads rules from a remote URL. // It supports plain and gzip compressed content. // Returns an error if there's an issue accessing the URL or if there's a syntax error in the YAML. -func (rs *RuleSet) loadRulesFromRemoteFile(rulesUrl string) error { +func (rs *RuleSet) loadRulesFromRemoteFile(rulesURL string) error { var r RuleSet - resp, err := http.Get(rulesUrl) + + resp, err := http.Get(rulesURL) if err != nil { - e := fmt.Errorf("failed to load rules from remote url '%s'", rulesUrl) + e := fmt.Errorf("failed to load rules from remote url '%s'", rulesURL) return errors.Join(e, err) } + defer resp.Body.Close() if resp.StatusCode >= 400 { - e := fmt.Errorf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesUrl) + e := fmt.Errorf("failed to load rules from remote url (%s) on '%s'", resp.Status, rulesURL) return errors.Join(e, err) } var reader io.Reader - isGzip := strings.HasSuffix(rulesUrl, ".gz") || strings.HasSuffix(rulesUrl, ".gzip") || resp.Header.Get("content-encoding") == "gzip" + + isGzip := strings.HasSuffix(rulesURL, ".gz") || strings.HasSuffix(rulesURL, ".gzip") || resp.Header.Get("content-encoding") == "gzip" if isGzip { reader, err = gzip.NewReader(resp.Body) + if err != nil { - return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesUrl, resp.Status, err) + return fmt.Errorf("failed to create gzip reader for URL '%s' with status code '%s': %w", rulesURL, resp.Status, err) } } else { reader = resp.Body @@ -211,12 +232,14 @@ func (rs *RuleSet) loadRulesFromRemoteFile(rulesUrl string) error { err = yaml.NewDecoder(reader).Decode(&r) if err != nil { - e := fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesUrl, resp.Status) + e := fmt.Errorf("failed to load rules from remote url '%s' with status code '%s' and possible syntax error", rulesURL, resp.Status) ee := errors.Join(e, err) + return ee } *rs = append(*rs, r...) + return nil } @@ -228,6 +251,7 @@ func (rs *RuleSet) Yaml() (string, error) { if err != nil { return "", err } + return string(y), nil } diff --git a/pkg/ruleset/ruleset_test.go b/pkg/ruleset/ruleset_test.go index 85c6f33..42f115c 100644 --- a/pkg/ruleset/ruleset_test.go +++ b/pkg/ruleset/ruleset_test.go @@ -33,6 +33,7 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { c.SendString(validYAML) return nil }) + app.Get("/invalid-config.yml", func(c *fiber.Ctx) error { c.SendString(invalidYAML) return nil @@ -40,10 +41,12 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { app.Get("/valid-config.gz", func(c *fiber.Ctx) error { c.Set("Content-Type", "application/octet-stream") + rs, err := loadRuleFromString(validYAML) if err != nil { t.Errorf("failed to load valid yaml from string: %s", err.Error()) } + s, err := rs.GzipYaml() if err != nil { t.Errorf("failed to load gzip serialize yaml: %s", err.Error()) @@ -70,15 +73,18 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { if err != nil { t.Errorf("failed to load plaintext ruleset from http server: %s", err.Error()) } + assert.Equal(t, rs[0].Domain, "example.com") rs, err = NewRuleset("http://127.0.0.1:9999/valid-config.gz") if err != nil { t.Errorf("failed to load gzipped ruleset from http server: %s", err.Error()) } + assert.Equal(t, rs[0].Domain, "example.com") os.Setenv("RULESET", "http://127.0.0.1:9999/valid-config.gz") + rs = NewRulesetFromEnv() if !assert.Equal(t, rs[0].Domain, "example.com") { t.Error("expected no errors loading ruleset from gzip url using environment variable, but got one") @@ -88,10 +94,14 @@ func TestLoadRulesFromRemoteFile(t *testing.T) { func loadRuleFromString(yaml string) (RuleSet, error) { // Create a temporary file and load it tmpFile, _ := os.CreateTemp("", "ruleset*.yaml") + defer os.Remove(tmpFile.Name()) + tmpFile.WriteString(yaml) + rs := RuleSet{} err := rs.loadRulesFromLocalFile(tmpFile.Name()) + return rs, err } @@ -101,6 +111,7 @@ func TestLoadRulesFromLocalFile(t *testing.T) { if err != nil { t.Errorf("Failed to load rules from valid YAML: %s", err) } + assert.Equal(t, rs[0].Domain, "example.com") assert.Equal(t, rs[0].RegexRules[0].Match, "^http:") assert.Equal(t, rs[0].RegexRules[0].Replace, "https:") @@ -118,30 +129,39 @@ func TestLoadRulesFromLocalDir(t *testing.T) { if err != nil { t.Fatalf("Failed to create temporary directory: %s", err) } + defer os.RemoveAll(baseDir) // Create a nested subdirectory nestedDir := filepath.Join(baseDir, "nested") - err = os.Mkdir(nestedDir, 0755) + err = os.Mkdir(nestedDir, 0o755) + if err != nil { t.Fatalf("Failed to create nested directory: %s", err) } // Create a nested subdirectory nestedTwiceDir := filepath.Join(nestedDir, "nestedTwice") - err = os.Mkdir(nestedTwiceDir, 0755) + err = os.Mkdir(nestedTwiceDir, 0o755) + if err != nil { + t.Fatalf("Failed to create twice-nested directory: %s", err) + } testCases := []string{"test.yaml", "test2.yaml", "test-3.yaml", "test 4.yaml", "1987.test.yaml.yml", "foobar.example.com.yaml", "foobar.com.yml"} for _, fileName := range testCases { filePath := filepath.Join(nestedDir, "2x-"+fileName) - os.WriteFile(filePath, []byte(validYAML), 0644) + os.WriteFile(filePath, []byte(validYAML), 0o644) + filePath = filepath.Join(nestedDir, fileName) - os.WriteFile(filePath, []byte(validYAML), 0644) + os.WriteFile(filePath, []byte(validYAML), 0o644) + filePath = filepath.Join(baseDir, "base-"+fileName) - os.WriteFile(filePath, []byte(validYAML), 0644) + os.WriteFile(filePath, []byte(validYAML), 0o644) } + rs := RuleSet{} err = rs.loadRulesFromLocalDir(baseDir) + assert.NoError(t, err) assert.Equal(t, rs.Count(), len(testCases)*3) From dc19c4c813f019afc3d73b29abbb283d344a5ff1 Mon Sep 17 00:00:00 2001 From: Damian Date: Fri, 24 Nov 2023 21:25:12 +0000 Subject: [PATCH 2/2] output to stdout by default --- cmd/main.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 8ed3b1a..d28787d 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -68,10 +68,15 @@ func main() { // utility cli flag to compile ruleset directory into single ruleset.yaml if *mergeRulesets || *mergeRulesetsGzip { - output, err := os.Create(*mergeRulesetsOutput) - if err != nil { - fmt.Println(err) - os.Exit(1) + output := os.Stdout + + if *mergeRulesetsOutput != "" { + output, err = os.Create(*mergeRulesetsOutput) + + if err != nil { + fmt.Println(err) + os.Exit(1) + } } err = cli.HandleRulesetMerge(*ruleset, *mergeRulesets, *mergeRulesetsGzip, output)