diff --git a/src/constants.ts b/src/constants.ts index 3eb743a72..ae6f796d6 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -201,6 +201,14 @@ export const AMD_TORCH_PACKAGES: string[] = [ 'https://repo.radeon.com/rocm/windows/rocm-rel-7.1.1/torchvision-0.24.0+rocmsdk20251116-cp312-cp312-win_amd64.whl', ]; +export const NVIDIA_TORCH_VERSION = '2.9.1+cu130'; +export const NVIDIA_TORCHVISION_VERSION = '0.24.1+cu130'; +export const NVIDIA_TORCH_PACKAGES: string[] = [ + `torch==${NVIDIA_TORCH_VERSION}`, + `torchaudio==${NVIDIA_TORCH_VERSION}`, + `torchvision==${NVIDIA_TORCHVISION_VERSION}`, +]; + /** The log files used by the desktop process. */ export enum LogFile { /** The ComfyUI server log file. */ diff --git a/src/install/installationManager.ts b/src/install/installationManager.ts index 20e66e71e..1f951d751 100644 --- a/src/install/installationManager.ts +++ b/src/install/installationManager.ts @@ -1,5 +1,7 @@ import { Notification, app, dialog, shell } from 'electron'; import log from 'electron-log/main'; +import { exec } from 'node:child_process'; +import { promisify } from 'node:util'; import { strictIpcMain as ipcMain } from '@/infrastructure/ipcChannels'; @@ -14,12 +16,37 @@ import { CmCli } from '../services/cmCli'; import { captureSentryException } from '../services/sentry'; import { type HasTelemetry, ITelemetry, trackEvent } from '../services/telemetry'; import { type DesktopConfig, useDesktopConfig } from '../store/desktopConfig'; -import { canExecuteShellCommand, validateHardware } from '../utils'; +import { canExecuteShellCommand, compareVersions, validateHardware } from '../utils'; import type { ProcessCallbacks, VirtualEnvironment } from '../virtualEnvironment'; import { createProcessCallbacks } from './createProcessCallbacks'; import { InstallWizard } from './installWizard'; import { Troubleshooting } from './troubleshooting'; +const execAsync = promisify(exec); +const NVIDIA_DRIVER_MIN_VERSION = '580'; + +/** + * Extracts the NVIDIA driver version from `nvidia-smi` output. + * @param output The `nvidia-smi` output to parse. + * @returns The driver version, if present. + */ +export function parseNvidiaDriverVersionFromSmiOutput(output: string): string | undefined { + const match = output.match(/driver version\s*:\s*([\d.]+)/i); + return match?.[1]; +} + +/** + * Returns `true` when the NVIDIA driver version is below the minimum. + * @param driverVersion The detected driver version. + * @param minimumVersion The minimum required driver version. + */ +export function isNvidiaDriverBelowMinimum( + driverVersion: string, + minimumVersion: string = NVIDIA_DRIVER_MIN_VERSION +): boolean { + return compareVersions(driverVersion, minimumVersion) < 0; +} + /** High-level / UI control over the installation of ComfyUI server. */ export class InstallationManager implements HasTelemetry { constructor( @@ -354,6 +381,8 @@ export class InstallationManager implements HasTelemetry { try { await installation.virtualEnvironment.installComfyUIRequirements(callbacks); await installation.virtualEnvironment.installComfyUIManagerRequirements(callbacks); + await this.warnIfNvidiaDriverTooOld(installation); + await installation.virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks); await installation.validate(); } catch (error) { log.error('Error auto-updating packages:', error); @@ -361,6 +390,59 @@ export class InstallationManager implements HasTelemetry { } } + /** + * Warns the user if their NVIDIA driver is too old for the required CUDA build. + * @param installation The current installation. + */ + private async warnIfNvidiaDriverTooOld(installation: ComfyInstallation): Promise { + if (process.platform !== 'win32') return; + if (installation.virtualEnvironment.selectedDevice !== 'nvidia') return; + + const driverVersion = + (await this.getNvidiaDriverVersionFromSmi()) ?? (await this.getNvidiaDriverVersionFromSmiFallback()); + if (!driverVersion) return; + + if (!isNvidiaDriverBelowMinimum(driverVersion)) return; + + await this.appWindow.showMessageBox({ + type: 'warning', + title: 'Update NVIDIA Driver', + message: 'Your NVIDIA driver may be too old for PyTorch 2.9.1 + cu130.', + detail: `Detected driver version: ${driverVersion}\nRecommended minimum: ${NVIDIA_DRIVER_MIN_VERSION}\n\nPlease consider updating your NVIDIA drivers and retrying if you run into issues.`, + buttons: ['OK'], + }); + } + + /** + * Reads the NVIDIA driver version from nvidia-smi query output. + */ + private async getNvidiaDriverVersionFromSmi(): Promise { + try { + const { stdout } = await execAsync('nvidia-smi --query-gpu=driver_version --format=csv,noheader'); + const version = stdout + .split(/\r?\n/) + .map((line) => line.trim()) + .find(Boolean); + return version || undefined; + } catch (error) { + log.debug('Failed to read NVIDIA driver version via nvidia-smi query.', error); + return undefined; + } + } + + /** + * Reads the NVIDIA driver version from nvidia-smi standard output. + */ + private async getNvidiaDriverVersionFromSmiFallback(): Promise { + try { + const { stdout } = await execAsync('nvidia-smi'); + return parseNvidiaDriverVersionFromSmiOutput(stdout); + } catch (error) { + log.debug('Failed to read NVIDIA driver version via nvidia-smi output.', error); + return undefined; + } + } + static setReinstallHandler(installation: ComfyInstallation) { ipcMain.removeHandler(IPC_CHANNELS.REINSTALL); ipcMain.handle(IPC_CHANNELS.REINSTALL, async () => await InstallationManager.reinstall(installation)); diff --git a/src/virtualEnvironment.ts b/src/virtualEnvironment.ts index a4a8d0fbc..b9629c5d7 100644 --- a/src/virtualEnvironment.ts +++ b/src/virtualEnvironment.ts @@ -7,7 +7,15 @@ import { readdir, rm } from 'node:fs/promises'; import os, { EOL } from 'node:os'; import path from 'node:path'; -import { AMD_ROCM_SDK_PACKAGES, AMD_TORCH_PACKAGES, InstallStage, TorchMirrorUrl } from './constants'; +import { + AMD_ROCM_SDK_PACKAGES, + AMD_TORCH_PACKAGES, + InstallStage, + NVIDIA_TORCHVISION_VERSION, + NVIDIA_TORCH_PACKAGES, + NVIDIA_TORCH_VERSION, + TorchMirrorUrl, +} from './constants'; import { PythonImportVerificationError } from './infrastructure/pythonImportVerificationError'; import { useAppState } from './main-process/appState'; import { createInstallStageInfo } from './main-process/installStages'; @@ -16,7 +24,7 @@ import { runPythonImportVerifyScript } from './services/pythonImportVerifier'; import { captureSentryException } from './services/sentry'; import { HasTelemetry, ITelemetry, trackEvent } from './services/telemetry'; import { getDefaultShell, getDefaultShellArgs } from './shell/util'; -import { pathAccessible } from './utils'; +import { compareVersions, pathAccessible } from './utils'; export type ProcessCallbacks = { onStdout?: (data: string) => void; @@ -43,6 +51,11 @@ interface PipInstallConfig { indexStrategy?: 'compatible' | 'unsafe-best-match'; } +type TorchPackageName = 'torch' | 'torchaudio' | 'torchvision'; +type TorchPackageVersions = Record; + +const TORCH_PACKAGE_NAMES: TorchPackageName[] = ['torch', 'torchaudio', 'torchvision']; + function getPipInstallArgs(config: PipInstallConfig): string[] { const installArgs = ['pip', 'install']; @@ -629,6 +642,118 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { } } + /** + * Ensures NVIDIA installs use the recommended PyTorch packages. + * @param callbacks The callbacks to use for the command. + */ + async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise { + if (this.selectedDevice !== 'nvidia') return; + + const installedVersions = await this.getInstalledTorchPackageVersions(); + if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) { + log.info('NVIDIA PyTorch packages already satisfy minimum recommended versions.', installedVersions); + return; + } + + const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice); + const config: PipInstallConfig = { + packages: NVIDIA_TORCH_PACKAGES, + indexUrl: torchMirror, + prerelease: torchMirror.includes('nightly'), + }; + + const installArgs = getPipInstallArgs(config); + log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions }); + const { exitCode: pinnedExitCode } = await this.runUvCommandAsync(installArgs, callbacks); + + if (pinnedExitCode === 0) return; + + log.warn('Failed to install recommended NVIDIA PyTorch packages. Falling back to unpinned install.', { + exitCode: pinnedExitCode, + }); + + const fallbackConfig: PipInstallConfig = { + packages: ['torch', 'torchvision', 'torchaudio'], + indexUrl: torchMirror, + prerelease: torchMirror.includes('nightly'), + upgradePackages: true, + }; + const fallbackArgs = getPipInstallArgs(fallbackConfig); + const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks); + if (fallbackExitCode !== 0) { + throw new Error( + `Failed to install NVIDIA PyTorch packages (pinned exit ${pinnedExitCode}, fallback exit ${fallbackExitCode})` + ); + } + } + + /** + * Reads installed torch package versions using `uv pip list --format=json`. + * @returns The torch package versions when available, otherwise `undefined`. + */ + private async getInstalledTorchPackageVersions(): Promise { + let stdout = ''; + let stderr = ''; + const callbacks: ProcessCallbacks = { + onStdout: (data) => { + stdout += data; + }, + onStderr: (data) => { + stderr += data; + }, + }; + + const { exitCode } = await this.runUvAsync(['pip', 'list', '--format=json', '--color=never'], callbacks); + + if (exitCode !== 0) { + log.warn('Failed to read torch package versions.', { exitCode, stderr }); + return undefined; + } + + if (!stdout.trim()) { + log.warn('Torch package list output was empty.', { stderr }); + return undefined; + } + + let parsed: unknown; + try { + parsed = JSON.parse(stdout); + } catch (error) { + log.warn('Failed to parse torch package list output.', { error, stdout, stderr }); + return undefined; + } + + if (!Array.isArray(parsed)) { + log.warn('Torch package list output was not an array.', { stdout, stderr }); + return undefined; + } + + const versions: TorchPackageVersions = { + torch: undefined, + torchaudio: undefined, + torchvision: undefined, + }; + let matched = 0; + const isRecord = (value: unknown): value is Record => typeof value === 'object' && value !== null; + + for (const entry of parsed) { + if (!isRecord(entry)) continue; + const name = entry.name; + const version = entry.version; + if (typeof name !== 'string' || typeof version !== 'string') continue; + const packageName = TORCH_PACKAGE_NAMES.find((pkg) => pkg === name.trim().toLowerCase()); + if (!packageName) continue; + matched += 1; + versions[packageName] = version.trim() || undefined; + } + if (matched === 0) { + log.warn('Torch package list did not contain expected packages.', { stdout, stderr }); + return undefined; + } + + return versions; + } + /** * Installs AMD ROCm SDK packages on Windows. * @param callbacks The callbacks to use for the command. @@ -838,10 +963,52 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor { return 'package-upgrade'; } + if (await this.needsNvidiaTorchUpgrade()) { + log.info('NVIDIA PyTorch version out of date. Treating as package upgrade.'); + return 'package-upgrade'; + } + log.debug('hasRequirements result:', 'OK'); return 'OK'; } + /** + * Returns `true` when NVIDIA PyTorch should be upgraded to the recommended version. + * @returns `true` when NVIDIA PyTorch is out of date, otherwise `false`. + */ + private async needsNvidiaTorchUpgrade(): Promise { + if (this.selectedDevice !== 'nvidia') return false; + + const installedVersions = await this.getInstalledTorchPackageVersions(); + if (!installedVersions) { + log.warn('Unable to read NVIDIA torch package versions. Skipping NVIDIA torch upgrade check.'); + return false; + } + + return !this.meetsMinimumNvidiaTorchVersions(installedVersions); + } + + private meetsMinimumNvidiaTorchVersions(installedVersions: TorchPackageVersions): boolean { + const torch = installedVersions.torch; + const torchaudio = installedVersions.torchaudio; + const torchvision = installedVersions.torchvision; + if (!torch || !torchaudio || !torchvision) return false; + + const requiredCudaTag = '+cu130'; + if ( + !torch.includes(requiredCudaTag) || + !torchaudio.includes(requiredCudaTag) || + !torchvision.includes(requiredCudaTag) + ) + return false; + + if (compareVersions(torch, NVIDIA_TORCH_VERSION) < 0) return false; + if (compareVersions(torchaudio, NVIDIA_TORCH_VERSION) < 0) return false; + if (compareVersions(torchvision, NVIDIA_TORCHVISION_VERSION) < 0) return false; + + return true; + } + /** * Verifies that the Python environment can import modules that frequently show up in errors. * @returns `true` if the Python environment successfully imports the modules, otherwise `false`. diff --git a/tests/unit/install/installationManager.test.ts b/tests/unit/install/installationManager.test.ts index 651a95e30..af9f92969 100644 --- a/tests/unit/install/installationManager.test.ts +++ b/tests/unit/install/installationManager.test.ts @@ -5,7 +5,11 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { ComfyServerConfig } from '@/config/comfyServerConfig'; import { ComfySettings } from '@/config/comfySettings'; import { IPC_CHANNELS } from '@/constants'; -import { InstallationManager } from '@/install/installationManager'; +import { + InstallationManager, + isNvidiaDriverBelowMinimum, + parseNvidiaDriverVersionFromSmiOutput, +} from '@/install/installationManager'; import type { AppWindow } from '@/main-process/appWindow'; import { ComfyInstallation } from '@/main-process/comfyInstallation'; import type { InstallValidation } from '@/preload'; @@ -248,3 +252,69 @@ describe('InstallationManager', () => { }); }); }); + +describe('parseNvidiaDriverVersionFromSmiOutput', () => { + it('parses driver version from nvidia-smi output', () => { + const output = String.raw` ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 591.59 Driver Version: 591.59 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 5090 WDDM | 00000000:01:00.0 On | N/A | +| 30% 31C P3 49W / 575W | 3248MiB / 32607MiB | 1% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3528 C+G ...5n1h2txyewy\TextInputHost.exe N/A | +| 0 N/A N/A 3964 C+G ...indows\System32\ShellHost.exe N/A | +| 0 N/A N/A 5044 C+G ...al\Programs\Notion\Notion.exe N/A | +| 0 N/A N/A 6924 C+G ...ntrolPanel\SystemSettings.exe N/A | +| 0 N/A N/A 7552 C+G ...l\slack\app-4.47.69\slack.exe N/A | +| 0 N/A N/A 9464 C+G C:\Windows\explorer.exe N/A | +| 0 N/A N/A 9900 C+G ...\Figma\app-125.11.6\Figma.exe N/A | +| 0 N/A N/A 10536 C+G ...2txyewy\CrossDeviceResume.exe N/A | +| 0 N/A N/A 12164 C+G ...y\StartMenuExperienceHost.exe N/A | +| 0 N/A N/A 12172 C+G ..._cw5n1h2txyewy\SearchHost.exe N/A | +| 0 N/A N/A 14404 C+G ...xyewy\ShellExperienceHost.exe N/A | +| 0 N/A N/A 14880 C+G ...em32\ApplicationFrameHost.exe N/A | +| 0 N/A N/A 15872 C+G ....0.3650.96\msedgewebview2.exe N/A | +| 0 N/A N/A 17980 C+G ...Chrome\Application\chrome.exe N/A | +| 0 N/A N/A 20912 C+G ...8bbwe\PhoneExperienceHost.exe N/A | +| 0 N/A N/A 21392 C ...indows-x86_64-none\python.exe N/A | +| 0 N/A N/A 21876 C+G ...x40ttqa\iCloud\iCloudHome.exe N/A | +| 0 N/A N/A 23364 C+G ... Insiders\Code - Insiders.exe N/A | +| 0 N/A N/A 23384 C+G ...Chrome\Application\chrome.exe N/A | +| 0 N/A N/A 24092 C+G ...0ttqa\iCloud\iCloudPhotos.exe N/A | +| 0 N/A N/A 26908 C+G ...l\slack\app-4.47.69\slack.exe N/A | +| 0 N/A N/A 32944 C+G ....0.3650.96\msedgewebview2.exe N/A | +| 0 N/A N/A 35384 C+G ...kyb3d8bbwe\EdgeGameAssist.exe N/A | +| 0 N/A N/A 38088 C+G ...\Programs\ComfyUI\ComfyUI.exe N/A | +| 0 N/A N/A 38248 C+G ...cord\app-1.0.9219\Discord.exe N/A | +| 0 N/A N/A 42816 C+G ...a\Roaming\Spotify\Spotify.exe N/A | +| 0 N/A N/A 45164 C+G ...t\Edge\Application\msedge.exe N/A | +| 0 N/A N/A 45312 C+G ...yb3d8bbwe\WindowsTerminal.exe N/A | ++-----------------------------------------------------------------------------------------+ +`; + + expect(parseNvidiaDriverVersionFromSmiOutput(output)).toBe('591.59'); + }); +}); + +describe('isNvidiaDriverBelowMinimum', () => { + it.each([ + { version: '579.0.0', expected: true, label: 'below 580' }, + { version: '580.0.0', expected: false, label: 'at 580' }, + { version: '580.0.1', expected: false, label: 'at a version of 580' }, + { version: '581.0', expected: false, label: 'above 580' }, + ])('returns $expected when $label', ({ version, expected }) => { + expect(isNvidiaDriverBelowMinimum(version, '580')).toBe(expected); + }); +});