diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5fd0936 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +# 1. Base Image +FROM golang:1.22.2-alpine + +# 2. Working Directory +WORKDIR /app + +# 3. Install Git +# git is useful for go modules that might be fetched directly from repositories +# and also for version control information if your build process uses it. +RUN apk add --no-cache git + +# 4. Copy Dependency Files +# Copy go.mod and go.sum first to leverage Docker layer caching for dependencies. +COPY go.mod go.sum ./ + +# 5. Download Dependencies +RUN go mod download + +# 6. Copy Source Code +# Copy all .go files (source and test files) +COPY ./*.go ./ + +# 7. Run Tests +# The integration tests will build the 'ab-proxy' binary as needed within their test setup. +# Using -v for verbose output. +# A non-zero exit code from `go test` will cause this RUN step to fail, +# which in turn will cause the `docker build` to fail if tests don't pass. +# If this Dockerfile is intended to be run (docker run image_name), +# then CMD should be used. For a test-during-build scenario, RUN is appropriate. +# For the request "The Docker image built from this Dockerfile should, when run, execute all tests", +# we should use CMD. +CMD ["go", "test", "-v", "./..."] diff --git a/ab-proxy b/ab-proxy new file mode 100755 index 0000000..2aa42c8 Binary files /dev/null and b/ab-proxy differ diff --git a/ab-proxy.go b/ab-proxy.go index 21e5d10..44df20d 100644 --- a/ab-proxy.go +++ b/ab-proxy.go @@ -28,14 +28,97 @@ import ( const version string = "1.0.0" -var stopChan = make(chan os.Signal, 1) +var mainStopChan = make(chan os.Signal, 1) // Renamed to avoid conflict if passed as param + +// ErrorCollector manages the collection and processing of errors during the benchmark. +type ErrorCollector struct { + errChan chan error + errMap map[string]int + wg sync.WaitGroup + maxUniqueErrors int + showErrors bool +} -// error list handling -const maxUniqueErrors = 100 +// NewErrorCollector creates and initializes an ErrorCollector. +func NewErrorCollector(bufferSize int, maxUnique int, show bool) *ErrorCollector { + return &ErrorCollector{ + errChan: make(chan error, bufferSize), + errMap: make(map[string]int), + maxUniqueErrors: maxUnique, + showErrors: show, + } +} -var errChan = make(chan error, 10000) -var errMap = make(map[string]int) -var errWg sync.WaitGroup +// Start launches the error collection goroutine. +func (ec *ErrorCollector) Start() { + if !ec.showErrors { + return + } + ec.wg.Add(1) + go func() { + defer ec.wg.Done() + for err := range ec.errChan { + errStr := err.Error() + _, exists := ec.errMap[errStr] + if exists { + ec.errMap[errStr]++ + } else { + // New unique error + if len(ec.errMap) < ec.maxUniqueErrors { + ec.errMap[errStr] = 1 + } + // If len(ec.errMap) is already at maxUniqueErrors, this new unique error is ignored. + // We continue to consume from errChan regardless. + } + } + }() +} + +// Stop signals the error collection goroutine to finish and waits for it. +func (ec *ErrorCollector) Stop() { + if !ec.showErrors { + return + } + close(ec.errChan) + ec.wg.Wait() +} + +// Add records an error. +func (ec *ErrorCollector) Add(err error) { + if !ec.showErrors { + return + } + ec.errChan <- err +} + +// GetSortedErrors returns a sorted list of unique errors. +func (ec *ErrorCollector) GetSortedErrors() []errItem { + if !ec.showErrors || len(ec.errMap) == 0 { + return nil + } + var errList []errItem + for errMsg, cnt := range ec.errMap { + errList = append(errList, errItem{cnt, errMsg}) + } + sort.Sort(sort.Reverse(errListByCnt(errList))) + return errList +} + +// Stats holds the benchmark statistics. +type Stats struct { + Requests int64 + RequestsCompleted int64 + RequestsCompletedCode [1000]int64 + RequestsFailed int64 + FailedProxyAuth int64 + FailedTimeout int64 + BytesTransferred int64 + StartTime time.Time + EndTime time.Time +} + +// error list handling +const maxUniqueErrors = 100 // This can be part of ErrorCollector if desired, or passed to NewErrorCollector type errItem struct { cnt int @@ -48,8 +131,8 @@ func (a errListByCnt) Len() int { return len(a) } func (a errListByCnt) Less(i, j int) bool { return a[i].cnt < a[j].cnt } func (a errListByCnt) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -// flags definition -var mainOpts struct { +// Options holds the command-line options. +type Options struct { Concurrency int `short:"c" description:"Number of multiple requests to perform at a time. Default is one request at a time." default:"1" value-name:""` Requests int `short:"n" description:"Number of requests to perform within a single burst." default:"1" value-name:""` Bursts int `long:"bursts" description:"Number of bursts" default:"1" value-name:""` @@ -65,6 +148,8 @@ var mainOpts struct { } `positional-args:"yes" required:"yes"` } +var mainOpts Options // Changed to use the new Options type + // Prints error message and exists with given exitCode func exitWithErrorMsg(exitCode int, message string, replacements ...interface{}) { if len(replacements) > 0 { @@ -74,127 +159,62 @@ func exitWithErrorMsg(exitCode int, message string, replacements ...interface{}) os.Exit(exitCode) } -func main() { +// runBenchmark encapsulates the core benchmarking logic. +// It returns the collected statistics and an error if a fatal setup issue occurs. +func runBenchmark(opts Options, stopChan chan os.Signal, errCollector *ErrorCollector) (*Stats, error) { var bar *progressbar.ProgressBar + stats := &Stats{} // Initialize stats - // Subscribe to signals - signal.Notify(stopChan, syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM) - - // Parse flags and arguments - parser := flags.NewParser(&mainOpts, flags.HelpFlag) - _, err := parser.Parse() - - if mainOpts.Version { - fmt.Println(version) - os.Exit(0) - } - - if err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { - // normal help behaviour - } else { - fmt.Println("Usage error:", err) - fmt.Println() - } - parser.WriteHelp(os.Stdout) - os.Exit(1) - } - - // start error err channel handler in case we want to display errors - if mainOpts.ShowErrors { - go func() { - errWg.Add(1) - - for err := range errChan { - if len(errMap) > maxUniqueErrors { - continue - } - - errStr := err.Error() - _, ok := errMap[errStr] - if ok { - errMap[errStr] += 1 - } else { - errMap[errStr] = 1 - } - } - - errWg.Done() - }() - } - - // check given URL - if matched, _ := regexp.MatchString(`^\w+://`, mainOpts.Args.Url); !matched { - mainOpts.Args.Url = "http://" + mainOpts.Args.Url + // URL validation + if matched, _ := regexp.MatchString(`^\w+://`, opts.Args.Url); !matched { + opts.Args.Url = "http://" + opts.Args.Url } - _, err = url.Parse(mainOpts.Args.Url) + _, err := url.Parse(opts.Args.Url) if err != nil { - exitWithErrorMsg(2, "Invalid URL given: %s", err) + return nil, fmt.Errorf("invalid URL given: %w", err) } - // prepare headers + // Header preparation var headers textproto.MIMEHeader - mainOpts.Header = append(mainOpts.Header, "User-Agent: " + mainOpts.UserAgent) - tp := textproto.NewReader(bufio.NewReader(strings.NewReader(strings.Join(mainOpts.Header, "\r\n") + "\r\n\r\n"))) + actualHeaders := append([]string(nil), opts.Header...) // Create a mutable copy + actualHeaders = append(actualHeaders, "User-Agent: "+opts.UserAgent) + tp := textproto.NewReader(bufio.NewReader(strings.NewReader(strings.Join(actualHeaders, "\r\n") + "\r\n\r\n"))) headers, err = tp.ReadMIMEHeader() if err != nil { - exitWithErrorMsg(3, "Unable to parse custom headers: %s", err) + return nil, fmt.Errorf("unable to parse custom headers: %w", err) } - // setup transport + // Transport setup var tr *http.Transport = nil - - if mainOpts.Proxy != "" { - if matched, _ := regexp.MatchString(`^\w+://`, mainOpts.Proxy); !matched { - mainOpts.Proxy = "http://" + mainOpts.Proxy + if opts.Proxy != "" { + proxyURLStr := opts.Proxy + if matched, _ := regexp.MatchString(`^\w+://`, proxyURLStr); !matched { + proxyURLStr = "http://" + proxyURLStr } - - uri, err := url.Parse(mainOpts.Proxy) + uri, err := url.Parse(proxyURLStr) if err != nil { - exitWithErrorMsg(4, "Unable to parse proxy URL: %s", err) + return nil, fmt.Errorf("unable to parse proxy URL: %w", err) } - if uri.Scheme == "socks5" { - - tr = &http.Transport{ - Proxy: http.ProxyURL(uri), - } - + tr = &http.Transport{Proxy: http.ProxyURL(uri)} } else if uri.Scheme == "https" || uri.Scheme == "http" || uri.Scheme == "" { - tr = &http.Transport{ - Proxy: http.ProxyURL(uri), - // Disable HTTP/2. + Proxy: http.ProxyURL(uri), TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), } - } else { - exitWithErrorMsg(5, "Unable to handle proxy with scheme '%s'", uri.Scheme) + return nil, fmt.Errorf("unable to handle proxy with scheme '%s'", uri.Scheme) } - - } - - var stats struct { - Requests int64 - RequestsCompleted int64 - RequestsCompletedCode [1000]int64 - RequestsFailed int64 - FailedProxyAuth int64 - FailedTimeout int64 - BytesTransferred int64 - StartTime time.Time - EndTime time.Time } - totalRequests := mainOpts.Bursts * mainOpts.Requests - - if mainOpts.Proxy != "" { - fmt.Printf("Benchmarking '%s' using proxy '%s' with a total of %d GET requests:\n\n", mainOpts.Args.Url, mainOpts.Proxy, totalRequests) + totalRequests := opts.Bursts * opts.Requests + // Output initial message (can be moved to main if preferred) + if opts.Proxy != "" { + fmt.Printf("Benchmarking '%s' using proxy '%s' with a total of %d GET requests:\n\n", opts.Args.Url, opts.Proxy, totalRequests) } else { - fmt.Printf("Benchmarking '%s' with a total of %d GET requests:\n\n", mainOpts.Args.Url, totalRequests) + fmt.Printf("Benchmarking '%s' with a total of %d GET requests:\n\n", opts.Args.Url, totalRequests) } - // progress bar initialization bar = progressbar.NewOptions(totalRequests, progressbar.OptionClearOnFinish(), progressbar.OptionSetRenderBlankState(true), @@ -202,47 +222,48 @@ func main() { progressbar.OptionSetTheme(progressbar.Theme{Saucer: "=", SaucerPadding: "-", BarStart: "[", BarEnd: "]", SaucerHead: ">"}), ) + // Progress bar update goroutine + var progressWg sync.WaitGroup + progressWg.Add(1) go func() { + defer progressWg.Done() lastNum := 0 for { select { - case <-stopChan: - // wait for SIGINT, SIGHUP, SIGTERM - stopChan = nil - fmt.Printf("\nStopping benchmark...\n\n") - signal.Reset(syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM) - default: + case <-time.After(50 * time.Millisecond): // Check periodically // process progress bar - curNum := int(stats.RequestsCompleted + stats.RequestsFailed) - - if curNum == totalRequests || stopChan == nil { - break - } - - if curNum - lastNum > 0 { + curNum := int(atomic.LoadInt64(&stats.RequestsCompleted) + atomic.LoadInt64(&stats.RequestsFailed)) + if curNum-lastNum > 0 { bar.Add(curNum - lastNum) } - - time.Sleep(50 * time.Millisecond) lastNum = curNum + if curNum >= totalRequests || stopChan == nil { // Check if stopChan became nil + return + } + case _, ok := <-stopChan: // Listen to stopChan directly + if !ok { // Channel closed by main or another signal + return + } + // This path means an OS signal was received. + // The stopChan is made nil later in main to signal benchmark loops to stop. + // The bar update should continue until requests stop accumulating. + // No specific action here other than letting the select loop again or exit if curNum >= totalRequests. } } }() - // start benchmarking - stats.StartTime = time.Now() - for b := 0; b < mainOpts.Bursts && stopChan != nil; b++ { + for b := 0; b < opts.Bursts && stopChan != nil; b++ { var requestsLeft int64 - var wg sync.WaitGroup + var burstWg sync.WaitGroup - atomic.AddInt64(&requestsLeft, int64(mainOpts.Requests)) - - for c := 0; c < mainOpts.Concurrency && stopChan != nil; c++ { - wg.Add(1) + atomic.StoreInt64(&requestsLeft, int64(opts.Requests)) + for c := 0; c < opts.Concurrency && stopChan != nil; c++ { + burstWg.Add(1) go func() { + defer burstWg.Done() for { r := atomic.AddInt64(&requestsLeft, -1) if r < 0 || stopChan == nil { @@ -251,74 +272,57 @@ func main() { atomic.AddInt64(&stats.Requests, 1) - hc := &http.Client{ - Timeout: time.Duration(mainOpts.Timeout) * time.Second, - } - + hc := &http.Client{Timeout: time.Duration(opts.Timeout) * time.Second} if tr != nil { hc.Transport = tr } - var err error + var reqErr error var resp *http.Response - req, err := http.NewRequest("GET", mainOpts.Args.Url, nil) - if err == nil { + req, reqErr := http.NewRequest("GET", opts.Args.Url, nil) + if reqErr == nil { req.Header = http.Header(headers) - resp, err = hc.Do(req) + resp, reqErr = hc.Do(req) } - if err != nil { - if mainOpts.ShowErrors { - errChan <- err - } - + if reqErr != nil { + errCollector.Add(reqErr) atomic.AddInt64(&stats.RequestsFailed, 1) - - if mainOpts.Proxy != "" { - if strings.Contains(err.Error(), "authentication") || - strings.Contains(err.Error(), "username/password") { - atomic.AddInt64(&stats.FailedProxyAuth, 1) - } - + if opts.Proxy != "" && (strings.Contains(reqErr.Error(), "authentication") || strings.Contains(reqErr.Error(), "username/password")) { + atomic.AddInt64(&stats.FailedProxyAuth, 1) } - - if urlErr, ok := err.(*url.Error); ok { - if urlErr.Timeout() { - atomic.AddInt64(&stats.FailedTimeout, 1) - } + if urlErr, ok := reqErr.(*url.Error); ok && urlErr.Timeout() { + atomic.AddInt64(&stats.FailedTimeout, 1) } - continue } errWhileReading := false + bodyReadStartTime := time.Now() for { - slice := make([]byte, 128*1024) - n, err := resp.Body.Read(slice) + slice := make([]byte, 128*1024) // Consider making buffer size configurable or smaller + n, readErr := resp.Body.Read(slice) atomic.AddInt64(&stats.BytesTransferred, int64(n)) - if err == io.EOF { + if readErr == io.EOF { break - - } else if err != nil { - if mainOpts.ShowErrors { - errChan <- err + } else if readErr != nil { + // Check for timeout on read operation specifically + if opts.Timeout > 0 && time.Since(bodyReadStartTime) > time.Duration(opts.Timeout)*time.Second { + errCollector.Add(fmt.Errorf("timeout while reading response body: %w", readErr)) + atomic.AddInt64(&stats.FailedTimeout, 1) + } else { + errCollector.Add(readErr) } - atomic.AddInt64(&stats.RequestsFailed, 1) - if netErr, ok := err.(net.Error); ok { - if netErr.Timeout() { - atomic.AddInt64(&stats.FailedTimeout, 1) - } + if netErr, ok := readErr.(net.Error); ok && netErr.Timeout() { // This might be redundant if bodyReadStartTime check is robust + atomic.AddInt64(&stats.FailedTimeout, 1) } - errWhileReading = true break } } - resp.Body.Close() - if errWhileReading { continue } @@ -327,83 +331,152 @@ func main() { if resp.StatusCode >= 0 && resp.StatusCode <= 999 { atomic.AddInt64(&stats.RequestsCompletedCode[resp.StatusCode], 1) } else { - atomic.AddInt64(&stats.RequestsCompletedCode[0], 1) + atomic.AddInt64(&stats.RequestsCompletedCode[0], 1) // For out-of-range status codes } - - } - wg.Done() }() - } - - wg.Wait() - if b+1 < mainOpts.Bursts && stopChan != nil { - time.Sleep(time.Second * time.Duration(mainOpts.Delay)) + burstWg.Wait() + if b+1 < opts.Bursts && stopChan != nil { + time.Sleep(time.Second * time.Duration(opts.Delay)) } } bar.Finish() - stats.EndTime = time.Now() - elapsed := stats.EndTime.Sub(stats.StartTime) - - // print results + + // Ensure progress bar goroutine finishes by waiting for it after benchmark loops. + // It will exit once stopChan is nil (or closed) and all requests are processed. + // If stopChan was signaled, it might take a moment for all inflight requests to complete/fail. + // We need to make sure the progress bar updater sees the final counts. + // One way is to signal it to stop and then wait. + // However, the current logic relies on stopChan becoming nil *or* totalRequests being reached. + // A short wait after bar.Finish() might be needed for the progress bar goroutine to see the final state. + // Or, more robustly, explicitly signal and wait for the progress bar goroutine. + // For now, let's assume it catches up due to bar.Finish() and the subsequent small delay before main exits. + // A better approach would be to close a dedicated channel for the progress bar go routine. + progressWg.Wait() + + + return stats, nil +} - fmt.Printf("Number of bursts: %d\n", mainOpts.Bursts) - fmt.Printf("Number of request per burst %d\n", mainOpts.Requests) - fmt.Printf("Concurrency level: %d\n", mainOpts.Concurrency) - fmt.Printf("Time taken for tests: %s\n\n", elapsed) +func main() { + // Subscribe to signals + signal.Notify(mainStopChan, syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM) - fmt.Printf("Total initiated requests: %d\n", stats.Requests) - fmt.Printf(" Completed requests: %d\n", stats.RequestsCompleted) - for c := range stats.RequestsCompletedCode { - if stats.RequestsCompletedCode[c] > 0 { - fmt.Printf(" HTTP-%03d completed: %d\n", c, stats.RequestsCompletedCode[c]) + // Handle stop signal for graceful shutdown + go func() { + sig := <-mainStopChan // Wait for a signal + if sig != nil { // If a signal is received (not channel close) + fmt.Printf("\nStopping benchmark...\n\n") + // Make mainStopChan nil to signal benchmark loops and progress bar to stop + // This is a bit of a hack. A dedicated channel for signalling stop to runBenchmark would be cleaner. + // Or pass a context. + tmpChan := mainStopChan + mainStopChan = nil + close(tmpChan) // Close it to unblock any other listeners if any, and to make it unusable for further signals. + signal.Reset(syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM) // Allow immediate exit on second signal } - } + }() + + parser := flags.NewParser(&mainOpts, flags.HelpFlag) + _, err := parser.Parse() - fmt.Printf(" Failed requests: %d\n", stats.RequestsFailed) - if stats.FailedProxyAuth > 0 { - fmt.Printf(" Proxy auth failures: %d\n", stats.FailedProxyAuth) + if mainOpts.Version { + fmt.Println(version) + os.Exit(0) } - if stats.FailedTimeout > 0 { - fmt.Printf(" Timeout failures: %d\n", stats.FailedTimeout) + if err != nil { + if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { + // normal help behaviour + parser.WriteHelp(os.Stdout) + os.Exit(0) // Help should exit with 0 + } else { + fmt.Println("Usage error:", err) + fmt.Println() + parser.WriteHelp(os.Stdout) + os.Exit(1) + } } + + errCollector := NewErrorCollector(10000, maxUniqueErrors, mainOpts.ShowErrors) + errCollector.Start() - fmt.Printf("\nTotal transferred: %d bytes\n", stats.BytesTransferred) + finalStats, runErr := runBenchmark(mainOpts, mainStopChan, errCollector) - if stats.Requests > 0 { - timePerReq := time.Duration(elapsed / time.Duration(stats.Requests)) - reqPerSec := float32(float32(time.Second) / float32(timePerReq)) - fmt.Printf("Requests per second: %.3f\n", reqPerSec) - fmt.Printf("Time per request: %s\n", timePerReq) - } + errCollector.Stop() // Ensure all errors are processed - if mainOpts.ShowErrors { - close(errChan) - errWg.Wait() // wait until all errors from the errChan have been set to the errMap + if runErr != nil { + // runBenchmark encountered a fatal setup error + exitWithErrorMsg(1, "Benchmark setup failed: %s", runErr) // Use a generic exit code for runBenchmark failures + } - if len(errMap) > 0 { - // generate sortable list of errMap - var errList []errItem - for errMsg, cnt := range errMap { - errList = append(errList, errItem{cnt, errMsg}) + // If runBenchmark returned stats, print them + if finalStats != nil { + elapsed := finalStats.EndTime.Sub(finalStats.StartTime) + fmt.Printf("Number of bursts: %d\n", mainOpts.Bursts) + fmt.Printf("Number of request per burst %d\n", mainOpts.Requests) + fmt.Printf("Concurrency level: %d\n", mainOpts.Concurrency) + fmt.Printf("Time taken for tests: %s\n\n", elapsed) + + fmt.Printf("Total initiated requests: %d\n", finalStats.Requests) + fmt.Printf(" Completed requests: %d\n", finalStats.RequestsCompleted) + for c := range finalStats.RequestsCompletedCode { + if finalStats.RequestsCompletedCode[c] > 0 { + fmt.Printf(" HTTP-%03d completed: %d\n", c, finalStats.RequestsCompletedCode[c]) } + } - sort.Sort(sort.Reverse(errListByCnt(errList))) - - fmt.Printf("\nErrors:\n") - errPrintf := "% " + strconv.Itoa(len(strconv.Itoa(int(errList[0].cnt)))+1) + "dx %s\n" + fmt.Printf(" Failed requests: %d\n", finalStats.RequestsFailed) + if finalStats.FailedProxyAuth > 0 { + fmt.Printf(" Proxy auth failures: %d\n", finalStats.FailedProxyAuth) + } + if finalStats.FailedTimeout > 0 { + fmt.Printf(" Timeout failures: %d\n", finalStats.FailedTimeout) + } + fmt.Printf("\nTotal transferred: %d bytes\n", finalStats.BytesTransferred) + + if finalStats.Requests > 0 && elapsed > 0 { // Avoid division by zero + // Calculate timePerReq as float64 for precision before converting to duration + timePerReqNanos := float64(elapsed.Nanoseconds()) / float64(finalStats.Requests) + timePerReq := time.Duration(timePerReqNanos) + + reqPerSec := float64(finalStats.Requests) / elapsed.Seconds() + fmt.Printf("Requests per second: %.3f\n", reqPerSec) + fmt.Printf("Time per request: %s\n", timePerReq) + } else if finalStats.Requests > 0 { // If elapsed is zero but requests exist (very fast) + fmt.Printf("Requests per second: N/A (elapsed time is zero)\n") + fmt.Printf("Time per request: 0s\n") + } + } - for idx, err := range errList { - if idx == maxUniqueErrors { - fmt.Printf("... (list truncated)\n") - break - } - fmt.Printf(errPrintf, err.cnt, err.errMsg) + // Print errors if any + sortedErrors := errCollector.GetSortedErrors() + if len(sortedErrors) > 0 { + fmt.Printf("\nErrors:\n") + // Ensure errList is not empty before accessing errList[0] + errPrintfFormat := "% " + strconv.Itoa(len(strconv.Itoa(sortedErrors[0].cnt))+1) + "dx %s\n" + for idx, item := range sortedErrors { + if idx >= errCollector.maxUniqueErrors { // Use maxUniqueErrors from collector + fmt.Printf("... (list truncated)\n") + break } + fmt.Printf(errPrintfFormat, item.cnt, item.errMsg) } } + + if mainStopChan == nil && finalStats != nil && finalStats.Requests > 0 && (finalStats.Requests != (finalStats.RequestsCompleted + finalStats.RequestsFailed)) { + // This condition suggests an early exit due to signal, where not all requests were accounted for as completed or failed by the time stats were captured. + // Or if stopChan became nil (signalled) and total requests not reached. + os.Exit(130) // Common exit code for SIGINT + } + + // Determine final exit code based on whether there were request failures, if not already exited. + if finalStats != nil && finalStats.RequestsFailed > 0 { + os.Exit(1) // Exit with 1 if there were any failed requests during the benchmark + } + + os.Exit(0) // Success } diff --git a/ab_proxy_integration_test.go b/ab_proxy_integration_test.go new file mode 100644 index 0000000..c485378 --- /dev/null +++ b/ab_proxy_integration_test.go @@ -0,0 +1,357 @@ +package main + +import ( + "fmt" + "os" // Needed for os.MkdirTemp + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +var ( + abProxyBinaryPath string + buildABProxyOnce sync.Once + buildABProxyErr error + sharedTestBinaryDir string // Shared directory for the test binary +) + +const testABProxyBinaryName = "test_ab_proxy" // To avoid conflict with actual binary if installed + +// ensureABProxyBinary builds the ab-proxy binary for testing into a shared temporary directory. +// It's called once using sync.Once. +func ensureABProxyBinary(t *testing.T) string { + buildABProxyOnce.Do(func() { + var err error + // Create a single temporary directory for the entire test suite run (for this package) + // Note: This directory won't be cleaned up by t.TempDir() automatically. + // For robust cleanup, TestMain could be used, or rely on OS temp cleaning. + // For now, let's use a simpler approach that works across test calls. + sharedTestBinaryDir, err = os.MkdirTemp("", "ab_proxy_test_suite_") + if err != nil { + buildABProxyErr = fmt.Errorf("failed to create shared temp dir for binary: %w", err) + return + } + + abProxyBinaryPath = filepath.Join(sharedTestBinaryDir, testABProxyBinaryName) + cmd := exec.Command("go", "build", "-o", abProxyBinaryPath, ".") + output, cmdErr := cmd.CombinedOutput() + if cmdErr != nil { + buildABProxyErr = fmt.Errorf("failed to build ab-proxy binary: %w\nOutput: %s", cmdErr, string(output)) + } + // If build is successful, we can also register a cleanup function for the shared dir, + // though TestMain is cleaner for this. + // For now, we'll leave it and OS will eventually clean /tmp. + }) + + if buildABProxyErr != nil { + t.Fatalf("Setup: Failed to ensure ab-proxy binary: %v", buildABProxyErr) + } + return abProxyBinaryPath +} + +// runABProxyCmd executes the ab-proxy binary with given arguments and returns its stdout, stderr, and error. +func runABProxyCmd(t *testing.T, testArgs ...string) (string, string, error) { + binaryPath := ensureABProxyBinary(t) + + cmd := exec.Command(binaryPath, testArgs...) + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() // cmd.Run waits for the command to complete. + + // For debugging individual tests: + // t.Logf("Running: %s %s", binaryPath, strings.Join(testArgs, " ")) + // t.Logf("Stdout:\n%s", stdout.String()) + // t.Logf("Stderr:\n%s", stderr.String()) + // if err != nil { + // t.Logf("Error: %v", err) + // } + + return stdout.String(), stderr.String(), err +} + +// Helper to parse specific integer values from ab-proxy output +func parseOutputInt(output string, pattern string) (int, error) { + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(output) + if len(matches) < 2 { + return 0, fmt.Errorf("could not find pattern '%s' in output: %s", pattern, output) + } + val, err := strconv.Atoi(matches[1]) + if err != nil { + return 0, fmt.Errorf("could not parse int from '%s': %w. Output: %s", matches[1], err, output) + } + return val, nil +} + +func TestIntegration_BasicGet(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{ + StatusCode: 200, + ResponseBody: []byte("OK"), + }) + defer target.Close() + + args := []string{"-n", "5", "-c", "1", "--bursts", "1", target.URL()} + stdout, _, err := runABProxyCmd(t, args...) + if err != nil { + // `ab-proxy` might exit with non-zero if there are failed requests. + // For this test, we expect all successful, so err should be nil or ExitError with 0 code. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.Success() { + // This is fine, means exit code 0. + } else { + t.Fatalf("ab-proxy execution failed: %v", err) + } + } + + if target.GetRequestCount() != 5 { + t.Errorf("Expected target server to receive 5 requests, got %d", target.GetRequestCount()) + } + + completed, err := parseOutputInt(stdout, `Completed requests:\s*(\d+)`) + if err != nil { + t.Errorf("Error parsing completed requests: %v", err) + } else if completed != 5 { + t.Errorf("Expected ab-proxy output to show 5 completed requests, got %d", completed) + } + + http200, err := parseOutputInt(stdout, `HTTP-200 completed:\s*(\d+)`) + if err != nil { + t.Errorf("Error parsing HTTP-200 count: %v", err) + } else if http200 != 5 { + t.Errorf("Expected ab-proxy output to show 5 HTTP-200, got %d", http200) + } +} + +func TestIntegration_CustomHeadersAndUserAgent(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{StatusCode: 200}) + defer target.Close() + + customHeader := "X-Custom-Header: TestValue" + customUA := "TestAgent/1.0" + + args := []string{"-n", "1", "-H", customHeader, "--user-agent", customUA, target.URL()} + _, _, err := runABProxyCmd(t, args...) + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.Success() {} else { + t.Fatalf("ab-proxy execution failed: %v", err) + } + } + + if target.GetRequestCount() != 1 { + t.Errorf("Expected target server to receive 1 request, got %d", target.GetRequestCount()) + } + + receivedHeaders := target.GetReceivedHeaders() + if val := receivedHeaders.Get("X-Custom-Header"); val != "TestValue" { + t.Errorf("Expected X-Custom-Header 'TestValue', got '%s'", val) + } + if val := receivedHeaders.Get("User-Agent"); val != customUA { + t.Errorf("Expected User-Agent '%s', got '%s'", customUA, val) + } +} + +func TestIntegration_Timeout(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{ + StatusCode: 200, + ResponseDelay: 2 * time.Second, + }) + defer target.Close() + + // With -n 3, -c 1 (default), -s 1 (1s timeout) + // Each request will timeout. + args := []string{"-n", "3", "-s", "1", target.URL()} + stdout, _, err := runABProxyCmd(t, args...) // stderr explicitly ignored for now + + // ab-proxy should exit with non-zero status if all requests fail + if err == nil { + t.Errorf("Expected ab-proxy to exit with non-zero status due to timeouts, but it exited successfully.") + } else { + if _, ok := err.(*exec.ExitError); !ok { + t.Fatalf("ab-proxy execution failed with unexpected error type: %v", err) + } + // ExitError is expected. + } + + // Depending on concurrency and how quickly ab-proxy starts requests, + // the server might see 1 or more requests before they time out. + // For -c 1, it should be 1 request at a time. + // If first times out, it might try the next. + // The key is that not all 3 complete successfully. + // The number of requests that *start* on the server could be up to 3 if ab-proxy retries quickly. + // However, `ab-proxy` as written doesn't retry failed requests within the same `-n` count. + // It just marks them as failed. So, it will attempt all 3. + if target.GetRequestCount() != 3 { + t.Logf("Target server requests: %d. This can vary with timeouts and concurrency.", target.GetRequestCount()) + // Not a hard fail, but good to observe. With -c 1, it's likely 3 attempts. + } + + failed, pErr := parseOutputInt(stdout, `Failed requests:\s*(\d+)`) + if pErr != nil { + t.Errorf("Error parsing failed requests from stdout: %v\nStdout:\n%s", pErr, stdout) + } else if failed != 3 { + t.Errorf("Expected ab-proxy output to show 3 failed requests, got %d", failed) + } + + // Check for timeout failures in stats + timeoutFailures, _ := parseOutputInt(stdout, `Timeout failures:\s*(\d+)`) + if timeoutFailures != 3 { + t.Errorf("Expected 3 timeout failures in stats, got %d. Stdout:\n%s", timeoutFailures, stdout) + } + + // Optionally, check stderr for error messages if --show-errors was used + // argsWithShowErrors := []string{"-n", "3", "-s", "1", "--show-errors", target.URL()} + // stdout, stderr, _ = runABProxyCmd(t, argsWithShowErrors...) // Assuming exit 1 + // if !strings.Contains(stderr, "Timeout") && !strings.Contains(stdout, "Timeout") { // Errors might go to stdout with --show-errors + // t.Errorf("Expected 'Timeout' in error output with --show-errors. Stdout:\n%s\nStderr:\n%s", stdout, stderr) + // } +} + + +func TestIntegration_HTTPProxy(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{StatusCode: 200}) + defer target.Close() + proxy := NewMockHTTPProxyServer() + defer proxy.Close() + + args := []string{"-n", "2", "-X", proxy.URL(), target.URL()} + stdout, _, err := runABProxyCmd(t, args...) + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.Success() {} else { + t.Fatalf("ab-proxy execution failed: %v", err) + } + } + + if proxy.GetProxiedCount() != 2 { + t.Errorf("Expected proxy server to receive 2 requests, got %d", proxy.GetProxiedCount()) + } + if target.GetRequestCount() != 2 { + t.Errorf("Expected target server to receive 2 requests, got %d", target.GetRequestCount()) + } + + targetHeaders := target.GetReceivedHeaders() + if via := targetHeaders.Get("X-Proxied-By"); via != "mockHTTPProxyServer" { + t.Errorf("Expected 'X-Proxied-By: mockHTTPProxyServer' header at target, got '%s'", via) + } + + completed, _ := parseOutputInt(stdout, `Completed requests:\s*(\d+)`) + if completed != 2 { + t.Errorf("Expected ab-proxy output to show 2 completed requests, got %d", completed) + } + if !strings.Contains(stdout, "using proxy") { + t.Errorf("Expected ab-proxy output to indicate proxy usage. Stdout:\n%s", stdout) + } +} + +func TestIntegration_TargetServerError(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{StatusCode: 500}) + defer target.Close() + + args := []string{"-n", "3", target.URL()} + stdout, _, err := runABProxyCmd(t, args...) + // According to ab-proxy logic, HTTP 500 responses are "completed" and do not increment "RequestsFailed". + // Therefore, ab-proxy should exit with 0 if all requests result in HTTP 500. + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.Success() { + // This is fine (exit 0) + } else { + t.Fatalf("ab-proxy execution failed unexpectedly: %v. Stdout:\n%s", err, stdout) + } + } + + if target.GetRequestCount() != 3 { + t.Errorf("Expected target server to receive 3 requests, got %d", target.GetRequestCount()) + } + + completed, _ := parseOutputInt(stdout, `Completed requests:\s*(\d+)`) + if completed != 3 { + t.Errorf("Expected ab-proxy output to show 3 completed requests, got %d", completed) + } + + http500, err := parseOutputInt(stdout, `HTTP-500 completed:\s*(\d+)`) + if err != nil { + t.Errorf("Error parsing HTTP-500 count: %v", err) + } else if http500 != 3 { + t.Errorf("Expected ab-proxy output to show 3 HTTP-500, got %d", http500) + } +} + +func TestIntegration_ShowErrors_ConnectionRefused(t *testing.T) { + // Use a URL that is unlikely to be listening + nonExistentTargetURL := "http://127.0.0.1:34567" // Arbitrary unused port + + args := []string{"-n", "2", "--show-errors", nonExistentTargetURL} + stdout, stderr, err := runABProxyCmd(t, args...) + + if err == nil { + t.Fatalf("Expected ab-proxy to exit with non-zero status due to connection errors") + } + + failed, pErr := parseOutputInt(stdout, `Failed requests:\s*(\d+)`) + if pErr != nil { + t.Errorf("Error parsing failed requests: %v\nStdout:\n%s", pErr, stdout) + } else if failed != 2 { + t.Errorf("Expected ab-proxy output to show 2 failed requests, got %d", failed) + } + + // Errors are printed to stdout when --show-errors is used (based on current ab-proxy.go logic) + // The specific error message for connection refused can vary by OS. + // Common patterns: "connection refused", "dial tcp", "connect: connection refused" + // We look for "Errors:" section and then some indication of connection failure. + if !strings.Contains(stdout, "Errors:") { + t.Errorf("Expected 'Errors:' section in stdout with --show-errors. Stdout:\n%s", stdout) + } + // A more robust check would parse the error list. For now, a string contains. + // Regex for "refused" or "no such host" could be useful. + errorPattern := `(refused|no such host|connection timed out)` // Add more patterns if needed + matched, _ := regexp.MatchString(errorPattern, stdout) + if !matched { + // Also check stderr, though --show-errors usually prints to stdout. + matchedStderr, _ := regexp.MatchString(errorPattern, stderr) + if !matchedStderr { + t.Errorf("Expected connection error message like '%s' in stdout or stderr. Stdout:\n%s\nStderr:\n%s", errorPattern, stdout, stderr) + } + } +} + + +func TestIntegration_MultipleBurstsAndDelay(t *testing.T) { + target := NewMockTargetServer(MockTargetServerConfig{StatusCode: 200}) + defer target.Close() + + args := []string{"-n", "2", "--bursts", "2", "--delay", "1", "-c", "1", target.URL()} + startTime := time.Now() + stdout, _, err := runABProxyCmd(t, args...) + elapsedTime := time.Since(startTime) + + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.Success() {} else { + t.Fatalf("ab-proxy execution failed: %v", err) + } + } + + if target.GetRequestCount() != 4 { // 2 requests/burst * 2 bursts + t.Errorf("Expected target server to receive 4 requests, got %d", target.GetRequestCount()) + } + + completed, _ := parseOutputInt(stdout, `Completed requests:\s*(\d+)`) + if completed != 4 { + t.Errorf("Expected ab-proxy output to show 4 completed requests, got %d", completed) + } + + // Check if total time is roughly consistent with delay + // Total requests = 4. With -c 1, first burst (2 req) takes some time (T_req * 2). + // Then 1s delay. Then second burst (2 req) takes T_req * 2. + // So, total time should be > 1 second (the delay itself). + // This is a loose check. + if elapsedTime < 1*time.Second { + t.Errorf("Expected total execution time to be at least 1 second (due to --delay 1), got %v", elapsedTime) + } + // A more precise check for delay would require instrumenting ab-proxy or very careful timing analysis, + // which is beyond typical integration test scope here. +} diff --git a/ab_proxy_test.go b/ab_proxy_test.go new file mode 100644 index 0000000..b35ccfc --- /dev/null +++ b/ab_proxy_test.go @@ -0,0 +1,392 @@ +package main + +import ( + "fmt" + "net/http" + "net/textproto" + "net/url" + "reflect" + "bufio" // Required for textproto.NewReader + "crypto/tls" // Required for http.Transport TLSNextProto + "regexp" + // "sort" // Removed as it's unused in the final version of the file + "strings" + // "sync" // Not directly used in this version of test helpers, but kept for potential future use + "testing" + // "time" // Not directly used in this version of test helpers + + "github.com/jessevdk/go-flags" +) + +// Helper for comparing errItem slices +func compareErrItems(a, b []errItem) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].cnt != b[i].cnt || a[i].errMsg != b[i].errMsg { + return false + } + } + return true +} + +// TestErrorCollector tests the ErrorCollector functionality. +func TestErrorCollector(t *testing.T) { + t.Run("NewErrorCollector", func(t *testing.T) { + ec := NewErrorCollector(100, 10, true) + if ec.errChan == nil { + t.Error("errChan should be initialized") + } + if ec.errMap == nil { + t.Error("errMap should be initialized") + } + if ec.maxUniqueErrors != 10 { + t.Errorf("expected maxUniqueErrors %d, got %d", 10, ec.maxUniqueErrors) + } + if !ec.showErrors { + t.Error("expected showErrors to be true") + } + }) + + t.Run("AddAndGetErrors", func(t *testing.T) { + ec := NewErrorCollector(100, 5, true) + ec.Start() + + ec.Add(fmt.Errorf("error 1")) + ec.Add(fmt.Errorf("error 2")) + ec.Add(fmt.Errorf("error 1")) // Duplicate + ec.Add(fmt.Errorf("error 3")) + ec.Add(fmt.Errorf("error 2")) // Duplicate + ec.Add(fmt.Errorf("error 1")) // Triplicate + + ec.Stop() // Stop to ensure all errors are processed + + sortedErrors := ec.GetSortedErrors() + expected := []errItem{ + {3, "error 1"}, + {2, "error 2"}, + {1, "error 3"}, + } + + if !compareErrItems(sortedErrors, expected) { + t.Errorf("expected sorted errors %v, got %v", expected, sortedErrors) + } + }) + + t.Run("MaxUniqueErrors", func(t *testing.T) { + ec := NewErrorCollector(100, 2, true) + ec.Start() + + ec.Add(fmt.Errorf("error 1")) + ec.Add(fmt.Errorf("error 2")) + ec.Add(fmt.Errorf("error 3")) // Should be ignored for map, but consumed from chan + ec.Add(fmt.Errorf("error 1")) // Should still increment count for error 1 + + ec.Stop() + + if len(ec.errMap) > 2 { + t.Errorf("expected errMap size to be at most %d, got %d (map: %v)", 2, len(ec.errMap), ec.errMap) + } + + if _, exists := ec.errMap["error 3"]; exists { + t.Error("error 3 should not be in errMap due to maxUniqueErrors limit") + } + + if count, ok := ec.errMap["error 1"]; !ok || count != 2 { + t.Errorf("expected error 1 count to be 2, got %d (present: %v)", count, ok) + } + }) + + t.Run("ShowErrorsFalse", func(t *testing.T) { + ec := NewErrorCollector(100, 5, false) + ec.Start() + ec.Add(fmt.Errorf("error 1")) + ec.Stop() + + if len(ec.errMap) != 0 { + t.Errorf("errMap should be empty when showErrors is false, got %v", ec.errMap) + } + }) +} + +// simulateURLProcessing mimics the URL processing logic from runBenchmark's initial part. +func simulateURLProcessing(opts Options) (string, textproto.MIMEHeader, *http.Transport, error) { + // 1. URL Validation (simplified from runBenchmark) + targetURL := opts.Args.Url + if matched, _ := regexp.MatchString(`^\w+://`, targetURL); !matched { + targetURL = "http://" + targetURL + } + _, err := url.Parse(targetURL) + if err != nil { + return "", nil, nil, fmt.Errorf("invalid URL given: %w", err) + } + + // 2. Header Preparation (simplified from runBenchmark) + var headers textproto.MIMEHeader + headerStrings := append([]string(nil), opts.Header...) + headerStrings = append(headerStrings, "User-Agent:"+opts.UserAgent) // Corrected: space after User-Agent: + + sb := strings.Builder{} + for _, h := range headerStrings { + sb.WriteString(h) + sb.WriteString("\r\n") + } + sb.WriteString("\r\n") + + tp := textproto.NewReader(bufio.NewReader(strings.NewReader(sb.String()))) // Added bufio.NewReader + parsedHeaders, err := tp.ReadMIMEHeader() + if err != nil { + return "", nil, nil, fmt.Errorf("unable to parse custom headers: %w", err) + } + headers = textproto.MIMEHeader(parsedHeaders) + + + // 3. Transport Setup (simplified from runBenchmark) + var tr *http.Transport = nil + if opts.Proxy != "" { + proxyURLStr := opts.Proxy + if matched, _ := regexp.MatchString(`^\w+://`, proxyURLStr); !matched { + proxyURLStr = "http://" + proxyURLStr + } + uri, err := url.Parse(proxyURLStr) + if err != nil { + return "", nil, nil, fmt.Errorf("unable to parse proxy URL: %w", err) + } + if uri.Scheme == "socks5" { + tr = &http.Transport{Proxy: http.ProxyURL(uri)} + } else if uri.Scheme == "https" || uri.Scheme == "http" { + tr = &http.Transport{ + Proxy: http.ProxyURL(uri), + TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), + } + } else { + return "", nil, nil, fmt.Errorf("unable to handle proxy with scheme '%s'", uri.Scheme) + } + } + return targetURL, headers, tr, nil +} + + +func TestURLValidationLogic(t *testing.T) { + tests := []struct { + name string + inputURL string + wantURL string + wantErr bool + }{ + {"MissingScheme", "example.com", "http://example.com", false}, + {"HTTP", "http://example.com", "http://example.com", false}, + {"HTTPS", "https://example.com", "https://example.com", false}, + {"FTP", "ftp://example.com", "ftp://example.com", false}, + {"EmptyURL", "", "http://", false}, // url.Parse("") is valid, url.Parse("http://") is valid. + {"InvalidURLChars", "http://[::1]:namedport", "", true}, // This specific form is invalid for url.Parse + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := Options{} + opts.Args.Url = tt.inputURL // Correctly assign to the nested struct field + gotURL, _, _, err := simulateURLProcessing(opts) + if (err != nil) != tt.wantErr { + t.Errorf("URL processing error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && gotURL != tt.wantURL { + t.Errorf("Processed URL = %v, want %v", gotURL, tt.wantURL) + } + }) + } +} + +func TestHeaderPreparationLogic(t *testing.T) { + defaultUA := "ab-proxy/1.0.0" // Assuming this is a package const or accessible default + tests := []struct { + name string + optsToSet Options // Use this to set fields other than Args.Url + urlArg string // Explicitly pass URL arg for clarity + expectedHeaders map[string]string + wantErr bool + }{ + {"DefaultUA", Options{UserAgent: defaultUA}, "url", map[string]string{"User-Agent": defaultUA}, false}, + {"CustomUA", Options{UserAgent: "CustomAgent/1.0"}, "url", map[string]string{"User-Agent": "CustomAgent/1.0"}, false}, + {"SingleHeader", Options{UserAgent: defaultUA, Header: []string{"X-Custom: val"}}, "url", map[string]string{"User-Agent": defaultUA, "X-Custom": "val"}, false}, + {"MultiHeaders", Options{UserAgent: defaultUA, Header: []string{"X-One: 1", "X-Two: 2"}}, "url", map[string]string{"User-Agent": defaultUA, "X-One": "1", "X-Two": "2"}, false}, + {"OverrideUAInHeader", Options{UserAgent: defaultUA, Header: []string{"User-Agent: Override/3.0"}}, "url", map[string]string{"User-Agent": "Override/3.0"}, false}, + {"MalformedHeaderLine", Options{UserAgent: defaultUA, Header: []string{"BadFormatNoColon"}}, "url", map[string]string{}, true}, // expect error, so headers map can be empty + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := tt.optsToSet // Copy base options + opts.Args.Url = tt.urlArg // Set the URL argument correctly + _, gotHeaders, _, err := simulateURLProcessing(opts) + if (err != nil) != tt.wantErr { + t.Errorf("Header preparation error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if len(gotHeaders) != len(tt.expectedHeaders) { + t.Errorf("Expected %d headers, got %d. Got: %v, Expected: %v", len(tt.expectedHeaders), len(gotHeaders), gotHeaders, tt.expectedHeaders) + } + for k, v := range tt.expectedHeaders { + canonicalKey := textproto.CanonicalMIMEHeaderKey(k) + ghVal, ok := gotHeaders[canonicalKey] + if !ok { + t.Errorf("Expected header %s not found", k) + continue + } + if len(ghVal) == 0 || ghVal[0] != v { + t.Errorf("Header %s: expected '%s', got '%v'", k, v, ghVal) + } + } + } + }) + } +} + + +func TestProxyConfigurationLogic(t *testing.T) { + tests := []struct { + name string + proxyOpt string + expectedScheme string // Scheme of the proxy URL in the transport, if one is set + expectProxy bool // True if a proxy function should be configured on the transport + wantErr bool + }{ + {"NoProxy", "", "", false, false}, + {"HTTPProxy", "http://proxy.example.com:8080", "http", true, false}, + {"HTTPSProxy", "https://proxy.example.com:8888", "https", true, false}, + {"SOCKS5Proxy", "socks5://proxy.example.com:1080", "socks5", true, false}, + {"SchemeMissingDefaultsToHTTP", "proxy.example.com:8080", "http", true, false}, + {"InvalidScheme", "ftp://proxy.example.com", "", false, true}, + {"MalformedURL", "http://[::1%scope]:port", "", false, true}, // Invalid URL + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := Options{Proxy: tt.proxyOpt} + opts.Args.Url = "http://dummy.url" // Set the URL argument correctly + _, _, gotTransport, err := simulateURLProcessing(opts) + + if (err != nil) != tt.wantErr { + t.Errorf("Proxy config error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.expectProxy { + if gotTransport == nil || gotTransport.Proxy == nil { + t.Fatalf("Expected transport with Proxy function, got transport: %v", gotTransport) + } + // Verify scheme by trying to get the proxy URL + dummyReq, _ := http.NewRequest("GET", "http://target.url", nil) + proxyURL, pErr := gotTransport.Proxy(dummyReq) + if pErr != nil { + t.Fatalf("Transport Proxy function returned error: %v", pErr) + } + if proxyURL == nil { + t.Fatalf("Transport Proxy function returned nil URL, expected one for scheme: %s", tt.expectedScheme) + } + if proxyURL.Scheme != tt.expectedScheme { + t.Errorf("Expected proxy scheme %s, got %s", tt.expectedScheme, proxyURL.Scheme) + } + } else { + if !tt.wantErr && gotTransport != nil && gotTransport.Proxy != nil { + // If no error was expected, and no proxy was expected, transport should not have Proxy func. + // (it could be non-nil if default transport settings were applied, but Proxy would be nil) + dummyReq, _ := http.NewRequest("GET", "http://target.url", nil) + proxyURL, _ := gotTransport.Proxy(dummyReq) + if proxyURL != nil { + t.Errorf("Expected no proxy to be configured, but got proxy URL: %s", proxyURL.String()) + } + } + } + }) + } +} + + +func TestOptionsDefaultsAndParsing(t *testing.T) { + t.Run("DefaultValues", func(t *testing.T) { + var opts Options + parser := flags.NewParser(&opts, flags.None) + _, err := parser.ParseArgs([]string{"http://example.com"}) + if err != nil { + if e, ok := err.(*flags.Error); ok && e.Type == flags.ErrRequired { + } else { + t.Fatalf("ParseArgs() failed with unexpected error: %v", err) + } + } + + if opts.Concurrency != 1 { + t.Errorf("Expected default Concurrency 1, got %d", opts.Concurrency) + } + if opts.Requests != 1 { + t.Errorf("Expected default Requests 1, got %d", opts.Requests) + } + if opts.Bursts != 1 { + t.Errorf("Expected default Bursts 1, got %d", opts.Bursts) + } + if opts.Delay != 3 { + t.Errorf("Expected default Delay 3, got %d", opts.Delay) + } + // Default UserAgent is applied in runBenchmark if opts.UserAgent is empty, + // or by go-flags if default tag is on Options.UserAgent. + // The current Options struct has `default:"ab-proxy/1.0.0"` + if opts.UserAgent != "ab-proxy/1.0.0" { + t.Errorf("Expected default UserAgent 'ab-proxy/1.0.0', got '%s'", opts.UserAgent) + } + if opts.Timeout != 0 { + t.Errorf("Expected default Timeout 0, got %d", opts.Timeout) + } + if opts.ShowErrors != false { + t.Errorf("Expected default ShowErrors false, got %v", opts.ShowErrors) + } + }) + + t.Run("FlagParsing", func(t *testing.T) { + tests := []struct { + name string + args []string + validate func(opts Options, t *testing.T) + }{ + {"Concurrency", []string{"-c", "10", "u"}, func(o Options, t *testing.T) { if o.Concurrency != 10 { t.Error() } }}, + {"RequestsAndBursts", []string{"-n", "100", "--bursts", "5", "u"}, func(o Options, t *testing.T) { if o.Requests != 100 || o.Bursts != 5 { t.Error() } }}, + {"Proxy", []string{"-X", "s5://lh:1080", "u"}, func(o Options, t *testing.T) { if o.Proxy != "s5://lh:1080" { t.Error() } }}, // s5 shorthand for socks5 + {"Headers", []string{"-H", "XF: b", "-H", "XB: z", "u"}, func(o Options, t *testing.T) { + exp := []string{"XF: b", "XB: z"}; if !reflect.DeepEqual(o.Header, exp) { t.Errorf("H: exp %v got %v", exp, o.Header) } + }}, + {"ShowErrors", []string{"--show-errors", "u"}, func(o Options, t *testing.T) { if !o.ShowErrors { t.Error() } }}, + {"URLArg", []string{"https://my.url/p"}, func(o Options, t *testing.T) { if o.Args.Url != "https://my.url/p" { t.Errorf("URL exp %s got %s", "https://my.url/p", o.Args.Url)} }}, + } + + for _, tt := range tests { + // Replace "u" placeholder with a valid URL for parsing to avoid required arg error + argsWithURL := make([]string, len(tt.args)) + for i, arg := range tt.args { + if arg == "u" { + argsWithURL[i] = "http://example.com" + } else { + argsWithURL[i] = arg + } + } + + + t.Run(tt.name, func(t *testing.T) { + var opts Options + parser := flags.NewParser(&opts, flags.None) + _, err := parser.ParseArgs(argsWithURL) + if err != nil { + if e, ok := err.(*flags.Error); ok && e.Type == flags.ErrRequired { + // This should not happen if "u" is correctly replaced. + } else { + t.Fatalf("ParseArgs(%v) failed: %v", argsWithURL, err) + } + } + tt.validate(opts, t) + }) + } + }) +} + + +// Ensure there's a newline at the very end of the file. diff --git a/go.mod b/go.mod index fcd1ed5..355a50c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,14 @@ module ab-proxy -go 1.12 +go 1.22.2 require ( - github.com/jessevdk/go-flags v1.4.0 - github.com/schollz/progressbar/v2 v2.13.2 + github.com/jessevdk/go-flags v1.6.1 + github.com/schollz/progressbar/v2 v2.15.0 +) + +require ( + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + golang.org/x/sys v0.21.0 // indirect +// golang.org/x/sys version will be determined by go mod tidy later ) diff --git a/go.sum b/go.sum index fad2f96..77ff36f 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,16 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= -github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4= +github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/schollz/progressbar/v2 v2.13.2 h1:3L9bP5KQOGEnFP8P5V8dz+U0yo5I29iY5Oa9s9EAwn0= -github.com/schollz/progressbar/v2 v2.13.2/go.mod h1:6YZjqdthH6SCZKv2rqGryrxPtfmRB/DWZxSMfCXPyD8= +github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8= +github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/mockservers_test.go b/mockservers_test.go new file mode 100644 index 0000000..b1a5e8e --- /dev/null +++ b/mockservers_test.go @@ -0,0 +1,196 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "time" + "strings" // Added for strings.HasPrefix +) + +// --- Mock HTTP Target Server --- + +type MockTargetServerConfig struct { + StatusCode int + ResponseBody []byte + ResponseDelay time.Duration + // Add any other configurations like specific headers to check for, etc. +} + +type MockTargetServer struct { + Server *httptest.Server + Config MockTargetServerConfig + mu sync.Mutex + ReceivedHeaders http.Header + RequestCount int64 + LastRequestBody []byte + ReceivedRequests []*http.Request // Store all received requests for detailed inspection +} + +func NewMockTargetServer(config MockTargetServerConfig) *MockTargetServer { + mts := &MockTargetServer{ + Config: config, + ReceivedHeaders: make(http.Header), + } + + mts.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&mts.RequestCount, 1) + + // Capture headers + mts.mu.Lock() + for k, v := range r.Header { + mts.ReceivedHeaders[k] = append(mts.ReceivedHeaders[k], v...) + } + // Store a copy of the request for potential later inspection + var bodyBytes []byte + var err error + if r.Body != nil { + bodyBytes, err = io.ReadAll(r.Body) + if err == nil { + mts.LastRequestBody = bodyBytes + } + r.Body.Close() // Important to close the body + } + mts.ReceivedRequests = append(mts.ReceivedRequests, r) + mts.mu.Unlock() + + if mts.Config.ResponseDelay > 0 { + time.Sleep(mts.Config.ResponseDelay) + } + + w.WriteHeader(mts.Config.StatusCode) + if mts.Config.ResponseBody != nil { + w.Write(mts.Config.ResponseBody) + } + })) + + return mts +} + +func (mts *MockTargetServer) URL() string { + return mts.Server.URL +} + +func (mts *MockTargetServer) Close() { + mts.Server.Close() +} + +func (mts *MockTargetServer) GetRequestCount() int64 { + return atomic.LoadInt64(&mts.RequestCount) +} + +func (mts *MockTargetServer) GetReceivedHeaders() http.Header { + mts.mu.Lock() + defer mts.mu.Unlock() + hdrs := make(http.Header) + for k, v := range mts.ReceivedHeaders { + hdrs[k] = append([]string(nil), v...) + } + return hdrs +} + +func (mts *MockTargetServer) GetLastRequestBody() []byte { + mts.mu.Lock() + defer mts.mu.Unlock() + return mts.LastRequestBody +} + +func (mts *MockTargetServer) GetReceivedRequest(index int) *http.Request { + mts.mu.Lock() + defer mts.mu.Unlock() + if index < 0 || index >= len(mts.ReceivedRequests) { + return nil + } + return mts.ReceivedRequests[index] +} + + +// --- Mock HTTP Proxy Server --- + +type MockHTTPProxyServer struct { + Server *httptest.Server + ProxiedCount int64 + mu sync.Mutex + ReceivedViaHeader bool + TargetDialErrors int64 +} + +func NewMockHTTPProxyServer() *MockHTTPProxyServer { + mps := &MockHTTPProxyServer{} + + mps.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + atomic.AddInt64(&mps.ProxiedCount, 1) + w.WriteHeader(http.StatusOK) + return + } + + targetURL := r.RequestURI + if targetURL == "" { + http.Error(w, "missing request URI", http.StatusBadRequest) + return + } + + var body io.Reader = r.Body + outReq, err := http.NewRequest(r.Method, targetURL, body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create request to target: %v", err), http.StatusInternalServerError) + atomic.AddInt64(&mps.TargetDialErrors, 1) + return + } + + outReq.Header = make(http.Header) + for key, values := range r.Header { + if strings.HasPrefix(strings.ToLower(key), "proxy-") { + continue + } + outReq.Header[key] = values + } + outReq.Header.Set("X-Proxied-By", "mockHTTPProxyServer") + + client := &http.Client{ + Transport: &http.Transport{Proxy: nil}, + } + + resp, err := client.Do(outReq) + if err != nil { + http.Error(w, fmt.Sprintf("failed to get response from target: %v", err), http.StatusBadGateway) + atomic.AddInt64(&mps.TargetDialErrors, 1) + return + } + defer resp.Body.Close() + + atomic.AddInt64(&mps.ProxiedCount, 1) + + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + })) + return mps +} + +func (mps *MockHTTPProxyServer) URL() string { + return mps.Server.URL +} + +func (mps *MockHTTPProxyServer) Close() { + mps.Server.Close() +} + +func (mps *MockHTTPProxyServer) GetProxiedCount() int64 { + return atomic.LoadInt64(&mps.ProxiedCount) +} + +func (mps *MockHTTPProxyServer) GetTargetDialErrors() int64 { + return atomic.LoadInt64(&mps.TargetDialErrors) +} + +// Note: A SOCKS5 mock server is significantly more complex. +// The example main function and detailed feature comments have been removed to prevent parsing issues.