Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions disk/disk_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,25 @@ func UsageWithContext(_ context.Context, path string) (*UsageStat, error) {
// PartitionsWithContext returns disk partitions.
// It uses procGetLogicalDriveStringsW to get drives with drive letters and procFindFirstVolumeW to get volumes without drive letters.
// Since the api calls don't have a timeout, this method uses context to set deadline by users.
func PartitionsWithContext(_ context.Context, _ bool) ([]PartitionStat, error) {
func PartitionsWithContext(ctx context.Context, _ bool) ([]PartitionStat, error) {
warnings := Warnings{Verbose: true}
processedPaths := make(map[string]struct{})
partitionStats := []PartitionStat{}

// Get drives with drive letters (including remote drives, ex: SMB shares)
drives, err := getLogicalDrives()
drives, err := getLogicalDrives(ctx)
if err != nil {
return partitionStats, err
}

partitionStats = processLogicalDrives(drives, processedPaths, partitionStats, warnings)
partitionStats = processLogicalDrives(ctx, drives, processedPaths, partitionStats, warnings)

// Get volumes without drive letters (ex: mounted folders with no drive letter)
partitionStats = processVolumesMountedAsFolders(partitionStats, warnings, processedPaths)
partitionStats = processVolumesMountedAsFolders(ctx, partitionStats, warnings, processedPaths)
return partitionStats, warnings.Reference()
}

func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings Warnings, processedPaths map[string]struct{}) []PartitionStat {
func processVolumesMountedAsFolders(ctx context.Context, partitionStats []PartitionStat, warnings Warnings, processedPaths map[string]struct{}) []PartitionStat {
volNameBuf := make([]uint16, maxVolumeNameLength)
nextVolHandle, _, err := procFindFirstVolumeW.Call(
uintptr(unsafe.Pointer(&volNameBuf[0])),
Expand All @@ -136,23 +136,26 @@ func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings War
return partitionStats
}
defer procFindVolumeClose.Call(nextVolHandle)
partitionStats = processVolumeLoop(ctx, nextVolHandle, volNameBuf, processedPaths, partitionStats, warnings)
return partitionStats
}

func processVolumeLoop(ctx context.Context, nextVolHandle uintptr, volNameBuf []uint16, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat {
for {
select {
case <-ctx.Done():
warnings.Add(fmt.Errorf("context cancelled while processing volumes: %w", ctx.Err()))
return partitionStats
default:
}

mounts, err := getVolumePaths(volNameBuf)
if err != nil {
warnings.Add(fmt.Errorf("failed to find paths for volume %s", windows.UTF16ToString(volNameBuf)))
continue
}

for _, mount := range mounts {
if _, ok := processedPaths[mount]; ok {
continue
}
if partitionStat, warning := buildPartitionStat(mount); warning == nil {
partitionStats = append(partitionStats, partitionStat)
} else {
warnings.Add(warning)
}
}
partitionStats = processMountsForVolume(ctx, mounts, processedPaths, partitionStats, warnings)

volNameBuf = make([]uint16, maxVolumeNameLength)
if volRet, _, err := procFindNextVolumeW.Call(
Expand All @@ -172,13 +175,17 @@ func processVolumesMountedAsFolders(partitionStats []PartitionStat, warnings War
return partitionStats
}

func processLogicalDrives(drives []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat {
for _, drive := range drives {
if drive != "" && drive[0] >= firstPossibleDriveLetter && drive[0] <= lastPossibleDriveLetter {
v := drive[0]
path := string(v) + ":"
if partitionStat, warning := buildPartitionStat(path); warning == nil {
processedPaths[partitionStat.Mountpoint+"\\"] = struct{}{}
func processMountsForVolume(ctx context.Context, mounts []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat {
for _, mount := range mounts {
select {
case <-ctx.Done():
warnings.Add(fmt.Errorf("context cancelled while processing mount points per volume: %w", ctx.Err()))
return partitionStats
default:
if _, ok := processedPaths[mount]; ok {
continue
}
if partitionStat, warning := buildPartitionStat(mount); warning == nil {
partitionStats = append(partitionStats, partitionStat)
} else {
warnings.Add(warning)
Expand All @@ -188,16 +195,41 @@ func processLogicalDrives(drives []string, processedPaths map[string]struct{}, p
return partitionStats
}

func processLogicalDrives(ctx context.Context, drives []string, processedPaths map[string]struct{}, partitionStats []PartitionStat, warnings Warnings) []PartitionStat {
for _, drive := range drives {
select {
case <-ctx.Done():
warnings.Add(fmt.Errorf("context cancelled while processing logical drives: %w", ctx.Err()))
return partitionStats
default:
if drive != "" && drive[0] >= firstPossibleDriveLetter && drive[0] <= lastPossibleDriveLetter {
v := drive[0]
path := string(v) + ":"
if partitionStat, warning := buildPartitionStat(path); warning == nil {
processedPaths[partitionStat.Mountpoint+"\\"] = struct{}{}
partitionStats = append(partitionStats, partitionStat)
} else {
warnings.Add(warning)
}
}
}
}
return partitionStats
}

// getLogicalDrives retrieves all logical drives using GetLogicalDriveStringsW.
// We first call GetLogicalDriveStringsW with a buffer length of 0 to get the required buffer size.
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getlogicaldrivestringsw
func getLogicalDrives() ([]string, error) {
func getLogicalDrives(ctx context.Context) ([]string, error) {
bufferLen, _, err := procGetLogicalDriveStringsW.Call(
uintptr(0),
uintptr(0))
if !errors.Is(err, windows.ERROR_SUCCESS) {
return nil, err // The call failed with an unexpected error
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
lpBuffer := make([]uint16, bufferLen)
// buffer can be longer than MAX_PATH
_, _, err = procGetLogicalDriveStringsW.Call(
Expand Down
5 changes: 3 additions & 2 deletions disk/disk_windows_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package disk

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetLogicalDrives(t *testing.T) {
drives, err := getLogicalDrives()
drives, err := getLogicalDrives(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, drives)
for _, d := range drives {
Expand All @@ -33,7 +34,7 @@ func TestProcessLogicalDrives(t *testing.T) {
processedPaths := map[string]struct{}{}
warnings := Warnings{}

parts := processLogicalDrives(drives, processedPaths, partitionStats, warnings)
parts := processLogicalDrives(context.Background(), drives, processedPaths, partitionStats, warnings)
assert.Len(t, parts, 1)
assert.Equal(t, "C:", parts[0].Mountpoint)
assert.Equal(t, "C:", parts[0].Device)
Expand Down