diff --git a/dev-tools/mage/downloads/utils.go b/dev-tools/mage/downloads/utils.go index 96ec73e9ed0..4235e522276 100644 --- a/dev-tools/mage/downloads/utils.go +++ b/dev-tools/mage/downloads/utils.go @@ -5,8 +5,10 @@ package downloads import ( + "context" "fmt" "io" + "net/http" "os" "path/filepath" "regexp" @@ -15,67 +17,51 @@ import ( devtools "github.com/elastic/elastic-agent/dev-tools/mage" "github.com/cenkalti/backoff/v4" - "github.com/gofrs/uuid/v5" ) var checksumFileRegex = regexp.MustCompile(`^([0-9a-f]{128})\s+(\w.*)$`) // downloadRequest struct contains download details ad path and URL type downloadRequest struct { - URL string - DownloadPath string - UnsanitizedFilePath string + URL string + TargetPath string } // downloadFile will download a url and store it in a temporary path. // It writes to the destination file as it downloads it, without // loading the entire file into memory. func downloadFile(downloadRequest *downloadRequest) error { - var filePath string - if downloadRequest.DownloadPath == "" { - u, err := uuid.NewV4() - if err != nil { - return fmt.Errorf("failed to create UUID: %w", err) - } - tempParentDir := filepath.Join(os.TempDir(), u.String()) - err = mkdirAll(tempParentDir) - if err != nil { - return fmt.Errorf("creating directory: %w", err) - } - u, err = uuid.NewV4() - if err != nil { - return fmt.Errorf("failed to create UUID: %w", err) - } - filePath = filepath.Join(tempParentDir, u.String()) - downloadRequest.DownloadPath = filePath - } else { - u, err := uuid.NewV4() - if err != nil { - return fmt.Errorf("failed to create UUID: %w", err) - } - filePath = filepath.Join(downloadRequest.DownloadPath, u.String()) - } - - tempFile, err := os.Create(filePath) + targetFile, err := os.Create(downloadRequest.TargetPath) if err != nil { return fmt.Errorf("creating file: %w", err) } - defer tempFile.Close() + defer func() { + _ = targetFile.Close() + }() - downloadRequest.UnsanitizedFilePath = tempFile.Name() exp := getExponentialBackoff(3) retryCount := 1 - var fileReader io.Reader download := func() error { - r := httpRequest{URL: downloadRequest.URL} - bodyStr, err := get(r) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, downloadRequest.URL, nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + resp, err := http.DefaultClient.Do(req) if err != nil { retryCount++ return fmt.Errorf("downloading file %s: %w", downloadRequest.URL, err) } + defer func() { + _ = resp.Body.Close() + }() + _, err = io.Copy(targetFile, resp.Body) + if err != nil { + // try to drain the body before returning to ensure the connection can be reused + _, _ = io.Copy(io.Discard, resp.Body) + return fmt.Errorf("writing file %s: %w", targetFile.Name(), err) + } - fileReader = strings.NewReader(bodyStr) return nil } @@ -84,12 +70,7 @@ func downloadFile(downloadRequest *downloadRequest) error { return err } - _, err = io.Copy(tempFile, fileReader) - if err != nil { - return fmt.Errorf("writing file %s: %w", tempFile.Name(), err) - } - - _ = os.Chmod(tempFile.Name(), 0666) + _ = os.Chmod(targetFile.Name(), 0666) return nil } diff --git a/dev-tools/mage/downloads/utils_test.go b/dev-tools/mage/downloads/utils_test.go index bf6fe4cd736..43c84df0fc0 100644 --- a/dev-tools/mage/downloads/utils_test.go +++ b/dev-tools/mage/downloads/utils_test.go @@ -26,13 +26,12 @@ func TestDownloadFile(t *testing.T) { var dRequest = downloadRequest{ URL: fmt.Sprintf("http://%s/some-file.txt", s.Listener.Addr().String()), - DownloadPath: "", + TargetPath: filepath.Join(t.TempDir(), "some-file.txt"), } err := downloadFile(&dRequest) assert.Nil(t, err) - assert.NotEmpty(t, dRequest.UnsanitizedFilePath) - defer os.Remove(filepath.Dir(dRequest.UnsanitizedFilePath)) + assert.FileExistsf(t, dRequest.TargetPath, "file should exist") } func TestVerifyChecksum(t *testing.T) { diff --git a/dev-tools/mage/downloads/versions.go b/dev-tools/mage/downloads/versions.go index 21e30923f70..217667b930e 100644 --- a/dev-tools/mage/downloads/versions.go +++ b/dev-tools/mage/downloads/versions.go @@ -10,7 +10,6 @@ import ( "fmt" "log/slog" "os" - "path" "path/filepath" "regexp" "strconv" @@ -107,51 +106,11 @@ func CheckPRVersion(version string, fallbackVersion string) string { return version } -// FetchElasticArtifact fetches an artifact from the right repository, returning binary name, path and error -func FetchElasticArtifact(ctx context.Context, artifact string, version string, os string, arch string, extension string, isDocker bool, xpack bool) (string, string, error) { - useCISnapshots := GithubCommitSha1 != "" - - return FetchElasticArtifactForSnapshots(ctx, useCISnapshots, artifact, version, os, arch, extension, isDocker, xpack) -} - -// FetchElasticArtifactForSnapshots fetches an artifact from the right repository, returning binary name, path and error -func FetchElasticArtifactForSnapshots(ctx context.Context, useCISnapshots bool, artifact string, version string, os string, arch string, extension string, isDocker bool, xpack bool) (string, string, error) { - binaryName := buildArtifactName(artifact, version, os, arch, extension, isDocker) - binaryPath, err := FetchProjectBinaryForSnapshots(ctx, useCISnapshots, artifact, binaryName, artifact, version, timeoutFactor, xpack, "", false) - if err != nil { - logger.Error("Could not download the binary for the Elastic artifact", - slog.String("artifact", artifact), - slog.String("version", version), - slog.String("os", os), - slog.String("arch", arch), - slog.String("extension", extension), - slog.String("error", err.Error()), - ) - return "", "", err - } - - return binaryName, binaryPath, nil -} - // GetCommitVersion returns a version including the version and the git commit, if it exists func GetCommitVersion(version string) string { return newElasticVersion(version).HashedVersion } -// GetElasticArtifactURL returns the URL of a released artifact, which its full name is defined in the first argument, -// from Elastic's artifact repository, building the JSON path query based on the full name -// It also returns the URL of the sha512 file of the released artifact. -// i.e. GetElasticArtifactURL("elastic-agent-$VERSION-$ARCH.deb", "elastic-agent", "$VERSION") -// i.e. GetElasticArtifactURL("elastic-agent-$VERSION-x86_64.rpm", "elastic-agent","$VERSION") -// i.e. GetElasticArtifactURL("elastic-agent-$VERSION-linux-$ARCH.tar.gz", "elastic-agent","$VERSION") -func GetElasticArtifactURL(artifactName string, artifact string, version string) (string, string, error) { - resolver := NewArtifactURLResolver(artifactName, artifact, version) - if resolver == nil { - return "", "", errors.New("nil resolver returned") - } - return resolver.Resolve() -} - // GetElasticArtifactVersion returns the current version: // 1. Elastic's artifact repository, building the JSON path query based // If the version is a SNAPSHOT including a commit, then it will directly use the version without checking the artifacts API @@ -354,11 +313,19 @@ func FetchProjectBinaryForSnapshots(ctx context.Context, useCISnapshots bool, pr return "", fmt.Errorf("⚠️ Beats local path usage is deprecated and not used to fetch the binaries. Please use the packaging job to generate the artifacts to be consumed by these tests") } + if downloadPath == "" { + return "", errors.New("downloadPath cannot be empty") + } + handleDownload := func(URL string) (string, error) { name := artifactName + if strings.HasSuffix(URL, ".sha512") { + name = fmt.Sprintf("%s.sha512", name) + } + downloadFilePath := filepath.Join(downloadPath, name) downloadRequest := downloadRequest{ - DownloadPath: downloadPath, - URL: URL, + TargetPath: downloadFilePath, + URL: URL, } span, _ := apm.StartSpanOptions(ctx, "Fetching Project binary", "project.url.fetch-binary", apm.SpanOptions{ Parent: apm.SpanFromContext(ctx).TraceContext(), @@ -379,28 +346,14 @@ func FetchProjectBinaryForSnapshots(ctx context.Context, useCISnapshots bool, pr err := downloadFile(&downloadRequest) if err != nil { - return downloadRequest.UnsanitizedFilePath, err - } - - if strings.HasSuffix(URL, ".sha512") { - name = fmt.Sprintf("%s.sha512", name) - } - // use artifact name as file name to avoid having URL params in the name - sanitizedFilePath := filepath.Join(path.Dir(downloadRequest.UnsanitizedFilePath), name) - err = os.Rename(downloadRequest.UnsanitizedFilePath, sanitizedFilePath) - if err != nil { - logger.Warn("Could not sanitize downloaded file name. Keeping old name", - slog.String("fileName", downloadRequest.UnsanitizedFilePath), - slog.String("sanitizedFileName", sanitizedFilePath), - ) - sanitizedFilePath = downloadRequest.UnsanitizedFilePath + return "", err } binariesMutex.Lock() - binariesCache[URL] = sanitizedFilePath + binariesCache[URL] = downloadFilePath binariesMutex.Unlock() - return sanitizedFilePath, nil + return downloadFilePath, nil } var downloadURL, downloadShaURL string