Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func newJobCommand() *cobra.Command {
cmd.AddCommand(newJobShowCommand())
cmd.AddCommand(newJobDeleteCommand())
cmd.AddCommand(newJobCancelCommand())
cmd.AddCommand(newJobDownloadCommand())

return cmd
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package cmd

import (
"fmt"
"os"
"path/filepath"

"azure.ai.customtraining/internal/azcopy"
"azure.ai.customtraining/internal/utils"
"azure.ai.customtraining/pkg/client"
"azure.ai.customtraining/pkg/models"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/azure/azure-dev/cli/azd/pkg/azdext"
"github.com/spf13/cobra"
)

func newJobDownloadCommand() *cobra.Command {
var name string
var downloadPath string

cmd := &cobra.Command{
Use: "download",
Short: "Download job output artifacts to a local directory",
Long: "Download output artifacts from a completed training job to a local directory.\n\n" +
"Example:\n" +
" azd ai training job download --name llama-sft\n" +
" azd ai training job download --name llama-sft --path ./downloads",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := azdext.WithAccessToken(cmd.Context())

if name == "" {
return fmt.Errorf("--name is required")
}

// Default download path to current directory
if downloadPath == "" {
downloadPath = "./"
}

azdClient, err := azdext.NewAzdClient()
if err != nil {
return fmt.Errorf("failed to create azd client: %w", err)
}
defer azdClient.Close()

envValues, err := utils.GetEnvironmentValues(ctx, azdClient)
if err != nil {
return fmt.Errorf("failed to get environment values: %w", err)
}

accountName := envValues[utils.EnvAzureAccountName]
projectName := envValues[utils.EnvAzureProjectName]
tenantID := envValues[utils.EnvAzureTenantID]

if accountName == "" || projectName == "" {
return fmt.Errorf("environment not configured. Run 'azd ai training init' first")
}

credential, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
TenantID: tenantID,
AdditionallyAllowedTenants: []string{"*"},
})
if err != nil {
return fmt.Errorf("failed to create azure credential: %w", err)
}

endpoint := buildProjectEndpoint(accountName, projectName)
apiClient, err := client.NewClient(endpoint, credential)
if err != nil {
return fmt.Errorf("failed to create API client: %w", err)
}

// Step 1: Verify job exists and is in a terminal state
fmt.Printf("Downloading artifacts for job '%s'...\n\n", name)

job, err := apiClient.GetJob(ctx, name)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}

if !models.TerminalStatuses[job.Properties.Status] {
return fmt.Errorf(
"job '%s' is in status '%s'. Download is only available for jobs in a terminal state "+
"(Completed, Failed, Canceled, NotResponding, Paused)",
name, job.Properties.Status,
)
}

// Step 2: List all artifacts to discover output paths/prefixes
fmt.Println("| Listing artifacts...")

allArtifacts, err := apiClient.ListAllArtifacts(ctx, name)
if err != nil {
return fmt.Errorf("failed to list artifacts: %w", err)
}

if len(allArtifacts) == 0 {
fmt.Println(" No artifacts found for this job.")
return nil
}

// Collect unique first-level folder prefixes for batch SAS URI retrieval
prefixes := utils.CollectArtifactPrefixes(allArtifacts)

fmt.Printf("✓ Found %d artifacts\n\n", len(allArtifacts))

// Step 3: Get SAS URIs for all artifacts using prefix/contentinfo (batch)
var allSASItems []models.ArtifactContentInfo
for _, prefix := range prefixes {
items, err := apiClient.GetAllArtifactSASForPath(ctx, name, prefix)
if err != nil {
return fmt.Errorf("failed to get SAS URIs for prefix '%s': %w", prefix, err)
}
allSASItems = append(allSASItems, items...)
}

if len(allSASItems) == 0 {
return fmt.Errorf("no downloadable SAS URIs returned for job artifacts")
}

// Compute total download size from SAS content info
var totalSize int64
for _, item := range allSASItems {
totalSize += item.ContentLength
}
fmt.Printf(" Total download size: %s\n\n", utils.FormatSize(totalSize))

// Initialize azcopy runner
azRunner, err := azcopy.NewRunner(ctx, "")
if err != nil {
return fmt.Errorf("failed to initialize azcopy: %w", err)
}

// Resolve absolute download path
absPath, err := filepath.Abs(downloadPath)
if err != nil {
return fmt.Errorf("failed to resolve download path: %w", err)
}

// Step 4: Download each artifact via azcopy
fmt.Println("| Downloading...")

for i, item := range allSASItems {
// Preserve directory structure from artifact path
localFilePath := filepath.Join(absPath, filepath.FromSlash(item.Path))
localDir := filepath.Dir(localFilePath)

if err := os.MkdirAll(localDir, 0750); err != nil {
return fmt.Errorf("failed to create directory %s: %w", localDir, err)
}

// Display progress tree
connector := "├─"
if i == len(allSASItems)-1 {
connector = "└─"
}
fmt.Printf(" %s %s (%s)\n", connector, item.Path, utils.FormatSize(item.ContentLength))

// Download: SAS URI → local file
if err := azRunner.Copy(ctx, item.ContentURI, localFilePath); err != nil {
return fmt.Errorf("failed to download %s: %w", item.Path, err)
}
}

fmt.Printf("\n✓ Downloaded to %s\n", absPath)

return nil
},
}

cmd.Flags().StringVar(&name, "name", "", "Job name/ID (required)")
cmd.Flags().StringVar(&downloadPath, "path", "./",
"Local directory to download into")

return cmd
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package utils

import (
"strings"

"azure.ai.customtraining/pkg/models"
)

// CollectArtifactPrefixes extracts unique first-level folder prefixes from artifact paths.
func CollectArtifactPrefixes(artifacts []models.Artifact) []string {
seen := make(map[string]bool)
var prefixes []string

for _, a := range artifacts {
parts := strings.SplitN(a.Path, "/", 2)
prefix := parts[0] + "/"

if !seen[prefix] {
seen[prefix] = true
prefixes = append(prefixes, prefix)
}
}

return prefixes
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,27 @@ func formatFieldValue(v reflect.Value) string {
return fmt.Sprintf("%v", v.Interface())
}
}

// FormatSize formats a byte count as a human-readable string.
func FormatSize(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)

if bytes <= 0 {
return "unknown size"
}

switch {
case bytes >= GB:
return fmt.Sprintf("%.1f GB", float64(bytes)/float64(GB))
case bytes >= MB:
return fmt.Sprintf("%.1f MB", float64(bytes)/float64(MB))
case bytes >= KB:
return fmt.Sprintf("%.1f KB", float64(bytes)/float64(KB))
default:
return fmt.Sprintf("%d B", bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ func (c *Client) ListArtifacts(ctx context.Context, jobID string) (*models.Artif
return &result, nil
}

// ListAllArtifacts pages through all artifacts for a job.
func (c *Client) ListAllArtifacts(ctx context.Context, jobID string) ([]models.Artifact, error) {
result, err := c.ListArtifacts(ctx, jobID)
if err != nil {
return nil, err
}
if result == nil {
return nil, nil
}
return result.Value, nil
}

// ListArtifactsInPath lists artifacts under a specific path prefix.
// GET .../jobs/{id}/artifacts/path?path={prefix}
func (c *Client) ListArtifactsInPath(
Expand Down Expand Up @@ -198,3 +210,17 @@ func (c *Client) GetArtifactSASForPath(

return &result, nil
}

// GetAllArtifactSASForPath pages through all SAS URIs for artifacts under a path prefix.
func (c *Client) GetAllArtifactSASForPath(
ctx context.Context, jobID string, pathPrefix string,
) ([]models.ArtifactContentInfo, error) {
result, err := c.GetArtifactSASForPath(ctx, jobID, pathPrefix)
if err != nil {
return nil, err
}
if result == nil {
return nil, nil
}
return result.Value, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ type CommandJob struct {
CreatedDateTime string `json:"createdDateTime,omitempty"`
Services map[string]interface{} `json:"services,omitempty"`
}

// TerminalStatuses contains job statuses that indicate the job has finished.
var TerminalStatuses = map[string]bool{
"Completed": true,
"Failed": true,
"Canceled": true,
"NotResponding": true,
"Paused": true,
}