diff --git a/disk/disk_windows.go b/disk/disk_windows.go index a19a2c58d4..2d3a878f45 100644 --- a/disk/disk_windows.go +++ b/disk/disk_windows.go @@ -108,25 +108,25 @@ func UsageWithContext(_ context.Context, path string) (*UsageStat, error) { // PartitionsWithContext returns disk partitions. // It uses procGetLogicalDriveStringsW to get drives with drive letters and procFindFirstVolumeW to get volumes without drive letters. // Since the api calls don't have a timeout, this method uses context to set deadline by users. -func PartitionsWithContext(_ context.Context, _ bool) ([]PartitionStat, error) { +func PartitionsWithContext(ctx context.Context, _ bool) ([]PartitionStat, error) { warnings := Warnings{Verbose: true} processedPaths := make(map[string]struct{}) partitionStats := []PartitionStat{} // Get drives with drive letters (including remote drives, ex: SMB shares) - drives, err := getLogicalDrives() + drives, err := getLogicalDrives(ctx) if err != nil { return partitionStats, err } - partitionStats = processLogicalDrives(drives, processedPaths, partitionStats, warnings) + partitionStats = processLogicalDrives(ctx, drives, processedPaths, partitionStats, warnings) // Get volumes without drive letters (ex: mounted folders with no drive letter) - partitionStats = processVolumesMountedAsFolders(partitionStats, warnings, processedPaths) + partitionStats = processVolumesMountedAsFolders(ctx, partitionStats, warnings, processedPaths) return partitionStats, warnings.Reference() } -func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings Warnings, processedPaths map[string]struct{}) []PartitionStat { +func processVolumesMountedAsFolders(ctx context.Context, partitionStats []PartitionStat, warnings Warnings, processedPaths map[string]struct{}) []PartitionStat { volNameBuf := make([]uint16, maxVolumeNameLength) nextVolHandle, _, err := procFindFirstVolumeW.Call( uintptr(unsafe.Pointer(&volNameBuf[0])), @@ -136,23 +136,26 @@ func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings War return partitionStats } defer procFindVolumeClose.Call(nextVolHandle) + partitionStats = processVolumeLoop(ctx, nextVolHandle, volNameBuf, processedPaths, partitionStats, warnings) + return partitionStats +} + +func processVolumeLoop(ctx context.Context, nextVolHandle uintptr, volNameBuf []uint16, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat { for { + select { + case <-ctx.Done(): + warnings.Add(fmt.Errorf("context cancelled while processing volumes: %w", ctx.Err())) + return partitionStats + default: + } + mounts, err := getVolumePaths(volNameBuf) if err != nil { warnings.Add(fmt.Errorf("failed to find paths for volume %s", windows.UTF16ToString(volNameBuf))) continue } - for _, mount := range mounts { - if _, ok := processedPaths[mount]; ok { - continue - } - if partitionStat, warning := buildPartitionStat(mount); warning == nil { - partitionStats = append(partitionStats, partitionStat) - } else { - warnings.Add(warning) - } - } + partitionStats = processMountsForVolume(ctx, mounts, processedPaths, partitionStats, warnings) volNameBuf = make([]uint16, maxVolumeNameLength) if volRet, _, err := procFindNextVolumeW.Call( @@ -172,13 +175,17 @@ func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings War return partitionStats } -func processLogicalDrives(drives []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat { - for _, drive := range drives { - if drive != "" && drive[0] >= firstPossibleDriveLetter && drive[0] <= lastPossibleDriveLetter { - v := drive[0] - path := string(v) + ":" - if partitionStat, warning := buildPartitionStat(path); warning == nil { - processedPaths[partitionStat.Mountpoint+"\\"] = struct{}{} +func processMountsForVolume(ctx context.Context, mounts []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat { + for _, mount := range mounts { + select { + case <-ctx.Done(): + warnings.Add(fmt.Errorf("context cancelled while processing mount points per volume: %w", ctx.Err())) + return partitionStats + default: + if _, ok := processedPaths[mount]; ok { + continue + } + if partitionStat, warning := buildPartitionStat(mount); warning == nil { partitionStats = append(partitionStats, partitionStat) } else { warnings.Add(warning) @@ -188,16 +195,41 @@ func processLogicalDrives(drives []string, processedPaths map[string]struct{}, p return partitionStats } +func processLogicalDrives(ctx context.Context, drives []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat { + for _, drive := range drives { + select { + case <-ctx.Done(): + warnings.Add(fmt.Errorf("context cancelled while processing logical drives: %w", ctx.Err())) + return partitionStats + default: + if drive != "" && drive[0] >= firstPossibleDriveLetter && drive[0] <= lastPossibleDriveLetter { + v := drive[0] + path := string(v) + ":" + if partitionStat, warning := buildPartitionStat(path); warning == nil { + processedPaths[partitionStat.Mountpoint+"\\"] = struct{}{} + partitionStats = append(partitionStats, partitionStat) + } else { + warnings.Add(warning) + } + } + } + } + return partitionStats +} + // getLogicalDrives retrieves all logical drives using GetLogicalDriveStringsW. // We first call GetLogicalDriveStringsW with a buffer length of 0 to get the required buffer size. // https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getlogicaldrivestringsw -func getLogicalDrives() ([]string, error) { +func getLogicalDrives(ctx context.Context) ([]string, error) { bufferLen, _, err := procGetLogicalDriveStringsW.Call( uintptr(0), uintptr(0)) if !errors.Is(err, windows.ERROR_SUCCESS) { return nil, err // The call failed with an unexpected error } + if ctx.Err() != nil { + return nil, ctx.Err() + } lpBuffer := make([]uint16, bufferLen) // buffer can be longer than MAX_PATH _, _, err = procGetLogicalDriveStringsW.Call( diff --git a/disk/disk_windows_test.go b/disk/disk_windows_test.go index a9262dde20..3e6160a1a0 100644 --- a/disk/disk_windows_test.go +++ b/disk/disk_windows_test.go @@ -1,6 +1,7 @@ package disk import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -8,7 +9,7 @@ import ( ) func TestGetLogicalDrives(t *testing.T) { - drives, err := getLogicalDrives() + drives, err := getLogicalDrives(context.Background()) require.NoError(t, err) assert.NotEmpty(t, drives) for _, d := range drives { @@ -33,7 +34,7 @@ func TestProcessLogicalDrives(t *testing.T) { processedPaths := map[string]struct{}{} warnings := Warnings{} - parts := processLogicalDrives(drives, processedPaths, partitionStats, warnings) + parts := processLogicalDrives(context.Background(), drives, processedPaths, partitionStats, warnings) assert.Len(t, parts, 1) assert.Equal(t, "C:", parts[0].Mountpoint) assert.Equal(t, "C:", parts[0].Device)