Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ const SdGenHandlersMap SD_GEN_HANDLERS = {
[](SdGenConfig& c, const picojson::value& v) {
c.negativePrompt = requireStr(v, "negative_prompt");
}},
{"lora",
[](SdGenConfig& c, const picojson::value& v) {
c.loraPath = requireStr(v, "lora");
}},

// ── Image dimensions
// ────────────────────────────────────────────────────────
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct SdGenConfig {
// ── Prompt ────────────────────────────────────────────────────────────────
std::string prompt;
std::string negativePrompt;
std::string loraPath;

// ── Image dimensions ─────────────────────────────────────────────────────
int width = 512; // must be a positive multiple of 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,31 @@ class SdImageBatch {
const int count_;
};

struct PreparedLoras {
std::vector<std::string> paths;
std::vector<sd_lora_t> items;
};

// Mirrors the pinned fork's CLI flow in examples/common/common.hpp:
// build owned path storage first, then build sd_lora_t entries that point
// at that stable storage for the lifetime of generate_image().
PreparedLoras prepareLoras(const std::string& loraPath) {
PreparedLoras prepared;
if (loraPath.empty()) {
return prepared;
}

prepared.paths.push_back(loraPath);

sd_lora_t item{};
item.is_high_noise = false;
item.multiplier = 1.0f;
item.path = prepared.paths.back().c_str();
prepared.items.push_back(item);

return prepared;
}

} // namespace

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -484,6 +509,10 @@ std::any SdModel::process(const std::any& input) {
sd_img_gen_params_t genParams{};
sd_img_gen_params_init(&genParams);

PreparedLoras loras = prepareLoras(gen.loraPath);

genParams.loras = loras.items.empty() ? nullptr : loras.items.data();
genParams.lora_count = static_cast<uint32_t>(loras.items.size());
genParams.prompt = gen.prompt.c_str();
genParams.negative_prompt = gen.negativePrompt.c_str();
genParams.width = gen.width;
Expand Down
1 change: 1 addition & 0 deletions packages/lib-infer-diffusion/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ export interface ImgStableDiffusionArgs {
export interface GenerationParams {
prompt: string
negative_prompt?: string
lora?: string
width?: number
height?: number
steps?: number
Expand Down
1 change: 1 addition & 0 deletions packages/lib-infer-diffusion/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class ImgStableDiffusion {
* @param {number} [params.batch_count=1] - Images per call
* @param {boolean} [params.vae_tiling=false] - Enable VAE tiling (for large images)
* @param {string} [params.cache_preset] - Cache preset: slow/medium/fast/ultra
* @param {string} [params.lora] - Absolute path to a LoRA adapter (.safetensors, etc.)
* @param {Uint8Array} [params.init_image] - Source image bytes for img2img (PNG/JPEG).
* FLUX2: in-context conditioning (ref_images).
* Others: SDEdit (init_image + strength).
Expand Down
151 changes: 151 additions & 0 deletions packages/lib-infer-diffusion/test/integration/lora-bridge.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
'use strict'

const fs = require('bare-fs')
const path = require('bare-path')
const os = require('bare-os')
const proc = require('bare-process')
const test = require('brittle')
const binding = require('../../binding')
const ImgStableDiffusion = require('../../index')
const {
ensureModel,
detectPlatform,
setupJsLogger,
isPng
} = require('./utils')

const platform = detectPlatform()
const isDarwinX64 = os.platform() === 'darwin' && os.arch() === 'x64'
const isLinuxArm64 = os.platform() === 'linux' && os.arch() === 'arm64'
const isMobile = os.platform() === 'ios' || os.platform() === 'android'
const noGpu = proc.env && proc.env.NO_GPU === 'true'
const useCpu = isDarwinX64 || isLinuxArm64 || noGpu
const skip = isMobile || noGpu

const DEFAULT_MODEL = {
name: 'stable-diffusion-v2-1-Q8_0.gguf',
url: 'https://huggingface.co/gpustack/stable-diffusion-v2-1-GGUF/resolve/main/stable-diffusion-v2-1-Q8_0.gguf'
}

const LORA_ADAPTER = {
name: 'pytorch_lora_weights-sd21-comfyui.safetensors',
url: 'https://huggingface.co/radames/sd-21-DPO-LoRA/resolve/main/pytorch_lora_weights-sd21-comfyui.safetensors'
}

test('SD2.1 txt2img with LoRA — generates a valid PNG image', { timeout: 600000, skip }, async (t) => {
setupJsLogger(binding)

const [downloadedModelName, modelDir] = await ensureModel({
modelName: DEFAULT_MODEL.name,
downloadUrl: DEFAULT_MODEL.url
})

const [downloadedLoraName] = await ensureModel({
modelName: LORA_ADAPTER.name,
downloadUrl: LORA_ADAPTER.url
})

console.log('\n' + '='.repeat(60))
console.log('STABLE DIFFUSION 2.1 — LORA INTEGRATION TEST')
console.log('='.repeat(60))
console.log(` Platform : ${platform}`)
console.log(` Model : ${downloadedModelName}`)
console.log(` LoRA : ${downloadedLoraName}`)
console.log(` Models dir: ${modelDir}`)

const modelPath = path.join(modelDir, downloadedModelName)
const loraPath = path.join(modelDir, downloadedLoraName)
t.ok(fs.existsSync(modelPath), 'Model file exists on disk')
t.ok(fs.existsSync(loraPath), 'LoRA adapter exists on disk')

const model = new ImgStableDiffusion({
files: {
model: modelPath
},
config: {
threads: 4,
device: useCpu ? 'cpu' : 'gpu',
prediction: 'v' // SD2.1 uses v-prediction
},
logger: console
})

const images = []
const progressTicks = []

try {
// ── Load ─────────────────────────────────────────────────────────────────
console.log('\n=== Loading model ===')
const tLoad = Date.now()
await model.load()
const loadMs = Date.now() - tLoad
console.log(`Loaded in ${(loadMs / 1000).toFixed(1)}s`)
t.ok(loadMs < 120000, `Model loaded within 120s (took ${(loadMs / 1000).toFixed(1)}s)`)

// ── Generate ──────────────────────────────────────────────────────────────
console.log('\n=== Generating image with LoRA ===')
const tGen = Date.now()

const response = await model.run({
prompt: 'a bright red sports car parked on a street, clean background, high detail, studio lighting',
negative_prompt: 'blurry, low quality, watermark',
lora: loraPath,
steps: 1,
width: 512,
height: 512,
cfg_scale: 7.5,
seed: 42 // fixed seed for reproducibility
})

await response
.onUpdate((data) => {
if (data instanceof Uint8Array) {
images.push(data)
} else if (typeof data === 'string') {
try {
const tick = JSON.parse(data)
if ('step' in tick && 'total' in tick) {
progressTicks.push(tick)
}
} catch (_) {}
}
})
.await()

const genMs = Date.now() - tGen
console.log(`\nGenerated in ${(genMs / 1000).toFixed(1)}s`)

// ── Assertions ────────────────────────────────────────────────────────────
t.ok(progressTicks.length > 0, `Received progress ticks (got ${progressTicks.length})`)
t.is(images.length, 1, 'Received exactly 1 image')

const img = images[0]
t.ok(img instanceof Uint8Array, 'Image is a Uint8Array')
t.ok(img.length > 0, `Image is non-empty (${img.length} bytes)`)
t.ok(isPng(img), 'Image has valid PNG magic bytes')

// Save output for CI artifact upload — filename encodes test origin.
// Saved to modelDir so mobile has write permission to the same path.
const outPath = path.join(modelDir, 'generate-image--sd2-lora-txt2img-seed42.png')
fs.writeFileSync(outPath, img)
console.log(`\nSaved → ${outPath}`)

// ── Summary ───────────────────────────────────────────────────────────────
console.log('\n' + '='.repeat(60))
console.log('TEST SUMMARY')
console.log('='.repeat(60))
console.log(` Load time : ${(loadMs / 1000).toFixed(1)}s`)
console.log(` Gen time : ${(genMs / 1000).toFixed(1)}s`)
console.log(` Steps ticks : ${progressTicks.length}`)
console.log(` Image size : ${img.length} bytes`)
console.log(' PNG valid : true')
console.log('='.repeat(60))
} finally {
console.log('\n=== Cleanup ===')
await model.unload()
try {
binding.releaseLogger()
} catch (_) {}
console.log('Done.')
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ TEST(SdGenHandlers_Scheduler, UnknownSchedulerThrows) {
StatusError);
}

TEST(SdGenHandlers_Prompt, LoraPathMapsCorrectly) {
auto cfg = applyOne("lora", str("/tmp/test-lora.safetensors"));
Comment thread
jesusmb1995 marked this conversation as resolved.
EXPECT_EQ(cfg.loraPath, "/tmp/test-lora.safetensors");
}

// ─────────────────────────────────────────────────────────────────────────────
// 3. parseCacheMode
// ─────────────────────────────────────────────────────────────────────────────
Expand Down
Loading