Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ export const AMD_TORCH_PACKAGES: string[] = [
'https://repo.radeon.com/rocm/windows/rocm-rel-7.1.1/torchvision-0.24.0+rocmsdk20251116-cp312-cp312-win_amd64.whl',
];

export const NVIDIA_TORCH_VERSION = '2.9.1+cu130';
export const NVIDIA_TORCHVISION_VERSION = '0.24.1+cu130';
export const NVIDIA_TORCH_PACKAGES: string[] = [
`torch==${NVIDIA_TORCH_VERSION}`,
`torchaudio==${NVIDIA_TORCH_VERSION}`,
`torchvision==${NVIDIA_TORCHVISION_VERSION}`,
];

/** The log files used by the desktop process. */
export enum LogFile {
/** The ComfyUI server log file. */
Expand Down
84 changes: 83 additions & 1 deletion src/install/installationManager.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { Notification, app, dialog, shell } from 'electron';
import log from 'electron-log/main';
import { exec } from 'node:child_process';
import { promisify } from 'node:util';

import { strictIpcMain as ipcMain } from '@/infrastructure/ipcChannels';

Expand All @@ -14,12 +16,37 @@ import { CmCli } from '../services/cmCli';
import { captureSentryException } from '../services/sentry';
import { type HasTelemetry, ITelemetry, trackEvent } from '../services/telemetry';
import { type DesktopConfig, useDesktopConfig } from '../store/desktopConfig';
import { canExecuteShellCommand, validateHardware } from '../utils';
import { canExecuteShellCommand, compareVersions, validateHardware } from '../utils';
import type { ProcessCallbacks, VirtualEnvironment } from '../virtualEnvironment';
import { createProcessCallbacks } from './createProcessCallbacks';
import { InstallWizard } from './installWizard';
import { Troubleshooting } from './troubleshooting';

const execAsync = promisify(exec);
const NVIDIA_DRIVER_MIN_VERSION = '580';
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for NVIDIA-related constants in the codebase

echo "=== Searching for NVIDIA constants ==="
rg -n --type=ts 'NVIDIA.*VERSION|NVIDIA.*PACKAGE|NVIDIA.*VERSION|NVIDIA_DRIVER' -g '!*.test.ts' -C2

Repository: Comfy-Org/desktop

Length of output: 3058


🏁 Script executed:

#!/bin/bash
# Check if NVIDIA_DRIVER_MIN_VERSION is exported or used in other files

echo "=== Checking for NVIDIA_DRIVER_MIN_VERSION exports ==="
rg -n 'export.*NVIDIA_DRIVER_MIN_VERSION' --type=ts

echo ""
echo "=== Checking for imports of NVIDIA_DRIVER_MIN_VERSION ==="
rg -n 'import.*NVIDIA_DRIVER_MIN_VERSION|from.*installationManager' --type=ts

echo ""
echo "=== All references to NVIDIA_DRIVER_MIN_VERSION ==="
rg -n 'NVIDIA_DRIVER_MIN_VERSION' --type=ts -g '!*.test.ts'

Repository: Comfy-Org/desktop

Length of output: 913


Consider centralizing NVIDIA_DRIVER_MIN_VERSION with other NVIDIA constants.

Other NVIDIA-related constants (NVIDIA_TORCH_VERSION, NVIDIA_TORCHVISION_VERSION, NVIDIA_TORCH_PACKAGES) are centralized in src/constants.ts. Moving NVIDIA_DRIVER_MIN_VERSION there would improve consistency and discoverability across the codebase, even though it is currently only used within this file.

🤖 Prompt for AI Agents
In @src/install/installationManager.ts at line 26, NVIDIA_DRIVER_MIN_VERSION is
declared locally but should be centralized with the other NVIDIA constants; move
the constant into the shared constants module by adding and exporting
NVIDIA_DRIVER_MIN_VERSION alongside
NVIDIA_TORCH_VERSION/NVIDIA_TORCHVISION_VERSION/NVIDIA_TORCH_PACKAGES, remove
the local const from installationManager.ts, and import the exported
NVIDIA_DRIVER_MIN_VERSION symbol where it’s used (in installationManager) so all
NVIDIA-related constants live in one place.


/**
* Extracts the NVIDIA driver version from `nvidia-smi` output.
* @param output The `nvidia-smi` output to parse.
* @returns The driver version, if present.
*/
export function parseNvidiaDriverVersionFromSmiOutput(output: string): string | undefined {
const match = output.match(/driver version\s*:\s*([\d.]+)/i);
return match?.[1];
}

/**
* Returns `true` when the NVIDIA driver version is below the minimum.
* @param driverVersion The detected driver version.
* @param minimumVersion The minimum required driver version.
*/
export function isNvidiaDriverBelowMinimum(
driverVersion: string,
minimumVersion: string = NVIDIA_DRIVER_MIN_VERSION
): boolean {
return compareVersions(driverVersion, minimumVersion) < 0;
}

/** High-level / UI control over the installation of ComfyUI server. */
export class InstallationManager implements HasTelemetry {
constructor(
Expand Down Expand Up @@ -354,13 +381,68 @@ export class InstallationManager implements HasTelemetry {
try {
await installation.virtualEnvironment.installComfyUIRequirements(callbacks);
await installation.virtualEnvironment.installComfyUIManagerRequirements(callbacks);
await this.warnIfNvidiaDriverTooOld(installation);
await installation.virtualEnvironment.ensureRecommendedNvidiaTorch(callbacks);
await installation.validate();
} catch (error) {
log.error('Error auto-updating packages:', error);
await this.appWindow.loadPage('server-start');
}
}

/**
* Warns the user if their NVIDIA driver is too old for the required CUDA build.
* @param installation The current installation.
*/
private async warnIfNvidiaDriverTooOld(installation: ComfyInstallation): Promise<void> {
if (process.platform !== 'win32') return;
if (installation.virtualEnvironment.selectedDevice !== 'nvidia') return;

const driverVersion =
(await this.getNvidiaDriverVersionFromSmi()) ?? (await this.getNvidiaDriverVersionFromSmiFallback());
if (!driverVersion) return;

if (!isNvidiaDriverBelowMinimum(driverVersion)) return;

await this.appWindow.showMessageBox({
type: 'warning',
title: 'Update NVIDIA Driver',
message: 'Your NVIDIA driver may be too old for PyTorch 2.9.1 + cu130.',
detail: `Detected driver version: ${driverVersion}\nRecommended minimum: ${NVIDIA_DRIVER_MIN_VERSION}\n\nPlease consider updating your NVIDIA drivers and retrying if you run into issues.`,
buttons: ['OK'],
});
}
Comment on lines +393 to +414
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find the NVIDIA_TORCH_VERSION constant definition
rg -n --type=ts 'NVIDIA_TORCH_VERSION' src/constants.ts -A2

# Check how it's used elsewhere
echo "---"
rg -n --type=ts 'NVIDIA_TORCH_VERSION' -g '!*.test.ts' -C2

Repository: Comfy-Org/desktop

Length of output: 1568


🏁 Script executed:

#!/bin/bash
# Check if NVIDIA_TORCH_VERSION is imported in installationManager.ts
rg -n 'NVIDIA_TORCH_VERSION' src/install/installationManager.ts

# Also check the imports at the top of installationManager.ts
head -50 src/install/installationManager.ts | grep -n "import\|from"

Repository: Comfy-Org/desktop

Length of output: 1540


Consider referencing NVIDIA_TORCH_VERSION constant in warning message.

The warning message (line 410) contains a hardcoded PyTorch version string "PyTorch 2.9.1 + cu130" that may become outdated as versions are upgraded. The NVIDIA_TORCH_VERSION constant is already defined in src/constants.ts (as '2.9.1+cu130') and is the source of truth for NVIDIA torch versions used elsewhere in the codebase. Referencing this constant would ensure the message stays synchronized with actual package versions and reduce maintenance burden.

🤖 Prompt for AI Agents
In @src/install/installationManager.ts around lines 393 - 414, The warning in
warnIfNvidiaDriverTooOld hardcodes "PyTorch 2.9.1 + cu130"; update the message
to reference the canonical NVIDIA_TORCH_VERSION constant instead so it stays in
sync. Import NVIDIA_TORCH_VERSION (and keep NVIDIA_DRIVER_MIN_VERSION usage) and
interpolate that constant into the message/detail text where the hardcoded
string appears in the appWindow.showMessageBox call.


/**
* Reads the NVIDIA driver version from nvidia-smi query output.
*/
private async getNvidiaDriverVersionFromSmi(): Promise<string | undefined> {
try {
const { stdout } = await execAsync('nvidia-smi --query-gpu=driver_version --format=csv,noheader');
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider adding timeout to execAsync calls for defensive programming.

While nvidia-smi is typically a fast and well-behaved command, adding a timeout parameter to the execAsync calls would prevent the installation flow from hanging indefinitely in edge cases (e.g., GPU driver issues, system instability).

⏱️ Example with timeout
-      const { stdout } = await execAsync('nvidia-smi --query-gpu=driver_version --format=csv,noheader');
+      const { stdout } = await execAsync('nvidia-smi --query-gpu=driver_version --format=csv,noheader', { timeout: 5000 });

And similarly for the fallback:

-      const { stdout } = await execAsync('nvidia-smi');
+      const { stdout } = await execAsync('nvidia-smi', { timeout: 5000 });

Also applies to: 438-438

🤖 Prompt for AI Agents
In @src/install/installationManager.ts at line 421, Add a timeout to the
execAsync calls that query nvidia-smi to avoid hangs: when calling execAsync
(the lines with "const { stdout } = await execAsync('nvidia-smi
--query-gpu=driver_version --format=csv,noheader')" and the fallback execAsync
at the other site), pass a reasonable timeout option (e.g., 5–15s) in the
options argument and ensure the call is wrapped in try/catch so timeout errors
are handled deterministically (log/throw a clear error or fall back as
appropriate).

const version = stdout
.split(/\r?\n/)
.map((line) => line.trim())
.find(Boolean);
return version || undefined;
} catch (error) {
log.debug('Failed to read NVIDIA driver version via nvidia-smi query.', error);
return undefined;
}
}

/**
* Reads the NVIDIA driver version from nvidia-smi standard output.
*/
private async getNvidiaDriverVersionFromSmiFallback(): Promise<string | undefined> {
try {
const { stdout } = await execAsync('nvidia-smi');
return parseNvidiaDriverVersionFromSmiOutput(stdout);
} catch (error) {
log.debug('Failed to read NVIDIA driver version via nvidia-smi output.', error);
return undefined;
}
}

static setReinstallHandler(installation: ComfyInstallation) {
ipcMain.removeHandler(IPC_CHANNELS.REINSTALL);
ipcMain.handle(IPC_CHANNELS.REINSTALL, async () => await InstallationManager.reinstall(installation));
Expand Down
171 changes: 169 additions & 2 deletions src/virtualEnvironment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ import { readdir, rm } from 'node:fs/promises';
import os, { EOL } from 'node:os';
import path from 'node:path';

import { AMD_ROCM_SDK_PACKAGES, AMD_TORCH_PACKAGES, InstallStage, TorchMirrorUrl } from './constants';
import {
AMD_ROCM_SDK_PACKAGES,
AMD_TORCH_PACKAGES,
InstallStage,
NVIDIA_TORCHVISION_VERSION,
NVIDIA_TORCH_PACKAGES,
NVIDIA_TORCH_VERSION,
TorchMirrorUrl,
} from './constants';
import { PythonImportVerificationError } from './infrastructure/pythonImportVerificationError';
import { useAppState } from './main-process/appState';
import { createInstallStageInfo } from './main-process/installStages';
Expand All @@ -16,7 +24,7 @@ import { runPythonImportVerifyScript } from './services/pythonImportVerifier';
import { captureSentryException } from './services/sentry';
import { HasTelemetry, ITelemetry, trackEvent } from './services/telemetry';
import { getDefaultShell, getDefaultShellArgs } from './shell/util';
import { pathAccessible } from './utils';
import { compareVersions, pathAccessible } from './utils';

export type ProcessCallbacks = {
onStdout?: (data: string) => void;
Expand All @@ -43,6 +51,11 @@ interface PipInstallConfig {
indexStrategy?: 'compatible' | 'unsafe-best-match';
}

type TorchPackageName = 'torch' | 'torchaudio' | 'torchvision';
type TorchPackageVersions = Record<TorchPackageName, string | undefined>;

const TORCH_PACKAGE_NAMES: TorchPackageName[] = ['torch', 'torchaudio', 'torchvision'];

Comment on lines +54 to +58
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider making TorchPackageVersions readonly for immutability.

The types and constants are well-defined. However, to align with the coding guideline to "prefer readonly properties and ReadonlyArray<T> where appropriate," consider making the Record values readonly.

♻️ Proposed refinement
-type TorchPackageVersions = Record<TorchPackageName, string | undefined>;
+type TorchPackageVersions = Readonly<Record<TorchPackageName, string | undefined>>;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
type TorchPackageName = 'torch' | 'torchaudio' | 'torchvision';
type TorchPackageVersions = Record<TorchPackageName, string | undefined>;
const TORCH_PACKAGE_NAMES: TorchPackageName[] = ['torch', 'torchaudio', 'torchvision'];
type TorchPackageName = 'torch' | 'torchaudio' | 'torchvision';
type TorchPackageVersions = Readonly<Record<TorchPackageName, string | undefined>>;
const TORCH_PACKAGE_NAMES: TorchPackageName[] = ['torch', 'torchaudio', 'torchvision'];
🤖 Prompt for AI Agents
In @src/virtualEnvironment.ts around lines 54 - 58, The TorchPackageVersions
type should be made immutable: change type TorchPackageVersions =
Record<TorchPackageName, string | undefined> to a readonly form (e.g.,
Readonly<Record<TorchPackageName, string | undefined>> or { readonly [K in
TorchPackageName]: string | undefined }) and also mark the package names array
as readonly by changing TORCH_PACKAGE_NAMES: TorchPackageName[] = [...] to a
ReadonlyArray<TorchPackageName> (or use as const) so both the mapping type and
the array are immutable.

function getPipInstallArgs(config: PipInstallConfig): string[] {
const installArgs = ['pip', 'install'];

Expand Down Expand Up @@ -629,6 +642,118 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor {
}
}

/**
* Ensures NVIDIA installs use the recommended PyTorch packages.
* @param callbacks The callbacks to use for the command.
*/
async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise<void> {
if (this.selectedDevice !== 'nvidia') return;

const installedVersions = await this.getInstalledTorchPackageVersions();
if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) {
log.info('NVIDIA PyTorch packages already satisfy minimum recommended versions.', installedVersions);
return;
}

const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
const config: PipInstallConfig = {
packages: NVIDIA_TORCH_PACKAGES,
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
};

const installArgs = getPipInstallArgs(config);
log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions });
const { exitCode: pinnedExitCode } = await this.runUvCommandAsync(installArgs, callbacks);

if (pinnedExitCode === 0) return;

log.warn('Failed to install recommended NVIDIA PyTorch packages. Falling back to unpinned install.', {
exitCode: pinnedExitCode,
});

const fallbackConfig: PipInstallConfig = {
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
upgradePackages: true,
};
const fallbackArgs = getPipInstallArgs(fallbackConfig);
const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks);
if (fallbackExitCode !== 0) {
throw new Error(
`Failed to install NVIDIA PyTorch packages (pinned exit ${pinnedExitCode}, fallback exit ${fallbackExitCode})`
);
}
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method ensureRecommendedNvidiaTorch doesn't verify that the installation succeeded by checking the torch version after installation. Consider adding a verification step after the installation completes to confirm that getTorchVersion() returns NVIDIA_TORCH_VERSION. This would help catch installation failures where the exit code is 0 but the actual version installed doesn't match expectations.

Suggested change
}
}
const verifiedVersion = await this.getTorchVersion();
if (verifiedVersion !== NVIDIA_TORCH_VERSION) {
log.error('Recommended NVIDIA PyTorch version verification failed.', {
expectedVersion: NVIDIA_TORCH_VERSION,
actualVersion: verifiedVersion,
});
throw new Error(
`Failed to verify recommended NVIDIA PyTorch version. Expected ${NVIDIA_TORCH_VERSION}, got ${verifiedVersion ?? 'unknown'}.`
);
}
log.info('Successfully installed recommended NVIDIA PyTorch version.', {
version: verifiedVersion,
});

Copilot uses AI. Check for mistakes.
}
Comment on lines +645 to +688
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add telemetry tracking for NVIDIA torch upgrade flow.

This new public method lacks telemetry events to track success/failure of NVIDIA torch upgrades. Given that other installation methods use @trackEvent decorators or explicit telemetry calls, this should follow the same pattern.

📊 Proposed telemetry additions
 async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise<void> {
   if (this.selectedDevice !== 'nvidia') return;

+  this.telemetry.track('install_flow:nvidia_torch_upgrade_start');
+
   const installedVersions = await this.getInstalledTorchPackageVersions();
   if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) {
     log.info('NVIDIA PyTorch packages already satisfy minimum recommended versions.', installedVersions);
+    this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'already_satisfied' });
     return;
   }

   const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
   const config: PipInstallConfig = {
     packages: NVIDIA_TORCH_PACKAGES,
     indexUrl: torchMirror,
     prerelease: torchMirror.includes('nightly'),
   };

   const installArgs = getPipInstallArgs(config);
   log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions });
   const { exitCode: pinnedExitCode } = await this.runUvCommandAsync(installArgs, callbacks);

-  if (pinnedExitCode === 0) return;
+  if (pinnedExitCode === 0) {
+    this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'pinned_install_success' });
+    return;
+  }

   log.warn('Failed to install recommended NVIDIA PyTorch packages. Falling back to unpinned install.', {
     exitCode: pinnedExitCode,
   });

   const fallbackConfig: PipInstallConfig = {
     packages: ['torch', 'torchvision', 'torchaudio'],
     indexUrl: torchMirror,
     prerelease: torchMirror.includes('nightly'),
     upgradePackages: true,
   };
   const fallbackArgs = getPipInstallArgs(fallbackConfig);
   const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks);
   if (fallbackExitCode !== 0) {
+    this.telemetry.track('install_flow:nvidia_torch_upgrade_error', {
+      pinnedExitCode,
+      fallbackExitCode,
+    });
     throw new Error(
       `Failed to install NVIDIA PyTorch packages (pinned exit ${pinnedExitCode}, fallback exit ${fallbackExitCode})`
     );
   }
+  this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'fallback_success' });
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/**
* Ensures NVIDIA installs use the recommended PyTorch packages.
* @param callbacks The callbacks to use for the command.
*/
async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise<void> {
if (this.selectedDevice !== 'nvidia') return;
const installedVersions = await this.getInstalledTorchPackageVersions();
if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) {
log.info('NVIDIA PyTorch packages already satisfy minimum recommended versions.', installedVersions);
return;
}
const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
const config: PipInstallConfig = {
packages: NVIDIA_TORCH_PACKAGES,
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
};
const installArgs = getPipInstallArgs(config);
log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions });
const { exitCode: pinnedExitCode } = await this.runUvCommandAsync(installArgs, callbacks);
if (pinnedExitCode === 0) return;
log.warn('Failed to install recommended NVIDIA PyTorch packages. Falling back to unpinned install.', {
exitCode: pinnedExitCode,
});
const fallbackConfig: PipInstallConfig = {
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
upgradePackages: true,
};
const fallbackArgs = getPipInstallArgs(fallbackConfig);
const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks);
if (fallbackExitCode !== 0) {
throw new Error(
`Failed to install NVIDIA PyTorch packages (pinned exit ${pinnedExitCode}, fallback exit ${fallbackExitCode})`
);
}
}
/**
* Ensures NVIDIA installs use the recommended PyTorch packages.
* @param callbacks The callbacks to use for the command.
*/
async ensureRecommendedNvidiaTorch(callbacks?: ProcessCallbacks): Promise<void> {
if (this.selectedDevice !== 'nvidia') return;
this.telemetry.track('install_flow:nvidia_torch_upgrade_start');
const installedVersions = await this.getInstalledTorchPackageVersions();
if (installedVersions && this.meetsMinimumNvidiaTorchVersions(installedVersions)) {
log.info('NVIDIA PyTorch packages already satisfy minimum recommended versions.', installedVersions);
this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'already_satisfied' });
return;
}
const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
const config: PipInstallConfig = {
packages: NVIDIA_TORCH_PACKAGES,
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
};
const installArgs = getPipInstallArgs(config);
log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions });
const { exitCode: pinnedExitCode } = await this.runUvCommandAsync(installArgs, callbacks);
if (pinnedExitCode === 0) {
this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'pinned_install_success' });
return;
}
log.warn('Failed to install recommended NVIDIA PyTorch packages. Falling back to unpinned install.', {
exitCode: pinnedExitCode,
});
const fallbackConfig: PipInstallConfig = {
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
upgradePackages: true,
};
const fallbackArgs = getPipInstallArgs(fallbackConfig);
const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks);
if (fallbackExitCode !== 0) {
this.telemetry.track('install_flow:nvidia_torch_upgrade_error', {
pinnedExitCode,
fallbackExitCode,
});
throw new Error(
`Failed to install NVIDIA PyTorch packages (pinned exit ${pinnedExitCode}, fallback exit ${fallbackExitCode})`
);
}
this.telemetry.track('install_flow:nvidia_torch_upgrade_end', { reason: 'fallback_success' });
}
🤖 Prompt for AI Agents
In @src/virtualEnvironment.ts around lines 645 - 688,
ensureRecommendedNvidiaTorch is missing telemetry around the NVIDIA torch
upgrade flow; add telemetry events to record attempts, successes and failures
for both the pinned install and the fallback install. Specifically, fire a
"nvidia_torch_install_started" event at the start of the method (before
runUvCommandAsync), record a "nvidia_torch_install_succeeded" event when
pinnedExitCode === 0, and if pinned fails record a
"nvidia_torch_install_pinned_failed" with pinnedExitCode before attempting
fallback; after the fallback runUvCommandAsync emit
"nvidia_torch_install_fallback_succeeded" on fallbackExitCode === 0 or
"nvidia_torch_install_fallback_failed" with both exit codes if it ultimately
throws. Use the same telemetry helper/decorator pattern used elsewhere in the
repo (the @trackEvent wrapper or telemetry client) and include contextual
properties like installedVersions, torchMirror, pinnedExitCode and
fallbackExitCode; locate changes around ensureRecommendedNvidiaTorch,
runUvCommandAsync calls, and the NVIDIA_TORCH_PACKAGES/getPipInstallArgs usage.


/**
* Reads installed torch package versions using `uv pip list --format=json`.
* @returns The torch package versions when available, otherwise `undefined`.
*/
private async getInstalledTorchPackageVersions(): Promise<TorchPackageVersions | undefined> {
let stdout = '';
let stderr = '';
const callbacks: ProcessCallbacks = {
onStdout: (data) => {
stdout += data;
},
onStderr: (data) => {
stderr += data;
},
};

const { exitCode } = await this.runUvAsync(['pip', 'list', '--format=json', '--color=never'], callbacks);

if (exitCode !== 0) {
log.warn('Failed to read torch package versions.', { exitCode, stderr });
return undefined;
}

if (!stdout.trim()) {
log.warn('Torch package list output was empty.', { stderr });
return undefined;
}

let parsed: unknown;
try {
parsed = JSON.parse(stdout);
} catch (error) {
log.warn('Failed to parse torch package list output.', { error, stdout, stderr });
return undefined;
}

if (!Array.isArray(parsed)) {
log.warn('Torch package list output was not an array.', { stdout, stderr });
return undefined;
}

const versions: TorchPackageVersions = {
torch: undefined,
torchaudio: undefined,
torchvision: undefined,
};
let matched = 0;
const isRecord = (value: unknown): value is Record<string, unknown> => typeof value === 'object' && value !== null;

for (const entry of parsed) {
if (!isRecord(entry)) continue;
const name = entry.name;
const version = entry.version;
if (typeof name !== 'string' || typeof version !== 'string') continue;
const packageName = TORCH_PACKAGE_NAMES.find((pkg) => pkg === name.trim().toLowerCase());
if (!packageName) continue;
matched += 1;
versions[packageName] = version.trim() || undefined;
}
if (matched === 0) {
log.warn('Torch package list did not contain expected packages.', { stdout, stderr });
return undefined;
}
Comment on lines +749 to +752
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's look at the getInstalledTorchPackageVersions method implementation
sed -n '690,755p' src/virtualEnvironment.ts

Repository: Comfy-Org/desktop

Length of output: 2157


🏁 Script executed:

#!/bin/bash
# Check the type definition for TorchPackageVersions
rg -B5 -A10 'type TorchPackageVersions' src/virtualEnvironment.ts

Repository: Comfy-Org/desktop

Length of output: 547


🏁 Script executed:

#!/bin/bash
# Find all usages of getInstalledTorchPackageVersions
rg -n 'getInstalledTorchPackageVersions' --type=ts

Repository: Comfy-Org/desktop

Length of output: 400


🏁 Script executed:

#!/bin/bash
# Look for the meetsMinimumNvidiaTorchVersions function that likely uses the result
rg -B5 -A20 'meetsMinimumNvidiaTorchVersions' src/virtualEnvironment.ts

Repository: Comfy-Org/desktop

Length of output: 2177


Consider returning undefined for all partial torch package installations, not just when no packages are found.

Currently, the method returns undefined only if matched === 0, but allows partial installations (1 or 2 packages found) to return an object with some undefined fields. The caller meetsMinimumNvidiaTorchVersions() then validates that all three packages (torch, torchaudio, torchvision) are present, returning false if any are missing. This pushes the partial-installation check to the caller rather than handling it at the source.

For clearer intent and consistency, consider returning undefined whenever fewer than 3 packages are found, since callers expect either undefined (packages not properly installed) or a complete set of all three versions.

🤖 Prompt for AI Agents
In @src/virtualEnvironment.ts around lines 749 - 752, The block that currently
only returns undefined when matched === 0 should instead treat any partial
installation as invalid: change the conditional to return undefined when matched
< 3 (i.e., fewer than the three packages torch, torchaudio, torchvision were
found) and update the log.warn message to reflect a partial or missing torch
package installation; this ensures callers like
meetsMinimumNvidiaTorchVersions() receive either a complete set of versions or
undefined.


return versions;
}

/**
* Installs AMD ROCm SDK packages on Windows.
* @param callbacks The callbacks to use for the command.
Expand Down Expand Up @@ -838,10 +963,52 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor {
return 'package-upgrade';
}

if (await this.needsNvidiaTorchUpgrade()) {
log.info('NVIDIA PyTorch version out of date. Treating as package upgrade.');
return 'package-upgrade';
}

log.debug('hasRequirements result:', 'OK');
return 'OK';
}

/**
* Returns `true` when NVIDIA PyTorch should be upgraded to the recommended version.
* @returns `true` when NVIDIA PyTorch is out of date, otherwise `false`.
*/
private async needsNvidiaTorchUpgrade(): Promise<boolean> {
if (this.selectedDevice !== 'nvidia') return false;

const installedVersions = await this.getInstalledTorchPackageVersions();
if (!installedVersions) {
log.warn('Unable to read NVIDIA torch package versions. Skipping NVIDIA torch upgrade check.');
return false;
}

return !this.meetsMinimumNvidiaTorchVersions(installedVersions);
}

private meetsMinimumNvidiaTorchVersions(installedVersions: TorchPackageVersions): boolean {
const torch = installedVersions.torch;
const torchaudio = installedVersions.torchaudio;
const torchvision = installedVersions.torchvision;
if (!torch || !torchaudio || !torchvision) return false;

const requiredCudaTag = '+cu130';
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Extract hardcoded CUDA tag to a constant.

The CUDA tag '+cu130' is hardcoded here but appears in multiple constants (NVIDIA_TORCH_VERSION, NVIDIA_TORCHVISION_VERSION). Extract it to a named constant for maintainability.

♻️ Proposed refactor

In src/constants.ts, add:

export const REQUIRED_CUDA_TAG = '+cu130';
export const NVIDIA_TORCH_VERSION = `2.9.1${REQUIRED_CUDA_TAG}`;
export const NVIDIA_TORCHVISION_VERSION = `0.24.1${REQUIRED_CUDA_TAG}`;

Then in this file:

+import {
+  AMD_ROCM_SDK_PACKAGES,
+  AMD_TORCH_PACKAGES,
+  InstallStage,
+  NVIDIA_TORCHVISION_VERSION,
+  NVIDIA_TORCH_PACKAGES,
+  NVIDIA_TORCH_VERSION,
+  REQUIRED_CUDA_TAG,
+  TorchMirrorUrl,
+} from './constants';

...

-    const requiredCudaTag = '+cu130';
+    const requiredCudaTag = REQUIRED_CUDA_TAG;

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @src/virtualEnvironment.ts at line 997, Extract the hardcoded CUDA tag
'+cu130' into a shared constant and use it everywhere: add REQUIRED_CUDA_TAG =
'+cu130' to your constants module (e.g., export const REQUIRED_CUDA_TAG) and
update NVIDIA_TORCH_VERSION and NVIDIA_TORCHVISION_VERSION to interpolate that
constant; then remove the local const requiredCudaTag in virtualEnvironment.ts
and import and use REQUIRED_CUDA_TAG instead (update any other occurrences that
reference requiredCudaTag to use the new constant).

if (
!torch.includes(requiredCudaTag) ||
!torchaudio.includes(requiredCudaTag) ||
!torchvision.includes(requiredCudaTag)
)
return false;

if (compareVersions(torch, NVIDIA_TORCH_VERSION) < 0) return false;
if (compareVersions(torchaudio, NVIDIA_TORCH_VERSION) < 0) return false;
if (compareVersions(torchvision, NVIDIA_TORCHVISION_VERSION) < 0) return false;

return true;
}

/**
* Verifies that the Python environment can import modules that frequently show up in errors.
* @returns `true` if the Python environment successfully imports the modules, otherwise `false`.
Expand Down
Loading
Loading