Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve podman cli execution for models #1810

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import os from 'node:os';
import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import { env, process as coreProcess } from '@podman-desktop/api';
import { env } from '@podman-desktop/api';
import type { RunResult, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
Expand All @@ -33,7 +33,6 @@ import type { GGUFParseOutput } from '@huggingface/gguf';
import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import { getPodmanMachineName } from '../utils/podman';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';
import { Uploader } from '../utils/uploader';

Expand All @@ -47,7 +46,6 @@ const mocks = vi.hoisted(() => {
getTargetMock: vi.fn(),
getDownloaderCompleter: vi.fn(),
isCompletionEventMock: vi.fn(),
getPodmanCliMock: vi.fn(),
};
});

Expand All @@ -59,11 +57,6 @@ vi.mock('@huggingface/gguf', () => ({
gguf: vi.fn(),
}));

vi.mock('../utils/podman', () => ({
getPodmanCli: mocks.getPodmanCliMock,
getPodmanMachineName: vi.fn(),
}));

vi.mock('@podman-desktop/api', () => {
return {
Disposable: {
Expand All @@ -72,9 +65,6 @@ vi.mock('@podman-desktop/api', () => {
env: {
isWindows: false,
},
process: {
exec: vi.fn(),
},
fs: {
createFileSystemWatcher: (): unknown => ({
onDidCreate: vi.fn(),
Expand Down Expand Up @@ -102,6 +92,7 @@ vi.mock('../utils/downloader', () => ({

const podmanConnectionMock = {
getContainerProviderConnections: vi.fn(),
executeSSH: vi.fn(),
} as unknown as PodmanConnection;

const cancellationTokenRegistryMock = {
Expand Down Expand Up @@ -598,8 +589,7 @@ describe('deleting models', () => {
});

test('deleting on windows should check for all connections', async () => {
vi.mocked(coreProcess.exec).mockResolvedValue({} as RunResult);
mocks.getPodmanCliMock.mockReturnValue('dummyCli');
vi.mocked(podmanConnectionMock.executeSSH).mockResolvedValue({} as RunResult);
vi.mocked(env).isWindows = true;
const connections: ContainerProviderConnection[] = [
{
Expand All @@ -622,7 +612,6 @@ describe('deleting models', () => {
},
];
vi.mocked(podmanConnectionMock.getContainerProviderConnections).mockReturnValue(connections);
vi.mocked(getPodmanMachineName).mockReturnValue('machine-2');

const rmSpy = vi.spyOn(fs.promises, 'rm');
rmSpy.mockResolvedValue(undefined);
Expand Down Expand Up @@ -659,10 +648,7 @@ describe('deleting models', () => {

expect(podmanConnectionMock.getContainerProviderConnections).toHaveBeenCalledOnce();

expect(coreProcess.exec).toHaveBeenCalledWith('dummyCli', [
'machine',
'ssh',
'machine-2',
expect(podmanConnectionMock.executeSSH).toHaveBeenCalledWith(connections[1], [
'rm',
'-f',
'/home/user/ai-lab/models/dummyFile',
Expand Down
32 changes: 23 additions & 9 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import type { Task } from '@shared/src/models/ITask';
import type { BaseEvent } from '../models/baseEvent';
import { isCompletionEvent, isProgressEvent } from '../models/baseEvent';
import { Uploader } from '../utils/uploader';
import { deleteRemoteModel, getLocalModelFile, isModelUploaded } from '../utils/modelsUtils';
import { getPodmanMachineName } from '../utils/podman';
import { getLocalModelFile, getRemoteModelFile } from '../utils/modelsUtils';
import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry';
import { getHash, hasValidSha } from '../utils/sha';
import type { GGUFParseOutput } from '@huggingface/gguf';
Expand Down Expand Up @@ -231,17 +230,32 @@ export class ModelsManager implements Disposable {
for (const connection of connections) {
// ignore non-wsl machines
if (connection.vmType !== VMType.WSL) continue;
// Get the corresponding machine name
const machineName = getPodmanMachineName(connection);

// check if model already loaded on the podman machine
const existsRemote = await isModelUploaded(machineName, modelInfo);
if (!existsRemote) return;
// check if remote model is
try {
await this.podmanConnection.executeSSH(connection, ['stat', getRemoteModelFile(modelInfo)]);
} catch (err: unknown) {
console.warn(err);
continue;
}

await deleteRemoteModel(machineName, modelInfo);
await this.deleteRemoteModelByConnection(connection, modelInfo);
}
}

/**
* Delete a model given a {@link ContainerProviderConnection}
* @param connection
* @param modelInfo
* @protected
*/
protected async deleteRemoteModelByConnection(
connection: ContainerProviderConnection,
modelInfo: ModelInfo,
): Promise<void> {
await this.podmanConnection.executeSSH(connection, ['rm', '-f', getRemoteModelFile(modelInfo)]);
}

/**
* This method will resolve when the provided model will be downloaded.
*
Expand Down Expand Up @@ -439,7 +453,7 @@ export class ModelsManager implements Disposable {
connection: connection.name,
});

const uploader = new Uploader(connection, model);
const uploader = new Uploader(this.podmanConnection, connection, model);
uploader.onEvent(event => this.onDownloadUploadEvent(event, 'upload'), this);

// perform download
Expand Down
61 changes: 1 addition & 60 deletions packages/backend/src/utils/modelsUtils.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,8 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { beforeEach, describe, expect, test, vi } from 'vitest';
import { process as apiProcess } from '@podman-desktop/api';
import {
deleteRemoteModel,
getLocalModelFile,
getRemoteModelFile,
isModelUploaded,
MACHINE_BASE_FOLDER,
} from './modelsUtils';
import { getLocalModelFile, getRemoteModelFile, MACHINE_BASE_FOLDER } from './modelsUtils';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getPodmanCli } from './podman';

vi.mock('@podman-desktop/api', () => {
return {
Expand All @@ -35,14 +27,8 @@ vi.mock('@podman-desktop/api', () => {
};
});

vi.mock('./podman', () => ({
getPodmanCli: vi.fn(),
}));

beforeEach(() => {
vi.resetAllMocks();

vi.mocked(getPodmanCli).mockReturnValue('dummyPodmanCli');
});

describe('getLocalModelFile', () => {
Expand Down Expand Up @@ -94,48 +80,3 @@ describe('getRemoteModelFile', () => {
expect(path).toBe(`${MACHINE_BASE_FOLDER}dummy.guff`);
});
});

describe('isModelUploaded', () => {
test('execute stat on targeted machine', async () => {
expect(
await isModelUploaded('dummyMachine', {
id: 'dummyModelId',
file: {
path: 'dummyPath',
file: 'dummy.guff',
},
} as unknown as ModelInfo),
).toBeTruthy();

expect(getPodmanCli).toHaveBeenCalled();
expect(apiProcess.exec).toHaveBeenCalledWith('dummyPodmanCli', [
'machine',
'ssh',
'dummyMachine',
'stat',
expect.anything(),
]);
});
});

describe('deleteRemoteModel', () => {
test('execute stat on targeted machine', async () => {
await deleteRemoteModel('dummyMachine', {
id: 'dummyModelId',
file: {
path: 'dummyPath',
file: 'dummy.guff',
},
} as unknown as ModelInfo);

expect(getPodmanCli).toHaveBeenCalled();
expect(apiProcess.exec).toHaveBeenCalledWith('dummyPodmanCli', [
'machine',
'ssh',
'dummyMachine',
'rm',
'-f',
expect.anything(),
]);
});
});
32 changes: 0 additions & 32 deletions packages/backend/src/utils/modelsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
***********************************************************************/
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { join, posix } from 'node:path';
import { getPodmanCli } from './podman';
import { process } from '@podman-desktop/api';

export const MACHINE_BASE_FOLDER = '/home/user/ai-lab/models/';

Expand All @@ -42,36 +40,6 @@ export function getRemoteModelFile(modelInfo: ModelInfo): string {
return posix.join(MACHINE_BASE_FOLDER, modelInfo.file.file);
}

/**
* utility method to determine if a model is already uploaded to the podman machine
* @param machine
* @param modelInfo
*/
export async function isModelUploaded(machine: string, modelInfo: ModelInfo): Promise<boolean> {
try {
const remotePath = getRemoteModelFile(modelInfo);
await process.exec(getPodmanCli(), ['machine', 'ssh', machine, 'stat', remotePath]);
return true;
} catch (err: unknown) {
console.error('Something went wrong while trying to stat remote model path', err);
return false;
}
}

/**
* Given a machine and a modelInfo, delete the corresponding file on the podman machine
* @param machine the machine to target
* @param modelInfo the model info
*/
export async function deleteRemoteModel(machine: string, modelInfo: ModelInfo): Promise<void> {
try {
const remotePath = getRemoteModelFile(modelInfo);
await process.exec(getPodmanCli(), ['machine', 'ssh', machine, 'rm', '-f', remotePath]);
} catch (err: unknown) {
console.error('Something went wrong while trying to stat remote model path', err);
}
}

export function getModelPropertiesForEnvironment(modelInfo: ModelInfo): string[] {
const envs: string[] = [];
if (modelInfo.properties) {
Expand Down
8 changes: 7 additions & 1 deletion packages/backend/src/utils/podman.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ export type MachineJSON = {
VMType?: string;
};

/**
* We should be using the {@link @podman-desktop/api#extensions.getExtensions} function to get podman
* exec method
* @deprecated
*/
export function getPodmanCli(): string {
// If we have a custom binary path regardless if we are running Windows or not
const customBinaryPath = getCustomBinaryPath();
Expand All @@ -52,8 +57,9 @@ export function getCustomBinaryPath(): string | undefined {
}

/**
* In the ${link ContainerProviderConnection.name} property the name is not usage, and we need to transform it
* In the {@link ContainerProviderConnection.name} property the name is not usage, and we need to transform it
* @param connection
* @deprecated
*/
export function getPodmanMachineName(connection: ContainerProviderConnection): string {
const runningConnectionName = connection.name;
Expand Down
10 changes: 6 additions & 4 deletions packages/backend/src/utils/uploader.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ import { Uploader } from './uploader';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { ContainerProviderConnection } from '@podman-desktop/api';
import { VMType } from '@shared/src/models/IPodman';
import type { PodmanConnection } from '../managers/podmanConnection';

vi.mock('@podman-desktop/api', async () => {
return {
env: {
isWindows: false,
},
process: {
exec: vi.fn(),
},
EventEmitter: vi.fn().mockImplementation(() => {
return {
fire: vi.fn(),
Expand All @@ -41,6 +39,10 @@ vi.mock('@podman-desktop/api', async () => {
};
});

const podmanConnectionMock: PodmanConnection = {
executeSSH: vi.fn(),
} as unknown as PodmanConnection;

const connectionMock: ContainerProviderConnection = {
name: 'machine2',
type: 'podman',
Expand All @@ -51,7 +53,7 @@ const connectionMock: ContainerProviderConnection = {
},
};

const uploader = new Uploader(connectionMock, {
const uploader = new Uploader(podmanConnectionMock, connectionMock, {
id: 'dummyModelId',
file: {
file: 'dummyFile.guff',
Expand Down
4 changes: 3 additions & 1 deletion packages/backend/src/utils/uploader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getLocalModelFile } from './modelsUtils';
import type { IWorker } from '../workers/IWorker';
import type { UploaderOptions } from '../workers/uploader/UploaderOptions';
import type { PodmanConnection } from '../managers/podmanConnection';

export class Uploader {
readonly #_onEvent = new EventEmitter<BaseEvent>();
readonly onEvent: Event<BaseEvent> = this.#_onEvent.event;
readonly #workers: IWorker<UploaderOptions, string>[] = [];

constructor(
podman: PodmanConnection,
private connection: ContainerProviderConnection,
private modelInfo: ModelInfo,
private abortSignal?: AbortSignal,
) {
this.#workers = [new WSLUploader()];
this.#workers = [new WSLUploader(podman)];
}

/**
Expand Down
Loading