Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions packages/sdk/client/api/classify.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import {
classifyResponseSchema,
type ClassifyRequest,
type ClassifyClientParams,
type ClassificationResult,
} from "@/schemas";
import { stream as streamRpc } from "@/client/rpc/rpc-client";
import { encodeBase64 } from "@/utils/encoding";

/**
* Classifies an image using a loaded classification model.
*
* The bundled MobileNetV3-Small model produces 3 labels: `"food"`, `"report"`, `"other"`.
* Custom models may emit different labels sourced from the GGUF metadata.
*
* @param params.modelId - The identifier of the loaded classification model
* @param params.image - JPEG or PNG buffer; raw RGB bytes also accepted with `width`, `height`, `channels`
* @param params.topK - Limit results to top-K classes (default: all)
* @returns Sorted classification results, highest confidence first
*
* @example
* ```typescript
* const modelId = await loadModel({ modelType: "ggml-classification" });
* const jpeg = fs.readFileSync("photo.jpg");
* const results = await classify({ modelId, image: jpeg });
* // [ { label: "food", confidence: 0.93 }, { label: "other", confidence: 0.05 }, ... ]
* await unloadModel({ modelId });
* ```
*/
export async function classify(
params: ClassifyClientParams,
): Promise<ClassificationResult[]> {
const request: ClassifyRequest = {
type: "classify",
modelId: params.modelId,
image: encodeBase64(params.image),
...(params.topK !== undefined && { topK: params.topK }),
...(params.width !== undefined && { width: params.width }),
...(params.height !== undefined && { height: params.height }),
...(params.channels !== undefined && { channels: params.channels }),
};

for await (const response of streamRpc(request)) {
if (response && typeof response === "object" && "type" in response && response.type === "classify") {
const parsed = classifyResponseSchema.parse(response);
if (parsed.done) {
return parsed.results;
}
}
}

return [];
}
1 change: 1 addition & 0 deletions packages/sdk/client/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export { getLoadedModelInfo } from "./get-loaded-model-info";
export { ocr } from "./ocr";
export { invokePlugin, invokePluginStream } from "./invoke-plugin";
export { diffusion, type DiffusionProgressTick } from "./diffusion";
export { classify } from "./classify";
export { upscale } from "./upscale";
export {
modelRegistryList,
Expand Down
29 changes: 29 additions & 0 deletions packages/sdk/examples/classification/classify-image.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import fs from "fs";
import { startQVACProvider, stopQVACProvider, loadModel, classify, unloadModel } from "@qvac/sdk";

/**
* Classify an image using the bundled MobileNetV3-Small model.
*
* The bundled model produces three classes: "food", "report", "other".
* No modelSrc is needed β€” the model ships inside @qvac/classification-ggml.
*/
async function main() {
await startQVACProvider({});

const modelId = await loadModel({
modelType: "ggml-classification",
});

const image = fs.readFileSync("image.jpg");
const results = await classify({ modelId, image });

console.log("Classification results:");
for (const { label, confidence } of results) {
console.log(` ${label}: ${(confidence * 100).toFixed(1)}%`);
}

await unloadModel({ modelId });
await stopQVACProvider();
}

main().catch(console.error);
4 changes: 4 additions & 0 deletions packages/sdk/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export {
invokePluginStream,
diffusion,
type DiffusionProgressTick,
classify,
upscale,
modelRegistryList,
modelRegistrySearch,
Expand Down Expand Up @@ -109,6 +110,8 @@ export {
type OCRClientParams,
type OCRTextBlock,
type OCROptions,
type ClassifyClientParams,
type ClassificationResult,
type DiffusionClientParams,
type DiffusionStreamResponse,
type DiffusionStats,
Expand Down Expand Up @@ -136,6 +139,7 @@ export {
PLUGIN_OCR,
PLUGIN_DIFFUSION,
PLUGIN_VLA,
PLUGIN_CLASSIFICATION,
SDK_DEFAULT_PLUGINS,
type BuiltinPlugin,
type ProfilerMode,
Expand Down
5 changes: 5 additions & 0 deletions packages/sdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@
"types": "./dist/server/bare/plugins/ggml-vla/plugin.d.ts",
"import": "./dist/server/bare/plugins/ggml-vla/plugin.js"
},
"./ggml-classification/plugin": {
"types": "./dist/server/bare/plugins/ggml-classification/plugin.d.ts",
"import": "./dist/server/bare/plugins/ggml-classification/plugin.js"
},
"./plugin-utils": {
"types": "./dist/schemas/plugin.d.ts",
"import": "./dist/schemas/plugin.js"
Expand Down Expand Up @@ -176,6 +180,7 @@
"changelog:generate": "node ../../scripts/sdk/generate-changelog-sdk-pod.cjs --package=sdk && prettier --write changelog"
},
"dependencies": {
"@qvac/classification-ggml": "^0.2.0",
"@qvac/decoder-audio": "^0.3.7",
"@qvac/diffusion-cpp": "^0.8.0",
"@qvac/embed-llamacpp": "^0.16.0",
Expand Down
2 changes: 2 additions & 0 deletions packages/sdk/pear/pre.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const BUILTIN_PLUGINS = [
"@qvac/sdk/onnx-ocr/plugin",
"@qvac/sdk/sdcpp-generation/plugin",
"@qvac/sdk/ggml-vla/plugin",
"@qvac/sdk/ggml-classification/plugin",
];

const BUILTIN_PLUGIN_EXPORTS: Record<string, string> = {
Expand All @@ -70,6 +71,7 @@ const BUILTIN_PLUGIN_EXPORTS: Record<string, string> = {
"onnx-ocr": "ocrPlugin",
"sdcpp-generation": "diffusionPlugin",
"ggml-vla": "vlaPlugin",
"ggml-classification": "classificationPlugin",
};

const SDK_NAME = "@qvac/sdk";
Expand Down
57 changes: 57 additions & 0 deletions packages/sdk/schemas/classification.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { z } from "zod";

export const classificationConfigSchema = z.object({
/** Absolute path to the GGUF weights file. Defaults to the bundled model inside @qvac/classification-ggml. */
modelPath: z.string().optional(),
/** Limit returned results to the top-K classes. Default: all classes. */
topK: z.number().int().optional(),
/** Forward native C++ log lines through the SDK logger. Off by default. */
nativeLogger: z.boolean().optional(),
});

export const classifyParamsSchema = z.object({
modelId: z.string(),
/** JPEG or PNG buffer encoded as base64, or raw RGB bytes. */
image: z.string(),
topK: z.number().int().optional(),
/** Raw RGB image width (required for raw bytes). */
width: z.number().int().optional(),
/** Raw RGB image height (required for raw bytes). */
height: z.number().int().optional(),
/** Channel count β€” must be 3 for raw RGB. */
channels: z.literal(3).optional(),
});

export const classifyRequestSchema = classifyParamsSchema.extend({
type: z.literal("classify"),
});

export const classificationResultSchema = z.object({
label: z.string(),
confidence: z.number(),
});

export const classifyResponseSchema = z.object({
type: z.literal("classify"),
results: z.array(classificationResultSchema),
done: z.boolean().optional(),
});

export type ClassificationConfig = z.infer<typeof classificationConfigSchema>;
export type ClassifyParams = z.infer<typeof classifyParamsSchema>;
export type ClassifyRequest = z.infer<typeof classifyRequestSchema>;
export type ClassificationResult = z.infer<typeof classificationResultSchema>;
export type ClassifyResponse = z.infer<typeof classifyResponseSchema>;

export interface ClassifyClientParams {
modelId: string;
/** JPEG or PNG buffer. */
image: Uint8Array;
topK?: number;
/** Raw RGB image width (required for raw bytes). */
width?: number;
/** Raw RGB image height (required for raw bytes). */
height?: number;
/** Channel count β€” must be 3 for raw RGB. */
channels?: 3;
}
6 changes: 6 additions & 0 deletions packages/sdk/schemas/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ import {
import { suspendRequestSchema, suspendResponseSchema } from "./suspend";
import { resumeRequestSchema, resumeResponseSchema } from "./resume";
import { stateRequestSchema, stateResponseSchema } from "./state";
import {
classifyRequestSchema,
classifyResponseSchema,
} from "./classification";

export const requestSchema = z.union([
heartbeatRequestSchema,
Expand Down Expand Up @@ -122,6 +126,7 @@ export const requestSchema = z.union([
suspendRequestSchema,
resumeRequestSchema,
stateRequestSchema,
classifyRequestSchema,
]);

export const responseSchema = z.discriminatedUnion("type", [
Expand Down Expand Up @@ -160,6 +165,7 @@ export const responseSchema = z.discriminatedUnion("type", [
suspendResponseSchema,
resumeResponseSchema,
stateResponseSchema,
classifyResponseSchema,
]);

export const rpcOptionsSchema = z.object({
Expand Down
4 changes: 4 additions & 0 deletions packages/sdk/schemas/engine-addon-map.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ModelType } from "./model-types";
import {
ADDON_CLASSIFICATION,
ADDON_DIFFUSION,
ADDON_EMBEDDING,
ADDON_LLM,
Expand Down Expand Up @@ -28,6 +29,7 @@ export const ENGINE_TO_ADDON = {
[ModelType.parakeetTranscription]: "parakeet",
[ModelType.sdcppGeneration]: "diffusion",
[ModelType.ggmlVla]: "vla",
[ModelType.ggmlClassification]: "classification",
"onnx-vad": "vad",
} as const satisfies Record<ModelRegistryEngine, ModelRegistryEntryAddon>;

Expand Down Expand Up @@ -56,6 +58,8 @@ const LEGACY_ENGINE_TO_CANONICAL: Record<string, ModelRegistryEngine> = {
diffusion: ModelType.sdcppGeneration,
[ADDON_VLA]: ModelType.ggmlVla,
vla: ModelType.ggmlVla,
[ADDON_CLASSIFICATION]: ModelType.ggmlClassification,
classification: ModelType.ggmlClassification,
};

// Resolves any engine string (legacy or canonical) to a validated canonical engine.
Expand Down
1 change: 1 addition & 0 deletions packages/sdk/schemas/get-model-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export const modelInfoSchema = z.object({
"ocr",
"diffusion",
"vla",
"classification",
"other",
])
.describe("Inference addon / capability category this model belongs to."),
Expand Down
1 change: 1 addition & 0 deletions packages/sdk/schemas/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export * from "./rag";
export * from "./ocr";
export * from "./sdcpp-config";
export * from "./vla";
export * from "./classification";
export * from "./shard";
export * from "./suspend";
export * from "./resume";
Expand Down
37 changes: 37 additions & 0 deletions packages/sdk/schemas/load-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import {
ocrModelTypeSchema,
diffusionModelTypeSchema,
vlaModelTypeSchema,
classificationModelTypeSchema,
ModelType,
ModelTypeAliases,
type CanonicalModelType,
type ModelTypeInput,
} from "./model-types";
import { sdcppConfigSchema } from "./sdcpp-config";
import { vlaConfigSchema } from "./vla";
import { classificationConfigSchema } from "./classification";

// Set of all built-in model types (canonical + aliases) for catch-all exclusion
const builtInModelTypes = new Set([
Expand Down Expand Up @@ -124,6 +126,14 @@ export const loadBuiltinModelOptionsBaseSchema = z.union([
modelConfig: vlaConfigSchema.strict().optional(),
})
.strict(),
z
.object({
...loadModelCommonFields,
modelSrc: modelSrcInputSchema.optional(),
modelType: classificationModelTypeSchema,
modelConfig: classificationConfigSchema.strict().optional(),
})
.strict(),
]);

// Custom plugin catch-all: any modelType string EXCEPT built-ins.
Expand Down Expand Up @@ -321,6 +331,25 @@ const loadModelOptionsToRequestBaseSchema = z.union([
delegate: data.delegate,
...(data.requestId !== undefined && { requestId: data.requestId }),
})),
z
.object({
...loadModelRequestCommonFields,
modelSrc: modelSrcInputSchema.optional(),
modelType: classificationModelTypeSchema,
modelConfig: classificationConfigSchema.strict().optional(),
})
.strict()
.transform((data) => ({
type: "loadModel" as const,
modelType: ModelType.ggmlClassification,
modelSrc: data.modelSrc ? modelInputToSrcSchema.parse(data.modelSrc) : "",
modelName: data.modelSrc ? modelInputToNameSchema.parse(data.modelSrc) : undefined,
modelConfig: data.modelConfig ?? {},
seed: data.seed ?? false,
withProgress: data.withProgress ?? !!data.onProgress,
delegate: data.delegate,
...(data.requestId !== undefined && { requestId: data.requestId }),
})),
z
.object({
...loadModelRequestCommonFields,
Expand Down Expand Up @@ -427,6 +456,13 @@ export const loadVlaModelRequestSchema = commonModelConfigSchema
})
.strict();

export const loadClassificationModelRequestSchema = commonModelConfigSchema
.extend({
modelType: z.literal(ModelType.ggmlClassification),
modelConfig: classificationConfigSchema.optional(),
})
.strict();

// Custom plugin catch-all: accepts any modelType string EXCEPT built-ins
export const loadCustomPluginModelRequestSchema =
commonModelConfigSchema.extend({
Expand All @@ -448,6 +484,7 @@ export const loadModelSrcRequestSchema = z
loadOcrModelRequestSchema,
loadDiffusionModelRequestSchema,
loadVlaModelRequestSchema,
loadClassificationModelRequestSchema,
loadCustomPluginModelRequestSchema,
])
.transform((data) => ({
Expand Down
17 changes: 17 additions & 0 deletions packages/sdk/schemas/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export const ModelType = {
onnxOcr: "onnx-ocr",
sdcppGeneration: "sdcpp-generation",
ggmlVla: "ggml-vla",
ggmlClassification: "ggml-classification",
} as const;

// === INTERNAL: Alias keys (backward compat names) ===
Expand All @@ -29,6 +30,7 @@ const AliasKeys = {
ocr: "ocr",
diffusion: "diffusion",
vla: "vla",
classification: "classification",
} as const;

// === INTERNAL: Aliases (backward compat mapping) ===
Expand All @@ -46,6 +48,7 @@ export const ModelTypeAliases = {
[AliasKeys.ocr]: ModelType.onnxOcr,
[AliasKeys.diffusion]: ModelType.sdcppGeneration,
[AliasKeys.vla]: ModelType.ggmlVla,
[AliasKeys.classification]: ModelType.ggmlClassification,
} as const;

// === TYPES ===
Expand Down Expand Up @@ -243,3 +246,17 @@ export const vlaModelTypeSchema = modelTypeInputSchema
.extract([AliasKeys.vla, ModelType.ggmlVla])
.describe('VLA model type: "vla" (alias) or "ggml-vla" (canonical)');
export type VlaModelTypeInput = z.infer<typeof vlaModelTypeSchema>;

/**
* Image Classification model type schema.
* - Alias: `"classification"` β†’ resolves to `"ggml-classification"`
* - Canonical: `"ggml-classification"`
*/
export const classificationModelTypeSchema = modelTypeInputSchema
.extract([AliasKeys.classification, ModelType.ggmlClassification])
.describe(
'Classification model type: "classification" (alias) or "ggml-classification" (canonical)',
);
export type ClassificationModelTypeInput = z.infer<
typeof classificationModelTypeSchema
>;
Loading
Loading