Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
} from '@shared/models/llama-stack/LlamaStackContainerInfo';
import type { ConfigurationRegistry } from '../../registries/ConfigurationRegistry';
import type { ExtensionConfiguration } from '@shared/models/IExtensionConfiguration';
import type { ModelsManager } from '../modelsManager';

vi.mock('@podman-desktop/api', () => {
return {
Expand All @@ -58,11 +59,13 @@ class TestLlamaStackManager extends LlamaStackManager {
const podmanConnection: PodmanConnection = {
onPodmanConnectionEvent: vi.fn(),
findRunningContainerProviderConnection: vi.fn(),
execute: vi.fn(),
} as unknown as PodmanConnection;

const containerRegistry = {
onStartContainerEvent: vi.fn(),
onStopContainerEvent: vi.fn(),
onHealthyContainerEvent: vi.fn(),
} as unknown as ContainerRegistry;

const configurationRegistry = {
Expand All @@ -74,6 +77,10 @@ const telemetryMock = {
logError: vi.fn(),
} as unknown as TelemetryLogger;

const modelsManagerMock = {
getModelsInfo: vi.fn(),
} as unknown as ModelsManager;

let taskRegistry: TaskRegistry;

let llamaStackManager: TestLlamaStackManager;
Expand Down Expand Up @@ -108,6 +115,7 @@ beforeEach(() => {
containerRegistry,
configurationRegistry,
telemetryMock,
modelsManagerMock,
);
});

Expand Down Expand Up @@ -181,7 +189,7 @@ test('requestCreateLlamaStackContainer returns id and error if listImage returns
expect(tasks.some(task => task.state === 'error')).toBeTruthy();
});

test('requestCreateLlamaStackContainer returns id and no error if createContainer returns id', async () => {
test('requestCreateLlamaStackContainer returns no error if createContainer returns id and container becomes healthy', async () => {
vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue({
name: 'Podman Machine',
vmType: VMType.UNKNOWN,
Expand All @@ -200,9 +208,83 @@ test('requestCreateLlamaStackContainer returns id and no error if createContaine
vi.mocked(configurationRegistry.getExtensionConfiguration).mockReturnValue({
apiPort: 10000,
} as ExtensionConfiguration);
vi.mocked(containerRegistry.onHealthyContainerEvent).mockReturnValue(NO_OP_DISPOSABLE);
await llamaStackManager.requestCreateLlamaStackContainer({});
const tasks = await waitTasks(LLAMA_STACK_CONTAINER_TRACKINGID, 3);
await vi.waitFor(() => {
const healthyListener = vi.mocked(containerRegistry.onHealthyContainerEvent).mock.calls[0][0];
expect(healthyListener).toBeDefined();
healthyListener({ id: 'containerId' });
});
const tasks = await waitTasks(LLAMA_STACK_CONTAINER_TRACKINGID, 4);
expect(tasks.some(task => task.state === 'error')).toBeFalsy();
});

test('requestCreateLlamaStackContainer registers all local models', async () => {
vi.mocked(podmanConnection.findRunningContainerProviderConnection).mockReturnValue({
name: 'Podman Machine',
vmType: VMType.UNKNOWN,
type: 'podman',
status: () => 'started',
endpoint: {
socketPath: 'socket.sock',
},
});
vi.mocked(containerEngine.listImages).mockResolvedValue([
{ RepoTags: [llama_stack_images.default] } as unknown as ImageInfo,
]);
vi.mocked(containerEngine.createContainer).mockResolvedValue({
id: 'containerId',
} as unknown as ContainerCreateResult);
vi.mocked(configurationRegistry.getExtensionConfiguration).mockReturnValue({
apiPort: 10000,
} as ExtensionConfiguration);
vi.mocked(containerRegistry.onHealthyContainerEvent).mockReturnValue(NO_OP_DISPOSABLE);
vi.mocked(modelsManagerMock.getModelsInfo).mockReturnValue([
{
id: 'model1',
name: 'Model 1',
description: '',
file: { file: 'model1', path: '/path/to' },
},
{
id: 'model2',
name: 'Model 2',
description: '',
file: { file: 'model2', path: '/path/to' },
},
{
id: 'model3',
name: 'Model 3',
description: '',
},
]);
await llamaStackManager.requestCreateLlamaStackContainer({});
await vi.waitFor(() => {
const healthyListener = vi.mocked(containerRegistry.onHealthyContainerEvent).mock.calls[0][0];
expect(healthyListener).toBeDefined();
healthyListener({ id: 'containerId' });
});
const tasks = await waitTasks(LLAMA_STACK_CONTAINER_TRACKINGID, 4);
expect(tasks.some(task => task.state === 'error')).toBeFalsy();
await vi.waitFor(() => {
expect(podmanConnection.execute).toHaveBeenCalledTimes(2);
});
expect(podmanConnection.execute).toHaveBeenCalledWith(expect.anything(), [
'exec',
'containerId',
'llama-stack-client',
'models',
'register',
'Model 1',
]);
expect(podmanConnection.execute).toHaveBeenCalledWith(expect.anything(), [
'exec',
'containerId',
'llama-stack-client',
'models',
'register',
'Model 2',
]);
});

test('onPodmanConnectionEvent start event should call refreshLlamaStackContainer and set containerInfo', async () => {
Expand Down
80 changes: 79 additions & 1 deletion packages/backend/src/managers/llama-stack/llamaStackManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import type { PodmanConnection, PodmanConnectionEvent } from '../podmanConnection';
import llama_stack_images from '../../assets/llama-stack-images.json';
import { getImageInfo } from '../../utils/inferenceUtils';
import type { ContainerRegistry, ContainerEvent } from '../../registries/ContainerRegistry';
import type { ContainerRegistry, ContainerEvent, ContainerHealthy } from '../../registries/ContainerRegistry';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils';
import { getRandomName } from '../../utils/randomUtils';
import type { LlamaStackContainerInfo } from '@shared/models/llama-stack/LlamaStackContainerInfo';
Expand All @@ -39,9 +39,11 @@ import fs from 'node:fs/promises';
import type { ConfigurationRegistry } from '../../registries/ConfigurationRegistry';
import { getFreeRandomPort } from '../../utils/ports';
import { TaskRunner } from '../TaskRunner';
import type { ModelsManager } from '../modelsManager';

export const LLAMA_STACK_CONTAINER_LABEL = 'ai-lab-llama-stack-container';
export const LLAMA_STACK_API_PORT_LABEL = 'ai-lab-llama-stack-api-port';
export const SECOND: number = 1_000_000_000;

export class LlamaStackManager implements Disposable {
#initialized: boolean;
Expand All @@ -56,6 +58,7 @@ export class LlamaStackManager implements Disposable {
private containerRegistry: ContainerRegistry,
private configurationRegistry: ConfigurationRegistry,
private telemetryLogger: TelemetryLogger,
private modelsManager: ModelsManager,
) {
this.#initialized = false;
this.#disposables = [];
Expand Down Expand Up @@ -185,6 +188,16 @@ export class LlamaStackManager implements Disposable {
() => getImageInfo(connection, image, () => {}),
);

let containerInfo = await this.createContainer(image, imageInfo, labels);
containerInfo = await this.waitLlamaStackContainerHealthy(containerInfo, labels);
return this.registerModels(containerInfo, labels, connection);
}

private async createContainer(
image: string,
imageInfo: ImageInfo,
labels: { [p: string]: string },
): Promise<LlamaStackContainerInfo> {
const folder = await this.getLlamaStackContainerFolder();

const aiLabApiPort = this.configurationRegistry.getExtensionConfiguration().apiPort;
Expand Down Expand Up @@ -218,6 +231,12 @@ export class LlamaStackManager implements Disposable {
Env: [`PODMAN_AI_LAB_URL=http://host.containers.internal:${aiLabApiPort}`],
OpenStdin: true,
start: true,
HealthCheck: {
// must be the port INSIDE the container not the exposed one
Test: ['CMD-SHELL', `curl -sSf localhost:8321/v1/models > /dev/null`],
Interval: SECOND * 5,
Retries: 4 * 5,
},
};

return this.#taskRunner.runAsTask<LlamaStackContainerInfo>(
Expand All @@ -236,6 +255,65 @@ export class LlamaStackManager implements Disposable {
);
}

async waitLlamaStackContainerHealthy(
containerInfo: LlamaStackContainerInfo,
labels: { [p: string]: string },
): Promise<LlamaStackContainerInfo> {
return this.#taskRunner.runAsTask<LlamaStackContainerInfo>(
labels,
{
loadingLabel: 'Waiting Llama Stack to be started',
errorMsg: err => `Something went wrong while trying to check container health check: ${String(err)}.`,
},
async ({ updateLabels }) => {
let disposable: Disposable;
return new Promise(resolve => {
disposable = this.containerRegistry.onHealthyContainerEvent((event: ContainerHealthy) => {
if (event.id !== containerInfo.containerId) {
return;
}
disposable.dispose();
// eslint-disable-next-line sonarjs/no-nested-functions
updateLabels(labels => ({
...labels,
containerId: containerInfo.containerId,
port: `${containerInfo.port}`,
}));
this.telemetryLogger.logUsage('llamaStack.startContainer');
resolve(containerInfo);
});
});
},
);
}

async registerModels(
containerInfo: LlamaStackContainerInfo,
labels: { [p: string]: string },
connection: ContainerProviderConnection,
): Promise<LlamaStackContainerInfo> {
for (const model of this.modelsManager.getModelsInfo().filter(model => model.file)) {
await this.#taskRunner.runAsTask(
labels,
{
loadingLabel: `Registering model ${model.name}`,
errorMsg: err => `Something went wrong while registering model: ${String(err)}.`,
},
async () => {
await this.podmanConnection.execute(connection, [
'exec',
containerInfo.containerId,
'llama-stack-client',
'models',
'register',
model.name,
]);
},
);
}
return containerInfo;
}

private async getLlamaStackContainerFolder(): Promise<string> {
const llamaStackPath = path.join(this.appUserDirectory, 'llama-stack', 'container');
await fs.mkdir(path.join(llamaStackPath, '.llama'), { recursive: true });
Expand Down
1 change: 1 addition & 0 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ export class Studio {
this.#containerRegistry,
this.#configurationRegistry,
this.#telemetry,
this.#modelsManager,
);
this.#extensionContext.subscriptions.push(this.#llamaStackManager);
this.#llamaStackManager.init();
Expand Down
Loading