Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions contrib/cdisetup/nvidia/nvidia.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ func (s *setup) Run(ctx context.Context) (err error) {
closeProgress(err)
}()

isDistro, _ := isDebianOrUbuntu()
if !isDistro {
osr, err := getOSRelease()
if err != nil {
return err
} else if osr.ID != "debian" && osr.ID != "ubuntu" {
return errors.Errorf("NVIDIA setup is currently only supported on Debian/Ubuntu")
}

Expand Down Expand Up @@ -131,7 +133,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
return err
}

if err := installPackages(ctx, dv, pw, dgst); err != nil {
if err := installPackages(ctx, osr, dv, pw, dgst); err != nil {
return err
}

Expand Down Expand Up @@ -167,8 +169,20 @@ func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Dig
return cmd.Run()
}

func installPackages(ctx context.Context, dv string, pw progress.Writer, dgst digest.Digest) error {
const aptDistro = "ubuntu2404"
func installPackages(ctx context.Context, osr *osrelease, dv string, pw progress.Writer, dgst digest.Digest) error {
aptDistro := "ubuntu2404"
switch osr.ID {
case "debian":
if osr.VersionID == "" {
aptDistro = "debian12"
} else {
aptDistro = "debian" + osr.VersionID
}
case "ubuntu":
if osr.VersionID != "" {
aptDistro = "ubuntu" + strings.ReplaceAll(osr.VersionID, ".", "")
}
}

var arch string
switch runtime.GOARCH {
Expand Down Expand Up @@ -274,35 +288,38 @@ func hasNvidiaDevices() (bool, error) {
return found, nil
}

func getOSID() (string, error) {
type osrelease struct {
ID string
VersionID string
}

func getOSRelease() (*osrelease, error) {
file, err := os.Open("/etc/os-release")
if err != nil {
return "", err
return nil, err
}
defer file.Close()

var id, versionID string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if id, ok := strings.CutPrefix(line, "ID="); ok {
return strings.Trim(id, `"`), nil // Remove potential quotes
if v, ok := strings.CutPrefix(line, "ID="); ok {
id = strings.Trim(v, `"`)
} else if v, ok := strings.CutPrefix(line, "VERSION_ID="); ok {
versionID = strings.Trim(v, `"`)
}
}

if err := scanner.Err(); err != nil {
return "", err
return nil, err
}

return "", errors.Errorf("ID not found in /etc/os-release")
}

func isDebianOrUbuntu() (bool, error) {
id, err := getOSID()
if err != nil {
return false, err
if id == "" {
return nil, errors.Errorf("ID not found in /etc/os-release")
}

return id == "debian" || id == "ubuntu", nil
return &osrelease{ID: id, VersionID: versionID}, nil
}

func hasWSLGPU() bool {
Expand Down